diff --git a/mopidy/__main__.py b/mopidy/__main__.py index 6584146f..ea1cab6b 100644 --- a/mopidy/__main__.py +++ b/mopidy/__main__.py @@ -66,21 +66,23 @@ def main(): root_cmd.add_child('config', config_cmd) root_cmd.add_child('deps', deps_cmd) - installed_extensions = ext.load_extensions() + extensions_data = ext.load_extensions() - for extension in installed_extensions: - ext_cmd = extension.get_command() - if ext_cmd: - ext_cmd.set(extension=extension) - root_cmd.add_child(extension.ext_name, ext_cmd) + for data in extensions_data: + if data.command: # TODO: check isinstance? + data.command.set(extension=data.extension) + root_cmd.add_child(data.extension.ext_name, data.command) args = root_cmd.parse(mopidy_args) - create_file_structures_and_config(args, installed_extensions) + create_file_structures_and_config(args, extensions_data) check_old_locations() config, config_errors = config_lib.load( - args.config_files, installed_extensions, args.config_overrides) + args.config_files, + [d.config_schema for d in extensions_data], + [d.config_defaults for d in extensions_data], + args.config_overrides) verbosity_level = args.base_verbosity_level if args.verbosity_level: @@ -90,8 +92,11 @@ def main(): extensions = { 'validate': [], 'config': [], 'disabled': [], 'enabled': []} - for extension in installed_extensions: - if not ext.validate_extension(extension): + for data in extensions_data: + extension = data.extension + + # TODO: factor out all of this to a helper that can be tested + if not ext.validate_extension_data(data): config[extension.ext_name] = {'enabled': False} config_errors[extension.ext_name] = { 'enabled': 'extension disabled by self check.'} @@ -109,12 +114,13 @@ def main(): else: extensions['enabled'].append(extension) - log_extension_info(installed_extensions, extensions['enabled']) + log_extension_info([d.extension for d in extensions_data], + extensions['enabled']) # Config and deps commands are simply special cased for now. if args.command == config_cmd: - return args.command.run( - config, config_errors, installed_extensions) + schemas = [d.config_schema for d in extensions_data] + return args.command.run(config, config_errors, schemas) elif args.command == deps_cmd: return args.command.run() diff --git a/mopidy/commands.py b/mopidy/commands.py index 29564779..24acfb7d 100644 --- a/mopidy/commands.py +++ b/mopidy/commands.py @@ -415,8 +415,8 @@ class ConfigCommand(Command): super(ConfigCommand, self).__init__() self.set(base_verbosity_level=-1) - def run(self, config, errors, extensions): - print(config_lib.format(config, extensions, errors)) + def run(self, config, errors, schemas): + print(config_lib.format(config, schemas, errors)) return 0 diff --git a/mopidy/config/__init__.py b/mopidy/config/__init__.py index fc6dcb60..3f1f978c 100644 --- a/mopidy/config/__init__.py +++ b/mopidy/config/__init__.py @@ -65,24 +65,20 @@ def read(config_file): return filehandle.read() -def load(files, extensions, overrides): - # Helper to get configs, as the rest of our config system should not need - # to know about extensions. +def load(files, ext_schemas, ext_defaults, overrides): config_dir = os.path.dirname(__file__) defaults = [read(os.path.join(config_dir, 'default.conf'))] - defaults.extend(e.get_default_config() for e in extensions) + defaults.extend(ext_defaults) raw_config = _load(files, defaults, keyring.fetch() + (overrides or [])) schemas = _schemas[:] - schemas.extend(e.get_config_schema() for e in extensions) + schemas.extend(ext_schemas) return _validate(raw_config, schemas) -def format(config, extensions, comments=None, display=True): - # Helper to format configs, as the rest of our config system should not - # need to know about extensions. +def format(config, ext_schemas, comments=None, display=True): schemas = _schemas[:] - schemas.extend(e.get_config_schema() for e in extensions) + schemas.extend(ext_schemas) return _format(config, comments or {}, schemas, display, False) diff --git a/mopidy/ext.py b/mopidy/ext.py index 3122611f..ab35008a 100644 --- a/mopidy/ext.py +++ b/mopidy/ext.py @@ -11,6 +11,12 @@ from mopidy import config as config_lib, exceptions logger = logging.getLogger(__name__) +_extension_data_fields = ['extension', 'entry_point', 'config_schema', + 'config_defaults', 'command'] + +ExtensionData = collections.namedtuple('ExtensionData', _extension_data_fields) + + class Extension(object): """Base class for Mopidy extensions""" @@ -148,55 +154,100 @@ def load_extensions(): for entry_point in pkg_resources.iter_entry_points('mopidy.ext'): logger.debug('Loading entry point: %s', entry_point) extension_class = entry_point.load(require=False) - extension = extension_class() - extension.entry_point = entry_point - installed_extensions.append(extension) + + try: + if not issubclass(extension_class, Extension): + raise TypeError # issubclass raises TypeError on non-class + except TypeError: + logger.error('Entry point %s did not contain a valid extension' + 'class: %r', entry_point.name, extension_class) + continue + + try: + extension = extension_class() + config_schema = extension.get_config_schema() + default_config = extension.get_default_config() + command = extension.get_command() + except Exception: + logger.exception('Setup of extension from entry point %s failed, ' + 'ignoring extension.', entry_point.name) + continue + + installed_extensions.append(ExtensionData( + extension, entry_point, config_schema, default_config, command)) + logger.debug( 'Loaded extension: %s %s', extension.dist_name, extension.version) - names = (e.ext_name for e in installed_extensions) + names = (ed.extension.ext_name for ed in installed_extensions) logger.debug('Discovered extensions: %s', ', '.join(names)) return installed_extensions -def validate_extension(extension): +def validate_extension_data(data): """Verify extension's dependencies and environment. :param extensions: an extension to check :returns: if extension should be run """ - logger.debug('Validating extension: %s', extension.ext_name) + logger.debug('Validating extension: %s', data.extension.ext_name) - if extension.ext_name != extension.entry_point.name: + if data.extension.ext_name != data.entry_point.name: logger.warning( 'Disabled extension %(ep)s: entry point name (%(ep)s) ' 'does not match extension name (%(ext)s)', - {'ep': extension.entry_point.name, 'ext': extension.ext_name}) + {'ep': data.entry_point.name, 'ext': data.extension.ext_name}) return False try: - extension.entry_point.require() + data.entry_point.require() except pkg_resources.DistributionNotFound as ex: logger.info( 'Disabled extension %s: Dependency %s not found', - extension.ext_name, ex) + data.extension.ext_name, ex) return False except pkg_resources.VersionConflict as ex: if len(ex.args) == 2: found, required = ex.args logger.info( 'Disabled extension %s: %s required, but found %s at %s', - extension.ext_name, required, found, found.location) + data.extension.ext_name, required, found, found.location) else: - logger.info('Disabled extension %s: %s', extension.ext_name, ex) + logger.info( + 'Disabled extension %s: %s', data.extension.ext_name, ex) return False try: - extension.validate_environment() + data.extension.validate_environment() except exceptions.ExtensionError as ex: logger.info( - 'Disabled extension %s: %s', extension.ext_name, ex.message) + 'Disabled extension %s: %s', data.extension.ext_name, ex.message) + return False + except Exception: + logger.exception('Validating extension %s failed with an exception.', + data.extension.ext_name) + return False + + if not data.config_schema: + logger.error('Extension %s does not have a config schema, disabling.', + data.extension.ext_name) + return False + elif not isinstance(data.config_schema.get('enabled'), config_lib.Boolean): + logger.error('Extension %s does not have the required "enabled" config' + ' option, disabling.', data.extension.ext_name) + return False + + for key, value in data.config_schema.items(): + if not isinstance(value, config_lib.ConfigValue): + logger.error('Extension %s config schema contains an invalid value' + ' for the option "%s", disabling.', + data.extension.ext_name, key) + return False + + if not data.config_defaults: + logger.error('Extension %s does not have a default config, disabling.', + data.extension.ext_name) return False return True diff --git a/tests/test_ext.py b/tests/test_ext.py index c58f6b20..748aebb3 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -1,35 +1,223 @@ from __future__ import absolute_import, unicode_literals -import unittest +import mock -from mopidy import config, ext +import pkg_resources + +import pytest + +from mopidy import config, exceptions, ext + +from tests import IsA, any_unicode -class ExtensionTest(unittest.TestCase): +class TestExtension(ext.Extension): + dist_name = 'Mopidy-Foobar' + ext_name = 'foobar' + version = '1.2.3' - def setUp(self): # noqa: N802 - self.ext = ext.Extension() + def get_default_config(self): + return '[foobar]\nenabled = true' - def test_dist_name_is_none(self): - self.assertIsNone(self.ext.dist_name) - def test_ext_name_is_none(self): - self.assertIsNone(self.ext.ext_name) +any_testextension = IsA(TestExtension) - def test_version_is_none(self): - self.assertIsNone(self.ext.version) - def test_get_default_config_raises_not_implemented(self): - with self.assertRaises(NotImplementedError): - self.ext.get_default_config() +class ExtensionTest(object): - def test_get_config_schema_returns_extension_schema(self): - schema = self.ext.get_config_schema() - self.assertIsInstance(schema['enabled'], config.Boolean) + @pytest.fixture + def extension(self): + return ext.Extension() - def test_validate_environment_does_nothing_by_default(self): - self.assertIsNone(self.ext.validate_environment()) + def test_dist_name_is_none(self, extension): + assert extension.dist_name is None - def test_setup_raises_not_implemented(self): - with self.assertRaises(NotImplementedError): - self.ext.setup(None) + def test_ext_name_is_none(self, extension): + assert extension.ext_name is None + + def test_version_is_none(self, extension): + assert extension.version is None + + def test_get_default_config_raises_not_implemented(self, extension): + with pytest.raises(NotImplementedError): + extension.get_default_config() + + def test_get_config_schema_returns_extension_schema(self, extension): + schema = extension.get_config_schema() + assert isinstance(schema['enabled'], config.Boolean) + + def test_validate_environment_does_nothing_by_default(self, extension): + assert extension.validate_environment() is None + + def test_setup_raises_not_implemented(self, extension): + with pytest.raises(NotImplementedError): + extension.setup(None) + + +class LoadExtensionsTest(object): + + @pytest.yield_fixture + def iter_entry_points_mock(self, request): + patcher = mock.patch('pkg_resources.iter_entry_points') + iter_entry_points = patcher.start() + iter_entry_points.return_value = [] + yield iter_entry_points + patcher.stop() + + def test_no_extensions(self, iter_entry_points_mock): + iter_entry_points_mock.return_value = [] + assert ext.load_extensions() == [] + + def test_load_extensions(self, iter_entry_points_mock): + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = TestExtension + + iter_entry_points_mock.return_value = [mock_entry_point] + + expected = ext.ExtensionData( + any_testextension, mock_entry_point, IsA(config.ConfigSchema), + any_unicode, None) + + assert ext.load_extensions() == [expected] + + def test_gets_wrong_class(self, iter_entry_points_mock): + + class WrongClass(object): + pass + + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = WrongClass + + iter_entry_points_mock.return_value = [mock_entry_point] + + assert ext.load_extensions() == [] + + def test_gets_instance(self, iter_entry_points_mock): + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = TestExtension() + + iter_entry_points_mock.return_value = [mock_entry_point] + + assert ext.load_extensions() == [] + + def test_creating_instance_fails(self, iter_entry_points_mock): + mock_extension = mock.Mock(spec=ext.Extension) + mock_extension.side_effect = Exception + + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = mock_extension + + iter_entry_points_mock.return_value = [mock_entry_point] + + assert ext.load_extensions() == [] + + def test_get_config_schema_fails(self, iter_entry_points_mock): + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = TestExtension + + iter_entry_points_mock.return_value = [mock_entry_point] + + with mock.patch.object(TestExtension, 'get_config_schema') as get: + get.side_effect = Exception + + assert ext.load_extensions() == [] + get.assert_called_once_with() + + def test_get_default_config_fails(self, iter_entry_points_mock): + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = TestExtension + + iter_entry_points_mock.return_value = [mock_entry_point] + + with mock.patch.object(TestExtension, 'get_default_config') as get: + get.side_effect = Exception + + assert ext.load_extensions() == [] + get.assert_called_once_with() + + def test_get_command_fails(self, iter_entry_points_mock): + mock_entry_point = mock.Mock() + mock_entry_point.load.return_value = TestExtension + + iter_entry_points_mock.return_value = [mock_entry_point] + + with mock.patch.object(TestExtension, 'get_command') as get: + get.side_effect = Exception + + assert ext.load_extensions() == [] + get.assert_called_once_with() + + +class ValidateExtensionDataTest(object): + + @pytest.fixture + def ext_data(self): + extension = TestExtension() + + entry_point = mock.Mock() + entry_point.name = extension.ext_name + + schema = extension.get_config_schema() + defaults = extension.get_default_config() + command = extension.get_command() + + return ext.ExtensionData( + extension, entry_point, schema, defaults, command) + + def test_name_mismatch(self, ext_data): + ext_data.entry_point.name = 'barfoo' + assert not ext.validate_extension_data(ext_data) + + def test_distribution_not_found(self, ext_data): + error = pkg_resources.DistributionNotFound + ext_data.entry_point.require.side_effect = error + assert not ext.validate_extension_data(ext_data) + + def test_version_conflict(self, ext_data): + error = pkg_resources.VersionConflict + ext_data.entry_point.require.side_effect = error + assert not ext.validate_extension_data(ext_data) + + def test_entry_point_require_exception(self, ext_data): + ext_data.entry_point.require.side_effect = Exception + + # Hope that entry points are well behaved, so exception will bubble. + with pytest.raises(Exception): + assert not ext.validate_extension_data(ext_data) + + def test_extenions_validate_environment_error(self, ext_data): + extension = ext_data.extension + with mock.patch.object(extension, 'validate_environment') as validate: + validate.side_effect = exceptions.ExtensionError('error') + + assert not ext.validate_extension_data(ext_data) + validate.assert_called_once_with() + + def test_extenions_validate_environment_exception(self, ext_data): + extension = ext_data.extension + with mock.patch.object(extension, 'validate_environment') as validate: + validate.side_effect = Exception + + assert not ext.validate_extension_data(ext_data) + validate.assert_called_once_with() + + def test_missing_schema(self, ext_data): + ext_data = ext_data._replace(config_schema=None) + assert not ext.validate_extension_data(ext_data) + + def test_schema_that_is_missing_enabled(self, ext_data): + del ext_data.config_schema['enabled'] + ext_data.config_schema['baz'] = config.String() + assert not ext.validate_extension_data(ext_data) + + def test_schema_with_wrong_types(self, ext_data): + ext_data.config_schema['enabled'] = 123 + assert not ext.validate_extension_data(ext_data) + + def test_schema_with_invalid_type(self, ext_data): + ext_data.config_schema['baz'] = 123 + assert not ext.validate_extension_data(ext_data) + + def test_no_default_config(self, ext_data): + ext_data = ext_data._replace(config_defaults=None) + assert not ext.validate_extension_data(ext_data)