Refactor MpdDispatcher to use a filter model, like Java Servlets. Password authentication handling becomes much cleaner.

This commit is contained in:
Stein Magnus Jodal 2011-06-04 02:21:14 +02:00
parent 3fe276f32a
commit 08f085fd8d
5 changed files with 88 additions and 67 deletions

View File

@ -7,7 +7,7 @@ from pykka.registry import ActorRegistry
from mopidy import settings
from mopidy.backends.base import Backend
from mopidy.frontends.mpd.exceptions import (MpdAckError, MpdArgError,
MpdUnknownCommand, MpdSystemError)
MpdPermissionError, MpdPasswordError, MpdSystemError, MpdUnknownCommand)
from mopidy.frontends.mpd.protocol import mpd_commands, request_handlers
# Do not remove the following import. The protocol modules must be imported to
# get them registered as request handlers.
@ -32,63 +32,82 @@ class MpdDispatcher(object):
self.authenticated = False
self.command_list = False
self.command_list_ok = False
self.command_list_index = None
self.context = MpdContext(self, session=session)
def handle_request(self, request, current_command_list_index=None):
"""Dispatch incoming requests to the correct handler."""
if not self.authenticated:
(self.authenticated, result) = self._check_password(request)
if result:
return result
self.command_list_index = current_command_list_index
response = []
filter_chain = [
self._catch_mpd_ack_errors_filter,
self._authenticate_filter,
self._command_list_filter,
self._add_ok_filter,
self._call_handler_filter,
]
return self._call_next_filter(request, response, filter_chain)
def _catch_mpd_ack_errors_filter(self, request, response, filter_chain):
try:
return self._call_next_filter(request, response, filter_chain)
except MpdAckError as mpd_ack_error:
if self.command_list_index is not None:
mpd_ack_error.index = self.command_list_index
return [mpd_ack_error.get_mpd_ack()]
def _authenticate_filter(self, request, response, filter_chain):
if self.authenticated or settings.MPD_SERVER_PASSWORD is None:
return self._call_next_filter(request, response, filter_chain)
else:
command = request.split(' ')[0]
if command in ('close', 'commands', 'notcommands', 'password', 'ping'):
return self._call_next_filter(request, response, filter_chain)
else:
raise MpdPermissionError(command=command)
def _command_list_filter(self, request, response, filter_chain):
if self._is_receiving_command_list(request):
self.command_list.append(request)
return None
try:
try:
result = self._call_handler(request)
except ActorDeadError as e:
logger.warning(u'Tried to communicate with dead actor.')
raise MpdSystemError(e.message)
except MpdAckError as e:
if current_command_list_index is not None:
e.index = current_command_list_index
return self._format_response(e.get_mpd_ack(), add_ok=False)
if (request in (u'command_list_begin', u'command_list_ok_begin')
or current_command_list_index is not None):
return self._format_response(result, add_ok=False)
return self._format_response(result)
def _check_password(self, request):
"""
Takes any request and tries to authenticate the client using it.
:rtype: a two-tuple containing (is_authenticated, response_message). If
the response_message is :class:`None`, normal processing should
continue, even though the client may not be authenticated.
"""
if settings.MPD_SERVER_PASSWORD is None:
return (True, None)
command = request.split(' ')[0]
if command == 'password':
if request == 'password "%s"' % settings.MPD_SERVER_PASSWORD:
return (True, [u'OK'])
else:
return (False, [u'ACK [3@0] {password} incorrect password'])
if command in ('close', 'commands', 'notcommands', 'ping'):
return (False, None)
return []
else:
return (False,
[u'ACK [4@0] {%(c)s} you don\'t have permission for "%(c)s"' %
{'c': command}])
response = self._call_next_filter(request, response, filter_chain)
if (self._is_receiving_command_list(request) or
self._is_processing_command_list(request)):
if response and response[-1] == u'OK':
response = response[:-1]
return response
def _is_receiving_command_list(self, request):
return (self.command_list is not False
and request != u'command_list_end')
def _is_processing_command_list(self, request):
return (self.command_list_index is not None
and request != u'command_list_end')
def _add_ok_filter(self, request, response, filter_chain):
response = self._call_next_filter(request, response, filter_chain)
if not self._has_error(response):
response.append(u'OK')
return response
def _has_error(self, response):
return response and response[-1].startswith(u'ACK')
def _call_handler_filter(self, request, response, filter_chain):
try:
response = self._format_response(self._call_handler(request))
return self._call_next_filter(request, response, filter_chain)
except ActorDeadError as e:
logger.warning(u'Tried to communicate with dead actor.')
raise MpdSystemError(e.message)
def _call_handler(self, request):
(handler, kwargs) = self._find_handler(request)
return handler(self.context, **kwargs)
@ -103,13 +122,19 @@ class MpdDispatcher(object):
raise MpdArgError(u'incorrect arguments', command=command)
raise MpdUnknownCommand(command=command)
def _format_response(self, result, add_ok=True):
response = []
for element in self._listify_result(result):
response.extend(self._format_lines(element))
if add_ok and (not response or not self._has_error(response)):
response.append(u'OK')
return response
def _call_next_filter(self, request, response, filter_chain):
if filter_chain:
next_filter = filter_chain.pop(0)
return next_filter(request, response, filter_chain)
else:
return response
def _format_response(self, response):
formatted_response = []
for element in self._listify_result(response):
formatted_response.extend(self._format_lines(element))
return formatted_response
def _listify_result(self, result):
if result is None:
@ -128,9 +153,6 @@ class MpdDispatcher(object):
return [u'%s: %s' % (key, value)]
return [line]
def _has_error(self, response):
return bool(response) and response[-1].startswith(u'ACK')
class MpdContext(object):
"""

View File

@ -31,17 +31,16 @@ def command_list_end(context):
context.dispatcher.command_list, False)
(command_list_ok, context.dispatcher.command_list_ok) = (
context.dispatcher.command_list_ok, False)
result = []
command_list_response = []
for index, command in enumerate(command_list):
response = context.dispatcher.handle_request(
command, current_command_list_index=index)
if response is not None:
result.append(response)
if response and response[-1].startswith(u'ACK'):
return result
command_list_response.extend(response)
if command_list_response and command_list_response[-1].startswith(u'ACK'):
return command_list_response
if command_list_ok:
response.append(u'list_OK')
return result
command_list_response.append(u'list_OK')
return command_list_response
@handle_pattern(r'^command_list_ok_begin$')
def command_list_ok_begin(context):

View File

@ -35,10 +35,9 @@ def password_(context, password):
This is used for authentication with the server. ``PASSWORD`` is
simply the plaintext password.
"""
# You will not get to this code without being authenticated. This is for
# when you are already authenticated, and are sending additional 'password'
# requests.
if settings.MPD_SERVER_PASSWORD != password:
if password == settings.MPD_SERVER_PASSWORD:
context.dispatcher.authenticated = True
else:
raise MpdPasswordError(u'incorrect password', command=u'password')
@handle_pattern(r'^ping$')

View File

@ -28,7 +28,7 @@ class AuthenticationTest(unittest.TestCase):
def test_authentication_with_anything_when_password_check_turned_off(self):
settings.MPD_SERVER_PASSWORD = None
response = self.dispatcher.handle_request(u'any request at all')
self.assertTrue(self.dispatcher.authenticated)
self.assertFalse(self.dispatcher.authenticated)
self.assert_('ACK [5@0] {} unknown command "any"' in response)
def test_anything_when_not_authenticated_should_fail(self):

View File

@ -43,6 +43,7 @@ class CommandListsTest(unittest.TestCase):
self.dispatcher.handle_request(u'play') # Known command
self.dispatcher.handle_request(u'paly') # Unknown command
result = self.dispatcher.handle_request(u'command_list_end')
self.assertEqual(len(result), 1, result)
self.assertEqual(result[0], u'ACK [5@1] {} unknown command "paly"')
def test_command_list_ok_begin(self):