Implement python mojo bindings unions.

R=qsr@chromium.org

Review URL: https://codereview.chromium.org/1218023006 .
diff --git a/mojo/public/interfaces/bindings/tests/test_unions.mojom b/mojo/public/interfaces/bindings/tests/test_unions.mojom
index 6e631b2..e998f47 100644
--- a/mojo/public/interfaces/bindings/tests/test_unions.mojom
+++ b/mojo/public/interfaces/bindings/tests/test_unions.mojom
@@ -69,6 +69,7 @@
 
 struct SmallObjStruct {
   ObjectUnion obj_union;
+  int8 f_int8;
 };
 
 interface SmallCache {
@@ -84,3 +85,12 @@
   DummyStruct? nullable;
   DummyStruct non_nullable;
 };
+
+union OldUnion {
+  int8 f_int8;
+};
+
+union NewUnion {
+  int8 f_int8;
+  int16 f_int16;
+};
diff --git a/mojo/public/python/mojo_bindings/descriptor.py b/mojo/public/python/mojo_bindings/descriptor.py
index efba73d..c359e44 100644
--- a/mojo/public/python/mojo_bindings/descriptor.py
+++ b/mojo/public/python/mojo_bindings/descriptor.py
@@ -34,6 +34,13 @@
     """
     return self.Convert(value)
 
+  def IsUnion(self):
+    """
+    Returns true if the type is a union. This is necessary to be able to
+    identify a union when descriptor.py cannot be imported.
+    """
+    return False
+
 
 class SerializableType(Type):
   """Describe a type that can be serialized by itself."""
@@ -153,6 +160,47 @@
     return float(value)
 
 
+class UnionType(SerializableType):
+  """Base Type object for union."""
+
+  def __init__(self, union_type_getter, nullable=False):
+    SerializableType.__init__(self, 'IIQ')
+    self.nullable = nullable
+    self._union_type_getter = union_type_getter
+    self._union_type = None
+
+  def IsUnion(self):
+    return True
+
+  @property
+  def union_type(self):
+    if not self._union_type:
+      self._union_type = self._union_type_getter()
+    return self._union_type
+
+  def Serialize(self, value, data_offset, data, handle_offset):
+    if not value:
+      if not self.nullable:
+        raise serialization.SerializationException(
+            'Trying to serialize null for non nullable type.')
+      return ((0, 0, 0), [])
+
+    ((size, tag, entry, new_data), new_handles) = (
+        value.SerializeInline(handle_offset))
+    if len(new_data) > 0:
+      data.extend(new_data)
+      entry = data_offset - 8
+
+    return ((size, tag, entry), new_handles)
+
+  def Deserialize(self, value, context):
+    result = self.union_type.Deserialize(context)
+    if not result and not self.nullable:
+      raise serialization.DeserializationException(
+          'Trying to deserialize null for non nullable type.')
+    return result
+
+
 class PointerType(SerializableType):
   """Base Type object for pointers."""
 
@@ -434,16 +482,19 @@
       to_pack.extend(serialization.Flatten(new_data))
       returned_handles.extend(new_handles)
       position = position + self.sub_type.GetByteSize()
+
     serialization.HEADER_STRUCT.pack_into(data, data_end, size, len(value))
-    struct.pack_into('%d%s' % (len(value), self.sub_type.GetTypeCode()),
+    # TODO(azani): Refactor so we don't have to create big formatting strings.
+    struct.pack_into(('%s' % self.sub_type.GetTypeCode()) * len(value),
                      data,
                      data_end + serialization.HEADER_STRUCT.size,
                      *to_pack)
     return (data_offset, returned_handles)
 
   def DeserializeArray(self, size, nb_elements, context):
+    # TODO(azani): Refactor so the format string isn't so big.
     values = struct.unpack_from(
-        '%d%s' % (nb_elements, self.sub_type.GetTypeCode()),
+        nb_elements * self.sub_type.GetTypeCode(),
         buffer(context.data, serialization.HEADER_STRUCT.size))
     values_per_element = len(self.sub_type.GetTypeCode())
     assert nb_elements * values_per_element == len(values)
diff --git a/mojo/public/python/mojo_bindings/reflection.py b/mojo/public/python/mojo_bindings/reflection.py
index 6c4767b..cfd7e64 100644
--- a/mojo/public/python/mojo_bindings/reflection.py
+++ b/mojo/public/python/mojo_bindings/reflection.py
@@ -135,6 +135,103 @@
     raise AttributeError('can\'t delete attribute')
 
 
+class MojoUnionType(type):
+
+  def __new__(mcs, name, bases, dictionary):
+    dictionary['__slots__'] = ('_cur_field', '_data')
+    descriptor = dictionary.pop('DESCRIPTOR', {})
+
+    fields = descriptor.get('fields', [])
+    def _BuildUnionProperty(field):
+
+      # pylint: disable=W0212
+      def Get(self):
+        if self._cur_field != field:
+          raise AttributeError('%s is not currently set' % field.name,
+              field.name, self._cur_field.name)
+        return self._data
+
+      # pylint: disable=W0212
+      def Set(self, value):
+        self._cur_field = field
+        self._data = field.field_type.Convert(value)
+
+      return property(Get, Set)
+
+    for field in fields:
+      dictionary[field.name] = _BuildUnionProperty(field)
+
+    def UnionInit(self, **kwargs):
+      self.SetInternals(None, None)
+      items = kwargs.items()
+      if len(items) == 0:
+        return
+
+      if len(items) > 1:
+        raise TypeError('only 1 member may be set on a union.')
+
+      setattr(self, items[0][0], items[0][1])
+    dictionary['__init__'] = UnionInit
+
+    serializer = serialization.UnionSerializer(fields)
+    def SerializeUnionInline(self, handle_offset=0):
+      return serializer.SerializeInline(self, handle_offset)
+    dictionary['SerializeInline'] = SerializeUnionInline
+
+    def SerializeUnion(self, handle_offset=0):
+      return serializer.Serialize(self, handle_offset)
+    dictionary['Serialize'] = SerializeUnion
+
+    def DeserializeUnion(cls, context):
+      return serializer.Deserialize(context, cls)
+    dictionary['Deserialize'] = classmethod(DeserializeUnion)
+
+    class Tags(object):
+      __metaclass__ = MojoEnumType
+      VALUES = [(field.name, field.index) for field in fields]
+    dictionary['Tags'] = Tags
+
+    def GetTag(self):
+      return self._cur_field.index
+    dictionary['tag'] = property(GetTag, None)
+
+    def GetData(self):
+      return self._data
+    dictionary['data'] = property(GetData, None)
+
+    def IsUnknown(self):
+      return not self._cur_field
+    dictionary['IsUnknown'] = IsUnknown
+
+    def UnionEq(self, other):
+      return (
+          (type(self) is type(other))
+          and (self.tag == other.tag)
+          and (self.data == other.data))
+    dictionary['__eq__'] = UnionEq
+
+    def UnionNe(self, other):
+      return not self.__eq__(other)
+    dictionary['__ne__'] = UnionNe
+
+    def UnionStr(self):
+      return '<%s.%s(%s): %s>' % (
+          self.__class__.__name__,
+          self._cur_field.name,
+          self.tag,
+          self.data)
+    dictionary['__str__'] = UnionStr
+    dictionary['__repr__'] = UnionStr
+
+    def SetInternals(self, field, data):
+      self._cur_field = field
+      self._data = data
+    dictionary['SetInternals'] = SetInternals
+
+
+    return type.__new__(mcs, name, bases, dictionary)
+
+
 class InterfaceRequest(object):
   """
   An interface request allows to send a request for an interface to a remote
diff --git a/mojo/public/python/mojo_bindings/serialization.py b/mojo/public/python/mojo_bindings/serialization.py
index 32f60f0..b1a35ec 100644
--- a/mojo/public/python/mojo_bindings/serialization.py
+++ b/mojo/public/python/mojo_bindings/serialization.py
@@ -7,9 +7,12 @@
 import struct
 
 
-# Format of a header for a struct or an array.
+# Format of a header for a struct, array or union.
 HEADER_STRUCT = struct.Struct("<II")
 
+# Format for a pointer.
+POINTER_STRUCT = struct.Struct("<Q")
+
 
 def Flatten(value):
   """Flattens nested lists/tuples into an one-level list. If value is not a
@@ -218,3 +221,89 @@
   if alignment_needed:
     codes.append('x' * alignment_needed)
   return struct.Struct(''.join(codes))
+
+
+class UnionSerializer(object):
+  """
+  Helper class to serialize/deserialize a union.
+  """
+  def __init__(self, fields):
+    self._fields = {field.index: field for field in fields}
+
+  def SerializeInline(self, union, handle_offset):
+    data = bytearray()
+    field = self._fields[union.tag]
+
+    # If the union value is a simple type or a nested union, it is returned as
+    # entry.
+    # Otherwise, the serialized value is appended to data and the value of entry
+    # is -1. The caller will need to set entry to the location where the
+    # caller will append data.
+    (entry, handles) = field.field_type.Serialize(
+        union.data, -1, data, handle_offset)
+
+    # If the value contained in the union is itself a union, we append its
+    # serialized value to data and set entry to -1. The caller will need to set
+    # entry to the location where the caller will append data.
+    if field.field_type.IsUnion():
+      nested_union = bytearray(16)
+      HEADER_STRUCT.pack_into(nested_union, 0, entry[0], entry[1])
+      POINTER_STRUCT.pack_into(nested_union, 8, entry[2])
+
+      data = nested_union + data
+
+      # Since we do not know where the caller will append the nested union,
+      # we set entry to an invalid value and let the caller figure out the right
+      # value.
+      entry = -1
+
+    return (16, union.tag, entry, data), handles
+
+  def Serialize(self, union, handle_offset):
+    (size, tag, entry, extra_data), handles = self.SerializeInline(
+        union, handle_offset)
+    data = bytearray(16)
+    if extra_data:
+      entry = 8
+    data.extend(extra_data)
+
+    field = self._fields[union.tag]
+
+    HEADER_STRUCT.pack_into(data, 0, size, tag)
+    typecode = field.GetTypeCode()
+
+    # If the value is a nested union, we store a 64 bits pointer to it.
+    if field.field_type.IsUnion():
+      typecode = 'Q'
+
+    struct.pack_into('<%s' % typecode, data, 8, entry)
+    return data, handles
+
+  def Deserialize(self, context, union_class):
+    if len(context.data) < HEADER_STRUCT.size:
+      raise DeserializationException(
+          'Available data too short to contain header.')
+    (size, tag) = HEADER_STRUCT.unpack_from(context.data)
+
+    if size == 0:
+      return None
+
+    if size != 16:
+      raise DeserializationException('Invalid union size %s' % size)
+
+    union = union_class.__new__(union_class)
+    if tag not in self._fields:
+      union.SetInternals(None, None)
+      return union
+
+    field = self._fields[tag]
+    if field.field_type.IsUnion():
+      ptr = POINTER_STRUCT.unpack_from(context.data, 8)[0]
+      value = field.field_type.Deserialize(ptr, context.GetSubContext(ptr+8))
+    else:
+      raw_value = struct.unpack_from(
+          field.GetTypeCode(), context.data, 8)[0]
+      value = field.field_type.Deserialize(raw_value, context.GetSubContext(8))
+
+    union.SetInternals(field, value)
+    return union
diff --git a/mojo/public/tools/bindings/generators/mojom_python_generator.py b/mojo/public/tools/bindings/generators/mojom_python_generator.py
index 1f726b6..a12f5e1 100644
--- a/mojo/public/tools/bindings/generators/mojom_python_generator.py
+++ b/mojo/public/tools/bindings/generators/mojom_python_generator.py
@@ -146,7 +146,13 @@
       arguments.append('nullable=True')
     return '_descriptor.MapType(%s)' % ', '.join(arguments)
 
-  if mojom.IsStructKind(kind) or mojom.IsUnionKind(kind):
+  if mojom.IsUnionKind(kind):
+    arguments = [ 'lambda: %s' % GetFullyQualifiedName(kind) ]
+    if mojom.IsNullableKind(kind):
+      arguments.append('nullable=True')
+    return '_descriptor.UnionType(%s)' % ', '.join(arguments)
+
+  if mojom.IsStructKind(kind):
     arguments = [ 'lambda: %s' % GetFullyQualifiedName(kind) ]
     if mojom.IsNullableKind(kind):
       arguments.append('nullable=True')
@@ -169,15 +175,14 @@
 
   return _kind_to_type[kind]
 
-def GetFieldDescriptor(packed_field):
-  field = packed_field.field
+def GetFieldDescriptor(field, index, min_version):
   class_name = 'SingleFieldGroup'
   if field.kind == mojom.BOOL:
     class_name = 'FieldDescriptor'
   arguments = [ '%r' % GetNameForElement(field) ]
   arguments.append(GetFieldType(field.kind, field))
-  arguments.append(str(packed_field.index))
-  arguments.append(str(packed_field.min_version))
+  arguments.append(str(index))
+  arguments.append(str(min_version))
   if field.default:
     if mojom.IsStructKind(field.kind):
       arguments.append('default_value=True')
@@ -185,12 +190,19 @@
       arguments.append('default_value=%s' % ExpressionToText(field.default))
   return '_descriptor.%s(%s)' % (class_name, ', '.join(arguments))
 
+def GetStructFieldDescriptor(packed_field):
+  return GetFieldDescriptor(
+      packed_field.field, packed_field.index, packed_field.min_version)
+
+def GetUnionFieldDescriptor(field):
+  return GetFieldDescriptor(field, field.ordinal, 0)
+
 def GetFieldGroup(byte):
   if byte.packed_fields[0].field.kind == mojom.BOOL:
-    descriptors = map(GetFieldDescriptor, byte.packed_fields)
+    descriptors = map(GetStructFieldDescriptor, byte.packed_fields)
     return '_descriptor.BooleanGroup([%s])' % ', '.join(descriptors)
   assert len(byte.packed_fields) == 1
-  return GetFieldDescriptor(byte.packed_fields[0])
+  return GetStructFieldDescriptor(byte.packed_fields[0])
 
 def MojomToPythonImport(mojom):
   return mojom.replace('.mojom', '_mojom')
@@ -200,6 +212,7 @@
   python_filters = {
     'expression_to_text': ExpressionToText,
     'field_group': GetFieldGroup,
+    'union_field_descriptor': GetUnionFieldDescriptor,
     'fully_qualified_name': GetFullyQualifiedName,
     'name': GetNameForElement,
   }
@@ -213,6 +226,7 @@
       'module': resolver.ResolveConstants(self.module, ExpressionToText),
       'namespace': self.module.namespace,
       'structs': self.GetStructs(),
+      'unions': self.GetUnions(),
     }
 
   def GenerateFiles(self, args):
diff --git a/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl b/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl
index 2a22932..9d57600 100644
--- a/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl
+++ b/mojo/public/tools/bindings/generators/python_templates/module.py.tmpl
@@ -1,5 +1,6 @@
 {% from "module_macros.tmpl" import enum_values %}
 {% from "module_macros.tmpl" import struct_descriptor %}
+{% from "module_macros.tmpl" import union_descriptor %}
 # Copyright 2014 The Chromium Authors. All rights reserved.
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
@@ -34,6 +35,13 @@
   __metaclass__ = _reflection.MojoStructType
   DESCRIPTOR = {{struct_descriptor(struct)|indent(2)}}
 {% endfor %}
+{% for union in unions %}
+
+class {{union|name}}(object):
+  __metaclass__ = _reflection.MojoUnionType
+  DESCRIPTOR = {{union_descriptor(union)|indent(2)}}
+{% endfor %}
+
 {% for interface in interfaces %}
 
 class {{interface|name}}(object):
diff --git a/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl b/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl
index 21c70db..c979f59 100644
--- a/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl
+++ b/mojo/public/tools/bindings/generators/python_templates/module_macros.tmpl
@@ -37,3 +37,13 @@
 {% endif %}
 }
 {%- endmacro -%}
+
+{%- macro union_descriptor(union) -%}
+{
+  'fields': [
+{% for field in union.fields %}
+    {{field|union_field_descriptor}},
+{% endfor %}
+  ],
+ }
+{%- endmacro -%}
diff --git a/mojo/python/tests/bindings_unions_unittest.py b/mojo/python/tests/bindings_unions_unittest.py
new file mode 100644
index 0000000..b1d742f
--- /dev/null
+++ b/mojo/python/tests/bindings_unions_unittest.py
@@ -0,0 +1,188 @@
+# Copyright 2015 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import struct
+import unittest
+
+# Generated files
+# pylint: disable=F0401
+import test_unions_mojom
+import mojo_bindings.serialization as serialization
+
+class UnionBindingsTest(unittest.TestCase):
+
+  def testBasics(self):
+    u = test_unions_mojom.PodUnion()
+    self.assertTrue(u.IsUnknown())
+
+    u.f_uint32 = 32
+    self.assertEquals(u.f_uint32, 32)
+    self.assertEquals(u.data, 32)
+    self.assertEquals(test_unions_mojom.PodUnion.Tags.f_uint32, u.tag)
+    self.assertFalse(u.IsUnknown())
+
+    u = test_unions_mojom.PodUnion(f_uint8=8)
+    self.assertEquals(u.f_uint8, 8)
+    self.assertEquals(u.data, 8)
+    self.assertEquals(test_unions_mojom.PodUnion.Tags.f_uint8, u.tag)
+
+    with self.assertRaises(TypeError):
+      test_unions_mojom.PodUnion(f_uint8=8, f_int16=10)
+
+    with self.assertRaises(AttributeError):
+      test_unions_mojom.PodUnion(bad_field=10)
+
+    with self.assertRaises(AttributeError):
+      u = test_unions_mojom.PodUnion()
+      u.bad_field = 32
+
+    with self.assertRaises(AttributeError):
+      _ = u.f_uint16
+
+  def testPodUnionSerialization(self):
+    u = test_unions_mojom.PodUnion(f_uint32=32)
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.PodUnion.Deserialize(context)
+
+    self.assertFalse(decoded.IsUnknown())
+    self.assertEquals(u, decoded)
+
+  def testUnionUnknownTag(self):
+    u = test_unions_mojom.NewUnion(f_int16=10)
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.OldUnion.Deserialize(context)
+
+    self.assertTrue(decoded.IsUnknown())
+
+  def testObjectInUnionSerialization(self):
+    u = test_unions_mojom.ObjectUnion(
+        f_dummy=test_unions_mojom.DummyStruct())
+    u.f_dummy.f_int8 = 8
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+    self.assertEquals(u, decoded)
+
+  def testObjectInUnionInObjectSerialization(self):
+    s = test_unions_mojom.SmallObjStruct()
+    s.obj_union = test_unions_mojom.ObjectUnion(
+        f_dummy=test_unions_mojom.DummyStruct())
+    s.obj_union.f_dummy.f_int8 = 25
+    (data, handles) = s.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.SmallObjStruct.Deserialize(context)
+
+    self.assertEquals(s, decoded)
+
+  def testNestedUnionSerialization(self):
+    u = test_unions_mojom.ObjectUnion(
+        f_pod_union=test_unions_mojom.PodUnion(f_int32=32))
+
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+    self.assertEquals(u, decoded)
+
+  def testNullableNullObjectInUnionSerialization(self):
+    u =  test_unions_mojom.ObjectUnion(f_nullable=None)
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+    self.assertEquals(u, decoded)
+
+  def testNonNullableNullObjectInUnionSerialization(self):
+    u =  test_unions_mojom.ObjectUnion(f_dummy=None)
+    with self.assertRaises(serialization.SerializationException):
+      u.Serialize()
+
+  def testArrayInUnionSerialization(self):
+    u = test_unions_mojom.ObjectUnion(
+        f_array_int8=[1, 2, 3, 4, 5])
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+    self.assertEquals(u, decoded)
+
+  def testMapInUnionSerialization(self):
+    u = test_unions_mojom.ObjectUnion(
+        f_map_int8={'one': 1, 'two': 2, 'three': 3})
+    (data, handles) = u.Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.ObjectUnion.Deserialize(context)
+
+    self.assertEquals(u, decoded)
+
+  def testUnionInObject(self):
+    s = test_unions_mojom.SmallStruct()
+    s.pod_union = test_unions_mojom.PodUnion(f_uint32=32)
+    (data, handles) = s.Serialize()
+
+    # This is where the data should be serialized to.
+    size, tag, value = struct.unpack_from('<IIQ', buffer(data), 16)
+    self.assertEquals(16, size)
+    self.assertEquals(6, tag)
+    self.assertEquals(32, value)
+
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+    self.assertEquals(s, decoded)
+
+  def testUnionInArray(self):
+    s = test_unions_mojom.SmallStruct()
+    s.pod_union_array = [
+        test_unions_mojom.PodUnion(f_uint32=32),
+        test_unions_mojom.PodUnion(f_uint16=16),
+        test_unions_mojom.PodUnion(f_uint64=64),
+        ]
+    (data, handles) = s.Serialize()
+
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+    self.assertEquals(s, decoded)
+
+  def testNonNullableNullUnionInArray(self):
+    s = test_unions_mojom.SmallStruct()
+    s.pod_union_array = [
+        test_unions_mojom.PodUnion(f_uint32=32),
+        None,
+        test_unions_mojom.PodUnion(f_uint64=64),
+        ]
+    with self.assertRaises(serialization.SerializationException):
+      s.Serialize()
+
+  def testNullableNullUnionInArray(self):
+    s = test_unions_mojom.SmallStruct()
+    s.nullable_pod_union_array = [
+        test_unions_mojom.PodUnion(f_uint32=32),
+        None,
+        test_unions_mojom.PodUnion(f_uint64=64),
+        ]
+    (data, handles) = s.Serialize()
+
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+    self.assertEquals(s, decoded)
+
+  def testUnionInMap(self):
+    s = test_unions_mojom.SmallStruct()
+    s.pod_union_map = {
+        'f_uint32': test_unions_mojom.PodUnion(f_uint32=32),
+        'f_uint16': test_unions_mojom.PodUnion(f_uint16=16),
+        'f_uint64': test_unions_mojom.PodUnion(f_uint64=64),
+        }
+    (data, handles) = s.Serialize()
+
+    context = serialization.RootDeserializationContext(data, handles)
+    decoded = test_unions_mojom.SmallStruct.Deserialize(context)
+
+    self.assertEquals(s, decoded)