Merge pull request #414 from adamcik/feature/config-value-kwargs

Make config value init kwargs strict.
This commit is contained in:
Stein Magnus Jodal 2013-04-15 09:15:08 -07:00
commit 8939167e88
10 changed files with 102 additions and 132 deletions

View File

@ -230,12 +230,12 @@ and ``password``.
version = __version__ version = __version__
def get_default_config(self): def get_default_config(self):
return default_config return bytes(default_config)
def get_config_schema(self): def get_config_schema(self):
schema = super(Extension, self).get_config_schema() schema = super(Extension, self).get_config_schema()
schema['username'] = config.String(required=True) schema['username'] = config.String()
schema['password'] = config.String(required=True, secret=True) schema['password'] = config.Secret()
return schema return schema
def validate_environment(self): def validate_environment(self):

View File

@ -19,7 +19,7 @@ class Extension(ext.Extension):
def get_config_schema(self): def get_config_schema(self):
schema = super(Extension, self).get_config_schema() schema = super(Extension, self).get_config_schema()
schema['username'] = config.String() schema['username'] = config.String()
schema['password'] = config.String(secret=True) schema['password'] = config.Secret()
schema['bitrate'] = config.Integer(choices=(96, 160, 320)) schema['bitrate'] = config.Integer(choices=(96, 160, 320))
schema['timeout'] = config.Integer(minimum=0) schema['timeout'] = config.Integer(minimum=0)
schema['cache_dir'] = config.Path() schema['cache_dir'] = config.Path()

View File

@ -27,7 +27,7 @@ _audio_schema['output'] = String()
_proxy_schema = ConfigSchema('proxy') _proxy_schema = ConfigSchema('proxy')
_proxy_schema['hostname'] = Hostname(optional=True) _proxy_schema['hostname'] = Hostname(optional=True)
_proxy_schema['username'] = String(optional=True) _proxy_schema['username'] = String(optional=True)
_proxy_schema['password'] = String(optional=True, secret=True) _proxy_schema['password'] = Secret(optional=True)
# NOTE: if multiple outputs ever comes something like LogLevelConfigSchema # NOTE: if multiple outputs ever comes something like LogLevelConfigSchema
#_outputs_schema = config.AudioOutputConfigSchema() #_outputs_schema = config.AudioOutputConfigSchema()

View File

@ -23,6 +23,15 @@ def encode(value):
return value.encode('utf-8') return value.encode('utf-8')
class ExpandedPath(bytes):
def __new__(self, value):
expanded = path.expand_path(value)
return super(ExpandedPath, self).__new__(self, expanded)
def __init__(self, value):
self.original = value
class ConfigValue(object): class ConfigValue(object):
"""Represents a config key's value and how to handle it. """Represents a config key's value and how to handle it.
@ -40,64 +49,32 @@ class ConfigValue(object):
the code interacting with the config should simply skip None config values. the code interacting with the config should simply skip None config values.
""" """
choices = None
"""
Collection of valid choices for converted value. Must be combined with
:func:`~mopidy.config.validators.validate_choice` in :meth:`deserialize`
do any thing.
"""
minimum = None
"""
Minimum of converted value. Must be combined with
:func:`~mopidy.config.validators.validate_minimum` in :meth:`deserialize`
do any thing.
"""
maximum = None
"""
Maximum of converted value. Must be combined with
:func:`~mopidy.config.validators.validate_maximum` in :meth:`deserialize`
do any thing.
"""
optional = None
"""Indicate if this field is required."""
secret = None
"""Indicate if we should mask the when printing for human consumption."""
def __init__(self, **kwargs):
self.choices = kwargs.get('choices')
self.minimum = kwargs.get('minimum')
self.maximum = kwargs.get('maximum')
self.optional = kwargs.get('optional')
self.secret = kwargs.get('secret')
def deserialize(self, value): def deserialize(self, value):
"""Cast raw string to appropriate type.""" """Cast raw string to appropriate type."""
return value return value
def serialize(self, value): def serialize(self, value):
"""Convert value back to string for saving.""" """Convert value back to string for saving."""
return str(value) return bytes(value)
def format(self, value): def format(self, value):
"""Format value for display.""" """Format value for display."""
if self.secret and value is not None:
return '********'
return self.serialize(value) return self.serialize(value)
class String(ConfigValue): class String(ConfigValue):
"""String value """String value.
Supported kwargs: ``optional``, ``choices``, and ``secret``. Is decoded as utf-8 and \\n \\t escapes should work and be preserved.
""" """
def __init__(self, optional=False, choices=None):
self._required = not optional
self._choices = choices
def deserialize(self, value): def deserialize(self, value):
value = decode(value).strip() value = decode(value).strip()
validators.validate_required(value, not self.optional) validators.validate_required(value, self._required)
validators.validate_choice(value, self.choices) validators.validate_choice(value, self._choices)
if not value: if not value:
return None return None
return value return value
@ -106,29 +83,46 @@ class String(ConfigValue):
return encode(value) return encode(value)
class Integer(ConfigValue): class Secret(ConfigValue):
"""Integer value """String value.
Supported kwargs: ``choices``, ``minimum``, ``maximum``, and ``secret`` Masked when being displayed, and is not decoded.
""" """
def __init__(self, optional=False, choices=None):
self._required = not optional
def deserialize(self, value):
validators.validate_required(value, self._required)
return value
def format(self, value):
return '********'
class Integer(ConfigValue):
"""Integer value."""
def __init__(self, minimum=None, maximum=None, choices=None):
self._minimum = minimum
self._maximum = maximum
self._choices = choices
def deserialize(self, value): def deserialize(self, value):
value = int(value) value = int(value)
validators.validate_choice(value, self.choices) validators.validate_choice(value, self._choices)
validators.validate_minimum(value, self.minimum) validators.validate_minimum(value, self._minimum)
validators.validate_maximum(value, self.maximum) validators.validate_maximum(value, self._maximum)
return value return value
class Boolean(ConfigValue): class Boolean(ConfigValue):
"""Boolean value """Boolean value.
Accepts ``1``, ``yes``, ``true``, and ``on`` with any casing as Accepts ``1``, ``yes``, ``true``, and ``on`` with any casing as
:class:`True`. :class:`True`.
Accepts ``0``, ``no``, ``false``, and ``off`` with any casing as Accepts ``0``, ``no``, ``false``, and ``off`` with any casing as
:class:`False`. :class:`False`.
Supported kwargs: ``secret``
""" """
true_values = ('1', 'yes', 'true', 'on') true_values = ('1', 'yes', 'true', 'on')
false_values = ('0', 'no', 'false', 'off') false_values = ('0', 'no', 'false', 'off')
@ -138,7 +132,6 @@ class Boolean(ConfigValue):
return True return True
elif value.lower() in self.false_values: elif value.lower() in self.false_values:
return False return False
raise ValueError('invalid value for boolean: %r' % value) raise ValueError('invalid value for boolean: %r' % value)
def serialize(self, value): def serialize(self, value):
@ -149,32 +142,33 @@ class Boolean(ConfigValue):
class List(ConfigValue): class List(ConfigValue):
"""List value """List value.
Supports elements split by commas or newlines. Supports elements split by commas or newlines. Newlines take presedence and
empty list items will be filtered out.
Supported kwargs: ``optional`` and ``secret``
""" """
def __init__(self, optional=False):
self._required = not optional
def deserialize(self, value): def deserialize(self, value):
validators.validate_required(value, not self.optional)
if b'\n' in value: if b'\n' in value:
values = re.split(r'\s*\n\s*', value) values = re.split(r'\s*\n\s*', value)
else: else:
values = re.split(r'\s*,\s*', value) values = re.split(r'\s*,\s*', value)
values = (decode(v).strip() for v in values) values = (decode(v).strip() for v in values)
return tuple(v for v in values if v) values = filter(None, values)
validators.validate_required(values, self._required)
return tuple(values)
def serialize(self, value): def serialize(self, value):
return b'\n ' + b'\n '.join(encode(v) for v in value if v) return b'\n ' + b'\n '.join(encode(v) for v in value if v)
class LogLevel(ConfigValue): class LogLevel(ConfigValue):
"""Log level value """Log level value.
Expects one of ``critical``, ``error``, ``warning``, ``info``, ``debug`` Expects one of ``critical``, ``error``, ``warning``, ``info``, ``debug``
with any casing. with any casing.
Supported kwargs: ``secret``
""" """
levels = { levels = {
'critical': logging.CRITICAL, 'critical': logging.CRITICAL,
@ -193,12 +187,13 @@ class LogLevel(ConfigValue):
class Hostname(ConfigValue): class Hostname(ConfigValue):
"""Hostname value """Network hostname value."""
def __init__(self, optional=False):
self._required = not optional
Supported kwargs: ``optional`` and ``secret``
"""
def deserialize(self, value): def deserialize(self, value):
validators.validate_required(value, not self.optional) validators.validate_required(value, self._required)
if not value.strip(): if not value.strip():
return None return None
try: try:
@ -209,26 +204,14 @@ class Hostname(ConfigValue):
class Port(Integer): class Port(Integer):
"""Port value """Network port value.
Expects integer in the range 1-65535 Expects integer in the range 0-65535, zero tells the kernel to simply
allocate a port for us.
Supported kwargs: ``choices`` and ``secret``
""" """
# TODO: consider probing if port is free or not? # TODO: consider probing if port is free or not?
def __init__(self, **kwargs): def __init__(self, choices=None):
super(Port, self).__init__(**kwargs) super(Port, self).__init__(minimum=0, maximum=2**16-1, choices=choices)
self.minimum = 1
self.maximum = 2 ** 16 - 1
class ExpandedPath(bytes):
def __new__(self, value):
expanded = path.expand_path(value)
return super(ExpandedPath, self).__new__(self, expanded)
def __init__(self, value):
self.original = value
class Path(ConfigValue): class Path(ConfigValue):
@ -248,10 +231,14 @@ class Path(ConfigValue):
Supported kwargs: ``optional``, ``choices``, and ``secret`` Supported kwargs: ``optional``, ``choices``, and ``secret``
""" """
def __init__(self, optional=False, choices=None):
self._required = not optional
self._choices = choices
def deserialize(self, value): def deserialize(self, value):
value = value.strip() value = value.strip()
validators.validate_required(value, not self.optional) validators.validate_required(value, self._required)
validators.validate_choice(value, self.choices) validators.validate_choice(value, self._choices)
if not value: if not value:
return None return None
return ExpandedPath(value) return ExpandedPath(value)

View File

@ -9,7 +9,7 @@ def validate_required(value, required):
Normally called in :meth:`~mopidy.config.types.ConfigValue.deserialize` on Normally called in :meth:`~mopidy.config.types.ConfigValue.deserialize` on
the raw string, _not_ the converted value. the raw string, _not_ the converted value.
""" """
if required and not value.strip(): if required and not value:
raise ValueError('must be set.') raise ValueError('must be set.')

View File

@ -34,9 +34,9 @@ class Extension(object):
""" """
def get_default_config(self): def get_default_config(self):
"""The extension's default config as a string """The extension's default config as a bytestring
:returns: string :returns: bytes
""" """
raise NotImplementedError( raise NotImplementedError(
'Add at least a config section with "enabled = true"') 'Add at least a config section with "enabled = true"')

View File

@ -20,7 +20,7 @@ class Extension(ext.Extension):
schema = super(Extension, self).get_config_schema() schema = super(Extension, self).get_config_schema()
schema['hostname'] = config.Hostname() schema['hostname'] = config.Hostname()
schema['port'] = config.Port() schema['port'] = config.Port()
schema['password'] = config.String(optional=True, secret=True) schema['password'] = config.Secret(optional=True)
schema['max_connections'] = config.Integer(minimum=1) schema['max_connections'] = config.Integer(minimum=1)
schema['connection_timeout'] = config.Integer(minimum=1) schema['connection_timeout'] = config.Integer(minimum=1)
return schema return schema

View File

@ -19,7 +19,7 @@ class Extension(ext.Extension):
def get_config_schema(self): def get_config_schema(self):
schema = super(Extension, self).get_config_schema() schema = super(Extension, self).get_config_schema()
schema['username'] = config.String() schema['username'] = config.String()
schema['password'] = config.String(secret=True) schema['password'] = config.Secret()
return schema return schema
def validate_environment(self): def validate_environment(self):

View File

@ -14,24 +14,6 @@ from tests import unittest
class ConfigValueTest(unittest.TestCase): class ConfigValueTest(unittest.TestCase):
def test_init(self):
value = types.ConfigValue()
self.assertIsNone(value.choices)
self.assertIsNone(value.maximum)
self.assertIsNone(value.minimum)
self.assertIsNone(value.optional)
self.assertIsNone(value.secret)
def test_init_with_params(self):
kwargs = {'choices': ['foo'], 'minimum': 0, 'maximum': 10,
'secret': True, 'optional': True}
value = types.ConfigValue(**kwargs)
self.assertEqual(['foo'], value.choices)
self.assertEqual(0, value.minimum)
self.assertEqual(10, value.maximum)
self.assertEqual(True, value.optional)
self.assertEqual(True, value.secret)
def test_deserialize_passes_through(self): def test_deserialize_passes_through(self):
value = types.ConfigValue() value = types.ConfigValue()
sentinel = object() sentinel = object()
@ -46,10 +28,6 @@ class ConfigValueTest(unittest.TestCase):
obj = object() obj = object()
self.assertEqual(value.serialize(obj), value.format(obj)) self.assertEqual(value.serialize(obj), value.format(obj))
def test_format_masks_secrets(self):
value = types.ConfigValue(secret=True)
self.assertEqual('********', value.format(object()))
class StringTest(unittest.TestCase): class StringTest(unittest.TestCase):
def test_deserialize_conversion_success(self): def test_deserialize_conversion_success(self):
@ -80,7 +58,6 @@ class StringTest(unittest.TestCase):
def test_deserialize_enforces_required(self): def test_deserialize_enforces_required(self):
value = types.String() value = types.String()
self.assertRaises(ValueError, value.deserialize, b'') self.assertRaises(ValueError, value.deserialize, b'')
self.assertRaises(ValueError, value.deserialize, b' ')
def test_deserialize_respects_optional(self): def test_deserialize_respects_optional(self):
value = types.String(optional=True) value = types.String(optional=True)
@ -111,8 +88,24 @@ class StringTest(unittest.TestCase):
self.assertIsInstance(result, bytes) self.assertIsInstance(result, bytes)
self.assertEqual(r'a\n\tb'.encode('utf-8'), result) self.assertEqual(r'a\n\tb'.encode('utf-8'), result)
def test_format_masks_secrets(self):
value = types.String(secret=True) class SecretTest(unittest.TestCase):
def test_deserialize_passes_through(self):
value = types.Secret()
result = value.deserialize(b'foo')
self.assertIsInstance(result, bytes)
self.assertEqual(b'foo', result)
def test_deserialize_enforces_required(self):
value = types.Secret()
self.assertRaises(ValueError, value.deserialize, b'')
def test_serialize_conversion_to_string(self):
value = types.Secret()
self.assertIsInstance(value.serialize(object()), bytes)
def test_format_masks_value(self):
value = types.Secret()
self.assertEqual('********', value.format('s3cret')) self.assertEqual('********', value.format('s3cret'))
@ -145,10 +138,6 @@ class IntegerTest(unittest.TestCase):
self.assertEqual(5, value.deserialize('5')) self.assertEqual(5, value.deserialize('5'))
self.assertRaises(ValueError, value.deserialize, '15') self.assertRaises(ValueError, value.deserialize, '15')
def test_format_masks_secrets(self):
value = types.Integer(secret=True)
self.assertEqual('********', value.format('1337'))
class BooleanTest(unittest.TestCase): class BooleanTest(unittest.TestCase):
def test_deserialize_conversion_success(self): def test_deserialize_conversion_success(self):
@ -173,10 +162,6 @@ class BooleanTest(unittest.TestCase):
self.assertEqual('true', value.serialize(True)) self.assertEqual('true', value.serialize(True))
self.assertEqual('false', value.serialize(False)) self.assertEqual('false', value.serialize(False))
def test_format_masks_secrets(self):
value = types.Boolean(secret=True)
self.assertEqual('********', value.format('true'))
class ListTest(unittest.TestCase): class ListTest(unittest.TestCase):
# TODO: add test_deserialize_ignores_blank # TODO: add test_deserialize_ignores_blank
@ -218,12 +203,10 @@ class ListTest(unittest.TestCase):
def test_deserialize_enforces_required(self): def test_deserialize_enforces_required(self):
value = types.List() value = types.List()
self.assertRaises(ValueError, value.deserialize, b'') self.assertRaises(ValueError, value.deserialize, b'')
self.assertRaises(ValueError, value.deserialize, b' ')
def test_deserialize_respects_optional(self): def test_deserialize_respects_optional(self):
value = types.List(optional=True) value = types.List(optional=True)
self.assertEqual(tuple(), value.deserialize(b'')) self.assertEqual(tuple(), value.deserialize(b''))
self.assertEqual(tuple(), value.deserialize(b' '))
def test_serialize(self): def test_serialize(self):
value = types.List() value = types.List()
@ -277,7 +260,6 @@ class HostnameTest(unittest.TestCase):
def test_deserialize_enforces_required(self, getaddrinfo_mock): def test_deserialize_enforces_required(self, getaddrinfo_mock):
value = types.Hostname() value = types.Hostname()
self.assertRaises(ValueError, value.deserialize, '') self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
self.assertEqual(0, getaddrinfo_mock.call_count) self.assertEqual(0, getaddrinfo_mock.call_count)
@mock.patch('socket.getaddrinfo') @mock.patch('socket.getaddrinfo')
@ -291,6 +273,7 @@ class HostnameTest(unittest.TestCase):
class PortTest(unittest.TestCase): class PortTest(unittest.TestCase):
def test_valid_ports(self): def test_valid_ports(self):
value = types.Port() value = types.Port()
self.assertEqual(0, value.deserialize('0'))
self.assertEqual(1, value.deserialize('1')) self.assertEqual(1, value.deserialize('1'))
self.assertEqual(80, value.deserialize('80')) self.assertEqual(80, value.deserialize('80'))
self.assertEqual(6600, value.deserialize('6600')) self.assertEqual(6600, value.deserialize('6600'))
@ -300,7 +283,6 @@ class PortTest(unittest.TestCase):
value = types.Port() value = types.Port()
self.assertRaises(ValueError, value.deserialize, '65536') self.assertRaises(ValueError, value.deserialize, '65536')
self.assertRaises(ValueError, value.deserialize, '100000') self.assertRaises(ValueError, value.deserialize, '100000')
self.assertRaises(ValueError, value.deserialize, '0')
self.assertRaises(ValueError, value.deserialize, '-1') self.assertRaises(ValueError, value.deserialize, '-1')
self.assertRaises(ValueError, value.deserialize, '') self.assertRaises(ValueError, value.deserialize, '')
@ -334,7 +316,6 @@ class PathTest(unittest.TestCase):
def test_deserialize_enforces_required(self): def test_deserialize_enforces_required(self):
value = types.Path() value = types.Path()
self.assertRaises(ValueError, value.deserialize, '') self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
def test_deserialize_respects_optional(self): def test_deserialize_respects_optional(self):
value = types.Path(optional=True) value = types.Path(optional=True)

View File

@ -57,11 +57,13 @@ class ValidateRequiredTest(unittest.TestCase):
validators.validate_required('foo', False) validators.validate_required('foo', False)
validators.validate_required('', False) validators.validate_required('', False)
validators.validate_required(' ', False) validators.validate_required(' ', False)
validators.validate_required([], False)
def test_passes_when_required_and_set(self): def test_passes_when_required_and_set(self):
validators.validate_required('foo', True) validators.validate_required('foo', True)
validators.validate_required(' foo ', True) validators.validate_required(' foo ', True)
validators.validate_required([1], True)
def test_blocks_when_required_and_emtpy(self): def test_blocks_when_required_and_emtpy(self):
self.assertRaises(ValueError, validators.validate_required, '', True) self.assertRaises(ValueError, validators.validate_required, '', True)
self.assertRaises(ValueError, validators.validate_required, ' ', True) self.assertRaises(ValueError, validators.validate_required, [], True)