commands: Add set_defaults() to the parser

- Removes mock based test that checks delegation, just check commands result
- Allow setting of defaults that propegate down the chain
- Move internal parser args to a _parse helper.
This commit is contained in:
Thomas Adamcik 2013-11-12 22:13:05 +01:00
parent 6bddcb7875
commit 001c8e0bdc
2 changed files with 56 additions and 34 deletions

View File

@ -17,6 +17,7 @@ class Command(object):
def __init__(self):
self._children = collections.OrderedDict()
self._arguments = []
self._defaults = {}
def _build(self):
actions = []
@ -35,6 +36,9 @@ class Command(object):
def add_argument(self, *args, **kwargs):
self._arguments.append((args, kwargs))
def set_defaults(self, **kwargs):
self._defaults.update(kwargs)
def format_usage(self, prog=None):
actions = self._build()[1]
prog = prog or os.path.basename(sys.argv[0])
@ -85,24 +89,27 @@ class Command(object):
for childname, child in self._children.items():
child._subhelp(' '.join((name, childname)), result)
def parse(self, args, namespace=None):
if not namespace:
namespace = argparse.Namespace()
def parse(self, args):
return self._parse(args, argparse.Namespace(), self._defaults.copy())
def _parse(self, args, namespace, defaults):
defaults.update(self._defaults)
parser = self._build()[0]
result, unknown = parser.parse_known_args(args, namespace)
if unknown:
raise CommandError('Unknown command options.')
args = result._args
delattr(result, '_args')
if not args:
if not result._args:
for attr, value in defaults.items():
if not hasattr(result, attr):
setattr(result, attr, value)
delattr(result, '_args')
result.command = self
return result
if args[0] not in self._children:
child = self._children.get(result._args[0])
if not child:
raise CommandError('Invalid sub-command provided.')
return self._children[args[0]].parse(args[1:], result)
return child._parse(result._args[1:], result, defaults)

View File

@ -17,15 +17,6 @@ class CommandParsingTest(unittest.TestCase):
result = cmd.parse([])
self.assertFalse(hasattr(result, '_args'))
def test_sub_command_delegation(self):
mock_cmd = mock.Mock(spec=command.Command)
cmd = command.Command()
cmd.add_child('foo', mock_cmd)
cmd.parse(['foo'])
mock_cmd.parse.assert_called_with([], mock.ANY)
def test_unknown_options_raises_error(self):
cmd = command.Command()
with self.assertRaises(command.CommandError):
@ -55,22 +46,6 @@ class CommandParsingTest(unittest.TestCase):
self.assertEqual(result.bar, 'baz')
self.assertEqual(result.baz, None)
def test_multiple_sub_commands(self):
mock_foo_cmd = mock.Mock(spec=command.Command)
mock_bar_cmd = mock.Mock(spec=command.Command)
mock_baz_cmd = mock.Mock(spec=command.Command)
cmd = command.Command()
cmd.add_child('foo', mock_foo_cmd)
cmd.add_child('bar', mock_bar_cmd)
cmd.add_child('baz', mock_baz_cmd)
cmd.parse(['bar'])
mock_bar_cmd.parse.assert_called_with([], mock.ANY)
cmd.parse(['baz'])
mock_baz_cmd.parse.assert_called_with([], mock.ANY)
def test_subcommand_may_have_positional(self):
child = command.Command()
child.add_argument('bar')
@ -103,6 +78,18 @@ class CommandParsingTest(unittest.TestCase):
result = cmd.parse([])
self.assertEqual(result.command, cmd)
child2 = command.Command()
cmd.add_child('bar', child2)
subchild = command.Command()
child.add_child('baz', subchild)
result = cmd.parse(['bar'])
self.assertEqual(result.command, child2)
result = cmd.parse(['foo', 'baz'])
self.assertEqual(result.command, subchild)
def test_invalid_type(self):
cmd = command.Command()
cmd.add_argument('--bar', type=int)
@ -143,6 +130,34 @@ class CommandParsingTest(unittest.TestCase):
self.assertEqual(cm.exception.message, 'too few arguments')
def test_set_defaults(self):
cmd = command.Command()
cmd.set_defaults(foo='bar')
result = cmd.parse([])
self.assertEqual(result.foo, 'bar')
def test_defaults_propegate(self):
child = command.Command()
cmd = command.Command()
cmd.set_defaults(foo='bar')
cmd.add_child('command', child)
result = cmd.parse(['command'])
self.assertEqual(result.foo, 'bar')
def test_innermost_defaults_wins(self):
child = command.Command()
child.set_defaults(foo='bar')
cmd = command.Command()
cmd.set_defaults(foo='baz')
cmd.add_child('command', child)
result = cmd.parse(['command'])
self.assertEqual(result.foo, 'bar')
class UsageTest(unittest.TestCase):
@mock.patch('sys.argv')