Refactor MpdDispatcher to use a filter model, like Java Servlets. Password authentication handling becomes much cleaner.

This commit is contained in:
Stein Magnus Jodal 2011-06-04 02:21:14 +02:00
parent 3fe276f32a
commit 08f085fd8d
5 changed files with 88 additions and 67 deletions

View File

@ -7,7 +7,7 @@ from pykka.registry import ActorRegistry
from mopidy import settings from mopidy import settings
from mopidy.backends.base import Backend from mopidy.backends.base import Backend
from mopidy.frontends.mpd.exceptions import (MpdAckError, MpdArgError, from mopidy.frontends.mpd.exceptions import (MpdAckError, MpdArgError,
MpdUnknownCommand, MpdSystemError) MpdPermissionError, MpdPasswordError, MpdSystemError, MpdUnknownCommand)
from mopidy.frontends.mpd.protocol import mpd_commands, request_handlers from mopidy.frontends.mpd.protocol import mpd_commands, request_handlers
# Do not remove the following import. The protocol modules must be imported to # Do not remove the following import. The protocol modules must be imported to
# get them registered as request handlers. # get them registered as request handlers.
@ -32,63 +32,82 @@ class MpdDispatcher(object):
self.authenticated = False self.authenticated = False
self.command_list = False self.command_list = False
self.command_list_ok = False self.command_list_ok = False
self.command_list_index = None
self.context = MpdContext(self, session=session) self.context = MpdContext(self, session=session)
def handle_request(self, request, current_command_list_index=None): def handle_request(self, request, current_command_list_index=None):
"""Dispatch incoming requests to the correct handler.""" """Dispatch incoming requests to the correct handler."""
if not self.authenticated: self.command_list_index = current_command_list_index
(self.authenticated, result) = self._check_password(request) response = []
if result: filter_chain = [
return result self._catch_mpd_ack_errors_filter,
self._authenticate_filter,
self._command_list_filter,
self._add_ok_filter,
self._call_handler_filter,
]
return self._call_next_filter(request, response, filter_chain)
def _catch_mpd_ack_errors_filter(self, request, response, filter_chain):
try:
return self._call_next_filter(request, response, filter_chain)
except MpdAckError as mpd_ack_error:
if self.command_list_index is not None:
mpd_ack_error.index = self.command_list_index
return [mpd_ack_error.get_mpd_ack()]
def _authenticate_filter(self, request, response, filter_chain):
if self.authenticated or settings.MPD_SERVER_PASSWORD is None:
return self._call_next_filter(request, response, filter_chain)
else:
command = request.split(' ')[0]
if command in ('close', 'commands', 'notcommands', 'password', 'ping'):
return self._call_next_filter(request, response, filter_chain)
else:
raise MpdPermissionError(command=command)
def _command_list_filter(self, request, response, filter_chain):
if self._is_receiving_command_list(request): if self._is_receiving_command_list(request):
self.command_list.append(request) self.command_list.append(request)
return None return []
try:
try:
result = self._call_handler(request)
except ActorDeadError as e:
logger.warning(u'Tried to communicate with dead actor.')
raise MpdSystemError(e.message)
except MpdAckError as e:
if current_command_list_index is not None:
e.index = current_command_list_index
return self._format_response(e.get_mpd_ack(), add_ok=False)
if (request in (u'command_list_begin', u'command_list_ok_begin')
or current_command_list_index is not None):
return self._format_response(result, add_ok=False)
return self._format_response(result)
def _check_password(self, request):
"""
Takes any request and tries to authenticate the client using it.
:rtype: a two-tuple containing (is_authenticated, response_message). If
the response_message is :class:`None`, normal processing should
continue, even though the client may not be authenticated.
"""
if settings.MPD_SERVER_PASSWORD is None:
return (True, None)
command = request.split(' ')[0]
if command == 'password':
if request == 'password "%s"' % settings.MPD_SERVER_PASSWORD:
return (True, [u'OK'])
else:
return (False, [u'ACK [3@0] {password} incorrect password'])
if command in ('close', 'commands', 'notcommands', 'ping'):
return (False, None)
else: else:
return (False, response = self._call_next_filter(request, response, filter_chain)
[u'ACK [4@0] {%(c)s} you don\'t have permission for "%(c)s"' % if (self._is_receiving_command_list(request) or
{'c': command}]) self._is_processing_command_list(request)):
if response and response[-1] == u'OK':
response = response[:-1]
return response
def _is_receiving_command_list(self, request): def _is_receiving_command_list(self, request):
return (self.command_list is not False return (self.command_list is not False
and request != u'command_list_end') and request != u'command_list_end')
def _is_processing_command_list(self, request):
return (self.command_list_index is not None
and request != u'command_list_end')
def _add_ok_filter(self, request, response, filter_chain):
response = self._call_next_filter(request, response, filter_chain)
if not self._has_error(response):
response.append(u'OK')
return response
def _has_error(self, response):
return response and response[-1].startswith(u'ACK')
def _call_handler_filter(self, request, response, filter_chain):
try:
response = self._format_response(self._call_handler(request))
return self._call_next_filter(request, response, filter_chain)
except ActorDeadError as e:
logger.warning(u'Tried to communicate with dead actor.')
raise MpdSystemError(e.message)
def _call_handler(self, request): def _call_handler(self, request):
(handler, kwargs) = self._find_handler(request) (handler, kwargs) = self._find_handler(request)
return handler(self.context, **kwargs) return handler(self.context, **kwargs)
@ -103,13 +122,19 @@ class MpdDispatcher(object):
raise MpdArgError(u'incorrect arguments', command=command) raise MpdArgError(u'incorrect arguments', command=command)
raise MpdUnknownCommand(command=command) raise MpdUnknownCommand(command=command)
def _format_response(self, result, add_ok=True):
response = [] def _call_next_filter(self, request, response, filter_chain):
for element in self._listify_result(result): if filter_chain:
response.extend(self._format_lines(element)) next_filter = filter_chain.pop(0)
if add_ok and (not response or not self._has_error(response)): return next_filter(request, response, filter_chain)
response.append(u'OK') else:
return response return response
def _format_response(self, response):
formatted_response = []
for element in self._listify_result(response):
formatted_response.extend(self._format_lines(element))
return formatted_response
def _listify_result(self, result): def _listify_result(self, result):
if result is None: if result is None:
@ -128,9 +153,6 @@ class MpdDispatcher(object):
return [u'%s: %s' % (key, value)] return [u'%s: %s' % (key, value)]
return [line] return [line]
def _has_error(self, response):
return bool(response) and response[-1].startswith(u'ACK')
class MpdContext(object): class MpdContext(object):
""" """

View File

@ -31,17 +31,16 @@ def command_list_end(context):
context.dispatcher.command_list, False) context.dispatcher.command_list, False)
(command_list_ok, context.dispatcher.command_list_ok) = ( (command_list_ok, context.dispatcher.command_list_ok) = (
context.dispatcher.command_list_ok, False) context.dispatcher.command_list_ok, False)
result = [] command_list_response = []
for index, command in enumerate(command_list): for index, command in enumerate(command_list):
response = context.dispatcher.handle_request( response = context.dispatcher.handle_request(
command, current_command_list_index=index) command, current_command_list_index=index)
if response is not None: command_list_response.extend(response)
result.append(response) if command_list_response and command_list_response[-1].startswith(u'ACK'):
if response and response[-1].startswith(u'ACK'): return command_list_response
return result
if command_list_ok: if command_list_ok:
response.append(u'list_OK') command_list_response.append(u'list_OK')
return result return command_list_response
@handle_pattern(r'^command_list_ok_begin$') @handle_pattern(r'^command_list_ok_begin$')
def command_list_ok_begin(context): def command_list_ok_begin(context):

View File

@ -35,10 +35,9 @@ def password_(context, password):
This is used for authentication with the server. ``PASSWORD`` is This is used for authentication with the server. ``PASSWORD`` is
simply the plaintext password. simply the plaintext password.
""" """
# You will not get to this code without being authenticated. This is for if password == settings.MPD_SERVER_PASSWORD:
# when you are already authenticated, and are sending additional 'password' context.dispatcher.authenticated = True
# requests. else:
if settings.MPD_SERVER_PASSWORD != password:
raise MpdPasswordError(u'incorrect password', command=u'password') raise MpdPasswordError(u'incorrect password', command=u'password')
@handle_pattern(r'^ping$') @handle_pattern(r'^ping$')

View File

@ -28,7 +28,7 @@ class AuthenticationTest(unittest.TestCase):
def test_authentication_with_anything_when_password_check_turned_off(self): def test_authentication_with_anything_when_password_check_turned_off(self):
settings.MPD_SERVER_PASSWORD = None settings.MPD_SERVER_PASSWORD = None
response = self.dispatcher.handle_request(u'any request at all') response = self.dispatcher.handle_request(u'any request at all')
self.assertTrue(self.dispatcher.authenticated) self.assertFalse(self.dispatcher.authenticated)
self.assert_('ACK [5@0] {} unknown command "any"' in response) self.assert_('ACK [5@0] {} unknown command "any"' in response)
def test_anything_when_not_authenticated_should_fail(self): def test_anything_when_not_authenticated_should_fail(self):

View File

@ -43,6 +43,7 @@ class CommandListsTest(unittest.TestCase):
self.dispatcher.handle_request(u'play') # Known command self.dispatcher.handle_request(u'play') # Known command
self.dispatcher.handle_request(u'paly') # Unknown command self.dispatcher.handle_request(u'paly') # Unknown command
result = self.dispatcher.handle_request(u'command_list_end') result = self.dispatcher.handle_request(u'command_list_end')
self.assertEqual(len(result), 1, result)
self.assertEqual(result[0], u'ACK [5@1] {} unknown command "paly"') self.assertEqual(result[0], u'ACK [5@1] {} unknown command "paly"')
def test_command_list_ok_begin(self): def test_command_list_ok_begin(self):