From 98269f4ed1093afa66986d4a80f82649e5c6cb8d Mon Sep 17 00:00:00 2001 From: Thomas Adamcik Date: Mon, 1 Apr 2013 23:21:56 +0200 Subject: [PATCH] config: Add optional setting to config values and improve tests. --- mopidy/utils/config.py | 48 ++++++++++------ tests/utils/config_test.py | 113 ++++++++++++++++++++++++++----------- 2 files changed, 112 insertions(+), 49 deletions(-) diff --git a/mopidy/utils/config.py b/mopidy/utils/config.py index fad641f1..e20afd26 100644 --- a/mopidy/utils/config.py +++ b/mopidy/utils/config.py @@ -7,9 +7,16 @@ import socket from mopidy import exceptions +def validate_required(value, required): + """Required validation, normally called in config value's validate() on the + raw string, _not_ the converted value.""" + if required and not value.strip(): + raise ValueError('must be set.') + + def validate_choice(value, choices): """Choice validation, normally called in config value's validate().""" - if choices is not None and value not in choices : + if choices is not None and value not in choices: names = ', '.join(repr(c) for c in choices) raise ValueError('must be one of %s, not %s.' % (names, value)) @@ -55,14 +62,18 @@ class ConfigValue(object): #: :function:`validate_maximum` in :method:`validate` do any thing. maximum = None + #: Indicate if this field is required. + opitional = None + #: Indicate if we should mask the when printing for human consumption. secret = None - def __init__(self, choices=None, minimum=None, maximum=None, secret=None): - self.choices = choices - self.minimum = minimum - self.maximum = maximum - self.secret = secret + def __init__(self, **kwargs): + self.choices = kwargs.get('choices') + self.minimum = kwargs.get('minimum') + self.maximum = kwargs.get('maximum') + self.optional = kwargs.get('optional') + self.secret = kwargs.get('secret') def deserialize(self, value): """Cast raw string to appropriate type.""" @@ -70,8 +81,6 @@ class ConfigValue(object): def serialize(self, value): """Convert value back to string for saving.""" - if value is None: - return '' return str(value) def format(self, value): @@ -84,7 +93,10 @@ class ConfigValue(object): class String(ConfigValue): def deserialize(self, value): value = value.strip() + validate_required(value, not self.optional) validate_choice(value, self.choices) + if not value: + return None return value def serialize(self, value): @@ -93,7 +105,7 @@ class String(ConfigValue): class Integer(ConfigValue): def deserialize(self, value): - value = int(value.strip()) + value = int(value) validate_choice(value, self.choices) validate_minimum(value, self.minimum) validate_maximum(value, self.maximum) @@ -121,10 +133,12 @@ class Boolean(ConfigValue): class List(ConfigValue): def deserialize(self, value): + validate_required(value, not self.optional) if '\n' in value: - return re.split(r'\s*\n\s*', value.strip()) + values = re.split(r'\s*\n\s*', value.strip()) else: - return re.split(r'\s*,\s*', value.strip()) + values = re.split(r'\s*,\s*', value.strip()) + return [v for v in values if v] def serialize(self, value): return '\n '.join(v.encode('utf-8') for v in value) @@ -138,8 +152,7 @@ class LogLevel(ConfigValue): 'debug' : logging.DEBUG} def deserialize(self, value): - if value.lower() not in self.levels: - raise ValueError('%r must be one of %s.' % (value, ', '.join(self.levels))) + validate_choice(value.lower(), self.levels.keys()) return self.levels.get(value.lower()) def serialize(self, value): @@ -148,6 +161,9 @@ class LogLevel(ConfigValue): class Hostname(ConfigValue): def deserialize(self, value): + validate_required(value, not self.optional) + if not value.strip(): + return None try: socket.getaddrinfo(value, None) except socket.error: @@ -156,6 +172,7 @@ class Hostname(ConfigValue): class Port(Integer): + # TODO: consider probing if port is free or not? def __init__(self, **kwargs): super(Port, self).__init__(**kwargs) self.minimum = 1 @@ -197,10 +214,7 @@ class ConfigSchema(object): for key, value in items: try: - if value.strip(): - values[key] = self._schema[key].deserialize(value) - else: # treat blank entries as none - values[key] = None + values[key] = self._schema[key].deserialize(value) except KeyError: # not in our schema errors[key] = 'unknown config key.' except ValueError as e: # deserialization failed diff --git a/tests/utils/config_test.py b/tests/utils/config_test.py index a98c37b5..527ec8d3 100644 --- a/tests/utils/config_test.py +++ b/tests/utils/config_test.py @@ -55,20 +55,38 @@ class ValidateMaximumTest(unittest.TestCase): self.assertRaises(ValueError, config.validate_maximum, 5, 0) +class ValidateRequiredTest(unittest.TestCase): + def test_passes_when_false(self): + config.validate_required('foo', False) + config.validate_required('', False) + config.validate_required(' ', False) + + def test_passes_when_required_and_set(self): + config.validate_required('foo', True) + config.validate_required(' foo ', True) + + def test_blocks_when_required_and_emtpy(self): + self.assertRaises(ValueError, config.validate_required, '', True) + self.assertRaises(ValueError, config.validate_required, ' ', True) + + class ConfigValueTest(unittest.TestCase): def test_init(self): value = config.ConfigValue() self.assertIsNone(value.choices) - self.assertIsNone(value.minimum) self.assertIsNone(value.maximum) + self.assertIsNone(value.minimum) + self.assertIsNone(value.optional) self.assertIsNone(value.secret) def test_init_with_params(self): - value = config.ConfigValue( - choices=['foo'], minimum=0, maximum=10, secret=True) + kwargs = {'choices': ['foo'], 'minimum': 0, 'maximum': 10, + 'secret': True, 'optional': True} + value = config.ConfigValue(**kwargs) self.assertEqual(['foo'], value.choices) self.assertEqual(0, value.minimum) self.assertEqual(10, value.maximum) + self.assertEqual(True, value.optional) self.assertEqual(True, value.secret) def test_deserialize_passes_through(self): @@ -91,7 +109,7 @@ class ConfigValueTest(unittest.TestCase): class StringTest(unittest.TestCase): - def test_deserialize_strips_whitespace(self): + def test_deserialize_converts_success(self): value = config.String() self.assertEqual('foo', value.deserialize(' foo ')) @@ -100,22 +118,34 @@ class StringTest(unittest.TestCase): self.assertEqual('foo', value.deserialize('foo')) self.assertRaises(ValueError, value.deserialize, 'foobar') + def test_deserialize_enforces_required(self): + value = config.String() + self.assertRaises(ValueError, value.deserialize, '') + self.assertRaises(ValueError, value.deserialize, ' ') + + def test_deserialize_respects_optional(self): + value = config.String(optional=True) + self.assertIsNone(value.deserialize('')) + self.assertIsNone(value.deserialize(' ')) + def test_format_masks_secrets(self): value = config.String(secret=True) self.assertEqual('********', value.format('s3cret')) class IntegerTest(unittest.TestCase): - def test_deserialize_converts_to_int(self): + def test_deserialize_converts_success(self): value = config.Integer() self.assertEqual(123, value.deserialize('123')) self.assertEqual(0, value.deserialize('0')) self.assertEqual(-10, value.deserialize('-10')) - def test_deserialize_fails_on_bad_data(self): + def test_deserialize_conversion_failure(self): value = config.Integer() self.assertRaises(ValueError, value.deserialize, 'asd') self.assertRaises(ValueError, value.deserialize, '3.14') + self.assertRaises(ValueError, value.deserialize, '') + self.assertRaises(ValueError, value.deserialize, ' ') def test_deserialize_enforces_choices(self): value = config.Integer(choices=[1, 2, 3]) @@ -138,7 +168,7 @@ class IntegerTest(unittest.TestCase): class BooleanTest(unittest.TestCase): - def test_deserialize_converts_to_bool(self): + def test_deserialize_converts_success(self): value = config.Boolean() for true in ('1', 'yes', 'true', 'on'): self.assertIs(value.deserialize(true), True) @@ -149,12 +179,13 @@ class BooleanTest(unittest.TestCase): self.assertIs(value.deserialize(false.upper()), False) self.assertIs(value.deserialize(false.capitalize()), False) - def test_deserialize_fails_on_bad_data(self): + def test_deserialize_conversion_failure(self): value = config.Boolean() self.assertRaises(ValueError, value.deserialize, 'nope') self.assertRaises(ValueError, value.deserialize, 'sure') + self.assertRaises(ValueError, value.deserialize, '') - def test_serialize_normalises_strings(self): + def test_serialize(self): value = config.Boolean() self.assertEqual('true', value.serialize(True)) self.assertEqual('false', value.serialize(False)) @@ -165,20 +196,29 @@ class BooleanTest(unittest.TestCase): class ListTest(unittest.TestCase): - def test_deserialize_splits_commas(self): + def test_deserialize_converts_success(self): value = config.List() - self.assertEqual(['foo', 'bar', 'baz'], - value.deserialize('foo, bar,baz')) - def test_deserialize_splits_newlines(self): - value = config.List() - self.assertEqual(['foo,bar', 'bar', 'baz'], - value.deserialize('foo,bar\nbar\nbaz')) + expected = ['foo', 'bar', 'baz'] + self.assertEqual(expected, value.deserialize('foo, bar ,baz ')) - def test_serialize_joins_by_newlines(self): + expected = ['foo,bar', 'bar', 'baz'] + self.assertEqual(expected, value.deserialize(' foo,bar\nbar\nbaz')) + + def test_deserialize_enforces_required(self): value = config.List() - self.assertRegexpMatches(value.serialize(['foo', 'bar', 'baz']), - r'foo\n\s*bar\n\s*baz') + self.assertRaises(ValueError, value.deserialize, '') + self.assertRaises(ValueError, value.deserialize, ' ') + + def test_deserialize_respects_optional(self): + value = config.List(optional=True) + self.assertEqual([], value.deserialize('')) + self.assertEqual([], value.deserialize(' ')) + + def test_serialize(self): + value = config.List() + result = value.serialize(['foo', 'bar', 'baz']) + self.assertRegexpMatches(result, r'foo\n\s*bar\n\s*baz') class BooleanTest(unittest.TestCase): @@ -188,41 +228,54 @@ class BooleanTest(unittest.TestCase): 'info': logging.INFO, 'debug': logging.DEBUG} - def test_deserialize_converts_to_numeric_loglevel(self): + def test_deserialize_converts_success(self): value = config.LogLevel() for name, level in self.levels.items(): self.assertEqual(level, value.deserialize(name)) self.assertEqual(level, value.deserialize(name.upper())) self.assertEqual(level, value.deserialize(name.capitalize())) - def test_deserialize_fails_on_bad_data(self): + def test_deserialize_conversion_failure(self): value = config.LogLevel() self.assertRaises(ValueError, value.deserialize, 'nope') self.assertRaises(ValueError, value.deserialize, 'sure') + self.assertRaises(ValueError, value.deserialize, '') + self.assertRaises(ValueError, value.deserialize, ' ') - def test_serialize_converts_to_string(self): + def test_serialize(self): value = config.LogLevel() for name, level in self.levels.items(): self.assertEqual(name, value.serialize(level)) - - def test_serialize_unknown_level(self): - value = config.LogLevel() self.assertIsNone(value.serialize(1337)) class HostnameTest(unittest.TestCase): @mock.patch('socket.getaddrinfo') - def test_deserialize_checks_addrinfo(self, getaddrinfo_mock): + def test_deserialize_converts_success(self, getaddrinfo_mock): value = config.Hostname() value.deserialize('example.com') getaddrinfo_mock.assert_called_once_with('example.com', None) @mock.patch('socket.getaddrinfo') - def test_deserialize_handles_failures(self, getaddrinfo_mock): + def test_deserialize_conversion_failure(self, getaddrinfo_mock): value = config.Hostname() getaddrinfo_mock.side_effect = socket.error self.assertRaises(ValueError, value.deserialize, 'example.com') + @mock.patch('socket.getaddrinfo') + def test_deserialize_enforces_required(self, getaddrinfo_mock): + value = config.Hostname() + self.assertRaises(ValueError, value.deserialize, '') + self.assertRaises(ValueError, value.deserialize, ' ') + self.assertEqual(0, getaddrinfo_mock.call_count) + + @mock.patch('socket.getaddrinfo') + def test_deserialize_respects_optional(self, getaddrinfo_mock): + value = config.Hostname(optional=True) + self.assertIsNone(value.deserialize('')) + self.assertIsNone(value.deserialize(' ')) + self.assertEqual(0, getaddrinfo_mock.call_count) + class PortTest(unittest.TestCase): def test_valid_ports(self): @@ -238,6 +291,7 @@ class PortTest(unittest.TestCase): self.assertRaises(ValueError, value.deserialize, '100000') self.assertRaises(ValueError, value.deserialize, '0') self.assertRaises(ValueError, value.deserialize, '-1') + self.assertRaises(ValueError, value.deserialize, '') class ConfigSchemaTest(unittest.TestCase): @@ -285,11 +339,6 @@ class ConfigSchemaTest(unittest.TestCase): self.assertIn('unknown', cm.exception['extra']) - def test_convert_with_blank_value(self): - self.values['foo'] = '' - result = self.schema.convert(self.values.items()) - self.assertIsNone(result['foo']) - def test_convert_with_deserialization_error(self): self.schema['foo'].deserialize.side_effect = ValueError('failure')