diff --git a/mopidy/mpd/protocol/__init__.py b/mopidy/mpd/protocol/__init__.py index 5504086a..e7ffb32c 100644 --- a/mopidy/mpd/protocol/__init__.py +++ b/mopidy/mpd/protocol/__init__.py @@ -12,7 +12,8 @@ implement our own MPD server which is compatible with the numerous existing from __future__ import unicode_literals -from collections import namedtuple +import collections +import inspect import re from mopidy.utils import formatting @@ -26,7 +27,7 @@ LINE_TERMINATOR = '\n' #: The MPD protocol version is 0.17.0. VERSION = '0.17.0' -MpdCommand = namedtuple('MpdCommand', ['name', 'auth_required']) +MpdCommand = collections.namedtuple('MpdCommand', ['name', 'auth_required']) #: Set of all available commands, represented as :class:`MpdCommand` objects. mpd_commands = set() @@ -133,3 +134,63 @@ def tokenize(line): result.append(unquoted or UNESCAPE_RE.sub(r'\g<1>', quoted)) return result + + +def integer(value): + if value is None: + raise ValueError('None is not a valid integer') + return int(value) + + +def boolean(value): + if value in ('1', '0'): + return bool(int(value)) + raise ValueError('%r is not 0 or 1' % value) + + +class Commands(object): + def __init__(self): + self.handlers = {} + + def add(self, command, **validators): + def wrapper(func): + if command in self.handlers: + raise Exception('%s already registered' % command) + + args, varargs, keywords, defaults = inspect.getargspec(func) + defaults = dict(zip(args[-len(defaults or []):], defaults or [])) + + if not args and not varargs: + raise TypeError('Handler must accept at least one argument.') + + if len(args) > 1 and varargs: + raise TypeError( + '*args may not be combined with regular argmuments') + + if not set(validators.keys()).issubset(args): + raise TypeError('Validator for non-existent arg passed') + + if keywords: + raise TypeError('**kwargs are not permitted') + + def validate(*args, **kwargs): + if varargs: + return func(*args, **kwargs) + callargs = inspect.getcallargs(func, *args, **kwargs) + for key, value in callargs.items(): + default = defaults.get(key, object()) + if key in validators and value != default: + callargs[key] = validators[key](value) + return func(**callargs) + + self.handlers[command] = validate + return func + return wrapper + + def call(self, args, context=None): + if not args: + raise TypeError('No args provided') + command = args.pop(0) + if command not in self.handlers: + raise LookupError('Unknown command') + return self.handlers[command](context, *args) diff --git a/tests/mpd/protocol/test_commands_decorator.py b/tests/mpd/protocol/test_commands_decorator.py new file mode 100644 index 00000000..0d2d2ad3 --- /dev/null +++ b/tests/mpd/protocol/test_commands_decorator.py @@ -0,0 +1,187 @@ +#encoding: utf-8 + +from __future__ import unicode_literals + +import unittest + +from mopidy.mpd import protocol + + +class TestConverts(unittest.TestCase): + def test_integer(self): + self.assertEqual(123, protocol.integer('123')) + self.assertEqual(-123, protocol.integer('-123')) + self.assertEqual(123, protocol.integer('+123')) + self.assertRaises(ValueError, protocol.integer, '3.14') + self.assertRaises(ValueError, protocol.integer, '') + self.assertRaises(ValueError, protocol.integer, 'abc') + self.assertRaises(ValueError, protocol.integer, '12 34') + + def test_boolean(self): + self.assertEqual(True, protocol.boolean('1')) + self.assertEqual(False, protocol.boolean('0')) + self.assertRaises(ValueError, protocol.boolean, '3.14') + self.assertRaises(ValueError, protocol.boolean, '') + self.assertRaises(ValueError, protocol.boolean, 'true') + self.assertRaises(ValueError, protocol.boolean, 'false') + self.assertRaises(ValueError, protocol.boolean, 'abc') + self.assertRaises(ValueError, protocol.boolean, '12 34') + + +class TestCommands(unittest.TestCase): + def setUp(self): + self.commands = protocol.Commands() + + def test_add_as_a_decorator(self): + @self.commands.add('test') + def test(context): + pass + + def test_register_second_command_to_same_name_fails(self): + func = lambda context: True + + self.commands.add('foo')(func) + with self.assertRaises(Exception): + self.commands.add('foo')(func) + + def test_function_only_takes_context_succeeds(self): + sentinel = object() + self.commands.add('bar')(lambda context: sentinel) + self.assertEqual(sentinel, self.commands.call(['bar'])) + + def test_function_has_required_arg_succeeds(self): + sentinel = object() + self.commands.add('bar')(lambda context, required: sentinel) + self.assertEqual(sentinel, self.commands.call(['bar', 'arg'])) + + def test_function_has_optional_args_succeeds(self): + sentinel = object() + self.commands.add('bar')(lambda context, optional=None: sentinel) + self.assertEqual(sentinel, self.commands.call(['bar'])) + self.assertEqual(sentinel, self.commands.call(['bar', 'arg'])) + + def test_function_has_required_and_optional_args_succeeds(self): + sentinel = object() + func = lambda context, required, optional=None: sentinel + self.commands.add('bar')(func) + self.assertEqual(sentinel, self.commands.call(['bar', 'arg'])) + self.assertEqual(sentinel, self.commands.call(['bar', 'arg', 'arg'])) + + def test_function_has_varargs_succeeds(self): + sentinel, args = object(), [] + self.commands.add('bar')(lambda context, *args: sentinel) + for i in range(10): + self.assertEqual(sentinel, self.commands.call(['bar'] + args)) + args.append('test') + + def test_function_has_only_varags_succeeds(self): + sentinel = object() + self.commands.add('baz')(lambda *args: sentinel) + self.assertEqual(sentinel, self.commands.call(['baz'])) + + def test_function_has_no_arguments_fails(self): + with self.assertRaises(TypeError): + self.commands.add('test')(lambda: True) + + def test_function_has_required_and_varargs_fails(self): + with self.assertRaises(TypeError): + func = lambda context, required, *args: True + self.commands.add('test')(func) + + def test_function_has_optional_and_varargs_fails(self): + with self.assertRaises(TypeError): + func = lambda context, optional=None, *args: True + self.commands.add('test')(func) + + def test_function_hash_keywordargs_fails(self): + with self.assertRaises(TypeError): + self.commands.add('test')(lambda context, **kwargs: True) + + def test_call_chooses_correct_handler(self): + sentinel1, sentinel2, sentinel3 = object(), object(), object() + self.commands.add('foo')(lambda context: sentinel1) + self.commands.add('bar')(lambda context: sentinel2) + self.commands.add('baz')(lambda context: sentinel3) + + self.assertEqual(sentinel1, self.commands.call(['foo'])) + self.assertEqual(sentinel2, self.commands.call(['bar'])) + self.assertEqual(sentinel3, self.commands.call(['baz'])) + + def test_call_with_nonexistent_handler(self): + with self.assertRaises(LookupError): + self.commands.call(['bar']) + + def test_call_passes_context(self): + sentinel = object() + self.commands.add('foo')(lambda context: context) + self.assertEqual( + sentinel, self.commands.call(['foo'], context=sentinel)) + + def test_call_without_args_fails(self): + with self.assertRaises(TypeError): + self.commands.call([]) + + def test_call_passes_required_argument(self): + self.commands.add('foo')(lambda context, required: required) + self.assertEqual('test123', self.commands.call(['foo', 'test123'])) + + def test_call_passes_optional_argument(self): + sentinel = object() + self.commands.add('foo')(lambda context, optional=sentinel: optional) + self.assertEqual(sentinel, self.commands.call(['foo'])) + self.assertEqual('test', self.commands.call(['foo', 'test'])) + + def test_call_passes_required_and_optional_argument(self): + func = lambda context, required, optional=None: (required, optional) + self.commands.add('foo')(func) + self.assertEqual(('arg', None), self.commands.call(['foo', 'arg'])) + self.assertEqual( + ('arg', 'kwarg'), self.commands.call(['foo', 'arg', 'kwarg'])) + + def test_call_passes_varargs(self): + self.commands.add('foo')(lambda context, *args: args) + + def test_call_incorrect_args(self): + self.commands.add('foo')(lambda context: context) + with self.assertRaises(TypeError): + self.commands.call(['foo', 'bar']) + + self.commands.add('bar')(lambda context, required: context) + with self.assertRaises(TypeError): + self.commands.call(['bar', 'bar', 'baz']) + + self.commands.add('baz')(lambda context, optional=None: context) + with self.assertRaises(TypeError): + self.commands.call(['baz', 'bar', 'baz']) + + def test_validator_gets_applied_to_required_arg(self): + sentinel = object() + func = lambda context, required: required + self.commands.add('test', required=lambda v: sentinel)(func) + self.assertEqual(sentinel, self.commands.call(['test', 'foo'])) + + def test_validator_gets_applied_to_optional_arg(self): + sentinel = object() + func = lambda context, optional=None: optional + self.commands.add('foo', optional=lambda v: sentinel)(func) + + self.assertEqual(sentinel, self.commands.call(['foo', '123'])) + + def test_validator_skips_optional_default(self): + sentinel = object() + func = lambda context, optional=sentinel: optional + self.commands.add('foo', optional=lambda v: None)(func) + + self.assertEqual(sentinel, self.commands.call(['foo'])) + + def test_validator_applied_to_non_existent_arg_fails(self): + self.commands.add('foo')(lambda context, arg: arg) + with self.assertRaises(TypeError): + func = lambda context, wrong_arg: wrong_arg + self.commands.add('bar', arg=lambda v: v)(func) + + def test_validator_called_context_fails(self): + return # TODO: how to handle this + with self.assertRaises(TypeError): + func = lambda context: True + self.commands.add('bar', context=lambda v: v)(func)