Implement python mojo bindings unions.
R=qsr@chromium.org
Review URL: https://codereview.chromium.org/1218023006 .
diff --git a/mojo/public/interfaces/bindings/tests/test_unions.mojom b/mojo/public/interfaces/bindings/tests/test_unions.mojom
index 6e631b2..e998f47 100644
--- a/mojo/public/interfaces/bindings/tests/test_unions.mojom
+++ b/mojo/public/interfaces/bindings/tests/test_unions.mojom
@@ -69,6 +69,7 @@
struct SmallObjStruct {
ObjectUnion obj_union;
+ int8 f_int8;
};
interface SmallCache {
@@ -84,3 +85,12 @@
DummyStruct? nullable;
DummyStruct non_nullable;
};
+
+union OldUnion {
+ int8 f_int8;
+};
+
+union NewUnion {
+ int8 f_int8;
+ int16 f_int16;
+};
diff --git a/mojo/public/python/mojo_bindings/descriptor.py b/mojo/public/python/mojo_bindings/descriptor.py
index efba73d..c359e44 100644
--- a/mojo/public/python/mojo_bindings/descriptor.py
+++ b/mojo/public/python/mojo_bindings/descriptor.py
@@ -34,6 +34,13 @@
"""
return self.Convert(value)
+ def IsUnion(self):
+ """
+ Returns true if the type is a union. This is necessary to be able to
+ identify a union when descriptor.py cannot be imported.
+ """
+ return False
+
class SerializableType(Type):
"""Describe a type that can be serialized by itself."""
@@ -153,6 +160,47 @@
return float(value)
+class UnionType(SerializableType):
+ """Base Type object for union."""
+
+ def __init__(self, union_type_getter, nullable=False):
+ SerializableType.__init__(self, 'IIQ')
+ self.nullable = nullable
+ self._union_type_getter = union_type_getter
+ self._union_type = None
+
+ def IsUnion(self):
+ return True
+
+ @property
+ def union_type(self):
+ if not self._union_type:
+ self._union_type = self._union_type_getter()
+ return self._union_type
+
+ def Serialize(self, value, data_offset, data, handle_offset):
+ if not value:
+ if not self.nullable:
+ raise serialization.SerializationException(
+ 'Trying to serialize null for non nullable type.')
+ return ((0, 0, 0), [])
+
+ ((size, tag, entry, new_data), new_handles) = (
+ value.SerializeInline(handle_offset))
+ if len(new_data) > 0:
+ data.extend(new_data)
+ entry = data_offset - 8
+
+ return ((size, tag, entry), new_handles)
+
+ def Deserialize(self, value, context):
+ result = self.union_type.Deserialize(context)
+ if not result and not self.nullable:
+ raise serialization.DeserializationException(
+ 'Trying to deserialize null for non nullable type.')
+ return result
+
+
class PointerType(SerializableType):
"""Base Type object for pointers."""
@@ -434,16 +482,19 @@
to_pack.extend(serialization.Flatten(new_data))
returned_handles.extend(new_handles)
position = position + self.sub_type.GetByteSize()
+
serialization.HEADER_STRUCT.pack_into(data, data_end, size, len(value))
- struct.pack_into('%d%s' % (len(value), self.sub_type.GetTypeCode()),
+ # TODO(azani): Refactor so we don't have to create big formatting strings.
+ struct.pack_into(('%s' % self.sub_type.GetTypeCode()) * len(value),
data,
data_end + serialization.HEADER_STRUCT.size,
*to_pack)
return (data_offset, returned_handles)
def DeserializeArray(self, size, nb_elements, context):
+ # TODO(azani): Refactor so the format string isn't so big.
values = struct.unpack_from(
- '%d%s' % (nb_elements, self.sub_type.GetTypeCode()),
+ nb_elements * self.sub_type.GetTypeCode(),
buffer(context.data, serialization.HEADER_STRUCT.size))
values_per_element = len(self.sub_type.GetTypeCode())
assert nb_elements * values_per_element == len(values)
diff --git a/mojo/public/python/mojo_bindings/reflection.py b/mojo/public/python/mojo_bindings/reflection.py
index 6c4767b..cfd7e64 100644
--- a/mojo/public/python/mojo_bindings/reflection.py
+++ b/mojo/public/python/mojo_bindings/reflection.py
@@ -135,6 +135,103 @@
raise AttributeError('can\'t delete attribute')
+class MojoUnionType(type):
+
+ def __new__(mcs, name, bases, dictionary):
+ dictionary['__slots__'] = ('_cur_field', '_data')
+ descriptor = dictionary.pop('DESCRIPTOR', {})
+
+ fields = descriptor.get('fields', [])
+ def _BuildUnionProperty(field):
+
+ # pylint: disable=W0212
+ def Get(self):
+ if self._cur_field != field:
+ raise AttributeError('%s is not currently set' % field.name,
+ field.name, self._cur_field.name)
+ return self._data
+
+ # pylint: disable=W0212
+ def Set(self, value):
+ self._cur_field = field
+ self._data = field.field_type.Convert(value)
+
+ return property(Get, Set)
+
+ for field in fields:
+ dictionary[field.name] = _BuildUnionProperty(field)
+
+ def UnionInit(self, **kwargs):
+ self.SetInternals(None, None)
+ items = kwargs.items()
+ if len(items) == 0:
+ return
+
+ if len(items) > 1:
+ raise TypeError('only 1 member may be set on a union.')
+
+ setattr(self, items[0][0], items[0][1])
+ dictionary['__init__'] = UnionInit
+
+ serializer = serialization.UnionSerializer(fields)
+ def SerializeUnionInline(self, handle_offset=0):
+ return serializer.SerializeInline(self, handle_offset)
+ dictionary['SerializeInline'] = SerializeUnionInline
+
+ def SerializeUnion(self, handle_offset=0):
+ return serializer.Serialize(self, handle_offset)
+ dictionary['Serialize'] = SerializeUnion
+
+ def DeserializeUnion(cls, context):
+ return serializer.Deserialize(context, cls)
+ dictionary['Deserialize'] = classmethod(DeserializeUnion)
+
+ class Tags(object):
+ __metaclass__ = MojoEnumType
+ VALUES = [(field.name, field.index) for field in fields]
+ dictionary['Tags'] = Tags
+
+ def GetTag(self):
+ return self._cur_field.index
+ dictionary['tag'] = property(GetTag, None)
+
+ def GetData(self):
+ return self._data
+ dictionary['data'] = property(GetData, None)
+
+ def IsUnknown(self):
+ return not self._cur_field
+ dictionary['IsUnknown'] = IsUnknown
+
+ def UnionEq(self, other):
+ return (
+ (type(self) is type(other))
+ and (self.tag == other.tag)
+ and (self.data == other.data))
+ dictionary['__eq__'] = UnionEq
+
+ def UnionNe(self, other):
+ return not self.__eq__(other)
+ dictionary['__ne__'] = UnionNe
+
+ def UnionStr(self):
+ return '<%s.%s(%s): %s>' % (
+ self.__class__.__name__,
+ self._cur_field.name,
+ self.tag,
+ self.data)
+ dictionary['__str__'] = UnionStr
+ dictionary['__repr__'] = UnionStr
+
+ def SetInternals(self, field, data):
+ self._cur_field = field
+ self._data = data
+ dictionary['SetInternals'] = SetInternals
+
+
+ return type.__new__(mcs, name, bases, dictionary)
+
+
class InterfaceRequest(object):
"""
An interface request allows to send a request for an interface to a remote
diff --git a/mojo/public/python/mojo_bindings/serialization.py b/mojo/public/python/mojo_bindings/serialization.py
index 32f60f0..b1a35ec 100644
--- a/mojo/public/python/mojo_bindings/serialization.py
+++ b/mojo/public/python/mojo_bindings/serialization.py
@@ -7,9 +7,12 @@
import struct
-# Format of a header for a struct or an array.
+# Format of a header for a struct, array or union.
HEADER_STRUCT = struct.Struct("<II")
+# Format for a pointer.
+POINTER_STRUCT = struct.Struct("<Q")
+
def Flatten(value):
"""Flattens nested lists/tuples into an one-level list. If value is not a
@@ -218,3 +221,89 @@
if alignment_needed:
codes.append('x' * alignment_needed)
return struct.Struct(''.join(codes))
+
+
+class UnionSerializer(object):
+ """
+ Helper class to serialize/deserialize a union.
+ """
+ def __init__(self, fields):
+ self._fields = {field.index: field for field in fields}
+
+ def SerializeInline(self, union, handle_offset):
+ data = bytearray()
+ field = self._fields[union.tag]
+
+ # If the union value is a simple type or a nested union, it is returned as
+ # entry.
+ # Otherwise, the serialized value is appended to data and the value of entry
+ # is -1. The caller will need to set entry to the location where the
+ # caller will append data.
+ (entry, handles) = field.field_type.Serialize(
+ union.data, -1, data, handle_offset)
+
+ # If the value contained in the union is itself a union, we append its
+ # serialized value to data and set entry to -1. The caller will need to set
+ # entry to the location where the caller will append data.
+ if field.field_type.IsUnion():
+ nested_union = bytearray(16)
+ HEADER_STRUCT.pack_into(nested_union, 0, entry[0], entry[1])
+ POINTER_STRUCT.pack_into(nested_union, 8, entry[2])
+
+ data = nested_union + data
+
+ # Since we do not know where the caller will append the nested union,
+ # we set entry to an invalid value and let the caller figure out the right
+ # value.
+ entry = -1
+
+ return (16, union.tag, entry, data), handles
+
+ def Serialize(self, union, handle_offset):
+ (size, tag, entry, extra_data), handles = self.SerializeInline(
+ union, handle_offset)
+ data = bytearray(16)
+ if extra_data:
+ entry = 8
+ data.extend(extra_data)
+
+ field = self._fields[union.tag]
+
+ HEADER_STRUCT.pack_into(data, 0, size, tag)
+ typecode = field.GetTypeCode()
+
+ # If the value is a nested union, we store a 64 bits pointer to it.
+ if field.field_type.IsUnion():
+ typecode = 'Q'
+
+ struct.pack_into('<%s' % typecode, data, 8, entry)
+ return data, handles
+
+ def Deserialize(self, context, union_class):
+ if len(context.data) < HEADER_STRUCT.size:
+ raise DeserializationException(
+ 'Available data too short to contain header.')
+ (size, tag) = HEADER_STRUCT.unpack_from(context.data)
+
+ if size == 0:
+ return None
+
+ if size != 16:
+ raise DeserializationException('Invalid union size %s' % size)
+
+ union = union_class.__new__(union_class)
+ if tag not in self._fields:
+ union.SetInternals(None, None)
+ return union
+
+ field = self._fields[tag]
+ if field.field_type.IsUnion():
+ ptr = POINTER_STRUCT.unpack_from(context.data, 8)[0]
+ value = field.field_type.Deserialize(ptr, context.GetSubContext(ptr+8))
+ else:
+ raw_value = struct.unpack_from(
+ field.GetTypeCode(), context.data, 8)[0]
+ value = field.field_type.Deserialize(raw_value, context.GetSubContext(8))
+
+ union.SetInternals(field, value)
+ return union
diff --git a/mojo/public/tools/bindings/generators/mojom_python_generator.py b/mojo/public/tools/bindings/generators/mojom_python_generator.py
index 1f726b6..a12f5e1 100644
--- a/mojo/public/tools/bindings/generators/mojom_python_generator.py
+++ b/mojo/public/tools/bindings/generators/mojom_python_generator.py
@@ -146,7 +146,13 @@
arguments.append('nullable=True')
return '_descriptor.MapType(%s)' % ', '.join(arguments)
- if mojom.IsStructKind(kind) or mojom.IsUnionKind(kind):
+ if mojom.IsUnionKind(kind):
+ arguments = [ 'lambda: %s' % GetFullyQualifiedName(kind) ]
+ if mojom.IsNullableKind(kind):
+ arguments.append('nullable=True')
+ return '_descriptor.UnionType(%s)' % ', '.join(arguments)
+
+ if mojom.IsStructKind(kind):
arguments = [ 'lambda: %s' % GetFullyQualifiedName(kind) ]
if mojom.IsNullableKind(kind):
arguments.append('nullable=True')
@@ -169,15 +175,14 @@
return _kind_to_type[kind]
-def GetFieldDescriptor(packed_field):
- field = packed_field.field
+def GetFieldDescriptor(field, index, min_version):
class_name = 'SingleFieldGroup'
if field.kind == mojom.BOOL:
class_name = 'FieldDescriptor'
arguments = [ '%r' % GetNameForElement(field) ]
arguments.append(GetFieldType(field.kind, field))
- arguments.append(str(packed_field.index))
- arguments.append(str(packed_field.min_version))
+ arguments.append(str(index))
+ arguments.append(str(min_version))
if field.default:
if mojom.IsStructKind(field.kind):
arguments.append('default_value=True')
@@ -185,12 +190,19 @@
arguments.append('default_value=%s' % ExpressionToText(field.default))
return '_descriptor.%s(%s)' % (class_name, ', '.join(arguments))
+def GetStructFieldDescriptor(packed_field):
+ return GetFieldDescriptor(
+ packed_field.field, packed_field.index, packed_field.min_version)
+
+def GetUnionFieldDescriptor(field):
+ return GetFieldDescriptor(field, field.ordinal, 0)
+
def GetFieldGroup(byte):
if byte.packed_fields[0].field.kind == mojom.BOOL:
- descriptors = map(GetFieldDescriptor, byte.packed_fields)
+ descriptors = map(GetStructFieldDescriptor, byte.packed_fields)
return '_descriptor.BooleanGroup([%s])' % ', '.join(descriptors)
assert len(byte.packed_fields) == 1
- return GetFieldDescriptor(byte.packed_fields[0])
+ return GetStructFieldDescriptor(byte.packed_fields[0])
def MojomToPythonImport(mojom):
return mojom.replace('.mojom', '_mojom')
@@ -200,6 +212,7 @@
python_filters = {
'expression_to_text': ExpressionToText,
'field_group': GetFieldGroup,
+ 'union_field_descriptor': GetUnionFieldDescriptor,
'fully_qualified_name': GetFullyQualifiedName,
'name': GetNameForElement,
}
@@ -213,6 +226,7 @@
'module': resolver.ResolveConstants(self.module, ExpressionToText),
'namespace': self.module.namespace,
'structs': self.GetStructs(),
+ 'unions': self.GetUnions(),
}
def GenerateFiles(self, args):
diff --git a/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl b/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl
index 2a22932..9d57600 100644
--- a/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl
+++ b/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl
@@ -1,5 +1,6 @@
{% from "module_macros.tmpl" import enum_values %}
{% from "module_macros.tmpl" import struct_descriptor %}
+{% from "module_macros.tmpl" import union_descriptor %}
# Copyright 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
@@ -34,6 +35,13 @@
__metaclass__ = _reflection.MojoStructType
DESCRIPTOR = {{struct_descriptor(struct)|indent(2)}}
{% endfor %}
+{% for union in unions %}
+
+class {{union|name}}(object):
+ __metaclass__ = _reflection.MojoUnionType
+ DESCRIPTOR = {{union_descriptor(union)|indent(2)}}
+{% endfor %}
+
{% for interface in interfaces %}
class {{interface|name}}(object):
diff --git a/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl b/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl
index 21c70db..c979f59 100644
--- a/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl
+++ b/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl
@@ -37,3 +37,13 @@
{% endif %}
}
{%- endmacro -%}
+
+{%- macro union_descriptor(union) -%}
+{
+ 'fields': [
+{% for field in union.fields %}
+ {{field|union_field_descriptor}},
+{% endfor %}
+ ],
+ }
+{%- endmacro -%}
diff --git a/mojo/python/tests/bindings_unions_unittest.py b/mojo/python/tests/bindings_unions_unittest.py
new file mode 100644
index 0000000..b1d742f
--- /dev/null
+++ b/mojo/python/tests/bindings_unions_unittest.py
@@ -0,0 +1,188 @@
+# Copyright 2015 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import struct
+import unittest
+
+# Generated files
+# pylint: disable=F0401
+import test_unions_mojom
+import mojo_bindings.serialization as serialization
+
+class UnionBindingsTest(unittest.TestCase):
+
+ def testBasics(self):
+ u = test_unions_mojom.PodUnion()
+ self.assertTrue(u.IsUnknown())
+
+ u.f_uint32 = 32
+ self.assertEquals(u.f_uint32, 32)
+ self.assertEquals(u.data, 32)
+ self.assertEquals(test_unions_mojom.PodUnion.Tags.f_uint32, u.tag)
+ self.assertFalse(u.IsUnknown())
+
+ u = test_unions_mojom.PodUnion(f_uint8=8)
+ self.assertEquals(u.f_uint8, 8)
+ self.assertEquals(u.data, 8)
+ self.assertEquals(test_unions_mojom.PodUnion.Tags.f_uint8, u.tag)
+
+ with self.assertRaises(TypeError):
+ test_unions_mojom.PodUnion(f_uint8=8, f_int16=10)
+
+ with self.assertRaises(AttributeError):
+ test_unions_mojom.PodUnion(bad_field=10)
+
+ with self.assertRaises(AttributeError):
+ u = test_unions_mojom.PodUnion()
+ u.bad_field = 32
+
+ with self.assertRaises(AttributeError):
+ _ = u.f_uint16
+
+ def testPodUnionSerialization(self):
+ u = test_unions_mojom.PodUnion(f_uint32=32)
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.PodUnion.Deserialize(context)
+
+ self.assertFalse(decoded.IsUnknown())
+ self.assertEquals(u, decoded)
+
+ def testUnionUnknownTag(self):
+ u = test_unions_mojom.NewUnion(f_int16=10)
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.OldUnion.Deserialize(context)
+
+ self.assertTrue(decoded.IsUnknown())
+
+ def testObjectInUnionSerialization(self):
+ u = test_unions_mojom.ObjectUnion(
+ f_dummy=test_unions_mojom.DummyStruct())
+ u.f_dummy.f_int8 = 8
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+ self.assertEquals(u, decoded)
+
+ def testObjectInUnionInObjectSerialization(self):
+ s = test_unions_mojom.SmallObjStruct()
+ s.obj_union = test_unions_mojom.ObjectUnion(
+ f_dummy=test_unions_mojom.DummyStruct())
+ s.obj_union.f_dummy.f_int8 = 25
+ (data, handles) = s.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.SmallObjStruct.Deserialize(context)
+
+ self.assertEquals(s, decoded)
+
+ def testNestedUnionSerialization(self):
+ u = test_unions_mojom.ObjectUnion(
+ f_pod_union=test_unions_mojom.PodUnion(f_int32=32))
+
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+ self.assertEquals(u, decoded)
+
+ def testNullableNullObjectInUnionSerialization(self):
+ u = test_unions_mojom.ObjectUnion(f_nullable=None)
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+ self.assertEquals(u, decoded)
+
+ def testNonNullableNullObjectInUnionSerialization(self):
+ u = test_unions_mojom.ObjectUnion(f_dummy=None)
+ with self.assertRaises(serialization.SerializationException):
+ u.Serialize()
+
+ def testArrayInUnionSerialization(self):
+ u = test_unions_mojom.ObjectUnion(
+ f_array_int8=[1, 2, 3, 4, 5])
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+ self.assertEquals(u, decoded)
+
+ def testMapInUnionSerialization(self):
+ u = test_unions_mojom.ObjectUnion(
+ f_map_int8={'one': 1, 'two': 2, 'three': 3})
+ (data, handles) = u.Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+ self.assertEquals(u, decoded)
+
+ def testUnionInObject(self):
+ s = test_unions_mojom.SmallStruct()
+ s.pod_union = test_unions_mojom.PodUnion(f_uint32=32)
+ (data, handles) = s.Serialize()
+
+ # This is where the data should be serialized to.
+ size, tag, value = struct.unpack_from('<IIQ', buffer(data), 16)
+ self.assertEquals(16, size)
+ self.assertEquals(6, tag)
+ self.assertEquals(32, value)
+
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+ self.assertEquals(s, decoded)
+
+ def testUnionInArray(self):
+ s = test_unions_mojom.SmallStruct()
+ s.pod_union_array = [
+ test_unions_mojom.PodUnion(f_uint32=32),
+ test_unions_mojom.PodUnion(f_uint16=16),
+ test_unions_mojom.PodUnion(f_uint64=64),
+ ]
+ (data, handles) = s.Serialize()
+
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+ self.assertEquals(s, decoded)
+
+ def testNonNullableNullUnionInArray(self):
+ s = test_unions_mojom.SmallStruct()
+ s.pod_union_array = [
+ test_unions_mojom.PodUnion(f_uint32=32),
+ None,
+ test_unions_mojom.PodUnion(f_uint64=64),
+ ]
+ with self.assertRaises(serialization.SerializationException):
+ s.Serialize()
+
+ def testNullableNullUnionInArray(self):
+ s = test_unions_mojom.SmallStruct()
+ s.nullable_pod_union_array = [
+ test_unions_mojom.PodUnion(f_uint32=32),
+ None,
+ test_unions_mojom.PodUnion(f_uint64=64),
+ ]
+ (data, handles) = s.Serialize()
+
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+ self.assertEquals(s, decoded)
+
+ def testUnionInMap(self):
+ s = test_unions_mojom.SmallStruct()
+ s.pod_union_map = {
+ 'f_uint32': test_unions_mojom.PodUnion(f_uint32=32),
+ 'f_uint16': test_unions_mojom.PodUnion(f_uint16=16),
+ 'f_uint64': test_unions_mojom.PodUnion(f_uint64=64),
+ }
+ (data, handles) = s.Serialize()
+
+ context = serialization.RootDeserializationContext(data, handles)
+ decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+ self.assertEquals(s, decoded)