commands: Replace set_defaults with set method

For our use case we only need to be able to override a value, not set defaults
so this simplifies everything.
This commit is contained in:
Thomas Adamcik 2013-11-14 23:00:37 +01:00
parent 39f7fd5955
commit c0f1a1352b
2 changed files with 18 additions and 18 deletions

View File

@ -34,7 +34,7 @@ class Command(object):
def __init__(self):
self._children = collections.OrderedDict()
self._arguments = []
self._defaults = {}
self._overrides = {}
def _build(self):
actions = []
@ -54,11 +54,11 @@ class Command(object):
def add_argument(self, *args, **kwargs):
self._arguments.append((args, kwargs))
def set_defaults(self, **kwargs):
self._defaults.update(kwargs)
def set(self, **kwargs):
self._overrides.update(kwargs)
def exit(self, status_code=0, message=None, usage=None):
print '\n\n'.join(m for m in (usage, message) if m.strip)
print '\n\n'.join(m for m in (usage, message) if m)
sys.exit(status_code)
def format_usage(self, prog=None):
@ -118,12 +118,12 @@ class Command(object):
prog = prog or os.path.basename(sys.argv[0])
try:
return self._parse(
args, argparse.Namespace(), self._defaults.copy(), prog)
args, argparse.Namespace(), self._overrides.copy(), prog)
except _HelpError:
self.exit(0, self.format_help(prog))
def _parse(self, args, namespace, defaults, prog):
defaults.update(self._defaults)
def _parse(self, args, namespace, overrides, prog):
overrides.update(self._overrides)
parser, actions = self._build()
try:
@ -132,9 +132,8 @@ class Command(object):
self.exit(1, e.message, self._usage(actions, prog))
if not result._args:
for attr, value in defaults.items():
if not hasattr(result, attr):
setattr(result, attr, value)
for attr, value in overrides.items():
setattr(result, attr, value)
delattr(result, '_args')
result.command = self
return result
@ -145,7 +144,7 @@ class Command(object):
self.exit(1, 'unrecognized command: %s' % child, usage)
return self._children[child]._parse(
result._args, result, defaults, ' '.join([prog, child]))
result._args, result, overrides, ' '.join([prog, child]))
def run(self, *args, **kwargs):
raise NotImplementedError

View File

@ -180,33 +180,34 @@ class CommandParsingTest(unittest.TestCase):
self.exit_mock.assert_called_once_with(
1, 'unrecognized command: bar', 'usage: foo')
def test_set_defaults(self):
def test_set(self):
cmd = command.Command()
cmd.set_defaults(foo='bar')
cmd.set(foo='bar')
result = cmd.parse([])
self.assertEqual(result.foo, 'bar')
def test_defaults_propegate(self):
def test_set_propegate(self):
child = command.Command()
cmd = command.Command()
cmd.set_defaults(foo='bar')
cmd.set(foo='bar')
cmd.add_child('command', child)
result = cmd.parse(['command'])
self.assertEqual(result.foo, 'bar')
def test_innermost_defaults_wins(self):
def test_innermost_set_wins(self):
child = command.Command()
child.set_defaults(foo='bar')
child.set(foo='bar', baz=1)
cmd = command.Command()
cmd.set_defaults(foo='baz')
cmd.set(foo='baz', baz=None)
cmd.add_child('command', child)
result = cmd.parse(['command'])
self.assertEqual(result.foo, 'bar')
self.assertEqual(result.baz, 1)
def test_help_action_works(self):
cmd = command.Command()