config: Add optional setting to config values and improve tests.

This commit is contained in:
Thomas Adamcik 2013-04-01 23:21:56 +02:00
parent c416893fb3
commit 98269f4ed1
2 changed files with 112 additions and 49 deletions

View File

@ -7,9 +7,16 @@ import socket
from mopidy import exceptions
def validate_required(value, required):
"""Required validation, normally called in config value's validate() on the
raw string, _not_ the converted value."""
if required and not value.strip():
raise ValueError('must be set.')
def validate_choice(value, choices):
"""Choice validation, normally called in config value's validate()."""
if choices is not None and value not in choices :
if choices is not None and value not in choices:
names = ', '.join(repr(c) for c in choices)
raise ValueError('must be one of %s, not %s.' % (names, value))
@ -55,14 +62,18 @@ class ConfigValue(object):
#: :function:`validate_maximum` in :method:`validate` do any thing.
maximum = None
#: Indicate if this field is required.
opitional = None
#: Indicate if we should mask the when printing for human consumption.
secret = None
def __init__(self, choices=None, minimum=None, maximum=None, secret=None):
self.choices = choices
self.minimum = minimum
self.maximum = maximum
self.secret = secret
def __init__(self, **kwargs):
self.choices = kwargs.get('choices')
self.minimum = kwargs.get('minimum')
self.maximum = kwargs.get('maximum')
self.optional = kwargs.get('optional')
self.secret = kwargs.get('secret')
def deserialize(self, value):
"""Cast raw string to appropriate type."""
@ -70,8 +81,6 @@ class ConfigValue(object):
def serialize(self, value):
"""Convert value back to string for saving."""
if value is None:
return ''
return str(value)
def format(self, value):
@ -84,7 +93,10 @@ class ConfigValue(object):
class String(ConfigValue):
def deserialize(self, value):
value = value.strip()
validate_required(value, not self.optional)
validate_choice(value, self.choices)
if not value:
return None
return value
def serialize(self, value):
@ -93,7 +105,7 @@ class String(ConfigValue):
class Integer(ConfigValue):
def deserialize(self, value):
value = int(value.strip())
value = int(value)
validate_choice(value, self.choices)
validate_minimum(value, self.minimum)
validate_maximum(value, self.maximum)
@ -121,10 +133,12 @@ class Boolean(ConfigValue):
class List(ConfigValue):
def deserialize(self, value):
validate_required(value, not self.optional)
if '\n' in value:
return re.split(r'\s*\n\s*', value.strip())
values = re.split(r'\s*\n\s*', value.strip())
else:
return re.split(r'\s*,\s*', value.strip())
values = re.split(r'\s*,\s*', value.strip())
return [v for v in values if v]
def serialize(self, value):
return '\n '.join(v.encode('utf-8') for v in value)
@ -138,8 +152,7 @@ class LogLevel(ConfigValue):
'debug' : logging.DEBUG}
def deserialize(self, value):
if value.lower() not in self.levels:
raise ValueError('%r must be one of %s.' % (value, ', '.join(self.levels)))
validate_choice(value.lower(), self.levels.keys())
return self.levels.get(value.lower())
def serialize(self, value):
@ -148,6 +161,9 @@ class LogLevel(ConfigValue):
class Hostname(ConfigValue):
def deserialize(self, value):
validate_required(value, not self.optional)
if not value.strip():
return None
try:
socket.getaddrinfo(value, None)
except socket.error:
@ -156,6 +172,7 @@ class Hostname(ConfigValue):
class Port(Integer):
# TODO: consider probing if port is free or not?
def __init__(self, **kwargs):
super(Port, self).__init__(**kwargs)
self.minimum = 1
@ -197,10 +214,7 @@ class ConfigSchema(object):
for key, value in items:
try:
if value.strip():
values[key] = self._schema[key].deserialize(value)
else: # treat blank entries as none
values[key] = None
values[key] = self._schema[key].deserialize(value)
except KeyError: # not in our schema
errors[key] = 'unknown config key.'
except ValueError as e: # deserialization failed

View File

@ -55,20 +55,38 @@ class ValidateMaximumTest(unittest.TestCase):
self.assertRaises(ValueError, config.validate_maximum, 5, 0)
class ValidateRequiredTest(unittest.TestCase):
def test_passes_when_false(self):
config.validate_required('foo', False)
config.validate_required('', False)
config.validate_required(' ', False)
def test_passes_when_required_and_set(self):
config.validate_required('foo', True)
config.validate_required(' foo ', True)
def test_blocks_when_required_and_emtpy(self):
self.assertRaises(ValueError, config.validate_required, '', True)
self.assertRaises(ValueError, config.validate_required, ' ', True)
class ConfigValueTest(unittest.TestCase):
def test_init(self):
value = config.ConfigValue()
self.assertIsNone(value.choices)
self.assertIsNone(value.minimum)
self.assertIsNone(value.maximum)
self.assertIsNone(value.minimum)
self.assertIsNone(value.optional)
self.assertIsNone(value.secret)
def test_init_with_params(self):
value = config.ConfigValue(
choices=['foo'], minimum=0, maximum=10, secret=True)
kwargs = {'choices': ['foo'], 'minimum': 0, 'maximum': 10,
'secret': True, 'optional': True}
value = config.ConfigValue(**kwargs)
self.assertEqual(['foo'], value.choices)
self.assertEqual(0, value.minimum)
self.assertEqual(10, value.maximum)
self.assertEqual(True, value.optional)
self.assertEqual(True, value.secret)
def test_deserialize_passes_through(self):
@ -91,7 +109,7 @@ class ConfigValueTest(unittest.TestCase):
class StringTest(unittest.TestCase):
def test_deserialize_strips_whitespace(self):
def test_deserialize_converts_success(self):
value = config.String()
self.assertEqual('foo', value.deserialize(' foo '))
@ -100,22 +118,34 @@ class StringTest(unittest.TestCase):
self.assertEqual('foo', value.deserialize('foo'))
self.assertRaises(ValueError, value.deserialize, 'foobar')
def test_deserialize_enforces_required(self):
value = config.String()
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
def test_deserialize_respects_optional(self):
value = config.String(optional=True)
self.assertIsNone(value.deserialize(''))
self.assertIsNone(value.deserialize(' '))
def test_format_masks_secrets(self):
value = config.String(secret=True)
self.assertEqual('********', value.format('s3cret'))
class IntegerTest(unittest.TestCase):
def test_deserialize_converts_to_int(self):
def test_deserialize_converts_success(self):
value = config.Integer()
self.assertEqual(123, value.deserialize('123'))
self.assertEqual(0, value.deserialize('0'))
self.assertEqual(-10, value.deserialize('-10'))
def test_deserialize_fails_on_bad_data(self):
def test_deserialize_conversion_failure(self):
value = config.Integer()
self.assertRaises(ValueError, value.deserialize, 'asd')
self.assertRaises(ValueError, value.deserialize, '3.14')
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
def test_deserialize_enforces_choices(self):
value = config.Integer(choices=[1, 2, 3])
@ -138,7 +168,7 @@ class IntegerTest(unittest.TestCase):
class BooleanTest(unittest.TestCase):
def test_deserialize_converts_to_bool(self):
def test_deserialize_converts_success(self):
value = config.Boolean()
for true in ('1', 'yes', 'true', 'on'):
self.assertIs(value.deserialize(true), True)
@ -149,12 +179,13 @@ class BooleanTest(unittest.TestCase):
self.assertIs(value.deserialize(false.upper()), False)
self.assertIs(value.deserialize(false.capitalize()), False)
def test_deserialize_fails_on_bad_data(self):
def test_deserialize_conversion_failure(self):
value = config.Boolean()
self.assertRaises(ValueError, value.deserialize, 'nope')
self.assertRaises(ValueError, value.deserialize, 'sure')
self.assertRaises(ValueError, value.deserialize, '')
def test_serialize_normalises_strings(self):
def test_serialize(self):
value = config.Boolean()
self.assertEqual('true', value.serialize(True))
self.assertEqual('false', value.serialize(False))
@ -165,20 +196,29 @@ class BooleanTest(unittest.TestCase):
class ListTest(unittest.TestCase):
def test_deserialize_splits_commas(self):
def test_deserialize_converts_success(self):
value = config.List()
self.assertEqual(['foo', 'bar', 'baz'],
value.deserialize('foo, bar,baz'))
def test_deserialize_splits_newlines(self):
value = config.List()
self.assertEqual(['foo,bar', 'bar', 'baz'],
value.deserialize('foo,bar\nbar\nbaz'))
expected = ['foo', 'bar', 'baz']
self.assertEqual(expected, value.deserialize('foo, bar ,baz '))
def test_serialize_joins_by_newlines(self):
expected = ['foo,bar', 'bar', 'baz']
self.assertEqual(expected, value.deserialize(' foo,bar\nbar\nbaz'))
def test_deserialize_enforces_required(self):
value = config.List()
self.assertRegexpMatches(value.serialize(['foo', 'bar', 'baz']),
r'foo\n\s*bar\n\s*baz')
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
def test_deserialize_respects_optional(self):
value = config.List(optional=True)
self.assertEqual([], value.deserialize(''))
self.assertEqual([], value.deserialize(' '))
def test_serialize(self):
value = config.List()
result = value.serialize(['foo', 'bar', 'baz'])
self.assertRegexpMatches(result, r'foo\n\s*bar\n\s*baz')
class BooleanTest(unittest.TestCase):
@ -188,41 +228,54 @@ class BooleanTest(unittest.TestCase):
'info': logging.INFO,
'debug': logging.DEBUG}
def test_deserialize_converts_to_numeric_loglevel(self):
def test_deserialize_converts_success(self):
value = config.LogLevel()
for name, level in self.levels.items():
self.assertEqual(level, value.deserialize(name))
self.assertEqual(level, value.deserialize(name.upper()))
self.assertEqual(level, value.deserialize(name.capitalize()))
def test_deserialize_fails_on_bad_data(self):
def test_deserialize_conversion_failure(self):
value = config.LogLevel()
self.assertRaises(ValueError, value.deserialize, 'nope')
self.assertRaises(ValueError, value.deserialize, 'sure')
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
def test_serialize_converts_to_string(self):
def test_serialize(self):
value = config.LogLevel()
for name, level in self.levels.items():
self.assertEqual(name, value.serialize(level))
def test_serialize_unknown_level(self):
value = config.LogLevel()
self.assertIsNone(value.serialize(1337))
class HostnameTest(unittest.TestCase):
@mock.patch('socket.getaddrinfo')
def test_deserialize_checks_addrinfo(self, getaddrinfo_mock):
def test_deserialize_converts_success(self, getaddrinfo_mock):
value = config.Hostname()
value.deserialize('example.com')
getaddrinfo_mock.assert_called_once_with('example.com', None)
@mock.patch('socket.getaddrinfo')
def test_deserialize_handles_failures(self, getaddrinfo_mock):
def test_deserialize_conversion_failure(self, getaddrinfo_mock):
value = config.Hostname()
getaddrinfo_mock.side_effect = socket.error
self.assertRaises(ValueError, value.deserialize, 'example.com')
@mock.patch('socket.getaddrinfo')
def test_deserialize_enforces_required(self, getaddrinfo_mock):
value = config.Hostname()
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
self.assertEqual(0, getaddrinfo_mock.call_count)
@mock.patch('socket.getaddrinfo')
def test_deserialize_respects_optional(self, getaddrinfo_mock):
value = config.Hostname(optional=True)
self.assertIsNone(value.deserialize(''))
self.assertIsNone(value.deserialize(' '))
self.assertEqual(0, getaddrinfo_mock.call_count)
class PortTest(unittest.TestCase):
def test_valid_ports(self):
@ -238,6 +291,7 @@ class PortTest(unittest.TestCase):
self.assertRaises(ValueError, value.deserialize, '100000')
self.assertRaises(ValueError, value.deserialize, '0')
self.assertRaises(ValueError, value.deserialize, '-1')
self.assertRaises(ValueError, value.deserialize, '')
class ConfigSchemaTest(unittest.TestCase):
@ -285,11 +339,6 @@ class ConfigSchemaTest(unittest.TestCase):
self.assertIn('unknown', cm.exception['extra'])
def test_convert_with_blank_value(self):
self.values['foo'] = ''
result = self.schema.convert(self.values.items())
self.assertIsNone(result['foo'])
def test_convert_with_deserialization_error(self):
self.schema['foo'].deserialize.side_effect = ValueError('failure')