Add validation tests to python bindings. R=sdefresne@chromium.org Review URL: https://codereview.chromium.org/761553003
diff --git a/mojo/public/python/mojo/bindings/descriptor.py b/mojo/public/python/mojo/bindings/descriptor.py index f190d2b..0df0bd6 100644 --- a/mojo/public/python/mojo/bindings/descriptor.py +++ b/mojo/public/python/mojo/bindings/descriptor.py
@@ -72,7 +72,7 @@ """ raise NotImplementedError() - def Deserialize(self, value, data, handles): + def Deserialize(self, value, context): """ Deserialize a value of this type. @@ -106,7 +106,7 @@ def Serialize(self, value, data_offset, data, handle_offset): return (value, []) - def Deserialize(self, value, data, handles): + def Deserialize(self, value, context): return value @@ -161,21 +161,31 @@ return (0, []) return self.SerializePointer(value, data_offset, data, handle_offset) - def Deserialize(self, value, data, handles): + def Deserialize(self, value, context): if value == 0: if not self.nullable: raise serialization.DeserializationException( 'Trying to deserialize null for non nullable type.') return None - pointed_data = buffer(data, value) - (size, nb_elements) = serialization.HEADER_STRUCT.unpack_from(pointed_data) - return self.DeserializePointer(size, nb_elements, pointed_data, handles) + if value % 8 != 0: + raise serialization.DeserializationException( + 'Pointer alignment is incorrect.') + sub_context = context.GetSubContext(value) + if len(sub_context.data) < serialization.HEADER_STRUCT.size: + raise serialization.DeserializationException( + 'Available data too short to contain header.') + (size, nb_elements) = serialization.HEADER_STRUCT.unpack_from( + sub_context.data) + if len(sub_context.data) < size or size < serialization.HEADER_STRUCT.size: + raise serialization.DeserializationException('Header size is incorrect.') + sub_context.ClaimMemory(0, size) + return self.DeserializePointer(size, nb_elements, sub_context) def SerializePointer(self, value, data_offset, data, handle_offset): """Serialize the not null value.""" raise NotImplementedError() - def DeserializePointer(self, size, nb_elements, data, handles): + def DeserializePointer(self, size, nb_elements, context): raise NotImplementedError() @@ -204,9 +214,8 @@ return self._array_type.SerializeArray( string_array, data_offset, data, handle_offset) - def DeserializePointer(self, size, nb_elements, data, handles): - string_array = self._array_type.DeserializeArray( - size, nb_elements, data, handles) + def DeserializePointer(self, size, nb_elements, context): + string_array = self._array_type.DeserializeArray(size, nb_elements, context) return unicode(string_array.tostring(), 'utf8') @@ -226,14 +235,13 @@ return (-1, []) return (handle_offset, [handle]) - def Deserialize(self, value, data, handles): + def Deserialize(self, value, context): if value == -1: if not self.nullable: raise serialization.DeserializationException( 'Trying to deserialize null for non nullable type.') return self.FromHandle(mojo.system.Handle()) - # TODO(qsr) validate handle order - return self.FromHandle(handles[value]) + return self.FromHandle(context.ClaimHandle(value)) def FromHandle(self, handle): raise NotImplementedError() @@ -326,12 +334,18 @@ """Serialize the not null array.""" raise NotImplementedError() - def DeserializePointer(self, size, nb_elements, data, handles): - if self.length != 0 and size != self.length: + def DeserializePointer(self, size, nb_elements, context): + if self.length != 0 and nb_elements != self.length: raise serialization.DeserializationException('Incorrect array size') - return self.DeserializeArray(size, nb_elements, data, handles) + if (size < + serialization.HEADER_STRUCT.size + self.SizeForLength(nb_elements)): + raise serialization.DeserializationException('Incorrect array size') + return self.DeserializeArray(size, nb_elements, context) - def DeserializeArray(self, size, nb_elements, data, handles): + def DeserializeArray(self, size, nb_elements, context): + raise NotImplementedError() + + def SizeForLength(self, nb_elements): raise NotImplementedError() @@ -351,9 +365,8 @@ converted = array.array('B', [_ConvertBooleansToByte(x) for x in groups]) return _SerializeNativeArray(converted, data_offset, data, len(value)) - def DeserializeArray(self, size, nb_elements, data, handles): - converted = self._array_type.DeserializeArray( - size, nb_elements, data, handles) + def DeserializeArray(self, size, nb_elements, context): + converted = self._array_type.DeserializeArray(size, nb_elements, context) elements = list(itertools.islice( itertools.chain.from_iterable( [_ConvertByteToBooleans(x, 8) for x in converted]), @@ -361,6 +374,9 @@ nb_elements)) return elements + def SizeForLength(self, nb_elements): + return (nb_elements + 7) // 8 + class GenericArrayType(BaseArrayType): """Type object for arrays of pointers.""" @@ -400,18 +416,22 @@ *to_pack) return (data_offset, returned_handles) - def DeserializeArray(self, size, nb_elements, data, handles): + def DeserializeArray(self, size, nb_elements, context): values = struct.unpack_from( '%d%s' % (nb_elements, self.sub_type.GetTypeCode()), - buffer(data, serialization.HEADER_STRUCT.size)) + buffer(context.data, serialization.HEADER_STRUCT.size)) result = [] - position = serialization.HEADER_STRUCT.size + sub_context = context.GetSubContext(serialization.HEADER_STRUCT.size) for value in values: - result.append( - self.sub_type.Deserialize(value, buffer(data, position), handles)) - position += self.sub_type.GetByteSize() + result.append(self.sub_type.Deserialize( + value, + sub_context)) + sub_context = sub_context.GetSubContext(self.sub_type.GetByteSize()) return result + def SizeForLength(self, nb_elements): + return nb_elements * self.sub_type.GetByteSize(); + class NativeArrayType(BaseArrayType): """Type object for arrays of native types.""" @@ -419,6 +439,7 @@ def __init__(self, typecode, nullable=False, length=0): BaseArrayType.__init__(self, nullable, length) self.array_typecode = typecode + self.element_size = struct.calcsize('<%s' % self.array_typecode) def Convert(self, value): if value is None: @@ -431,13 +452,16 @@ def SerializeArray(self, value, data_offset, data, handle_offset): return _SerializeNativeArray(value, data_offset, data, len(value)) - def DeserializeArray(self, size, nb_elements, data, handles): + def DeserializeArray(self, size, nb_elements, context): result = array.array(self.array_typecode) - result.fromstring(buffer(data, + result.fromstring(buffer(context.data, serialization.HEADER_STRUCT.size, size - serialization.HEADER_STRUCT.size)) return result + def SizeForLength(self, nb_elements): + return nb_elements * self.element_size + class StructType(PointerType): """Type object for structs.""" @@ -469,8 +493,8 @@ data.extend(new_data) return (data_offset, new_handles) - def DeserializePointer(self, size, nb_elements, data, handles): - return self.struct_type.Deserialize(data, handles) + def DeserializePointer(self, size, nb_elements, context): + return self.struct_type.Deserialize(context) class MapType(SerializableType): @@ -511,8 +535,8 @@ s = self.struct(keys=keys, values=values) return self.struct_type.Serialize(s, data_offset, data, handle_offset) - def Deserialize(self, value, data, handles): - s = self.struct_type.Deserialize(value, data, handles) + def Deserialize(self, value, context): + s = self.struct_type.Deserialize(value, context) if s: if len(s.keys) != len(s.values): raise serialization.DeserializationException( @@ -590,7 +614,7 @@ def Serialize(self, obj, data_offset, data, handle_offset): raise NotImplementedError() - def Deserialize(self, value, data, handles): + def Deserialize(self, value, context): raise NotImplementedError() @@ -615,8 +639,8 @@ value = getattr(obj, self.name) return self.field_type.Serialize(value, data_offset, data, handle_offset) - def Deserialize(self, value, data, handles): - entity = self.field_type.Deserialize(value, data, handles) + def Deserialize(self, value, context): + entity = self.field_type.Deserialize(value, context) return { self.name: entity } @@ -640,7 +664,7 @@ [getattr(obj, field.name) for field in self.GetDescriptors()]) return (value, []) - def Deserialize(self, value, data, handles): + def Deserialize(self, value, context): values = itertools.izip_longest([x.name for x in self.descriptors], _ConvertByteToBooleans(value), fillvalue=False) @@ -663,7 +687,7 @@ def _ConvertByteToBooleans(value, min_size=0): - "Unpack an integer into a list of booleans.""" + """Unpack an integer into a list of booleans.""" res = [] while value: res.append(bool(value&1))
diff --git a/mojo/public/python/mojo/bindings/reflection.py b/mojo/public/python/mojo/bindings/reflection.py index 5ca38bf..9668c2a 100644 --- a/mojo/public/python/mojo/bindings/reflection.py +++ b/mojo/public/python/mojo/bindings/reflection.py
@@ -117,10 +117,10 @@ return self._fields dictionary['AsDict'] = AsDict - def Deserialize(cls, data, handles): + def Deserialize(cls, context): result = cls.__new__(cls) fields = {} - serialization_object.Deserialize(fields, data, handles) + serialization_object.Deserialize(fields, context) result._fields = fields return result dictionary['Deserialize'] = classmethod(Deserialize) @@ -476,8 +476,9 @@ try: assert message.header.message_type == method.ordinal payload = message.payload - response = method.response_struct.Deserialize(payload.data, - payload.handles) + response = method.response_struct.Deserialize( + serialization.RootDeserializationContext(payload.data, + payload.handles)) as_dict = response.AsDict() if len(as_dict) == 1: value = as_dict.values()[0] @@ -533,7 +534,8 @@ method = methods_by_ordinal[header.message_type] payload = message.payload parameters = method.parameters_struct.Deserialize( - payload.data, payload.handles).AsDict() + serialization.RootDeserializationContext( + payload.data, payload.handles)).AsDict() response = getattr(self.impl, method.name)(**parameters) if header.expects_response: def SendResponse(response):
diff --git a/mojo/public/python/mojo/bindings/serialization.py b/mojo/public/python/mojo/bindings/serialization.py index 2c0478f..b5ea1bd 100644 --- a/mojo/public/python/mojo/bindings/serialization.py +++ b/mojo/public/python/mojo/bindings/serialization.py
@@ -21,6 +21,68 @@ pass +class DeserializationContext(object): + + def ClaimHandle(self, handle): + raise NotImplementedError() + + def ClaimMemory(self, start, size): + raise NotImplementedError() + + def GetSubContext(self, offset): + raise NotImplementedError() + + def IsInitialContext(self): + raise NotImplementedError() + + +class RootDeserializationContext(DeserializationContext): + def __init__(self, data, handles): + if isinstance(data, buffer): + self.data = data + else: + self.data = buffer(data) + self._handles = handles + self._next_handle = 0; + self._next_memory = 0; + + def ClaimHandle(self, handle): + if handle < self._next_handle: + raise DeserializationException('Accessing handles out of order.') + self._next_handle = handle + 1 + return self._handles[handle] + + def ClaimMemory(self, start, size): + if start < self._next_memory: + raise DeserializationException('Accessing buffer out of order.') + self._next_memory = start + size + + def GetSubContext(self, offset): + return _ChildDeserializationContext(self, offset) + + def IsInitialContext(self): + return True + + +class _ChildDeserializationContext(DeserializationContext): + def __init__(self, parent, offset): + self._parent = parent + self._offset = offset + self.data = buffer(parent.data, offset) + + def ClaimHandle(self, handle): + return self._parent.ClaimHandle(handle) + + def ClaimMemory(self, start, size): + return self._parent.ClaimMemory(self._offset + start, size) + + def GetSubContext(self, offset): + return self._parent.GetSubContext(self._offset + offset) + + def IsInitialContext(self): + return False + + class Serialization(object): """ Helper class to serialize/deserialize a struct. @@ -78,18 +140,23 @@ self._GetMainStruct().pack_into(data, HEADER_STRUCT.size, *to_pack) return (data, handles) - def Deserialize(self, fields, data, handles): - if not isinstance(data, buffer): - data = buffer(data) - (_, version) = HEADER_STRUCT.unpack_from(data) + def Deserialize(self, fields, context): + if len(context.data) < HEADER_STRUCT.size: + raise DeserializationException( + 'Available data too short to contain header.') + (size, version) = HEADER_STRUCT.unpack_from(context.data) + if len(context.data) < size or size < HEADER_STRUCT.size: + raise DeserializationException('Header size is incorrect.') + if context.IsInitialContext(): + context.ClaimMemory(0, size) version_struct = self._GetStruct(version) - entitities = version_struct.unpack_from(data, HEADER_STRUCT.size) + entitities = version_struct.unpack_from(context.data, HEADER_STRUCT.size) filtered_groups = self._GetGroups(version) position = HEADER_STRUCT.size for (group, value) in zip(filtered_groups, entitities): position = position + NeededPaddingForAlignment(position, group.GetByteSize()) - fields.update(group.Deserialize(value, buffer(data, position), handles)) + fields.update(group.Deserialize(value, context.GetSubContext(position))) position += group.GetByteSize()
diff --git a/mojo/python/BUILD.gn b/mojo/python/BUILD.gn index 7b07f4f..f96c2cc 100644 --- a/mojo/python/BUILD.gn +++ b/mojo/python/BUILD.gn
@@ -8,6 +8,7 @@ group("python") { deps = [ ":embedder", + ":validation_util", "//mojo/public/python", ] } @@ -25,3 +26,25 @@ "//mojo/public/python:system", ] } + +copy("tests_module") { + sources = [ + "system/mojo/tests/__init__.py", + ] + outputs = [ + "$root_out_dir/python/mojo/tests/{{source_file_part}}", + ] +} + +python_binary_module("validation_util") { + python_base_module = "mojo/tests" + sources = [ + "system/mojo/tests/validation_util.pyx", + ] + deps = [ + "//mojo/public/cpp/bindings/tests:mojo_public_bindings_test_utils", + ] + datadeps = [ + ":tests_module", + ] +}
diff --git a/mojo/python/system/mojo/tests/__init__.py b/mojo/python/system/mojo/tests/__init__.py new file mode 100644 index 0000000..4d6aabb --- /dev/null +++ b/mojo/python/system/mojo/tests/__init__.py
@@ -0,0 +1,3 @@ +# Copyright 2014 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file.
diff --git a/mojo/python/system/mojo/tests/validation_util.pyx b/mojo/python/system/mojo/tests/validation_util.pyx new file mode 100644 index 0000000..e7bdbcd --- /dev/null +++ b/mojo/python/system/mojo/tests/validation_util.pyx
@@ -0,0 +1,35 @@ +# Copyright 2014 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +# distutils: language = c++ + +from libc.stdint cimport uint8_t +from libcpp cimport bool +from libcpp.string cimport string +from libcpp.vector cimport vector + +cdef extern from "third_party/cython/python_export.h": + pass + +cdef extern from "mojo/public/cpp/bindings/tests/validation_test_input_parser.h": + cdef bool ParseValidationTestInput "mojo::test::ParseValidationTestInput"( + string input, + vector[uint8_t]* data, + size_t* num_handles, + string* error_message) + +class Data(object): + def __init__(self, data, num_handles, error_message): + self.data = data + self.num_handles = num_handles + self.error_message = error_message + +def ParseData(value): + cdef string value_as_string = value + cdef vector[uint8_t] data_as_vector + cdef size_t num_handles + cdef string error_message + ParseValidationTestInput( + value, &data_as_vector, &num_handles, &error_message) + return Data(bytearray(data_as_vector), num_handles, error_message)
diff --git a/mojo/python/tests/bindings_serialization_deserialization_unittest.py b/mojo/python/tests/bindings_serialization_deserialization_unittest.py index 39685ea..da725c0 100644 --- a/mojo/python/tests/bindings_serialization_deserialization_unittest.py +++ b/mojo/python/tests/bindings_serialization_deserialization_unittest.py
@@ -7,6 +7,7 @@ import mojo_unittest # pylint: disable=E0611,F0401 +import mojo.bindings.serialization as serialization import mojo.system # Generated files @@ -68,13 +69,15 @@ def testFooDeserialization(self): (data, handles) = _NewFoo().Serialize() + context = serialization.RootDeserializationContext(data, handles) self.assertTrue( - sample_service_mojom.Foo.Deserialize(data, handles)) + sample_service_mojom.Foo.Deserialize(context)) def testFooSerializationDeserialization(self): foo1 = _NewFoo() (data, handles) = foo1.Serialize() - foo2 = sample_service_mojom.Foo.Deserialize(data, handles) + context = serialization.RootDeserializationContext(data, handles) + foo2 = sample_service_mojom.Foo.Deserialize(context) self.assertEquals(foo1, foo2) def testDefaultsTestSerializationDeserialization(self): @@ -85,7 +88,8 @@ v1.a22.location = sample_import_mojom.Point() v1.a22.size = sample_import2_mojom.Size() (data, handles) = v1.Serialize() - v2 = sample_service_mojom.DefaultsTest.Deserialize(data, handles) + context = serialization.RootDeserializationContext(data, handles) + v2 = sample_service_mojom.DefaultsTest.Deserialize(context) # NaN needs to be a special case. self.assertNotEquals(v1, v2) self.assertTrue(math.isnan(v2.a28))
diff --git a/mojo/python/tests/mojo_unittest.py b/mojo/python/tests/mojo_unittest.py index 101f19c..f80ffac 100644 --- a/mojo/python/tests/mojo_unittest.py +++ b/mojo/python/tests/mojo_unittest.py
@@ -10,11 +10,15 @@ class MojoTestCase(unittest.TestCase): - - def setUp(self): - mojo.embedder.Init() - self.loop = mojo.system.RunLoop() - - def tearDown(self): + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) self.loop = None - assert mojo.embedder.ShutdownForTest() + + def run(self, *args, **kwargs): + try: + mojo.embedder.Init() + self.loop = mojo.system.RunLoop() + unittest.TestCase.run(self, *args, **kwargs) + finally: + self.loop = None + assert mojo.embedder.ShutdownForTest()
diff --git a/mojo/python/tests/validation_unittest.py b/mojo/python/tests/validation_unittest.py new file mode 100644 index 0000000..753ff52 --- /dev/null +++ b/mojo/python/tests/validation_unittest.py
@@ -0,0 +1,85 @@ +# Copyright 2014 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import logging +import os +import os.path + +import mojo_unittest +import validation_test_interfaces_mojom + +# pylint: disable=E0611 +from mojo import system +from mojo.bindings import messaging +from mojo.tests import validation_util +from mopy.paths import Paths + +logging.basicConfig(level=logging.ERROR) +paths = Paths() + + +class RoutingMessageReceiver(messaging.MessageReceiver): + def __init__(self, request, response): + self.request = request + self.response = response + + def Accept(self, message): + if message.header.is_response: + return self.response.Accept(message) + else: + return self.request.Accept(message) + + +class SinkMessageReceiver(messaging.MessageReceiverWithResponder): + + def Accept(self, message): + return False + + def AcceptWithResponder(self, message, responder): + return False + + def Close(self): + pass + + +class HandleMock(object): + def IsValid(self): + return True + + def Close(self): + pass + + +class ValidationTest(mojo_unittest.MojoTestCase): + + @staticmethod + def ParseData(data_dir, filename): + data = validation_util.ParseData( + open(os.path.join(data_dir, filename), 'r').read()) + expect_file = filename[:-4] + 'expected' + expected_error = open( + os.path.join(data_dir, expect_file), 'r').read().strip(); + success = expected_error == 'PASS' + return (filename, data, success) + + @staticmethod + def GetData(prefix): + data_dir = os.path.join(paths.src_root, 'mojo', 'public', 'interfaces', + 'bindings', 'tests', 'data', 'validation') + return [ValidationTest.ParseData(data_dir, x) for x in os.listdir(data_dir) + if x.startswith(prefix) and x.endswith('.data')] + + def runTest(self, prefix, message_receiver): + for (filename, data, expected) in ValidationTest.GetData(prefix): + self.assertEquals(len(data.error_message), 0) + handles = [HandleMock() for _ in xrange(data.num_handles)] + message = messaging.Message(data.data, handles) + self.assertEquals(message_receiver.Accept(message), expected, + 'Unexpected result for test: %s' % filename) + + def testConformance(self): + manager = validation_test_interfaces_mojom.ConformanceTestInterface.manager + proxy = manager._InternalProxy(SinkMessageReceiver(), None) + stub = manager._Stub(proxy) + self.runTest('conformance_', stub)