From b4c98ec4a54fab4dc58182abf32b020d7295d743 Mon Sep 17 00:00:00 2001 From: Jarryd Tilbrook Date: Tue, 13 Feb 2018 20:58:39 +0800 Subject: [PATCH] Fix/1531 add unix domain socket (#1629) mpd: add functionality for unix domain socket (Fixes #1531) The Hostname config type now supports a Unix socket path prefixed with `unix:` --- docs/changelog.rst | 3 + docs/ext/mpd.rst | 6 +- mopidy/config/types.py | 3 + mopidy/core/actor.py | 1 - mopidy/internal/network.py | 82 +++++++++++--- mopidy/internal/path.py | 8 ++ mopidy/local/json.py | 1 - mopidy/mpd/__init__.py | 2 +- mopidy/mpd/actor.py | 14 ++- mopidy/mpd/session.py | 7 +- tests/config/test_types.py | 6 ++ tests/core/test_actor.py | 2 +- tests/core/test_mixer.py | 1 + tests/internal/network/test_connection.py | 124 +++++++++++----------- tests/internal/network/test_server.py | 109 ++++++++++++++++--- tests/internal/network/test_utils.py | 25 ++++- tests/internal/test_path.py | 12 +++ 17 files changed, 299 insertions(+), 107 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 068c79a9..f4439a7c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,9 @@ Bug fix release. - MPD: Added ``idle`` to the list of available commands. (Fixes: :issue:`1593`, PR: :issue:`1597`) +- MPD: Added Unix domain sockets for binding MPD to. + (Fixes: :issue:`1531`, PR: :issue:`1629`) + v2.1.0 (2017-01-02) =================== diff --git a/docs/ext/mpd.rst b/docs/ext/mpd.rst index 7f02facc..db56187f 100644 --- a/docs/ext/mpd.rst +++ b/docs/ext/mpd.rst @@ -63,7 +63,8 @@ See :ref:`config` for general help on configuring Mopidy. .. confval:: mpd/hostname - Which address the MPD server should bind to. + Which address the MPD server should bind to. This can be a network address + or the path to a Unix socket. ``127.0.0.1`` Listens only on the IPv4 loopback interface @@ -73,6 +74,9 @@ See :ref:`config` for general help on configuring Mopidy. Listens on all IPv4 interfaces ``::`` Listens on all interfaces, both IPv4 and IPv6 + ``unix:/path/to/unix/socket.sock`` + Listen on the Unix socket at the specified path. Must be prefixed with + ``unix:`` .. confval:: mpd/port diff --git a/mopidy/config/types.py b/mopidy/config/types.py index 9d673c43..f31ba7d8 100644 --- a/mopidy/config/types.py +++ b/mopidy/config/types.py @@ -258,6 +258,9 @@ class Hostname(ConfigValue): validators.validate_required(value, self._required) if not value.strip(): return None + socket_path = path.get_unix_socket_path(value) + if socket_path is not None: + return 'unix:' + Path(not self._required).deserialize(socket_path) try: socket.getaddrinfo(value, None) except socket.error: diff --git a/mopidy/core/actor.py b/mopidy/core/actor.py index 03efd6a8..ad9acba9 100644 --- a/mopidy/core/actor.py +++ b/mopidy/core/actor.py @@ -8,7 +8,6 @@ import os import pykka import mopidy - from mopidy import audio, backend, mixer from mopidy.audio import PlaybackState from mopidy.core.history import HistoryController diff --git a/mopidy/internal/network.py b/mopidy/internal/network.py index cefdf8ea..0c5a9109 100644 --- a/mopidy/internal/network.py +++ b/mopidy/internal/network.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, unicode_literals import errno import logging +import os import re import socket import sys @@ -9,13 +10,20 @@ import threading import pykka -from mopidy.internal import encoding +from mopidy.internal import encoding, path, validation from mopidy.internal.gi import GObject logger = logging.getLogger(__name__) +def is_unix_socket(sock): + """Check if the provided socket is a Unix domain socket""" + if hasattr(socket, 'AF_UNIX'): + return sock.family == socket.AF_UNIX + return False + + class ShouldRetrySocketCall(Exception): """Indicate that attempted socket call should be retried""" @@ -40,7 +48,7 @@ def try_ipv6_socket(): has_ipv6 = try_ipv6_socket() -def create_socket(): +def create_tcp_socket(): """Create a TCP socket with or without IPv6 depending on system support""" if has_ipv6: sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) @@ -57,6 +65,19 @@ def create_socket(): return sock +def create_unix_socket(): + """Create a Unix domain socket""" + return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + +def format_socket_name(sock): + """Format the connection string for the given socket""" + if is_unix_socket(sock): + return '%s' % sock.getsockname() + else: + return '[%s]:%s' % sock.getsockname()[:2] + + def format_hostname(hostname): """Format hostname for display.""" if (has_ipv6 and re.match(r'\d+.\d+.\d+.\d+', hostname) is not None): @@ -76,17 +97,42 @@ class Server(object): self.timeout = timeout self.server_socket = self.create_server_socket(host, port) - self.register_server_socket(self.server_socket.fileno()) + self.watcher = self.register_server_socket(self.server_socket.fileno()) def create_server_socket(self, host, port): - sock = create_socket() + socket_path = path.get_unix_socket_path(host) + if socket_path is not None: # host is a path so use unix socket + sock = create_unix_socket() + sock.bind(socket_path) + else: + # ensure the port is supplied + validation.check_integer(port) + sock = create_tcp_socket() + sock.bind((host, port)) + sock.setblocking(False) - sock.bind((host, port)) sock.listen(1) return sock + def stop(self): + GObject.source_remove(self.watcher) + if is_unix_socket(self.server_socket): + unix_socket_path = self.server_socket.getsockname() + else: + unix_socket_path = None + + self.server_socket.shutdown(socket.SHUT_RDWR) + self.server_socket.close() + + # clean up the socket file + if unix_socket_path is not None: + os.unlink(unix_socket_path) + def register_server_socket(self, fileno): - GObject.io_add_watch(fileno, GObject.IO_IN, self.handle_connection) + return GObject.io_add_watch( + fileno, + GObject.IO_IN, + self.handle_connection) def handle_connection(self, fd, flags): try: @@ -102,7 +148,10 @@ class Server(object): def accept_connection(self): try: - return self.server_socket.accept() + sock, addr = self.server_socket.accept() + if is_unix_socket(sock): + addr = (sock.getsockname(), None) + return sock, addr except socket.error as e: if e.errno in (errno.EAGAIN, errno.EINTR): raise ShouldRetrySocketCall @@ -117,7 +166,9 @@ class Server(object): def reject_connection(self, sock, addr): # FIXME provide more context in logging? - logger.warning('Rejected connection from [%s]:%s', addr[0], addr[1]) + logger.warning( + 'Rejected connection from %s', + format_socket_name(sock)) try: sock.close() except socket.error: @@ -142,7 +193,7 @@ class Connection(object): self.host, self.port = addr[:2] # IPv6 has larger addr - self.sock = sock + self._sock = sock self.protocol = protocol self.protocol_kwargs = protocol_kwargs self.timeout = timeout @@ -180,7 +231,7 @@ class Connection(object): self.disable_send() try: - self.sock.close() + self._sock.close() except socket.error: pass @@ -195,7 +246,7 @@ class Connection(object): def send(self, data): """Send data to client, return any unsent data.""" try: - sent = self.sock.send(data) + sent = self._sock.send(data) return data[sent:] except socket.error as e: if e.errno in (errno.EWOULDBLOCK, errno.EINTR): @@ -226,7 +277,7 @@ class Connection(object): try: self.recv_id = GObject.io_add_watch( - self.sock.fileno(), + self._sock.fileno(), GObject.IO_IN | GObject.IO_ERR | GObject.IO_HUP, self.recv_callback) except socket.error as e: @@ -244,7 +295,7 @@ class Connection(object): try: self.send_id = GObject.io_add_watch( - self.sock.fileno(), + self._sock.fileno(), GObject.IO_OUT | GObject.IO_ERR | GObject.IO_HUP, self.send_callback) except socket.error as e: @@ -263,7 +314,7 @@ class Connection(object): return True try: - data = self.sock.recv(4096) + data = self._sock.recv(4096) except socket.error as e: if e.errno not in (errno.EWOULDBLOCK, errno.EINTR): self.stop('Unexpected client error: %s' % e) @@ -304,6 +355,9 @@ class Connection(object): self.stop('Client inactive for %ds; closing connection' % self.timeout) return False + def __str__(self): + return format_socket_name(self._sock) + class LineProtocol(pykka.ThreadingActor): diff --git a/mopidy/internal/path.py b/mopidy/internal/path.py index 307d733c..c06c6d66 100644 --- a/mopidy/internal/path.py +++ b/mopidy/internal/path.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, unicode_literals import logging import os +import re import stat import string import threading @@ -47,6 +48,13 @@ def get_or_create_file(file_path, mkdir=True, content=None): return file_path +def get_unix_socket_path(socket_path): + match = re.search('^unix:(.*)', socket_path) + if not match: + return None + return match.group(1) + + def path_to_uri(path): """ Convert OS specific path to file:// URI. diff --git a/mopidy/local/json.py b/mopidy/local/json.py index 2e39b68b..7e26b6db 100644 --- a/mopidy/local/json.py +++ b/mopidy/local/json.py @@ -7,7 +7,6 @@ import re import sys import mopidy - from mopidy import compat, local, models from mopidy.internal import storage as internal_storage from mopidy.internal import timer diff --git a/mopidy/mpd/__init__.py b/mopidy/mpd/__init__.py index 84cf47cb..4cc5a988 100644 --- a/mopidy/mpd/__init__.py +++ b/mopidy/mpd/__init__.py @@ -19,7 +19,7 @@ class Extension(ext.Extension): def get_config_schema(self): schema = super(Extension, self).get_config_schema() schema['hostname'] = config.Hostname() - schema['port'] = config.Port() + schema['port'] = config.Port(optional=True) schema['password'] = config.Secret(optional=True) schema['max_connections'] = config.Integer(minimum=1) schema['connection_timeout'] = config.Integer(minimum=1) diff --git a/mopidy/mpd/actor.py b/mopidy/mpd/actor.py index 067d20c5..00041bf3 100644 --- a/mopidy/mpd/actor.py +++ b/mopidy/mpd/actor.py @@ -41,11 +41,11 @@ class MpdFrontend(pykka.ThreadingActor, CoreListener): self.zeroconf_name = config['mpd']['zeroconf'] self.zeroconf_service = None - self._setup_server(config, core) + self.server = self._setup_server(config, core) def _setup_server(self, config, core): try: - network.Server( + server = network.Server( self.hostname, self.port, protocol=session.MpdSession, protocol_kwargs={ @@ -60,10 +60,15 @@ class MpdFrontend(pykka.ThreadingActor, CoreListener): 'MPD server startup failed: %s' % encoding.locale_decode(error)) - logger.info('MPD server running at [%s]:%s', self.hostname, self.port) + logger.info( + 'MPD server running at %s', + network.format_socket_name(server.server_socket)) + + return server def on_start(self): - if self.zeroconf_name: + if (self.zeroconf_name and not + network.is_unix_socket(self.server.server_socket)): self.zeroconf_service = zeroconf.Zeroconf( name=self.zeroconf_name, stype='_mpd._tcp', @@ -75,6 +80,7 @@ class MpdFrontend(pykka.ThreadingActor, CoreListener): self.zeroconf_service.unpublish() process.stop_actors_by_class(session.MpdSession) + self.server.stop() def on_event(self, event, **kwargs): if event not in _CORE_EVENTS_TO_IDLE_SUBSYSTEMS: diff --git a/mopidy/mpd/session.py b/mopidy/mpd/session.py index d484d986..8d4b39d1 100644 --- a/mopidy/mpd/session.py +++ b/mopidy/mpd/session.py @@ -25,18 +25,19 @@ class MpdSession(network.LineProtocol): session=self, config=config, core=core, uri_map=uri_map) def on_start(self): - logger.info('New MPD connection from [%s]:%s', self.host, self.port) + logger.info('New MPD connection from %s', self.connection) self.send_lines(['OK MPD %s' % protocol.VERSION]) def on_line_received(self, line): - logger.debug('Request from [%s]:%s: %s', self.host, self.port, line) + logger.debug('Request from [%s]: %s', self.connection, line) response = self.dispatcher.handle_request(line) if not response: return logger.debug( - 'Response to [%s]:%s: %s', self.host, self.port, + 'Response to [%s]: %s', + self.connection, formatting.indent(self.terminator.join(response))) self.send_lines(response) diff --git a/tests/config/test_types.py b/tests/config/test_types.py index 40226c51..5a7f5b5c 100644 --- a/tests/config/test_types.py +++ b/tests/config/test_types.py @@ -344,6 +344,12 @@ class HostnameTest(unittest.TestCase): self.assertIsNone(value.deserialize(' ')) self.assertEqual(0, getaddrinfo_mock.call_count) + @mock.patch('mopidy.internal.path.expand_path') + def test_deserialize_with_unix_socket(self, expand_path_mock): + value = types.Hostname() + self.assertIsNotNone(value.deserialize('unix:/tmp/mopidy.socket')) + expand_path_mock.assert_called_once_with('/tmp/mopidy.socket') + class PortTest(unittest.TestCase): diff --git a/tests/core/test_actor.py b/tests/core/test_actor.py index c5da74d1..66e7ee9c 100644 --- a/tests/core/test_actor.py +++ b/tests/core/test_actor.py @@ -10,10 +10,10 @@ import mock import pykka import mopidy - from mopidy.core import Core from mopidy.internal import models, storage, versioning from mopidy.models import Track + from tests import dummy_mixer diff --git a/tests/core/test_mixer.py b/tests/core/test_mixer.py index 996b7c23..38ff091d 100644 --- a/tests/core/test_mixer.py +++ b/tests/core/test_mixer.py @@ -8,6 +8,7 @@ import pykka from mopidy import core, mixer from mopidy.internal.models import MixerState + from tests import dummy_mixer diff --git a/tests/internal/network/test_connection.py b/tests/internal/network/test_connection.py index 9ee0aaf3..899cfffd 100644 --- a/tests/internal/network/test_connection.py +++ b/tests/internal/network/test_connection.py @@ -51,7 +51,7 @@ class ConnectionTest(unittest.TestCase): network.Connection.__init__( self.mock, protocol, protocol_kwargs, sock, addr, sentinel.timeout) - self.assertEqual(sock, self.mock.sock) + self.assertEqual(sock, self.mock._sock) self.assertEqual(protocol, self.mock.protocol) self.assertEqual(protocol_kwargs, self.mock.protocol_kwargs) self.assertEqual(sentinel.timeout, self.mock.timeout) @@ -73,7 +73,7 @@ class ConnectionTest(unittest.TestCase): def test_stop_disables_recv_send_and_timeout(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) self.mock.disable_timeout.assert_called_once_with() @@ -83,24 +83,24 @@ class ConnectionTest(unittest.TestCase): def test_stop_closes_socket(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) - self.mock.sock.close.assert_called_once_with() + self.mock._sock.close.assert_called_once_with() def test_stop_closes_socket_error(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.close.side_effect = socket.error + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.close.side_effect = socket.error network.Connection.stop(self.mock, sentinel.reason) - self.mock.sock.close.assert_called_once_with() + self.mock._sock.close.assert_called_once_with() def test_stop_stops_actor(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) self.mock.actor_ref.stop.assert_called_once_with(block=False) @@ -109,7 +109,7 @@ class ConnectionTest(unittest.TestCase): self.mock.stopping = False self.mock.actor_ref = Mock() self.mock.actor_ref.stop.side_effect = pykka.ActorDeadError() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) self.mock.actor_ref.stop.assert_called_once_with(block=False) @@ -117,7 +117,7 @@ class ConnectionTest(unittest.TestCase): def test_stop_sets_stopping_to_true(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) self.assertEqual(True, self.mock.stopping) @@ -125,17 +125,17 @@ class ConnectionTest(unittest.TestCase): def test_stop_does_not_proceed_when_already_stopping(self): self.mock.stopping = True self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) self.assertEqual(0, self.mock.actor_ref.stop.call_count) - self.assertEqual(0, self.mock.sock.close.call_count) + self.assertEqual(0, self.mock._sock.close.call_count) @patch.object(network.logger, 'log', new=Mock()) def test_stop_logs_reason(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) network.logger.log.assert_called_once_with( @@ -145,7 +145,7 @@ class ConnectionTest(unittest.TestCase): def test_stop_logs_reason_with_level(self): self.mock.stopping = False self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop( self.mock, sentinel.reason, level=sentinel.level) @@ -156,7 +156,7 @@ class ConnectionTest(unittest.TestCase): def test_stop_logs_that_it_is_calling_itself(self): self.mock.stopping = True self.mock.actor_ref = Mock() - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.stop(self.mock, sentinel.reason) network.logger.log(any_int, any_unicode) @@ -164,8 +164,8 @@ class ConnectionTest(unittest.TestCase): @patch.object(GObject, 'io_add_watch', new=Mock()) def test_enable_recv_registers_with_gobject(self): self.mock.recv_id = None - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.fileno.return_value = sentinel.fileno + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.fileno.return_value = sentinel.fileno GObject.io_add_watch.return_value = sentinel.tag network.Connection.enable_recv(self.mock) @@ -177,7 +177,7 @@ class ConnectionTest(unittest.TestCase): @patch.object(GObject, 'io_add_watch', new=Mock()) def test_enable_recv_already_registered(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) self.mock.recv_id = sentinel.tag network.Connection.enable_recv(self.mock) @@ -185,7 +185,7 @@ class ConnectionTest(unittest.TestCase): def test_enable_recv_does_not_change_tag(self): self.mock.recv_id = sentinel.tag - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.enable_recv(self.mock) self.assertEqual(sentinel.tag, self.mock.recv_id) @@ -208,8 +208,8 @@ class ConnectionTest(unittest.TestCase): def test_enable_recv_on_closed_socket(self): self.mock.recv_id = None - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.fileno.side_effect = socket.error(errno.EBADF, '') + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.fileno.side_effect = socket.error(errno.EBADF, '') network.Connection.enable_recv(self.mock) self.mock.stop.assert_called_once_with(any_unicode) @@ -218,8 +218,8 @@ class ConnectionTest(unittest.TestCase): @patch.object(GObject, 'io_add_watch', new=Mock()) def test_enable_send_registers_with_gobject(self): self.mock.send_id = None - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.fileno.return_value = sentinel.fileno + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.fileno.return_value = sentinel.fileno GObject.io_add_watch.return_value = sentinel.tag network.Connection.enable_send(self.mock) @@ -231,7 +231,7 @@ class ConnectionTest(unittest.TestCase): @patch.object(GObject, 'io_add_watch', new=Mock()) def test_enable_send_already_registered(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) self.mock.send_id = sentinel.tag network.Connection.enable_send(self.mock) @@ -239,7 +239,7 @@ class ConnectionTest(unittest.TestCase): def test_enable_send_does_not_change_tag(self): self.mock.send_id = sentinel.tag - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) network.Connection.enable_send(self.mock) self.assertEqual(sentinel.tag, self.mock.send_id) @@ -262,8 +262,8 @@ class ConnectionTest(unittest.TestCase): def test_enable_send_on_closed_socket(self): self.mock.send_id = None - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.fileno.side_effect = socket.error(errno.EBADF, '') + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.fileno.side_effect = socket.error(errno.EBADF, '') network.Connection.enable_send(self.mock) self.assertEqual(None, self.mock.send_id) @@ -367,7 +367,7 @@ class ConnectionTest(unittest.TestCase): self.assertEqual('', self.mock.send_buffer) def test_recv_callback_respects_io_err(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) self.mock.actor_ref = Mock() self.assertTrue(network.Connection.recv_callback( @@ -375,7 +375,7 @@ class ConnectionTest(unittest.TestCase): self.mock.stop.assert_called_once_with(any_unicode) def test_recv_callback_respects_io_hup(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) self.mock.actor_ref = Mock() self.assertTrue(network.Connection.recv_callback( @@ -383,7 +383,7 @@ class ConnectionTest(unittest.TestCase): self.mock.stop.assert_called_once_with(any_unicode) def test_recv_callback_respects_io_hup_and_io_err(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) self.mock.actor_ref = Mock() self.assertTrue(network.Connection.recv_callback( @@ -392,8 +392,8 @@ class ConnectionTest(unittest.TestCase): self.mock.stop.assert_called_once_with(any_unicode) def test_recv_callback_sends_data_to_actor(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.recv.return_value = 'data' + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.recv.return_value = 'data' self.mock.actor_ref = Mock() self.assertTrue(network.Connection.recv_callback( @@ -402,8 +402,8 @@ class ConnectionTest(unittest.TestCase): {'received': 'data'}) def test_recv_callback_handles_dead_actors(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.recv.return_value = 'data' + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.recv.return_value = 'data' self.mock.actor_ref = Mock() self.mock.actor_ref.tell.side_effect = pykka.ActorDeadError() @@ -412,38 +412,38 @@ class ConnectionTest(unittest.TestCase): self.mock.stop.assert_called_once_with(any_unicode) def test_recv_callback_gets_no_data(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.recv.return_value = '' + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.recv.return_value = '' self.mock.actor_ref = Mock() self.assertTrue(network.Connection.recv_callback( self.mock, sentinel.fd, GObject.IO_IN)) self.assertEqual(self.mock.mock_calls, [ - call.sock.recv(any_int), + call._sock.recv(any_int), call.disable_recv(), call.actor_ref.tell({'close': True}), ]) def test_recv_callback_recoverable_error(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) for error in (errno.EWOULDBLOCK, errno.EINTR): - self.mock.sock.recv.side_effect = socket.error(error, '') + self.mock._sock.recv.side_effect = socket.error(error, '') self.assertTrue(network.Connection.recv_callback( self.mock, sentinel.fd, GObject.IO_IN)) self.assertEqual(0, self.mock.stop.call_count) def test_recv_callback_unrecoverable_error(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.recv.side_effect = socket.error + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.recv.side_effect = socket.error self.assertTrue(network.Connection.recv_callback( self.mock, sentinel.fd, GObject.IO_IN)) self.mock.stop.assert_called_once_with(any_unicode) def test_send_callback_respects_io_err(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 1 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 1 self.mock.send_lock = Mock() self.mock.actor_ref = Mock() self.mock.send_buffer = '' @@ -453,8 +453,8 @@ class ConnectionTest(unittest.TestCase): self.mock.stop.assert_called_once_with(any_unicode) def test_send_callback_respects_io_hup(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 1 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 1 self.mock.send_lock = Mock() self.mock.actor_ref = Mock() self.mock.send_buffer = '' @@ -464,8 +464,8 @@ class ConnectionTest(unittest.TestCase): self.mock.stop.assert_called_once_with(any_unicode) def test_send_callback_respects_io_hup_and_io_err(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 1 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 1 self.mock.send_lock = Mock() self.mock.actor_ref = Mock() self.mock.send_buffer = '' @@ -479,8 +479,8 @@ class ConnectionTest(unittest.TestCase): self.mock.send_lock = Mock() self.mock.send_lock.acquire.return_value = True self.mock.send_buffer = '' - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 0 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 0 self.assertTrue(network.Connection.send_callback( self.mock, sentinel.fd, GObject.IO_IN)) @@ -491,13 +491,13 @@ class ConnectionTest(unittest.TestCase): self.mock.send_lock = Mock() self.mock.send_lock.acquire.return_value = False self.mock.send_buffer = '' - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 0 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 0 self.assertTrue(network.Connection.send_callback( self.mock, sentinel.fd, GObject.IO_IN)) self.mock.send_lock.acquire.assert_called_once_with(False) - self.assertEqual(0, self.mock.sock.send.call_count) + self.assertEqual(0, self.mock._sock.send.call_count) def test_send_callback_sends_all_data(self): self.mock.send_lock = Mock() @@ -523,31 +523,31 @@ class ConnectionTest(unittest.TestCase): self.assertEqual('ta', self.mock.send_buffer) def test_send_recoverable_error(self): - self.mock.sock = Mock(spec=socket.SocketType) + self.mock._sock = Mock(spec=socket.SocketType) for error in (errno.EWOULDBLOCK, errno.EINTR): - self.mock.sock.send.side_effect = socket.error(error, '') + self.mock._sock.send.side_effect = socket.error(error, '') network.Connection.send(self.mock, 'data') self.assertEqual(0, self.mock.stop.call_count) def test_send_calls_socket_send(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 4 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 4 self.assertEqual('', network.Connection.send(self.mock, 'data')) - self.mock.sock.send.assert_called_once_with('data') + self.mock._sock.send.assert_called_once_with('data') def test_send_calls_socket_send_partial_send(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.return_value = 2 + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.return_value = 2 self.assertEqual('ta', network.Connection.send(self.mock, 'data')) - self.mock.sock.send.assert_called_once_with('data') + self.mock._sock.send.assert_called_once_with('data') def test_send_unrecoverable_error(self): - self.mock.sock = Mock(spec=socket.SocketType) - self.mock.sock.send.side_effect = socket.error + self.mock._sock = Mock(spec=socket.SocketType) + self.mock._sock.send.side_effect = socket.error self.assertEqual('', network.Connection.send(self.mock, 'data')) self.mock.stop.assert_called_once_with(any_unicode) diff --git a/tests/internal/network/test_server.py b/tests/internal/network/test_server.py index 072e24de..88e1fc29 100644 --- a/tests/internal/network/test_server.py +++ b/tests/internal/network/test_server.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, unicode_literals import errno +import os import socket import unittest from mock import Mock, patch, sentinel +from mopidy import exceptions from mopidy.internal import network from mopidy.internal.gi import GObject @@ -22,6 +24,7 @@ class ServerTest(unittest.TestCase): self.mock, sentinel.host, sentinel.port, sentinel.protocol) self.mock.create_server_socket.assert_called_once_with( sentinel.host, sentinel.port) + self.mock.stop() def test_init_calls_register_server(self): sock = Mock(spec=socket.SocketType) @@ -55,40 +58,99 @@ class ServerTest(unittest.TestCase): self.assertEqual(sentinel.timeout, self.mock.timeout) self.assertEqual(sock, self.mock.server_socket) - @patch.object(network, 'create_socket', spec=socket.SocketType) - def test_create_server_socket_sets_up_listener(self, create_socket): - sock = create_socket.return_value + def test_create_server_socket_no_port(self): + with self.assertRaises(exceptions.ValidationError): + network.Server.create_server_socket( + self.mock, str(sentinel.host), None) + + def test_create_server_socket_invalid_port(self): + with self.assertRaises(exceptions.ValidationError): + network.Server.create_server_socket( + self.mock, str(sentinel.host), str(sentinel.port)) + + @patch.object(network, 'create_tcp_socket', spec=socket.SocketType) + def test_create_server_socket_sets_up_listener(self, create_tcp_socket): + sock = create_tcp_socket.return_value network.Server.create_server_socket( - self.mock, sentinel.host, sentinel.port) + self.mock, str(sentinel.host), 1234) sock.setblocking.assert_called_once_with(False) - sock.bind.assert_called_once_with((sentinel.host, sentinel.port)) + sock.bind.assert_called_once_with((str(sentinel.host), 1234)) sock.listen.assert_called_once_with(any_int) + create_tcp_socket.assert_called_once() - @patch.object(network, 'create_socket', new=Mock()) + @patch.object(network, 'create_unix_socket', spec=socket.SocketType) + def test_create_server_socket_sets_up_listener_unix( + self, + create_unix_socket): + sock = create_unix_socket.return_value + + network.Server.create_server_socket( + self.mock, 'unix:' + str(sentinel.host), sentinel.port) + sock.setblocking.assert_called_once_with(False) + sock.bind.assert_called_once_with(str(sentinel.host)) + sock.listen.assert_called_once_with(any_int) + create_unix_socket.assert_called_once() + + @patch.object(network, 'create_tcp_socket', new=Mock()) def test_create_server_socket_fails(self): - network.create_socket.side_effect = socket.error + network.create_tcp_socket.side_effect = socket.error with self.assertRaises(socket.error): network.Server.create_server_socket( - self.mock, sentinel.host, sentinel.port) + self.mock, str(sentinel.host), 1234) - @patch.object(network, 'create_socket', new=Mock()) + @patch.object(network, 'create_unix_socket', new=Mock()) + def test_create_server_socket_fails_unix(self): + network.create_unix_socket.side_effect = socket.error + with self.assertRaises(socket.error): + network.Server.create_server_socket( + self.mock, 'unix:' + str(sentinel.host), sentinel.port) + + @patch.object(network, 'create_tcp_socket', new=Mock()) def test_create_server_bind_fails(self): - sock = network.create_socket.return_value + sock = network.create_tcp_socket.return_value sock.bind.side_effect = socket.error with self.assertRaises(socket.error): network.Server.create_server_socket( - self.mock, sentinel.host, sentinel.port) + self.mock, str(sentinel.host), 1234) - @patch.object(network, 'create_socket', new=Mock()) + @patch.object(network, 'create_unix_socket', new=Mock()) + def test_create_server_bind_fails_unix(self): + sock = network.create_unix_socket.return_value + sock.bind.side_effect = socket.error + + with self.assertRaises(socket.error): + network.Server.create_server_socket( + self.mock, 'unix:' + str(sentinel.host), sentinel.port) + + @patch.object(network, 'create_tcp_socket', new=Mock()) def test_create_server_listen_fails(self): - sock = network.create_socket.return_value + sock = network.create_tcp_socket.return_value sock.listen.side_effect = socket.error with self.assertRaises(socket.error): network.Server.create_server_socket( - self.mock, sentinel.host, sentinel.port) + self.mock, str(sentinel.host), 1234) + + @patch.object(network, 'create_unix_socket', new=Mock()) + def test_create_server_listen_fails_unix(self): + sock = network.create_unix_socket.return_value + sock.listen.side_effect = socket.error + + with self.assertRaises(socket.error): + network.Server.create_server_socket( + self.mock, 'unix:' + str(sentinel.host), sentinel.port) + + @patch.object(os, 'unlink', new=Mock()) + @patch.object(GObject, 'source_remove', new=Mock()) + def test_stop_server_cleans_unix_socket(self): + self.mock.watcher = Mock() + sock = Mock() + sock.family = socket.AF_UNIX + self.mock.server_socket = sock + network.Server.stop(self.mock) + os.unlink.assert_called_once_with(sock.getsockname()) @patch.object(GObject, 'io_add_watch', new=Mock()) def test_register_server_socket_sets_up_io_watch(self): @@ -124,13 +186,26 @@ class ServerTest(unittest.TestCase): def test_accept_connection(self): sock = Mock(spec=socket.SocketType) - sock.accept.return_value = (sentinel.sock, sentinel.addr) + connected_sock = Mock(spec=socket.SocketType) + sock.accept.return_value = (connected_sock, sentinel.addr) self.mock.server_socket = sock sock, addr = network.Server.accept_connection(self.mock) - self.assertEqual(sentinel.sock, sock) + self.assertEqual(connected_sock, sock) self.assertEqual(sentinel.addr, addr) + def test_accept_connection_unix(self): + sock = Mock(spec=socket.SocketType) + connected_sock = Mock(spec=socket.SocketType) + connected_sock.family = socket.AF_UNIX + connected_sock.getsockname.return_value = sentinel.sockname + sock.accept.return_value = (connected_sock, sentinel.addr) + self.mock.server_socket = sock + + sock, addr = network.Server.accept_connection(self.mock) + self.assertEqual(connected_sock, sock) + self.assertEqual((sentinel.sockname, None), addr) + def test_accept_connection_recoverable_error(self): sock = Mock(spec=socket.SocketType) self.mock.server_socket = sock @@ -182,6 +257,7 @@ class ServerTest(unittest.TestCase): sentinel.protocol, {}, sentinel.sock, sentinel.addr, sentinel.timeout) + @patch.object(network, 'format_socket_name', new=Mock()) def test_reject_connection(self): sock = Mock(spec=socket.SocketType) @@ -189,6 +265,7 @@ class ServerTest(unittest.TestCase): self.mock, sock, (sentinel.host, sentinel.port)) sock.close.assert_called_once_with() + @patch.object(network, 'format_socket_name', new=Mock()) def test_reject_connection_error(self): sock = Mock(spec=socket.SocketType) sock.close.side_effect = socket.error diff --git a/tests/internal/network/test_utils.py b/tests/internal/network/test_utils.py index a769ff93..e1c2e4be 100644 --- a/tests/internal/network/test_utils.py +++ b/tests/internal/network/test_utils.py @@ -3,7 +3,7 @@ from __future__ import absolute_import, unicode_literals import socket import unittest -from mock import Mock, patch +from mock import Mock, patch, sentinel from mopidy.internal import network @@ -22,6 +22,25 @@ class FormatHostnameTest(unittest.TestCase): self.assertEqual(network.format_hostname('0.0.0.0'), '0.0.0.0') +class FormatSocketConnectionTest(unittest.TestCase): + + def test_format_socket_name(self): + sock = Mock(spec=socket.SocketType) + sock.family = socket.AF_INET + sock.getsockname.return_value = (sentinel.ip, sentinel.port) + self.assertEqual( + network.format_socket_name(sock), + '[%s]:%s' % (sentinel.ip, sentinel.port)) + + def test_format_socket_name_unix(self): + sock = Mock(spec=socket.SocketType) + sock.family = socket.AF_UNIX + sock.getsockname.return_value = sentinel.sockname + self.assertEqual( + network.format_socket_name(sock), + str(sentinel.sockname)) + + class TryIPv6SocketTest(unittest.TestCase): @patch('socket.has_ipv6', False) @@ -46,14 +65,14 @@ class CreateSocketTest(unittest.TestCase): @patch('mopidy.internal.network.has_ipv6', False) @patch('socket.socket') def test_ipv4_socket(self, socket_mock): - network.create_socket() + network.create_tcp_socket() self.assertEqual( socket_mock.call_args[0], (socket.AF_INET, socket.SOCK_STREAM)) @patch('mopidy.internal.network.has_ipv6', True) @patch('socket.socket') def test_ipv6_socket(self, socket_mock): - network.create_socket() + network.create_tcp_socket() self.assertEqual( socket_mock.call_args[0], (socket.AF_INET6, socket.SOCK_STREAM)) diff --git a/tests/internal/test_path.py b/tests/internal/test_path.py index 6eebaaa3..dec80f1a 100644 --- a/tests/internal/test_path.py +++ b/tests/internal/test_path.py @@ -137,6 +137,18 @@ class GetOrCreateFileTest(unittest.TestCase): self.assertEqual(fh.read(), b'foobar\xc3\xa6\xc3\xb8\xc3\xa5') +class GetUnixSocketPathTest(unittest.TestCase): + + def test_correctly_matched_socket_path(self): + self.assertEqual( + path.get_unix_socket_path('unix:/tmp/mopidy.socket'), + '/tmp/mopidy.socket' + ) + + def test_correctly_no_match_socket_path(self): + self.assertIsNone(path.get_unix_socket_path('127.0.0.1')) + + class PathToFileURITest(unittest.TestCase): def test_simple_path(self):