Refactor network.Server to improve testability

This commit is contained in:
Thomas Adamcik 2011-07-11 18:53:49 +02:00
parent 34cd3008d9
commit 91270ef535

View File

@ -12,6 +12,9 @@ from mopidy.utils.log import indent
logger = logging.getLogger('mopidy.utils.server')
class ShouldRetrySocketCall(Exception):
"""Indicate that attempted socket call should be retried"""
def _try_ipv6_socket():
"""Determine if system really supports IPv6"""
if not socket.has_ipv6:
@ -51,61 +54,79 @@ class Server(object):
self.protocol = protocol
self.max_connections = max_connections
self.timeout = timeout
self.server_socket = self.create_server_socket(host, port)
self.listener = create_socket()
self.listener.setblocking(False)
self.listener.bind((host, port))
self.listener.listen(1)
gobject.io_add_watch(
self.listener.fileno(), gobject.IO_IN, self.handle_accept)
self.register_server_socket(self.server_socket.fileno())
logger.debug(u'Listening on [%s]:%s using %s as protocol handler',
host, port, self.protocol.__name__)
host, port, self.protocol)
def handle_accept(self, fd, flags):
def create_server_socket(self, host, port):
sock = create_socket()
sock.setblocking(False)
sock.bind((host, port))
sock.listen(1)
return sock
def register_server_socket(self, fileno):
gobject.io_add_watch(fileno, gobject.IO_IN, self.handle_connection)
def handle_connection(self, fd, flags):
try:
sock, addr = self.listener.accept()
except socket.error as e:
if e.errno in (errno.EAGAIN, errno.EINTR):
return True # i.e. retry
raise
num_connections = len(ActorRegistry.get_by_class(self.protocol))
if self.max_connections and num_connections >= self.max_connections:
logger.warning(u'Rejected connection from [%s]:%s', addr[0], addr[1])
try:
sock.close()
except socket.error:
pass
sock, addr = self.accept_connection()
except ShouldRetrySocketCall:
return True
client = Client(self.protocol, sock, addr, self.timeout)
client.start()
if self.maximum_connections_exceeded():
self.reject_connection(sock, addr)
else:
self.init_connection(sock, addr)
return True
def accept_connection(self):
try:
return self.server_socket.accept()
except socket.error as e:
if e.errno in (errno.EAGAIN, errno.EINTR):
raise ShouldRetrySocketCall
raise
class Client(object):
def maximum_connections_exceeded(self):
return (self.max_connections is not None and
self.number_of_connections() >= self.max_connections)
def number_of_connections(self):
return len(ActorRegistry.get_by_class(self.protocol))
def reject_connection(self, sock, addr):
logger.warning(u'Rejected connection from [%s]:%s', addr[0], addr[1])
try:
sock.close()
except socket.error:
pass
def init_connection(self, sock, addr):
Connection(self.protocol, sock, addr, self.timeout)
class Connection(object):
def __init__(self, protocol, sock, addr, timeout):
sock.setblocking(False)
self._sock = sock
self.host, self.port = addr[:2] # IPv6 has larger addr
self._sock = sock
self._protocol = protocol
self._timeout_time = timeout
self._send_lock = threading.Lock()
self._send_buffer = ''
self._actor_ref = None
self._recv_id = None
self._send_id = None
self._timeout_id = None
def start(self):
self._actor_ref = self._protocol.start(self)
self._enable_recv()
self.enable_timeout()
@ -223,8 +244,8 @@ class LineProtocol(ThreadingActor):
#: What encoding to expect incomming data to be in, can be :class:`None`.
encoding = 'utf-8'
def __init__(self, client):
self.client = client
def __init__(self, connection):
self.connection = connection
self.recv_buffer = ''
self.terminator_re = re.compile(self.terminator)
@ -242,7 +263,7 @@ class LineProtocol(ThreadingActor):
if 'received' not in message:
return
self.client.disable_timeout()
self.connection.disable_timeout()
self.log_raw_data(message['received'])
for line in self.parse_lines(message['received']):
@ -250,11 +271,11 @@ class LineProtocol(ThreadingActor):
self.log_request(line)
self.on_line_received(line)
self.client.enable_timeout()
self.connection.enable_timeout()
def on_stop(self):
"""Ensure that cleanup when actor stops."""
self.client.stop()
self.connection.stop()
def parse_lines(self, new_data=None):
"""Consume new data and yield any lines found."""
@ -280,7 +301,7 @@ class LineProtocol(ThreadingActor):
Can be overridden by subclasses to change logging behaviour.
"""
logger.debug(u'Request from %s to %s: %s', self.client, self.actor_urn,
logger.debug(u'Request from %s to %s: %s', self.connection, self.actor_urn,
indent(request))
def log_response(self, response):
@ -289,7 +310,7 @@ class LineProtocol(ThreadingActor):
Can be overridden by subclasses to change logging behaviour.
"""
logger.debug(u'Response to %s from %s: %s', self.client,
logger.debug(u'Response to %s from %s: %s', self.connection,
self.actor_urn, indent(response))
def log_error(self, error):
@ -299,7 +320,7 @@ class LineProtocol(ThreadingActor):
Can be overridden by subclasses to change logging behaviour.
"""
logger.warning(u'Problem with connection to %s in %s: %s',
self.client, self.actor_urn, error)
self.connection, self.actor_urn, error)
def log_timeout(self):
"""
@ -308,7 +329,7 @@ class LineProtocol(ThreadingActor):
Can be overridden by subclasses to change logging behaviour.
"""
logger.debug(u'Closing connection to %s in %s due to timeout.',
self.client, self.actor_urn)
self.connection, self.actor_urn)
def encode(self, line):
"""
@ -332,7 +353,7 @@ class LineProtocol(ThreadingActor):
def send_lines(self, lines):
"""
Send array of lines to client.
Send array of lines to client via connection.
Join lines using the terminator that is set for this class, encode it
and send it to the client.
@ -342,4 +363,4 @@ class LineProtocol(ThreadingActor):
data = self.terminator.join(lines)
self.log_response(data)
self.client.send(self.encode(data + self.terminator))
self.connection.send(self.encode(data + self.terminator))