diff --git a/mopidy/utils/command.py b/mopidy/utils/command.py index 596adcdb..b6645c51 100644 --- a/mopidy/utils/command.py +++ b/mopidy/utils/command.py @@ -5,7 +5,9 @@ import sys class CommandError(Exception): - pass + def __init__(self, message, usage=None): + self.message = message + self.usage = usage class ArgumentParser(argparse.ArgumentParser): @@ -42,9 +44,12 @@ class Command(object): def format_usage(self, prog=None): actions = self._build()[1] prog = prog or os.path.basename(sys.argv[0]) + return self._usage(actions, prog) + '\n' + + def _usage(self, actions, prog): formatter = argparse.HelpFormatter(prog) formatter.add_usage(None, actions, []) - return formatter.format_help() + return formatter.format_help().strip() def format_help(self, prog=None): actions = self._build()[1] @@ -89,16 +94,20 @@ class Command(object): for childname, child in self._children.items(): child._subhelp(' '.join((name, childname)), result) - def parse(self, args): - return self._parse(args, argparse.Namespace(), self._defaults.copy()) + def parse(self, args, prog=None): + prog = prog or os.path.basename(sys.argv[0]) + return self._parse( + args, argparse.Namespace(), self._defaults.copy(), prog) - def _parse(self, args, namespace, defaults): + def _parse(self, args, namespace, defaults, prog): defaults.update(self._defaults) - parser = self._build()[0] - result, unknown = parser.parse_known_args(args, namespace) + parser, actions = self._build() - if unknown: - raise CommandError('Unknown command options.') + try: + result = parser.parse_args(args, namespace) + except CommandError as e: + e.usage = self._usage(actions, prog) + raise if not result._args: for attr, value in defaults.items(): @@ -108,8 +117,10 @@ class Command(object): result.command = self return result - child = self._children.get(result._args[0]) - if not child: - raise CommandError('Invalid sub-command provided.') + child = result._args.pop(0) + if child not in self._children: + raise CommandError('unrecognized command: %s' % child, + usage=self._usage(actions, prog)) - return child._parse(result._args[1:], result, defaults) + return self._children[child]._parse( + result._args, result, defaults, ' '.join([prog, child])) diff --git a/tests/utils/command_test.py b/tests/utils/command_test.py index 39cc7ab1..14530cc6 100644 --- a/tests/utils/command_test.py +++ b/tests/utils/command_test.py @@ -95,28 +95,46 @@ class CommandParsingTest(unittest.TestCase): cmd.add_argument('--bar', type=int) with self.assertRaises(command.CommandError) as cm: - cmd.parse(['--bar', b'zero']) + cmd.parse(['--bar', b'zero'], prog='foo') self.assertEqual(cm.exception.message, "argument --bar: invalid int value: 'zero'") + self.assertEqual(cm.exception.usage, 'usage: foo [--bar BAR]') + + @mock.patch('sys.argv') + def test_command_error_usage_prog(self, argv_mock): + argv_mock.__getitem__.return_value = '/usr/bin/foo' + + cmd = command.Command() + cmd.add_argument('--bar', required=True) + + with self.assertRaises(command.CommandError) as cm: + cmd.parse([]) + self.assertEqual(cm.exception.usage, 'usage: foo --bar BAR') + + with self.assertRaises(command.CommandError) as cm: + cmd.parse([], prog='baz') + self.assertEqual(cm.exception.usage, 'usage: baz --bar BAR') def test_missing_required(self): cmd = command.Command() cmd.add_argument('--bar', required=True) with self.assertRaises(command.CommandError) as cm: - cmd.parse([]) + cmd.parse([], prog='foo') self.assertEqual(cm.exception.message, 'argument --bar is required') + self.assertEqual(cm.exception.usage, 'usage: foo --bar BAR') def test_missing_positionals(self): cmd = command.Command() cmd.add_argument('bar') with self.assertRaises(command.CommandError) as cm: - cmd.parse([]) + cmd.parse([], prog='foo') self.assertEqual(cm.exception.message, 'too few arguments') + self.assertEqual(cm.exception.usage, 'usage: foo bar') def test_missing_positionals_subcommand(self): child = command.Command() @@ -126,9 +144,30 @@ class CommandParsingTest(unittest.TestCase): cmd.add_child('bar', child) with self.assertRaises(command.CommandError) as cm: - cmd.parse(['bar']) + cmd.parse(['bar'], prog='foo') self.assertEqual(cm.exception.message, 'too few arguments') + self.assertEqual(cm.exception.usage, 'usage: foo bar baz') + + def test_unknown_command(self): + cmd = command.Command() + + with self.assertRaises(command.CommandError) as cm: + cmd.parse(['--help'], prog='foo') + + self.assertEqual( + cm.exception.message, 'unrecognized arguments: --help') + self.assertEqual(cm.exception.usage, 'usage: foo') + + def test_invalid_subcommand(self): + cmd = command.Command() + cmd.add_child('baz', command.Command()) + + with self.assertRaises(command.CommandError) as cm: + cmd.parse(['bar'], prog='foo') + + self.assertEqual(cm.exception.message, 'unrecognized command: bar') + self.assertEqual(cm.exception.usage, 'usage: foo') def test_set_defaults(self): cmd = command.Command()