config: Add optional setting to config values and improve tests.
This commit is contained in:
parent
c416893fb3
commit
98269f4ed1
@ -7,9 +7,16 @@ import socket
|
||||
from mopidy import exceptions
|
||||
|
||||
|
||||
def validate_required(value, required):
|
||||
"""Required validation, normally called in config value's validate() on the
|
||||
raw string, _not_ the converted value."""
|
||||
if required and not value.strip():
|
||||
raise ValueError('must be set.')
|
||||
|
||||
|
||||
def validate_choice(value, choices):
|
||||
"""Choice validation, normally called in config value's validate()."""
|
||||
if choices is not None and value not in choices :
|
||||
if choices is not None and value not in choices:
|
||||
names = ', '.join(repr(c) for c in choices)
|
||||
raise ValueError('must be one of %s, not %s.' % (names, value))
|
||||
|
||||
@ -55,14 +62,18 @@ class ConfigValue(object):
|
||||
#: :function:`validate_maximum` in :method:`validate` do any thing.
|
||||
maximum = None
|
||||
|
||||
#: Indicate if this field is required.
|
||||
opitional = None
|
||||
|
||||
#: Indicate if we should mask the when printing for human consumption.
|
||||
secret = None
|
||||
|
||||
def __init__(self, choices=None, minimum=None, maximum=None, secret=None):
|
||||
self.choices = choices
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.secret = secret
|
||||
def __init__(self, **kwargs):
|
||||
self.choices = kwargs.get('choices')
|
||||
self.minimum = kwargs.get('minimum')
|
||||
self.maximum = kwargs.get('maximum')
|
||||
self.optional = kwargs.get('optional')
|
||||
self.secret = kwargs.get('secret')
|
||||
|
||||
def deserialize(self, value):
|
||||
"""Cast raw string to appropriate type."""
|
||||
@ -70,8 +81,6 @@ class ConfigValue(object):
|
||||
|
||||
def serialize(self, value):
|
||||
"""Convert value back to string for saving."""
|
||||
if value is None:
|
||||
return ''
|
||||
return str(value)
|
||||
|
||||
def format(self, value):
|
||||
@ -84,7 +93,10 @@ class ConfigValue(object):
|
||||
class String(ConfigValue):
|
||||
def deserialize(self, value):
|
||||
value = value.strip()
|
||||
validate_required(value, not self.optional)
|
||||
validate_choice(value, self.choices)
|
||||
if not value:
|
||||
return None
|
||||
return value
|
||||
|
||||
def serialize(self, value):
|
||||
@ -93,7 +105,7 @@ class String(ConfigValue):
|
||||
|
||||
class Integer(ConfigValue):
|
||||
def deserialize(self, value):
|
||||
value = int(value.strip())
|
||||
value = int(value)
|
||||
validate_choice(value, self.choices)
|
||||
validate_minimum(value, self.minimum)
|
||||
validate_maximum(value, self.maximum)
|
||||
@ -121,10 +133,12 @@ class Boolean(ConfigValue):
|
||||
|
||||
class List(ConfigValue):
|
||||
def deserialize(self, value):
|
||||
validate_required(value, not self.optional)
|
||||
if '\n' in value:
|
||||
return re.split(r'\s*\n\s*', value.strip())
|
||||
values = re.split(r'\s*\n\s*', value.strip())
|
||||
else:
|
||||
return re.split(r'\s*,\s*', value.strip())
|
||||
values = re.split(r'\s*,\s*', value.strip())
|
||||
return [v for v in values if v]
|
||||
|
||||
def serialize(self, value):
|
||||
return '\n '.join(v.encode('utf-8') for v in value)
|
||||
@ -138,8 +152,7 @@ class LogLevel(ConfigValue):
|
||||
'debug' : logging.DEBUG}
|
||||
|
||||
def deserialize(self, value):
|
||||
if value.lower() not in self.levels:
|
||||
raise ValueError('%r must be one of %s.' % (value, ', '.join(self.levels)))
|
||||
validate_choice(value.lower(), self.levels.keys())
|
||||
return self.levels.get(value.lower())
|
||||
|
||||
def serialize(self, value):
|
||||
@ -148,6 +161,9 @@ class LogLevel(ConfigValue):
|
||||
|
||||
class Hostname(ConfigValue):
|
||||
def deserialize(self, value):
|
||||
validate_required(value, not self.optional)
|
||||
if not value.strip():
|
||||
return None
|
||||
try:
|
||||
socket.getaddrinfo(value, None)
|
||||
except socket.error:
|
||||
@ -156,6 +172,7 @@ class Hostname(ConfigValue):
|
||||
|
||||
|
||||
class Port(Integer):
|
||||
# TODO: consider probing if port is free or not?
|
||||
def __init__(self, **kwargs):
|
||||
super(Port, self).__init__(**kwargs)
|
||||
self.minimum = 1
|
||||
@ -197,10 +214,7 @@ class ConfigSchema(object):
|
||||
|
||||
for key, value in items:
|
||||
try:
|
||||
if value.strip():
|
||||
values[key] = self._schema[key].deserialize(value)
|
||||
else: # treat blank entries as none
|
||||
values[key] = None
|
||||
values[key] = self._schema[key].deserialize(value)
|
||||
except KeyError: # not in our schema
|
||||
errors[key] = 'unknown config key.'
|
||||
except ValueError as e: # deserialization failed
|
||||
|
||||
@ -55,20 +55,38 @@ class ValidateMaximumTest(unittest.TestCase):
|
||||
self.assertRaises(ValueError, config.validate_maximum, 5, 0)
|
||||
|
||||
|
||||
class ValidateRequiredTest(unittest.TestCase):
|
||||
def test_passes_when_false(self):
|
||||
config.validate_required('foo', False)
|
||||
config.validate_required('', False)
|
||||
config.validate_required(' ', False)
|
||||
|
||||
def test_passes_when_required_and_set(self):
|
||||
config.validate_required('foo', True)
|
||||
config.validate_required(' foo ', True)
|
||||
|
||||
def test_blocks_when_required_and_emtpy(self):
|
||||
self.assertRaises(ValueError, config.validate_required, '', True)
|
||||
self.assertRaises(ValueError, config.validate_required, ' ', True)
|
||||
|
||||
|
||||
class ConfigValueTest(unittest.TestCase):
|
||||
def test_init(self):
|
||||
value = config.ConfigValue()
|
||||
self.assertIsNone(value.choices)
|
||||
self.assertIsNone(value.minimum)
|
||||
self.assertIsNone(value.maximum)
|
||||
self.assertIsNone(value.minimum)
|
||||
self.assertIsNone(value.optional)
|
||||
self.assertIsNone(value.secret)
|
||||
|
||||
def test_init_with_params(self):
|
||||
value = config.ConfigValue(
|
||||
choices=['foo'], minimum=0, maximum=10, secret=True)
|
||||
kwargs = {'choices': ['foo'], 'minimum': 0, 'maximum': 10,
|
||||
'secret': True, 'optional': True}
|
||||
value = config.ConfigValue(**kwargs)
|
||||
self.assertEqual(['foo'], value.choices)
|
||||
self.assertEqual(0, value.minimum)
|
||||
self.assertEqual(10, value.maximum)
|
||||
self.assertEqual(True, value.optional)
|
||||
self.assertEqual(True, value.secret)
|
||||
|
||||
def test_deserialize_passes_through(self):
|
||||
@ -91,7 +109,7 @@ class ConfigValueTest(unittest.TestCase):
|
||||
|
||||
|
||||
class StringTest(unittest.TestCase):
|
||||
def test_deserialize_strips_whitespace(self):
|
||||
def test_deserialize_converts_success(self):
|
||||
value = config.String()
|
||||
self.assertEqual('foo', value.deserialize(' foo '))
|
||||
|
||||
@ -100,22 +118,34 @@ class StringTest(unittest.TestCase):
|
||||
self.assertEqual('foo', value.deserialize('foo'))
|
||||
self.assertRaises(ValueError, value.deserialize, 'foobar')
|
||||
|
||||
def test_deserialize_enforces_required(self):
|
||||
value = config.String()
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
self.assertRaises(ValueError, value.deserialize, ' ')
|
||||
|
||||
def test_deserialize_respects_optional(self):
|
||||
value = config.String(optional=True)
|
||||
self.assertIsNone(value.deserialize(''))
|
||||
self.assertIsNone(value.deserialize(' '))
|
||||
|
||||
def test_format_masks_secrets(self):
|
||||
value = config.String(secret=True)
|
||||
self.assertEqual('********', value.format('s3cret'))
|
||||
|
||||
|
||||
class IntegerTest(unittest.TestCase):
|
||||
def test_deserialize_converts_to_int(self):
|
||||
def test_deserialize_converts_success(self):
|
||||
value = config.Integer()
|
||||
self.assertEqual(123, value.deserialize('123'))
|
||||
self.assertEqual(0, value.deserialize('0'))
|
||||
self.assertEqual(-10, value.deserialize('-10'))
|
||||
|
||||
def test_deserialize_fails_on_bad_data(self):
|
||||
def test_deserialize_conversion_failure(self):
|
||||
value = config.Integer()
|
||||
self.assertRaises(ValueError, value.deserialize, 'asd')
|
||||
self.assertRaises(ValueError, value.deserialize, '3.14')
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
self.assertRaises(ValueError, value.deserialize, ' ')
|
||||
|
||||
def test_deserialize_enforces_choices(self):
|
||||
value = config.Integer(choices=[1, 2, 3])
|
||||
@ -138,7 +168,7 @@ class IntegerTest(unittest.TestCase):
|
||||
|
||||
|
||||
class BooleanTest(unittest.TestCase):
|
||||
def test_deserialize_converts_to_bool(self):
|
||||
def test_deserialize_converts_success(self):
|
||||
value = config.Boolean()
|
||||
for true in ('1', 'yes', 'true', 'on'):
|
||||
self.assertIs(value.deserialize(true), True)
|
||||
@ -149,12 +179,13 @@ class BooleanTest(unittest.TestCase):
|
||||
self.assertIs(value.deserialize(false.upper()), False)
|
||||
self.assertIs(value.deserialize(false.capitalize()), False)
|
||||
|
||||
def test_deserialize_fails_on_bad_data(self):
|
||||
def test_deserialize_conversion_failure(self):
|
||||
value = config.Boolean()
|
||||
self.assertRaises(ValueError, value.deserialize, 'nope')
|
||||
self.assertRaises(ValueError, value.deserialize, 'sure')
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
|
||||
def test_serialize_normalises_strings(self):
|
||||
def test_serialize(self):
|
||||
value = config.Boolean()
|
||||
self.assertEqual('true', value.serialize(True))
|
||||
self.assertEqual('false', value.serialize(False))
|
||||
@ -165,20 +196,29 @@ class BooleanTest(unittest.TestCase):
|
||||
|
||||
|
||||
class ListTest(unittest.TestCase):
|
||||
def test_deserialize_splits_commas(self):
|
||||
def test_deserialize_converts_success(self):
|
||||
value = config.List()
|
||||
self.assertEqual(['foo', 'bar', 'baz'],
|
||||
value.deserialize('foo, bar,baz'))
|
||||
|
||||
def test_deserialize_splits_newlines(self):
|
||||
value = config.List()
|
||||
self.assertEqual(['foo,bar', 'bar', 'baz'],
|
||||
value.deserialize('foo,bar\nbar\nbaz'))
|
||||
expected = ['foo', 'bar', 'baz']
|
||||
self.assertEqual(expected, value.deserialize('foo, bar ,baz '))
|
||||
|
||||
def test_serialize_joins_by_newlines(self):
|
||||
expected = ['foo,bar', 'bar', 'baz']
|
||||
self.assertEqual(expected, value.deserialize(' foo,bar\nbar\nbaz'))
|
||||
|
||||
def test_deserialize_enforces_required(self):
|
||||
value = config.List()
|
||||
self.assertRegexpMatches(value.serialize(['foo', 'bar', 'baz']),
|
||||
r'foo\n\s*bar\n\s*baz')
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
self.assertRaises(ValueError, value.deserialize, ' ')
|
||||
|
||||
def test_deserialize_respects_optional(self):
|
||||
value = config.List(optional=True)
|
||||
self.assertEqual([], value.deserialize(''))
|
||||
self.assertEqual([], value.deserialize(' '))
|
||||
|
||||
def test_serialize(self):
|
||||
value = config.List()
|
||||
result = value.serialize(['foo', 'bar', 'baz'])
|
||||
self.assertRegexpMatches(result, r'foo\n\s*bar\n\s*baz')
|
||||
|
||||
|
||||
class BooleanTest(unittest.TestCase):
|
||||
@ -188,41 +228,54 @@ class BooleanTest(unittest.TestCase):
|
||||
'info': logging.INFO,
|
||||
'debug': logging.DEBUG}
|
||||
|
||||
def test_deserialize_converts_to_numeric_loglevel(self):
|
||||
def test_deserialize_converts_success(self):
|
||||
value = config.LogLevel()
|
||||
for name, level in self.levels.items():
|
||||
self.assertEqual(level, value.deserialize(name))
|
||||
self.assertEqual(level, value.deserialize(name.upper()))
|
||||
self.assertEqual(level, value.deserialize(name.capitalize()))
|
||||
|
||||
def test_deserialize_fails_on_bad_data(self):
|
||||
def test_deserialize_conversion_failure(self):
|
||||
value = config.LogLevel()
|
||||
self.assertRaises(ValueError, value.deserialize, 'nope')
|
||||
self.assertRaises(ValueError, value.deserialize, 'sure')
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
self.assertRaises(ValueError, value.deserialize, ' ')
|
||||
|
||||
def test_serialize_converts_to_string(self):
|
||||
def test_serialize(self):
|
||||
value = config.LogLevel()
|
||||
for name, level in self.levels.items():
|
||||
self.assertEqual(name, value.serialize(level))
|
||||
|
||||
def test_serialize_unknown_level(self):
|
||||
value = config.LogLevel()
|
||||
self.assertIsNone(value.serialize(1337))
|
||||
|
||||
|
||||
class HostnameTest(unittest.TestCase):
|
||||
@mock.patch('socket.getaddrinfo')
|
||||
def test_deserialize_checks_addrinfo(self, getaddrinfo_mock):
|
||||
def test_deserialize_converts_success(self, getaddrinfo_mock):
|
||||
value = config.Hostname()
|
||||
value.deserialize('example.com')
|
||||
getaddrinfo_mock.assert_called_once_with('example.com', None)
|
||||
|
||||
@mock.patch('socket.getaddrinfo')
|
||||
def test_deserialize_handles_failures(self, getaddrinfo_mock):
|
||||
def test_deserialize_conversion_failure(self, getaddrinfo_mock):
|
||||
value = config.Hostname()
|
||||
getaddrinfo_mock.side_effect = socket.error
|
||||
self.assertRaises(ValueError, value.deserialize, 'example.com')
|
||||
|
||||
@mock.patch('socket.getaddrinfo')
|
||||
def test_deserialize_enforces_required(self, getaddrinfo_mock):
|
||||
value = config.Hostname()
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
self.assertRaises(ValueError, value.deserialize, ' ')
|
||||
self.assertEqual(0, getaddrinfo_mock.call_count)
|
||||
|
||||
@mock.patch('socket.getaddrinfo')
|
||||
def test_deserialize_respects_optional(self, getaddrinfo_mock):
|
||||
value = config.Hostname(optional=True)
|
||||
self.assertIsNone(value.deserialize(''))
|
||||
self.assertIsNone(value.deserialize(' '))
|
||||
self.assertEqual(0, getaddrinfo_mock.call_count)
|
||||
|
||||
|
||||
class PortTest(unittest.TestCase):
|
||||
def test_valid_ports(self):
|
||||
@ -238,6 +291,7 @@ class PortTest(unittest.TestCase):
|
||||
self.assertRaises(ValueError, value.deserialize, '100000')
|
||||
self.assertRaises(ValueError, value.deserialize, '0')
|
||||
self.assertRaises(ValueError, value.deserialize, '-1')
|
||||
self.assertRaises(ValueError, value.deserialize, '')
|
||||
|
||||
|
||||
class ConfigSchemaTest(unittest.TestCase):
|
||||
@ -285,11 +339,6 @@ class ConfigSchemaTest(unittest.TestCase):
|
||||
|
||||
self.assertIn('unknown', cm.exception['extra'])
|
||||
|
||||
def test_convert_with_blank_value(self):
|
||||
self.values['foo'] = ''
|
||||
result = self.schema.convert(self.values.items())
|
||||
self.assertIsNone(result['foo'])
|
||||
|
||||
def test_convert_with_deserialization_error(self):
|
||||
self.schema['foo'].deserialize.side_effect = ValueError('failure')
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user