Add validation tests to python bindings.

R=sdefresne@chromium.org

Review URL: https://codereview.chromium.org/761553003
diff --git a/mojo/public/python/mojo/bindings/descriptor.py b/mojo/public/python/mojo/bindings/descriptor.py
index f190d2b..0df0bd6 100644
--- a/mojo/public/python/mojo/bindings/descriptor.py
+++ b/mojo/public/python/mojo/bindings/descriptor.py
@@ -72,7 +72,7 @@
     """
     raise NotImplementedError()
 
-  def Deserialize(self, value, data, handles):
+  def Deserialize(self, value, context):
     """
     Deserialize a value of this type.
 
@@ -106,7 +106,7 @@
   def Serialize(self, value, data_offset, data, handle_offset):
     return (value, [])
 
-  def Deserialize(self, value, data, handles):
+  def Deserialize(self, value, context):
     return value
 
 
@@ -161,21 +161,31 @@
       return (0, [])
     return self.SerializePointer(value, data_offset, data, handle_offset)
 
-  def Deserialize(self, value, data, handles):
+  def Deserialize(self, value, context):
     if value == 0:
       if not self.nullable:
         raise serialization.DeserializationException(
             'Trying to deserialize null for non nullable type.')
       return None
-    pointed_data = buffer(data, value)
-    (size, nb_elements) = serialization.HEADER_STRUCT.unpack_from(pointed_data)
-    return self.DeserializePointer(size, nb_elements, pointed_data, handles)
+    if value % 8 != 0:
+      raise serialization.DeserializationException(
+          'Pointer alignment is incorrect.')
+    sub_context = context.GetSubContext(value)
+    if len(sub_context.data) < serialization.HEADER_STRUCT.size:
+      raise serialization.DeserializationException(
+          'Available data too short to contain header.')
+    (size, nb_elements) = serialization.HEADER_STRUCT.unpack_from(
+        sub_context.data)
+    if len(sub_context.data) < size or size < serialization.HEADER_STRUCT.size:
+      raise serialization.DeserializationException('Header size is incorrect.')
+    sub_context.ClaimMemory(0, size)
+    return self.DeserializePointer(size, nb_elements, sub_context)
 
   def SerializePointer(self, value, data_offset, data, handle_offset):
     """Serialize the not null value."""
     raise NotImplementedError()
 
-  def DeserializePointer(self, size, nb_elements, data, handles):
+  def DeserializePointer(self, size, nb_elements, context):
     raise NotImplementedError()
 
 
@@ -204,9 +214,8 @@
     return self._array_type.SerializeArray(
         string_array, data_offset, data, handle_offset)
 
-  def DeserializePointer(self, size, nb_elements, data, handles):
-    string_array = self._array_type.DeserializeArray(
-        size, nb_elements, data, handles)
+  def DeserializePointer(self, size, nb_elements, context):
+    string_array = self._array_type.DeserializeArray(size, nb_elements, context)
     return unicode(string_array.tostring(), 'utf8')
 
 
@@ -226,14 +235,13 @@
       return (-1, [])
     return (handle_offset, [handle])
 
-  def Deserialize(self, value, data, handles):
+  def Deserialize(self, value, context):
     if value == -1:
       if not self.nullable:
         raise serialization.DeserializationException(
             'Trying to deserialize null for non nullable type.')
       return self.FromHandle(mojo.system.Handle())
-    # TODO(qsr) validate handle order
-    return self.FromHandle(handles[value])
+    return self.FromHandle(context.ClaimHandle(value))
 
   def FromHandle(self, handle):
     raise NotImplementedError()
@@ -326,12 +334,18 @@
     """Serialize the not null array."""
     raise NotImplementedError()
 
-  def DeserializePointer(self, size, nb_elements, data, handles):
-    if self.length != 0 and size != self.length:
+  def DeserializePointer(self, size, nb_elements, context):
+    if self.length != 0 and nb_elements != self.length:
       raise serialization.DeserializationException('Incorrect array size')
-    return self.DeserializeArray(size, nb_elements, data, handles)
+    if (size <
+        serialization.HEADER_STRUCT.size + self.SizeForLength(nb_elements)):
+      raise serialization.DeserializationException('Incorrect array size')
+    return self.DeserializeArray(size, nb_elements, context)
 
-  def DeserializeArray(self, size, nb_elements, data, handles):
+  def DeserializeArray(self, size, nb_elements, context):
+    raise NotImplementedError()
+
+  def SizeForLength(self, nb_elements):
     raise NotImplementedError()
 
 
@@ -351,9 +365,8 @@
     converted = array.array('B', [_ConvertBooleansToByte(x) for x in groups])
     return _SerializeNativeArray(converted, data_offset, data, len(value))
 
-  def DeserializeArray(self, size, nb_elements, data, handles):
-    converted = self._array_type.DeserializeArray(
-        size, nb_elements, data, handles)
+  def DeserializeArray(self, size, nb_elements, context):
+    converted = self._array_type.DeserializeArray(size, nb_elements, context)
     elements = list(itertools.islice(
         itertools.chain.from_iterable(
             [_ConvertByteToBooleans(x, 8) for x in converted]),
@@ -361,6 +374,9 @@
         nb_elements))
     return elements
 
+  def SizeForLength(self, nb_elements):
+    return (nb_elements + 7) // 8
+
 
 class GenericArrayType(BaseArrayType):
   """Type object for arrays of pointers."""
@@ -400,18 +416,22 @@
                      *to_pack)
     return (data_offset, returned_handles)
 
-  def DeserializeArray(self, size, nb_elements, data, handles):
+  def DeserializeArray(self, size, nb_elements, context):
     values = struct.unpack_from(
         '%d%s' % (nb_elements, self.sub_type.GetTypeCode()),
-        buffer(data, serialization.HEADER_STRUCT.size))
+        buffer(context.data, serialization.HEADER_STRUCT.size))
     result = []
-    position = serialization.HEADER_STRUCT.size
+    sub_context = context.GetSubContext(serialization.HEADER_STRUCT.size)
     for value in values:
-      result.append(
-          self.sub_type.Deserialize(value, buffer(data, position), handles))
-      position += self.sub_type.GetByteSize()
+      result.append(self.sub_type.Deserialize(
+          value,
+          sub_context))
+      sub_context = sub_context.GetSubContext(self.sub_type.GetByteSize())
     return result
 
+  def SizeForLength(self, nb_elements):
+    return nb_elements * self.sub_type.GetByteSize();
+
 
 class NativeArrayType(BaseArrayType):
   """Type object for arrays of native types."""
@@ -419,6 +439,7 @@
   def __init__(self, typecode, nullable=False, length=0):
     BaseArrayType.__init__(self, nullable, length)
     self.array_typecode = typecode
+    self.element_size = struct.calcsize('<%s' % self.array_typecode)
 
   def Convert(self, value):
     if value is None:
@@ -431,13 +452,16 @@
   def SerializeArray(self, value, data_offset, data, handle_offset):
     return _SerializeNativeArray(value, data_offset, data, len(value))
 
-  def DeserializeArray(self, size, nb_elements, data, handles):
+  def DeserializeArray(self, size, nb_elements, context):
     result = array.array(self.array_typecode)
-    result.fromstring(buffer(data,
+    result.fromstring(buffer(context.data,
                              serialization.HEADER_STRUCT.size,
                              size - serialization.HEADER_STRUCT.size))
     return result
 
+  def SizeForLength(self, nb_elements):
+    return nb_elements * self.element_size
+
 
 class StructType(PointerType):
   """Type object for structs."""
@@ -469,8 +493,8 @@
     data.extend(new_data)
     return (data_offset, new_handles)
 
-  def DeserializePointer(self, size, nb_elements, data, handles):
-    return self.struct_type.Deserialize(data, handles)
+  def DeserializePointer(self, size, nb_elements, context):
+    return self.struct_type.Deserialize(context)
 
 
 class MapType(SerializableType):
@@ -511,8 +535,8 @@
       s = self.struct(keys=keys, values=values)
     return self.struct_type.Serialize(s, data_offset, data, handle_offset)
 
-  def Deserialize(self, value, data, handles):
-    s = self.struct_type.Deserialize(value, data, handles)
+  def Deserialize(self, value, context):
+    s = self.struct_type.Deserialize(value, context)
     if s:
       if len(s.keys) != len(s.values):
         raise serialization.DeserializationException(
@@ -590,7 +614,7 @@
   def Serialize(self, obj, data_offset, data, handle_offset):
     raise NotImplementedError()
 
-  def Deserialize(self, value, data, handles):
+  def Deserialize(self, value, context):
     raise NotImplementedError()
 
 
@@ -615,8 +639,8 @@
     value = getattr(obj, self.name)
     return self.field_type.Serialize(value, data_offset, data, handle_offset)
 
-  def Deserialize(self, value, data, handles):
-    entity = self.field_type.Deserialize(value, data, handles)
+  def Deserialize(self, value, context):
+    entity = self.field_type.Deserialize(value, context)
     return { self.name: entity }
 
 
@@ -640,7 +664,7 @@
         [getattr(obj, field.name) for field in self.GetDescriptors()])
     return (value, [])
 
-  def Deserialize(self, value, data, handles):
+  def Deserialize(self, value, context):
     values =  itertools.izip_longest([x.name for x in self.descriptors],
                                       _ConvertByteToBooleans(value),
                                      fillvalue=False)
@@ -663,7 +687,7 @@
 
 
 def _ConvertByteToBooleans(value, min_size=0):
-  "Unpack an integer into a list of booleans."""
+  """Unpack an integer into a list of booleans."""
   res = []
   while value:
     res.append(bool(value&1))
diff --git a/mojo/public/python/mojo/bindings/reflection.py b/mojo/public/python/mojo/bindings/reflection.py
index 5ca38bf..9668c2a 100644
--- a/mojo/public/python/mojo/bindings/reflection.py
+++ b/mojo/public/python/mojo/bindings/reflection.py
@@ -117,10 +117,10 @@
       return self._fields
     dictionary['AsDict'] = AsDict
 
-    def Deserialize(cls, data, handles):
+    def Deserialize(cls, context):
       result = cls.__new__(cls)
       fields = {}
-      serialization_object.Deserialize(fields, data, handles)
+      serialization_object.Deserialize(fields, context)
       result._fields = fields
       return result
     dictionary['Deserialize'] = classmethod(Deserialize)
@@ -476,8 +476,9 @@
           try:
             assert message.header.message_type == method.ordinal
             payload = message.payload
-            response = method.response_struct.Deserialize(payload.data,
-                                                          payload.handles)
+            response = method.response_struct.Deserialize(
+                serialization.RootDeserializationContext(payload.data,
+                                                         payload.handles))
             as_dict = response.AsDict()
             if len(as_dict) == 1:
               value = as_dict.values()[0]
@@ -533,7 +534,8 @@
       method = methods_by_ordinal[header.message_type]
       payload = message.payload
       parameters = method.parameters_struct.Deserialize(
-          payload.data, payload.handles).AsDict()
+          serialization.RootDeserializationContext(
+              payload.data, payload.handles)).AsDict()
       response = getattr(self.impl, method.name)(**parameters)
       if header.expects_response:
         def SendResponse(response):
diff --git a/mojo/public/python/mojo/bindings/serialization.py b/mojo/public/python/mojo/bindings/serialization.py
index 2c0478f..b5ea1bd 100644
--- a/mojo/public/python/mojo/bindings/serialization.py
+++ b/mojo/public/python/mojo/bindings/serialization.py
@@ -21,6 +21,68 @@
   pass
 
 
+class DeserializationContext(object):
+
+  def ClaimHandle(self, handle):
+    raise NotImplementedError()
+
+  def ClaimMemory(self, start, size):
+    raise NotImplementedError()
+
+  def GetSubContext(self, offset):
+    raise NotImplementedError()
+
+  def IsInitialContext(self):
+    raise NotImplementedError()
+
+
+class RootDeserializationContext(DeserializationContext):
+  def __init__(self, data, handles):
+    if isinstance(data, buffer):
+      self.data = data
+    else:
+      self.data = buffer(data)
+    self._handles = handles
+    self._next_handle = 0;
+    self._next_memory = 0;
+
+  def ClaimHandle(self, handle):
+    if handle < self._next_handle:
+      raise DeserializationException('Accessing handles out of order.')
+    self._next_handle = handle + 1
+    return self._handles[handle]
+
+  def ClaimMemory(self, start, size):
+    if start < self._next_memory:
+      raise DeserializationException('Accessing buffer out of order.')
+    self._next_memory = start + size
+
+  def GetSubContext(self, offset):
+    return _ChildDeserializationContext(self, offset)
+
+  def IsInitialContext(self):
+    return True
+
+
+class _ChildDeserializationContext(DeserializationContext):
+  def __init__(self, parent, offset):
+    self._parent = parent
+    self._offset = offset
+    self.data = buffer(parent.data, offset)
+
+  def ClaimHandle(self, handle):
+    return self._parent.ClaimHandle(handle)
+
+  def ClaimMemory(self, start, size):
+    return self._parent.ClaimMemory(self._offset + start, size)
+
+  def GetSubContext(self, offset):
+    return self._parent.GetSubContext(self._offset + offset)
+
+  def IsInitialContext(self):
+    return False
+
+
 class Serialization(object):
   """
   Helper class to serialize/deserialize a struct.
@@ -78,18 +140,23 @@
     self._GetMainStruct().pack_into(data, HEADER_STRUCT.size, *to_pack)
     return (data, handles)
 
-  def Deserialize(self, fields, data, handles):
-    if not isinstance(data, buffer):
-      data = buffer(data)
-    (_, version) = HEADER_STRUCT.unpack_from(data)
+  def Deserialize(self, fields, context):
+    if len(context.data) < HEADER_STRUCT.size:
+      raise DeserializationException(
+          'Available data too short to contain header.')
+    (size, version) = HEADER_STRUCT.unpack_from(context.data)
+    if len(context.data) < size or size < HEADER_STRUCT.size:
+      raise DeserializationException('Header size is incorrect.')
+    if context.IsInitialContext():
+      context.ClaimMemory(0, size)
     version_struct = self._GetStruct(version)
-    entitities = version_struct.unpack_from(data, HEADER_STRUCT.size)
+    entitities = version_struct.unpack_from(context.data, HEADER_STRUCT.size)
     filtered_groups = self._GetGroups(version)
     position = HEADER_STRUCT.size
     for (group, value) in zip(filtered_groups, entitities):
       position = position + NeededPaddingForAlignment(position,
                                                       group.GetByteSize())
-      fields.update(group.Deserialize(value, buffer(data, position), handles))
+      fields.update(group.Deserialize(value, context.GetSubContext(position)))
       position += group.GetByteSize()
 
 
diff --git a/mojo/python/BUILD.gn b/mojo/python/BUILD.gn
index 7b07f4f..f96c2cc 100644
--- a/mojo/python/BUILD.gn
+++ b/mojo/python/BUILD.gn
@@ -8,6 +8,7 @@
 group("python") {
   deps = [
     ":embedder",
+    ":validation_util",
     "//mojo/public/python",
   ]
 }
@@ -25,3 +26,25 @@
     "//mojo/public/python:system",
   ]
 }
+
+copy("tests_module") {
+  sources = [
+    "system/mojo/tests/__init__.py",
+  ]
+  outputs = [
+    "$root_out_dir/python/mojo/tests/{{source_file_part}}",
+  ]
+}
+
+python_binary_module("validation_util") {
+  python_base_module = "mojo/tests"
+  sources = [
+    "system/mojo/tests/validation_util.pyx",
+  ]
+  deps = [
+    "//mojo/public/cpp/bindings/tests:mojo_public_bindings_test_utils",
+  ]
+  datadeps = [
+    ":tests_module",
+  ]
+}
diff --git a/mojo/python/system/mojo/tests/__init__.py b/mojo/python/system/mojo/tests/__init__.py
new file mode 100644
index 0000000..4d6aabb
--- /dev/null
+++ b/mojo/python/system/mojo/tests/__init__.py
@@ -0,0 +1,3 @@
+# Copyright 2014 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
diff --git a/mojo/python/system/mojo/tests/validation_util.pyx b/mojo/python/system/mojo/tests/validation_util.pyx
new file mode 100644
index 0000000..e7bdbcd
--- /dev/null
+++ b/mojo/python/system/mojo/tests/validation_util.pyx
@@ -0,0 +1,35 @@
+# Copyright 2014 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+# distutils: language = c++
+
+from libc.stdint cimport uint8_t
+from libcpp cimport bool
+from libcpp.string cimport string
+from libcpp.vector cimport vector
+
+cdef extern from "third_party/cython/python_export.h":
+  pass
+
+cdef extern from "mojo/public/cpp/bindings/tests/validation_test_input_parser.h":
+  cdef bool ParseValidationTestInput "mojo::test::ParseValidationTestInput"(
+      string input,
+      vector[uint8_t]* data,
+      size_t* num_handles,
+      string* error_message)
+
+class Data(object):
+  def __init__(self, data, num_handles, error_message):
+    self.data = data
+    self.num_handles = num_handles
+    self.error_message = error_message
+
+def ParseData(value):
+  cdef string value_as_string = value
+  cdef vector[uint8_t] data_as_vector
+  cdef size_t num_handles
+  cdef string error_message
+  ParseValidationTestInput(
+      value, &data_as_vector, &num_handles, &error_message)
+  return Data(bytearray(data_as_vector), num_handles, error_message)
diff --git a/mojo/python/tests/bindings_serialization_deserialization_unittest.py b/mojo/python/tests/bindings_serialization_deserialization_unittest.py
index 39685ea..da725c0 100644
--- a/mojo/python/tests/bindings_serialization_deserialization_unittest.py
+++ b/mojo/python/tests/bindings_serialization_deserialization_unittest.py
@@ -7,6 +7,7 @@
 import mojo_unittest
 
 # pylint: disable=E0611,F0401
+import mojo.bindings.serialization as serialization
 import mojo.system
 
 # Generated files
@@ -68,13 +69,15 @@
 
   def testFooDeserialization(self):
     (data, handles) = _NewFoo().Serialize()
+    context = serialization.RootDeserializationContext(data, handles)
     self.assertTrue(
-        sample_service_mojom.Foo.Deserialize(data, handles))
+        sample_service_mojom.Foo.Deserialize(context))
 
   def testFooSerializationDeserialization(self):
     foo1 = _NewFoo()
     (data, handles) = foo1.Serialize()
-    foo2 = sample_service_mojom.Foo.Deserialize(data, handles)
+    context = serialization.RootDeserializationContext(data, handles)
+    foo2 = sample_service_mojom.Foo.Deserialize(context)
     self.assertEquals(foo1, foo2)
 
   def testDefaultsTestSerializationDeserialization(self):
@@ -85,7 +88,8 @@
     v1.a22.location = sample_import_mojom.Point()
     v1.a22.size = sample_import2_mojom.Size()
     (data, handles) = v1.Serialize()
-    v2 = sample_service_mojom.DefaultsTest.Deserialize(data, handles)
+    context = serialization.RootDeserializationContext(data, handles)
+    v2 = sample_service_mojom.DefaultsTest.Deserialize(context)
     # NaN needs to be a special case.
     self.assertNotEquals(v1, v2)
     self.assertTrue(math.isnan(v2.a28))
diff --git a/mojo/python/tests/mojo_unittest.py b/mojo/python/tests/mojo_unittest.py
index 101f19c..f80ffac 100644
--- a/mojo/python/tests/mojo_unittest.py
+++ b/mojo/python/tests/mojo_unittest.py
@@ -10,11 +10,15 @@
 
 
 class MojoTestCase(unittest.TestCase):
-
-  def setUp(self):
-    mojo.embedder.Init()
-    self.loop = mojo.system.RunLoop()
-
-  def tearDown(self):
+  def __init__(self, *args, **kwargs):
+    unittest.TestCase.__init__(self, *args, **kwargs)
     self.loop = None
-    assert mojo.embedder.ShutdownForTest()
+
+  def run(self, *args, **kwargs):
+    try:
+      mojo.embedder.Init()
+      self.loop = mojo.system.RunLoop()
+      unittest.TestCase.run(self, *args, **kwargs)
+    finally:
+      self.loop = None
+      assert mojo.embedder.ShutdownForTest()
diff --git a/mojo/python/tests/validation_unittest.py b/mojo/python/tests/validation_unittest.py
new file mode 100644
index 0000000..753ff52
--- /dev/null
+++ b/mojo/python/tests/validation_unittest.py
@@ -0,0 +1,85 @@
+# Copyright 2014 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import logging
+import os
+import os.path
+
+import mojo_unittest
+import validation_test_interfaces_mojom
+
+# pylint: disable=E0611
+from mojo import system
+from mojo.bindings import messaging
+from mojo.tests import validation_util
+from mopy.paths import Paths
+
+logging.basicConfig(level=logging.ERROR)
+paths = Paths()
+
+
+class RoutingMessageReceiver(messaging.MessageReceiver):
+  def __init__(self, request, response):
+    self.request = request
+    self.response = response
+
+  def Accept(self, message):
+    if message.header.is_response:
+      return self.response.Accept(message)
+    else:
+      return self.request.Accept(message)
+
+
+class SinkMessageReceiver(messaging.MessageReceiverWithResponder):
+
+  def Accept(self, message):
+    return False
+
+  def AcceptWithResponder(self, message, responder):
+    return False
+
+  def Close(self):
+    pass
+
+
+class HandleMock(object):
+  def IsValid(self):
+    return True
+
+  def Close(self):
+    pass
+
+
+class ValidationTest(mojo_unittest.MojoTestCase):
+
+  @staticmethod
+  def ParseData(data_dir, filename):
+    data = validation_util.ParseData(
+        open(os.path.join(data_dir, filename), 'r').read())
+    expect_file = filename[:-4] + 'expected'
+    expected_error = open(
+        os.path.join(data_dir, expect_file), 'r').read().strip();
+    success = expected_error == 'PASS'
+    return (filename, data, success)
+
+  @staticmethod
+  def GetData(prefix):
+    data_dir = os.path.join(paths.src_root, 'mojo', 'public', 'interfaces',
+                            'bindings', 'tests', 'data', 'validation')
+    return [ValidationTest.ParseData(data_dir, x) for x in os.listdir(data_dir)
+            if x.startswith(prefix) and x.endswith('.data')]
+
+  def runTest(self, prefix, message_receiver):
+    for (filename, data, expected) in ValidationTest.GetData(prefix):
+      self.assertEquals(len(data.error_message), 0)
+      handles = [HandleMock() for _ in xrange(data.num_handles)]
+      message = messaging.Message(data.data, handles)
+      self.assertEquals(message_receiver.Accept(message), expected,
+                        'Unexpected result for test: %s' % filename)
+
+  def testConformance(self):
+    manager = validation_test_interfaces_mojom.ConformanceTestInterface.manager
+    proxy = manager._InternalProxy(SinkMessageReceiver(), None)
+    stub = manager._Stub(proxy)
+    self.runTest('conformance_', stub)