From 07912e1091cb1f17a55217f6d1204aefb7beeac8 Mon Sep 17 00:00:00 2001 From: Thomas Adamcik Date: Sat, 4 Apr 2015 15:17:31 +0200 Subject: [PATCH] models: Add fields for supporting validation of models Feature makes use of python descriptors to hook in type checking and other validation when fields get set. --- mopidy/models.py | 121 ++++++++++++++++++++++ tests/models/test_fields.py | 194 ++++++++++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 tests/models/test_fields.py diff --git a/mopidy/models.py b/mopidy/models.py index 1ae26811..52af300d 100644 --- a/mopidy/models.py +++ b/mopidy/models.py @@ -2,6 +2,127 @@ from __future__ import absolute_import, unicode_literals import json +# TODO: split into base models, serialization and fields? + + +class Field(object): + def __init__(self, default=None, type=None, choices=None): + """ + Base field for use in :class:`ImmutableObject`. These fields are + responsible type checking and other data sanitation in our models. + + For simplicity fields use the Python descriptor protocol to store the + values in the instance dictionary. Also note that fields are mutable if + the object they are attached to allow it. + + Default values will be validated with the exception of :class:`None`. + + :param default: default value for field + :param type: if set the field value must be of this type + :param choices: if set the field value must be one of these + """ + self._name = None # Set by FieldMeta + self._choices = choices + self._default = default + self._type = type + + if self._default is not None: + self.validate(self._default) + + def validate(self, value): + """Validate and possibly modify the field value before assignment""" + if self._type and not isinstance(value, self._type): + raise TypeError('Expected %s to be a %s, not %r' % + (self._name, self._type, value)) + if self._choices and value not in self._choices: + raise TypeError('Expected %s to be a one of %s, not %r' % + (self._name, self._choices, value)) + return value + + def __get__(self, instance, owner): + if not instance: + return self + return instance.__dict__.get(self._name, self._default) + + def __set__(self, instance, value): + if value is None: + value = self._default + value = self.validate(value) + if value is not None: + instance.__dict__[self._name] = value + else: + self.__delete__(instance) + + def __delete__(self, instance): + instance.__dict__.pop(self._name, None) + + +class String(Field): + def __init__(self, default=None): + """ + Specialized :class:`Field` which is wired up for bytes and unicode. + + :param default: default value for field + """ + # TODO: normalize to unicode? + # TODO: only allow unicode? + # TODO: disallow empty strings? + super(String, self).__init__(type=basestring, default=default) + + +class Integer(Field): + def __init__(self, default=None, min=None, max=None): + """ + :class:`Field` for storing integer numbers. + + :param default: default value for field + :param min: if set the field value larger or equal to this value + :param max: if set the field value smaller or equal to this value + """ + self._min = min + self._max = max + super(Integer, self).__init__(type=(int, long), default=default) + + def validate(self, value): + value = super(Integer, self).validate(value) + if self._min is not None and value < self._min: + raise ValueError('Expected %s to be at least %d, not %d' % + (self._name, self._min, value)) + if self._max is not None and value > self._max: + raise ValueError('Expected %s to be at most %d, not %d' % + (self._name, self._max, value)) + return value + + +class Collection(Field): + def __init__(self, type, container=tuple): + """ + :class:`Field` for storing collections of a given type. + + :param type: all items stored in the collection must be of this type + :param container: the type to store the items in + """ + super(Collection, self).__init__(type=type, default=container()) + + def validate(self, value): + if isinstance(value, basestring): + raise TypeError('Expected %s to be a collection of %s, not %r' + % (self._name, self._type.__name__, value)) + for v in value: + if not isinstance(v, self._type): + raise TypeError('Expected %s to be a collection of %s, not %r' + % (self._name, self._type.__name__, value)) + return self._default.__class__(value) or None + + +class FieldOwner(type): + """Helper to automatically assign field names to descriptors.""" + def __new__(cls, name, bases, attrs): + for key, value in attrs.items(): + if isinstance(value, Field): + value._name = key + return super(FieldOwner, cls).__new__(cls, name, bases, attrs) + class ImmutableObject(object): diff --git a/tests/models/test_fields.py b/tests/models/test_fields.py new file mode 100644 index 00000000..f2b55f01 --- /dev/null +++ b/tests/models/test_fields.py @@ -0,0 +1,194 @@ +from __future__ import absolute_import, unicode_literals + +import unittest + +from mopidy.models import * # noqa: F403 + + +def create_instance(field): + """Create an instance of a dummy class for testing fields.""" + + class Dummy(object): + __metaclass__ = FieldOwner + attr = field + + return Dummy() + + +class FieldDescriptorTest(unittest.TestCase): + def test_raw_field_accesible_through_class(self): + field = Field() + instance = create_instance(field) + self.assertEqual(field, instance.__class__.attr) + + def test_field_knows_its_name(self): + instance = create_instance(Field()) + self.assertEqual('attr', instance.__class__.attr._name) + + def test_field_has_none_as_default(self): + instance = create_instance(Field()) + self.assertIsNone(instance.attr) + + def test_field_does_not_store_default_in_dict(self): + instance = create_instance(Field()) + self.assertNotIn('attr', instance.__dict__) + + def test_field_assigment_and_retrival(self): + instance = create_instance(Field()) + instance.attr = 1234 + self.assertEqual(1234, instance.attr) + self.assertEqual(1234, instance.__dict__['attr']) + + def test_field_can_be_reassigned(self): + instance = create_instance(Field()) + instance.attr = 1234 + instance.attr = 5678 + self.assertEqual(5678, instance.attr) + + def test_field_can_be_deleted(self): + instance = create_instance(Field()) + instance.attr = 1234 + del instance.attr + self.assertEqual(None, instance.attr) + self.assertNotIn('attr', instance.__dict__) + + def test_field_can_be_set_to_none(self): + instance = create_instance(Field()) + instance.attr = 1234 + instance.attr = None + self.assertEqual(None, instance.attr) + self.assertNotIn('attr', instance.__dict__) + + +class FieldTest(unittest.TestCase): + def test_default_handling(self): + instance = create_instance(Field(default=1234)) + self.assertEqual(1234, instance.attr) + + def test_type_checking(self): + instance = create_instance(Field(type=set)) + instance.attr = set() + + with self.assertRaises(TypeError): + instance.attr = 1234 + + def test_choices_checking(self): + instance = create_instance(Field(choices=(1, 2, 3))) + instance.attr = 1 + + with self.assertRaises(TypeError): + instance.attr = 4 + + def test_default_respects_type_check(self): + with self.assertRaises(TypeError): + create_instance(Field(type=int, default='123')) + + def test_default_respects_choices_check(self): + with self.assertRaises(TypeError): + create_instance(Field(choices=(1, 2, 3), default=5)) + + +class StringTest(unittest.TestCase): + def test_default_handling(self): + instance = create_instance(String(default='abc')) + self.assertEqual('abc', instance.attr) + + def test_str_allowed(self): + instance = create_instance(String()) + instance.attr = str('abc') + self.assertEqual(b'abc', instance.attr) + + def test_unicode_allowed(self): + instance = create_instance(String()) + instance.attr = unicode('abc') + self.assertEqual(u'abc', instance.attr) + + def test_other_disallowed(self): + instance = create_instance(String()) + with self.assertRaises(TypeError): + instance.attr = 1234 + + def test_empty_string(self): + instance = create_instance(String()) + instance.attr = '' + self.assertEqual('', instance.attr) + + +class IntegerTest(unittest.TestCase): + def test_default_handling(self): + instance = create_instance(Integer(default=1234)) + self.assertEqual(1234, instance.attr) + + def test_int_allowed(self): + instance = create_instance(Integer()) + instance.attr = int(123) + self.assertEqual(123, instance.attr) + + def test_long_allowed(self): + instance = create_instance(Integer()) + instance.attr = long(123) + self.assertEqual(123, instance.attr) + + def test_float_disallowed(self): + instance = create_instance(Integer()) + with self.assertRaises(TypeError): + instance.attr = 123.0 + + def test_numeric_string_disallowed(self): + instance = create_instance(Integer()) + with self.assertRaises(TypeError): + instance.attr = '123' + + def test_other_disallowed(self): + instance = create_instance(String()) + with self.assertRaises(TypeError): + instance.attr = tuple() + + def test_min_validation(self): + instance = create_instance(Integer(min=0)) + instance.attr = 0 + self.assertEqual(0, instance.attr) + + with self.assertRaises(ValueError): + instance.attr = -1 + + def test_max_validation(self): + instance = create_instance(Integer(max=10)) + instance.attr = 10 + self.assertEqual(10, instance.attr) + + with self.assertRaises(ValueError): + instance.attr = 11 + + +class CollectionTest(unittest.TestCase): + def test_container_instance_is_default(self): + instance = create_instance(Collection(type=int, container=frozenset)) + self.assertEqual(frozenset(), instance.attr) + + def test_empty_collection(self): + instance = create_instance(Collection(type=int, container=frozenset)) + instance.attr = [] + self.assertEqual(frozenset(), instance.attr) + self.assertNotIn('attr', instance.__dict__) + + def test_collection_gets_stored_in_container(self): + instance = create_instance(Collection(type=int, container=frozenset)) + instance.attr = [1, 2, 3] + self.assertEqual(frozenset([1, 2, 3]), instance.attr) + self.assertEqual(frozenset([1, 2, 3]), instance.__dict__['attr']) + + def test_collection_with_wrong_type(self): + instance = create_instance(Collection(type=int, container=frozenset)) + with self.assertRaises(TypeError): + instance.attr = [1, '2', 3] + + def test_collection_with_string(self): + instance = create_instance(Collection(type=int, container=frozenset)) + with self.assertRaises(TypeError): + instance.attr = '123' + + def test_strings_should_not_be_considered_a_collection(self): + instance = create_instance(Collection(type=str, container=tuple)) + with self.assertRaises(TypeError): + instance.attr = b'123'