diff --git a/mopidy/utils/config.py b/mopidy/utils/config.py index 0d6eb928..efc07d10 100644 --- a/mopidy/utils/config.py +++ b/mopidy/utils/config.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import logging import re +import socket def validate_choice(value, choices): @@ -139,3 +140,19 @@ class LogLevel(ConfigValue): def serialize(self, value): return dict((v, k) for k, v in self.levels.items()).get(value) + + +class Hostname(ConfigValue): + def deserialize(self, value): + try: + socket.getaddrinfo(value, None) + except socket.error: + raise ValueError('must be a resolveable hostname or valid IP.') + return value + + +class Port(Integer): + def __init__(self, **kwargs): + super(Port, self).__init__(**kwargs) + self.minimum = 1 + self.maximum = 2**16 - 1 diff --git a/tests/utils/config_test.py b/tests/utils/config_test.py index 289d9df8..b0ccfe78 100644 --- a/tests/utils/config_test.py +++ b/tests/utils/config_test.py @@ -1,6 +1,8 @@ from __future__ import unicode_literals import logging +import mock +import socket from mopidy.utils import config @@ -209,3 +211,33 @@ class BooleanTest(unittest.TestCase): def test_serialize_unknown_level(self): value = config.LogLevel() self.assertIsNone(value.serialize(1337)) + + +class HostnameTest(unittest.TestCase): + @mock.patch('socket.getaddrinfo') + def test_deserialize_checks_addrinfo(self, getaddrinfo_mock): + value = config.Hostname() + value.deserialize('example.com') + getaddrinfo_mock.assert_called_once_with('example.com', None) + + @mock.patch('socket.getaddrinfo') + def test_deserialize_handles_failures(self, getaddrinfo_mock): + value = config.Hostname() + getaddrinfo_mock.side_effect = socket.error + self.assertRaises(ValueError, value.deserialize, 'example.com') + + +class PortTest(unittest.TestCase): + def test_valid_ports(self): + value = config.Port() + self.assertEqual(1, value.deserialize('1')) + self.assertEqual(80, value.deserialize('80')) + self.assertEqual(6600, value.deserialize('6600')) + self.assertEqual(65535, value.deserialize('65535')) + + def test_invalid_ports(self): + value = config.Port() + self.assertRaises(ValueError, value.deserialize, '65536') + self.assertRaises(ValueError, value.deserialize, '100000') + self.assertRaises(ValueError, value.deserialize, '0') + self.assertRaises(ValueError, value.deserialize, '-1')