Add validation tests to python bindings.
R=sdefresne@chromium.org
Review URL: https://codereview.chromium.org/761553003
diff --git a/mojo/public/python/mojo/bindings/descriptor.py b/mojo/public/python/mojo/bindings/descriptor.py
index f190d2b..0df0bd6 100644
--- a/mojo/public/python/mojo/bindings/descriptor.py
+++ b/mojo/public/python/mojo/bindings/descriptor.py
@@ -72,7 +72,7 @@
"""
raise NotImplementedError()
- def Deserialize(self, value, data, handles):
+ def Deserialize(self, value, context):
"""
Deserialize a value of this type.
@@ -106,7 +106,7 @@
def Serialize(self, value, data_offset, data, handle_offset):
return (value, [])
- def Deserialize(self, value, data, handles):
+ def Deserialize(self, value, context):
return value
@@ -161,21 +161,31 @@
return (0, [])
return self.SerializePointer(value, data_offset, data, handle_offset)
- def Deserialize(self, value, data, handles):
+ def Deserialize(self, value, context):
if value == 0:
if not self.nullable:
raise serialization.DeserializationException(
'Trying to deserialize null for non nullable type.')
return None
- pointed_data = buffer(data, value)
- (size, nb_elements) = serialization.HEADER_STRUCT.unpack_from(pointed_data)
- return self.DeserializePointer(size, nb_elements, pointed_data, handles)
+ if value % 8 != 0:
+ raise serialization.DeserializationException(
+ 'Pointer alignment is incorrect.')
+ sub_context = context.GetSubContext(value)
+ if len(sub_context.data) < serialization.HEADER_STRUCT.size:
+ raise serialization.DeserializationException(
+ 'Available data too short to contain header.')
+ (size, nb_elements) = serialization.HEADER_STRUCT.unpack_from(
+ sub_context.data)
+ if len(sub_context.data) < size or size < serialization.HEADER_STRUCT.size:
+ raise serialization.DeserializationException('Header size is incorrect.')
+ sub_context.ClaimMemory(0, size)
+ return self.DeserializePointer(size, nb_elements, sub_context)
def SerializePointer(self, value, data_offset, data, handle_offset):
"""Serialize the not null value."""
raise NotImplementedError()
- def DeserializePointer(self, size, nb_elements, data, handles):
+ def DeserializePointer(self, size, nb_elements, context):
raise NotImplementedError()
@@ -204,9 +214,8 @@
return self._array_type.SerializeArray(
string_array, data_offset, data, handle_offset)
- def DeserializePointer(self, size, nb_elements, data, handles):
- string_array = self._array_type.DeserializeArray(
- size, nb_elements, data, handles)
+ def DeserializePointer(self, size, nb_elements, context):
+ string_array = self._array_type.DeserializeArray(size, nb_elements, context)
return unicode(string_array.tostring(), 'utf8')
@@ -226,14 +235,13 @@
return (-1, [])
return (handle_offset, [handle])
- def Deserialize(self, value, data, handles):
+ def Deserialize(self, value, context):
if value == -1:
if not self.nullable:
raise serialization.DeserializationException(
'Trying to deserialize null for non nullable type.')
return self.FromHandle(mojo.system.Handle())
- # TODO(qsr) validate handle order
- return self.FromHandle(handles[value])
+ return self.FromHandle(context.ClaimHandle(value))
def FromHandle(self, handle):
raise NotImplementedError()
@@ -326,12 +334,18 @@
"""Serialize the not null array."""
raise NotImplementedError()
- def DeserializePointer(self, size, nb_elements, data, handles):
- if self.length != 0 and size != self.length:
+ def DeserializePointer(self, size, nb_elements, context):
+ if self.length != 0 and nb_elements != self.length:
raise serialization.DeserializationException('Incorrect array size')
- return self.DeserializeArray(size, nb_elements, data, handles)
+ if (size <
+ serialization.HEADER_STRUCT.size + self.SizeForLength(nb_elements)):
+ raise serialization.DeserializationException('Incorrect array size')
+ return self.DeserializeArray(size, nb_elements, context)
- def DeserializeArray(self, size, nb_elements, data, handles):
+ def DeserializeArray(self, size, nb_elements, context):
+ raise NotImplementedError()
+
+ def SizeForLength(self, nb_elements):
raise NotImplementedError()
@@ -351,9 +365,8 @@
converted = array.array('B', [_ConvertBooleansToByte(x) for x in groups])
return _SerializeNativeArray(converted, data_offset, data, len(value))
- def DeserializeArray(self, size, nb_elements, data, handles):
- converted = self._array_type.DeserializeArray(
- size, nb_elements, data, handles)
+ def DeserializeArray(self, size, nb_elements, context):
+ converted = self._array_type.DeserializeArray(size, nb_elements, context)
elements = list(itertools.islice(
itertools.chain.from_iterable(
[_ConvertByteToBooleans(x, 8) for x in converted]),
@@ -361,6 +374,9 @@
nb_elements))
return elements
+ def SizeForLength(self, nb_elements):
+ return (nb_elements + 7) // 8
+
class GenericArrayType(BaseArrayType):
"""Type object for arrays of pointers."""
@@ -400,18 +416,22 @@
*to_pack)
return (data_offset, returned_handles)
- def DeserializeArray(self, size, nb_elements, data, handles):
+ def DeserializeArray(self, size, nb_elements, context):
values = struct.unpack_from(
'%d%s' % (nb_elements, self.sub_type.GetTypeCode()),
- buffer(data, serialization.HEADER_STRUCT.size))
+ buffer(context.data, serialization.HEADER_STRUCT.size))
result = []
- position = serialization.HEADER_STRUCT.size
+ sub_context = context.GetSubContext(serialization.HEADER_STRUCT.size)
for value in values:
- result.append(
- self.sub_type.Deserialize(value, buffer(data, position), handles))
- position += self.sub_type.GetByteSize()
+ result.append(self.sub_type.Deserialize(
+ value,
+ sub_context))
+ sub_context = sub_context.GetSubContext(self.sub_type.GetByteSize())
return result
+ def SizeForLength(self, nb_elements):
+ return nb_elements * self.sub_type.GetByteSize();
+
class NativeArrayType(BaseArrayType):
"""Type object for arrays of native types."""
@@ -419,6 +439,7 @@
def __init__(self, typecode, nullable=False, length=0):
BaseArrayType.__init__(self, nullable, length)
self.array_typecode = typecode
+ self.element_size = struct.calcsize('<%s' % self.array_typecode)
def Convert(self, value):
if value is None:
@@ -431,13 +452,16 @@
def SerializeArray(self, value, data_offset, data, handle_offset):
return _SerializeNativeArray(value, data_offset, data, len(value))
- def DeserializeArray(self, size, nb_elements, data, handles):
+ def DeserializeArray(self, size, nb_elements, context):
result = array.array(self.array_typecode)
- result.fromstring(buffer(data,
+ result.fromstring(buffer(context.data,
serialization.HEADER_STRUCT.size,
size - serialization.HEADER_STRUCT.size))
return result
+ def SizeForLength(self, nb_elements):
+ return nb_elements * self.element_size
+
class StructType(PointerType):
"""Type object for structs."""
@@ -469,8 +493,8 @@
data.extend(new_data)
return (data_offset, new_handles)
- def DeserializePointer(self, size, nb_elements, data, handles):
- return self.struct_type.Deserialize(data, handles)
+ def DeserializePointer(self, size, nb_elements, context):
+ return self.struct_type.Deserialize(context)
class MapType(SerializableType):
@@ -511,8 +535,8 @@
s = self.struct(keys=keys, values=values)
return self.struct_type.Serialize(s, data_offset, data, handle_offset)
- def Deserialize(self, value, data, handles):
- s = self.struct_type.Deserialize(value, data, handles)
+ def Deserialize(self, value, context):
+ s = self.struct_type.Deserialize(value, context)
if s:
if len(s.keys) != len(s.values):
raise serialization.DeserializationException(
@@ -590,7 +614,7 @@
def Serialize(self, obj, data_offset, data, handle_offset):
raise NotImplementedError()
- def Deserialize(self, value, data, handles):
+ def Deserialize(self, value, context):
raise NotImplementedError()
@@ -615,8 +639,8 @@
value = getattr(obj, self.name)
return self.field_type.Serialize(value, data_offset, data, handle_offset)
- def Deserialize(self, value, data, handles):
- entity = self.field_type.Deserialize(value, data, handles)
+ def Deserialize(self, value, context):
+ entity = self.field_type.Deserialize(value, context)
return { self.name: entity }
@@ -640,7 +664,7 @@
[getattr(obj, field.name) for field in self.GetDescriptors()])
return (value, [])
- def Deserialize(self, value, data, handles):
+ def Deserialize(self, value, context):
values = itertools.izip_longest([x.name for x in self.descriptors],
_ConvertByteToBooleans(value),
fillvalue=False)
@@ -663,7 +687,7 @@
def _ConvertByteToBooleans(value, min_size=0):
- "Unpack an integer into a list of booleans."""
+ """Unpack an integer into a list of booleans."""
res = []
while value:
res.append(bool(value&1))
diff --git a/mojo/public/python/mojo/bindings/reflection.py b/mojo/public/python/mojo/bindings/reflection.py
index 5ca38bf..9668c2a 100644
--- a/mojo/public/python/mojo/bindings/reflection.py
+++ b/mojo/public/python/mojo/bindings/reflection.py
@@ -117,10 +117,10 @@
return self._fields
dictionary['AsDict'] = AsDict
- def Deserialize(cls, data, handles):
+ def Deserialize(cls, context):
result = cls.__new__(cls)
fields = {}
- serialization_object.Deserialize(fields, data, handles)
+ serialization_object.Deserialize(fields, context)
result._fields = fields
return result
dictionary['Deserialize'] = classmethod(Deserialize)
@@ -476,8 +476,9 @@
try:
assert message.header.message_type == method.ordinal
payload = message.payload
- response = method.response_struct.Deserialize(payload.data,
- payload.handles)
+ response = method.response_struct.Deserialize(
+ serialization.RootDeserializationContext(payload.data,
+ payload.handles))
as_dict = response.AsDict()
if len(as_dict) == 1:
value = as_dict.values()[0]
@@ -533,7 +534,8 @@
method = methods_by_ordinal[header.message_type]
payload = message.payload
parameters = method.parameters_struct.Deserialize(
- payload.data, payload.handles).AsDict()
+ serialization.RootDeserializationContext(
+ payload.data, payload.handles)).AsDict()
response = getattr(self.impl, method.name)(**parameters)
if header.expects_response:
def SendResponse(response):
diff --git a/mojo/public/python/mojo/bindings/serialization.py b/mojo/public/python/mojo/bindings/serialization.py
index 2c0478f..b5ea1bd 100644
--- a/mojo/public/python/mojo/bindings/serialization.py
+++ b/mojo/public/python/mojo/bindings/serialization.py
@@ -21,6 +21,68 @@
pass
+class DeserializationContext(object):
+
+ def ClaimHandle(self, handle):
+ raise NotImplementedError()
+
+ def ClaimMemory(self, start, size):
+ raise NotImplementedError()
+
+ def GetSubContext(self, offset):
+ raise NotImplementedError()
+
+ def IsInitialContext(self):
+ raise NotImplementedError()
+
+
+class RootDeserializationContext(DeserializationContext):
+ def __init__(self, data, handles):
+ if isinstance(data, buffer):
+ self.data = data
+ else:
+ self.data = buffer(data)
+ self._handles = handles
+ self._next_handle = 0;
+ self._next_memory = 0;
+
+ def ClaimHandle(self, handle):
+ if handle < self._next_handle:
+ raise DeserializationException('Accessing handles out of order.')
+ self._next_handle = handle + 1
+ return self._handles[handle]
+
+ def ClaimMemory(self, start, size):
+ if start < self._next_memory:
+ raise DeserializationException('Accessing buffer out of order.')
+ self._next_memory = start + size
+
+ def GetSubContext(self, offset):
+ return _ChildDeserializationContext(self, offset)
+
+ def IsInitialContext(self):
+ return True
+
+
+class _ChildDeserializationContext(DeserializationContext):
+ def __init__(self, parent, offset):
+ self._parent = parent
+ self._offset = offset
+ self.data = buffer(parent.data, offset)
+
+ def ClaimHandle(self, handle):
+ return self._parent.ClaimHandle(handle)
+
+ def ClaimMemory(self, start, size):
+ return self._parent.ClaimMemory(self._offset + start, size)
+
+ def GetSubContext(self, offset):
+ return self._parent.GetSubContext(self._offset + offset)
+
+ def IsInitialContext(self):
+ return False
+
+
class Serialization(object):
"""
Helper class to serialize/deserialize a struct.
@@ -78,18 +140,23 @@
self._GetMainStruct().pack_into(data, HEADER_STRUCT.size, *to_pack)
return (data, handles)
- def Deserialize(self, fields, data, handles):
- if not isinstance(data, buffer):
- data = buffer(data)
- (_, version) = HEADER_STRUCT.unpack_from(data)
+ def Deserialize(self, fields, context):
+ if len(context.data) < HEADER_STRUCT.size:
+ raise DeserializationException(
+ 'Available data too short to contain header.')
+ (size, version) = HEADER_STRUCT.unpack_from(context.data)
+ if len(context.data) < size or size < HEADER_STRUCT.size:
+ raise DeserializationException('Header size is incorrect.')
+ if context.IsInitialContext():
+ context.ClaimMemory(0, size)
version_struct = self._GetStruct(version)
- entitities = version_struct.unpack_from(data, HEADER_STRUCT.size)
+ entitities = version_struct.unpack_from(context.data, HEADER_STRUCT.size)
filtered_groups = self._GetGroups(version)
position = HEADER_STRUCT.size
for (group, value) in zip(filtered_groups, entitities):
position = position + NeededPaddingForAlignment(position,
group.GetByteSize())
- fields.update(group.Deserialize(value, buffer(data, position), handles))
+ fields.update(group.Deserialize(value, context.GetSubContext(position)))
position += group.GetByteSize()
diff --git a/mojo/python/BUILD.gn b/mojo/python/BUILD.gn
index 7b07f4f..f96c2cc 100644
--- a/mojo/python/BUILD.gn
+++ b/mojo/python/BUILD.gn
@@ -8,6 +8,7 @@
group("python") {
deps = [
":embedder",
+ ":validation_util",
"//mojo/public/python",
]
}
@@ -25,3 +26,25 @@
"//mojo/public/python:system",
]
}
+
+copy("tests_module") {
+ sources = [
+ "system/mojo/tests/__init__.py",
+ ]
+ outputs = [
+ "$root_out_dir/python/mojo/tests/{{source_file_part}}",
+ ]
+}
+
+python_binary_module("validation_util") {
+ python_base_module = "mojo/tests"
+ sources = [
+ "system/mojo/tests/validation_util.pyx",
+ ]
+ deps = [
+ "//mojo/public/cpp/bindings/tests:mojo_public_bindings_test_utils",
+ ]
+ datadeps = [
+ ":tests_module",
+ ]
+}
diff --git a/mojo/python/system/mojo/tests/__init__.py b/mojo/python/system/mojo/tests/__init__.py
new file mode 100644
index 0000000..4d6aabb
--- /dev/null
+++ b/mojo/python/system/mojo/tests/__init__.py
@@ -0,0 +1,3 @@
+# Copyright 2014 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
diff --git a/mojo/python/system/mojo/tests/validation_util.pyx b/mojo/python/system/mojo/tests/validation_util.pyx
new file mode 100644
index 0000000..e7bdbcd
--- /dev/null
+++ b/mojo/python/system/mojo/tests/validation_util.pyx
@@ -0,0 +1,35 @@
+# Copyright 2014 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+# distutils: language = c++
+
+from libc.stdint cimport uint8_t
+from libcpp cimport bool
+from libcpp.string cimport string
+from libcpp.vector cimport vector
+
+cdef extern from "third_party/cython/python_export.h":
+ pass
+
+cdef extern from "mojo/public/cpp/bindings/tests/validation_test_input_parser.h":
+ cdef bool ParseValidationTestInput "mojo::test::ParseValidationTestInput"(
+ string input,
+ vector[uint8_t]* data,
+ size_t* num_handles,
+ string* error_message)
+
+class Data(object):
+ def __init__(self, data, num_handles, error_message):
+ self.data = data
+ self.num_handles = num_handles
+ self.error_message = error_message
+
+def ParseData(value):
+ cdef string value_as_string = value
+ cdef vector[uint8_t] data_as_vector
+ cdef size_t num_handles
+ cdef string error_message
+ ParseValidationTestInput(
+ value, &data_as_vector, &num_handles, &error_message)
+ return Data(bytearray(data_as_vector), num_handles, error_message)
diff --git a/mojo/python/tests/bindings_serialization_deserialization_unittest.py b/mojo/python/tests/bindings_serialization_deserialization_unittest.py
index 39685ea..da725c0 100644
--- a/mojo/python/tests/bindings_serialization_deserialization_unittest.py
+++ b/mojo/python/tests/bindings_serialization_deserialization_unittest.py
@@ -7,6 +7,7 @@
import mojo_unittest
# pylint: disable=E0611,F0401
+import mojo.bindings.serialization as serialization
import mojo.system
# Generated files
@@ -68,13 +69,15 @@
def testFooDeserialization(self):
(data, handles) = _NewFoo().Serialize()
+ context = serialization.RootDeserializationContext(data, handles)
self.assertTrue(
- sample_service_mojom.Foo.Deserialize(data, handles))
+ sample_service_mojom.Foo.Deserialize(context))
def testFooSerializationDeserialization(self):
foo1 = _NewFoo()
(data, handles) = foo1.Serialize()
- foo2 = sample_service_mojom.Foo.Deserialize(data, handles)
+ context = serialization.RootDeserializationContext(data, handles)
+ foo2 = sample_service_mojom.Foo.Deserialize(context)
self.assertEquals(foo1, foo2)
def testDefaultsTestSerializationDeserialization(self):
@@ -85,7 +88,8 @@
v1.a22.location = sample_import_mojom.Point()
v1.a22.size = sample_import2_mojom.Size()
(data, handles) = v1.Serialize()
- v2 = sample_service_mojom.DefaultsTest.Deserialize(data, handles)
+ context = serialization.RootDeserializationContext(data, handles)
+ v2 = sample_service_mojom.DefaultsTest.Deserialize(context)
# NaN needs to be a special case.
self.assertNotEquals(v1, v2)
self.assertTrue(math.isnan(v2.a28))
diff --git a/mojo/python/tests/mojo_unittest.py b/mojo/python/tests/mojo_unittest.py
index 101f19c..f80ffac 100644
--- a/mojo/python/tests/mojo_unittest.py
+++ b/mojo/python/tests/mojo_unittest.py
@@ -10,11 +10,15 @@
class MojoTestCase(unittest.TestCase):
-
- def setUp(self):
- mojo.embedder.Init()
- self.loop = mojo.system.RunLoop()
-
- def tearDown(self):
+ def __init__(self, *args, **kwargs):
+ unittest.TestCase.__init__(self, *args, **kwargs)
self.loop = None
- assert mojo.embedder.ShutdownForTest()
+
+ def run(self, *args, **kwargs):
+ try:
+ mojo.embedder.Init()
+ self.loop = mojo.system.RunLoop()
+ unittest.TestCase.run(self, *args, **kwargs)
+ finally:
+ self.loop = None
+ assert mojo.embedder.ShutdownForTest()
diff --git a/mojo/python/tests/validation_unittest.py b/mojo/python/tests/validation_unittest.py
new file mode 100644
index 0000000..753ff52
--- /dev/null
+++ b/mojo/python/tests/validation_unittest.py
@@ -0,0 +1,85 @@
+# Copyright 2014 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import logging
+import os
+import os.path
+
+import mojo_unittest
+import validation_test_interfaces_mojom
+
+# pylint: disable=E0611
+from mojo import system
+from mojo.bindings import messaging
+from mojo.tests import validation_util
+from mopy.paths import Paths
+
+logging.basicConfig(level=logging.ERROR)
+paths = Paths()
+
+
+class RoutingMessageReceiver(messaging.MessageReceiver):
+ def __init__(self, request, response):
+ self.request = request
+ self.response = response
+
+ def Accept(self, message):
+ if message.header.is_response:
+ return self.response.Accept(message)
+ else:
+ return self.request.Accept(message)
+
+
+class SinkMessageReceiver(messaging.MessageReceiverWithResponder):
+
+ def Accept(self, message):
+ return False
+
+ def AcceptWithResponder(self, message, responder):
+ return False
+
+ def Close(self):
+ pass
+
+
+class HandleMock(object):
+ def IsValid(self):
+ return True
+
+ def Close(self):
+ pass
+
+
+class ValidationTest(mojo_unittest.MojoTestCase):
+
+ @staticmethod
+ def ParseData(data_dir, filename):
+ data = validation_util.ParseData(
+ open(os.path.join(data_dir, filename), 'r').read())
+ expect_file = filename[:-4] + 'expected'
+ expected_error = open(
+ os.path.join(data_dir, expect_file), 'r').read().strip();
+ success = expected_error == 'PASS'
+ return (filename, data, success)
+
+ @staticmethod
+ def GetData(prefix):
+ data_dir = os.path.join(paths.src_root, 'mojo', 'public', 'interfaces',
+ 'bindings', 'tests', 'data', 'validation')
+ return [ValidationTest.ParseData(data_dir, x) for x in os.listdir(data_dir)
+ if x.startswith(prefix) and x.endswith('.data')]
+
+ def runTest(self, prefix, message_receiver):
+ for (filename, data, expected) in ValidationTest.GetData(prefix):
+ self.assertEquals(len(data.error_message), 0)
+ handles = [HandleMock() for _ in xrange(data.num_handles)]
+ message = messaging.Message(data.data, handles)
+ self.assertEquals(message_receiver.Accept(message), expected,
+ 'Unexpected result for test: %s' % filename)
+
+ def testConformance(self):
+ manager = validation_test_interfaces_mojom.ConformanceTestInterface.manager
+ proxy = manager._InternalProxy(SinkMessageReceiver(), None)
+ stub = manager._Stub(proxy)
+ self.runTest('conformance_', stub)