Merge pull request #413 from adamcik/feature/binary-config-reading

Treat configs as binary data
This commit is contained in:
Stein Magnus Jodal 2013-04-14 09:57:40 -07:00
commit ceee6630fa
4 changed files with 121 additions and 35 deletions

View File

@ -44,8 +44,12 @@ def main():
extensions = [] # Make sure it is defined before the finally block
# TODO: figure out a way to make the boilerplate in this file reusable in
# scanner and other places we need it.
try:
create_file_structures()
# TODO: run raw logging config trough escape code etc, or just validate?
logging_config = config_lib.load(config_files, config_overrides)
log.setup_logging(
logging_config, options.verbosity_level, options.save_debug_log)

View File

@ -1,6 +1,5 @@
from __future__ import unicode_literals
import codecs
import ConfigParser as configparser
import io
import logging
@ -57,22 +56,21 @@ def _load(files, defaults, overrides):
files = [path.expand_path(f) for f in files]
sources = ['builtin-defaults'] + files + ['command-line']
logger.info('Loading config from: %s', ', '.join(sources))
for default in defaults: # TODO: remove decoding
parser.readfp(io.StringIO(default.decode('utf-8')))
# TODO: simply return path to config file for defaults so we can load it
# all in the same way?
for default in defaults:
parser.readfp(io.BytesIO(default))
# Load config from a series of config files
for filename in files:
# TODO: if this is the initial load of logging config we might not have
# a logger at this point, we might want to handle this better.
try:
with codecs.open(filename, encoding='utf-8') as filehandle:
with io.open(filename, 'rb') as filehandle:
parser.readfp(filehandle)
except IOError:
# TODO: if this is the initial load of logging config we might not
# have a logger at this point, we might want to handle this better.
logger.debug('Config file %s not found; skipping', filename)
continue
except UnicodeDecodeError:
logger.error('Config file %s is not UTF-8 encoded', filename)
sys.exit(1)
raw_config = {}
for section in parser.sections():

View File

@ -8,6 +8,21 @@ from mopidy.utils import path
from mopidy.config import validators
def decode(value):
if isinstance(value, unicode):
return value
# TODO: only unescape \n \t and \\?
return value.decode('string-escape').decode('utf-8')
def encode(value):
if not isinstance(value, unicode):
return value
for char in ('\\', '\n', '\t'): # TODO: more escapes?
value = value.replace(char, char.encode('unicode-escape'))
return value.encode('utf-8')
class ConfigValue(object):
"""Represents a config key's value and how to handle it.
@ -80,7 +95,7 @@ class String(ConfigValue):
Supported kwargs: ``optional``, ``choices``, and ``secret``.
"""
def deserialize(self, value):
value = value.strip()
value = decode(value).strip()
validators.validate_required(value, not self.optional)
validators.validate_choice(value, self.choices)
if not value:
@ -88,7 +103,7 @@ class String(ConfigValue):
return value
def serialize(self, value):
return value.encode('utf-8').encode('string-escape')
return encode(value)
class Integer(ConfigValue):
@ -142,14 +157,15 @@ class List(ConfigValue):
"""
def deserialize(self, value):
validators.validate_required(value, not self.optional)
if '\n' in value:
values = re.split(r'\s*\n\s*', value.strip())
if b'\n' in value:
values = re.split(r'\s*\n\s*', value)
else:
values = re.split(r'\s*,\s*', value.strip())
return tuple([v for v in values if v])
values = re.split(r'\s*,\s*', value)
values = (decode(v).strip() for v in values)
return tuple(v for v in values if v)
def serialize(self, value):
return '\n ' + '\n '.join(v.encode('utf-8') for v in value)
return b'\n ' + b'\n '.join(encode(v) for v in value if v)
class LogLevel(ConfigValue):

View File

@ -1,3 +1,5 @@
# encoding: utf-8
from __future__ import unicode_literals
import logging
@ -8,6 +10,8 @@ from mopidy.config import types
from tests import unittest
# TODO: DecodeTest and EncodeTest
class ConfigValueTest(unittest.TestCase):
def test_init(self):
@ -30,12 +34,12 @@ class ConfigValueTest(unittest.TestCase):
def test_deserialize_passes_through(self):
value = types.ConfigValue()
obj = object()
self.assertEqual(obj, value.deserialize(obj))
sentinel = object()
self.assertEqual(sentinel, value.deserialize(sentinel))
def test_serialize_conversion_to_string(self):
value = types.ConfigValue()
self.assertIsInstance(value.serialize(object()), basestring)
self.assertIsInstance(value.serialize(object()), bytes)
def test_format_uses_serialize(self):
value = types.ConfigValue()
@ -50,26 +54,62 @@ class ConfigValueTest(unittest.TestCase):
class StringTest(unittest.TestCase):
def test_deserialize_conversion_success(self):
value = types.String()
self.assertEqual('foo', value.deserialize(' foo '))
self.assertEqual('foo', value.deserialize(b' foo '))
self.assertIsInstance(value.deserialize(b'foo'), unicode)
def test_deserialize_decodes_utf8(self):
value = types.String()
result = value.deserialize('æøå'.encode('utf-8'))
self.assertEqual('æøå', result)
def test_deserialize_does_not_double_encode_unicode(self):
value = types.String()
result = value.deserialize('æøå')
self.assertEqual('æøå', result)
def test_deserialize_handles_escapes(self):
value = types.String(optional=True)
result = value.deserialize(b'a\\t\\nb')
self.assertEqual('a\t\nb', result)
def test_deserialize_enforces_choices(self):
value = types.String(choices=['foo', 'bar', 'baz'])
self.assertEqual('foo', value.deserialize('foo'))
self.assertRaises(ValueError, value.deserialize, 'foobar')
self.assertEqual('foo', value.deserialize(b'foo'))
self.assertRaises(ValueError, value.deserialize, b'foobar')
def test_deserialize_enforces_required(self):
value = types.String()
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
self.assertRaises(ValueError, value.deserialize, b'')
self.assertRaises(ValueError, value.deserialize, b' ')
def test_deserialize_respects_optional(self):
value = types.String(optional=True)
self.assertIsNone(value.deserialize(''))
self.assertIsNone(value.deserialize(' '))
self.assertIsNone(value.deserialize(b''))
self.assertIsNone(value.deserialize(b' '))
def test_serialize_string_escapes(self):
def test_deserialize_decode_failure(self):
value = types.String()
self.assertEqual(r'\r\n\t', value.serialize('\r\n\t'))
incorrectly_encoded_bytes = u'æøå'.encode('iso-8859-1')
self.assertRaises(
ValueError, value.deserialize, incorrectly_encoded_bytes)
def test_serialize_encodes_utf8(self):
value = types.String()
result = value.serialize('æøå')
self.assertIsInstance(result, bytes)
self.assertEqual('æøå'.encode('utf-8'), result)
def test_serialize_does_not_encode_bytes(self):
value = types.String()
result = value.serialize('æøå'.encode('utf-8'))
self.assertIsInstance(result, bytes)
self.assertEqual('æøå'.encode('utf-8'), result)
def test_serialize_handles_escapes(self):
value = types.String()
result = value.serialize('a\n\tb')
self.assertIsInstance(result, bytes)
self.assertEqual(r'a\n\tb'.encode('utf-8'), result)
def test_format_masks_secrets(self):
value = types.String(secret=True)
@ -139,28 +179,56 @@ class BooleanTest(unittest.TestCase):
class ListTest(unittest.TestCase):
# TODO: add test_deserialize_ignores_blank
# TODO: add test_serialize_ignores_blank
# TODO: add test_deserialize_handles_escapes
def test_deserialize_conversion_success(self):
value = types.List()
expected = ('foo', 'bar', 'baz')
self.assertEqual(expected, value.deserialize('foo, bar ,baz '))
self.assertEqual(expected, value.deserialize(b'foo, bar ,baz '))
expected = ('foo,bar', 'bar', 'baz')
self.assertEqual(expected, value.deserialize(' foo,bar\nbar\nbaz'))
self.assertEqual(expected, value.deserialize(b' foo,bar\nbar\nbaz'))
def test_deserialize_creates_tuples(self):
value = types.List(optional=True)
self.assertIsInstance(value.deserialize(b'foo,bar,baz'), tuple)
self.assertIsInstance(value.deserialize(b''), tuple)
def test_deserialize_decodes_utf8(self):
value = types.List()
result = value.deserialize('æ, ø, å'.encode('utf-8'))
self.assertEqual(('æ', 'ø', 'å'), result)
result = value.deserialize('æ\nø\nå'.encode('utf-8'))
self.assertEqual(('æ', 'ø', 'å'), result)
def test_deserialize_does_not_double_encode_unicode(self):
value = types.List()
result = value.deserialize('æ, ø, å')
self.assertEqual(('æ', 'ø', 'å'), result)
result = value.deserialize('æ\nø\nå')
self.assertEqual(('æ', 'ø', 'å'), result)
def test_deserialize_enforces_required(self):
value = types.List()
self.assertRaises(ValueError, value.deserialize, '')
self.assertRaises(ValueError, value.deserialize, ' ')
self.assertRaises(ValueError, value.deserialize, b'')
self.assertRaises(ValueError, value.deserialize, b' ')
def test_deserialize_respects_optional(self):
value = types.List(optional=True)
self.assertEqual(tuple(), value.deserialize(''))
self.assertEqual(tuple(), value.deserialize(' '))
self.assertEqual(tuple(), value.deserialize(b''))
self.assertEqual(tuple(), value.deserialize(b' '))
def test_serialize(self):
value = types.List()
result = value.serialize(('foo', 'bar', 'baz'))
self.assertIsInstance(result, bytes)
self.assertRegexpMatches(result, r'foo\n\s*bar\n\s*baz')