Refactor network.Server to improve testability
This commit is contained in:
parent
34cd3008d9
commit
91270ef535
@ -12,6 +12,9 @@ from mopidy.utils.log import indent
|
||||
|
||||
logger = logging.getLogger('mopidy.utils.server')
|
||||
|
||||
class ShouldRetrySocketCall(Exception):
|
||||
"""Indicate that attempted socket call should be retried"""
|
||||
|
||||
def _try_ipv6_socket():
|
||||
"""Determine if system really supports IPv6"""
|
||||
if not socket.has_ipv6:
|
||||
@ -51,61 +54,79 @@ class Server(object):
|
||||
self.protocol = protocol
|
||||
self.max_connections = max_connections
|
||||
self.timeout = timeout
|
||||
self.server_socket = self.create_server_socket(host, port)
|
||||
|
||||
self.listener = create_socket()
|
||||
self.listener.setblocking(False)
|
||||
self.listener.bind((host, port))
|
||||
self.listener.listen(1)
|
||||
|
||||
gobject.io_add_watch(
|
||||
self.listener.fileno(), gobject.IO_IN, self.handle_accept)
|
||||
self.register_server_socket(self.server_socket.fileno())
|
||||
|
||||
logger.debug(u'Listening on [%s]:%s using %s as protocol handler',
|
||||
host, port, self.protocol.__name__)
|
||||
host, port, self.protocol)
|
||||
|
||||
def handle_accept(self, fd, flags):
|
||||
def create_server_socket(self, host, port):
|
||||
sock = create_socket()
|
||||
sock.setblocking(False)
|
||||
sock.bind((host, port))
|
||||
sock.listen(1)
|
||||
return sock
|
||||
|
||||
def register_server_socket(self, fileno):
|
||||
gobject.io_add_watch(fileno, gobject.IO_IN, self.handle_connection)
|
||||
|
||||
def handle_connection(self, fd, flags):
|
||||
try:
|
||||
sock, addr = self.listener.accept()
|
||||
except socket.error as e:
|
||||
if e.errno in (errno.EAGAIN, errno.EINTR):
|
||||
return True # i.e. retry
|
||||
raise
|
||||
|
||||
num_connections = len(ActorRegistry.get_by_class(self.protocol))
|
||||
if self.max_connections and num_connections >= self.max_connections:
|
||||
logger.warning(u'Rejected connection from [%s]:%s', addr[0], addr[1])
|
||||
try:
|
||||
sock.close()
|
||||
except socket.error:
|
||||
pass
|
||||
sock, addr = self.accept_connection()
|
||||
except ShouldRetrySocketCall:
|
||||
return True
|
||||
|
||||
client = Client(self.protocol, sock, addr, self.timeout)
|
||||
client.start()
|
||||
|
||||
if self.maximum_connections_exceeded():
|
||||
self.reject_connection(sock, addr)
|
||||
else:
|
||||
self.init_connection(sock, addr)
|
||||
return True
|
||||
|
||||
def accept_connection(self):
|
||||
try:
|
||||
return self.server_socket.accept()
|
||||
except socket.error as e:
|
||||
if e.errno in (errno.EAGAIN, errno.EINTR):
|
||||
raise ShouldRetrySocketCall
|
||||
raise
|
||||
|
||||
class Client(object):
|
||||
def maximum_connections_exceeded(self):
|
||||
return (self.max_connections is not None and
|
||||
self.number_of_connections() >= self.max_connections)
|
||||
|
||||
def number_of_connections(self):
|
||||
return len(ActorRegistry.get_by_class(self.protocol))
|
||||
|
||||
def reject_connection(self, sock, addr):
|
||||
logger.warning(u'Rejected connection from [%s]:%s', addr[0], addr[1])
|
||||
try:
|
||||
sock.close()
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
def init_connection(self, sock, addr):
|
||||
Connection(self.protocol, sock, addr, self.timeout)
|
||||
|
||||
class Connection(object):
|
||||
def __init__(self, protocol, sock, addr, timeout):
|
||||
sock.setblocking(False)
|
||||
|
||||
self._sock = sock
|
||||
self.host, self.port = addr[:2] # IPv6 has larger addr
|
||||
|
||||
self._sock = sock
|
||||
self._protocol = protocol
|
||||
self._timeout_time = timeout
|
||||
|
||||
self._send_lock = threading.Lock()
|
||||
self._send_buffer = ''
|
||||
|
||||
self._actor_ref = None
|
||||
|
||||
self._recv_id = None
|
||||
self._send_id = None
|
||||
self._timeout_id = None
|
||||
|
||||
def start(self):
|
||||
self._actor_ref = self._protocol.start(self)
|
||||
|
||||
self._enable_recv()
|
||||
self.enable_timeout()
|
||||
|
||||
@ -223,8 +244,8 @@ class LineProtocol(ThreadingActor):
|
||||
#: What encoding to expect incomming data to be in, can be :class:`None`.
|
||||
encoding = 'utf-8'
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
self.recv_buffer = ''
|
||||
self.terminator_re = re.compile(self.terminator)
|
||||
@ -242,7 +263,7 @@ class LineProtocol(ThreadingActor):
|
||||
if 'received' not in message:
|
||||
return
|
||||
|
||||
self.client.disable_timeout()
|
||||
self.connection.disable_timeout()
|
||||
self.log_raw_data(message['received'])
|
||||
|
||||
for line in self.parse_lines(message['received']):
|
||||
@ -250,11 +271,11 @@ class LineProtocol(ThreadingActor):
|
||||
self.log_request(line)
|
||||
self.on_line_received(line)
|
||||
|
||||
self.client.enable_timeout()
|
||||
self.connection.enable_timeout()
|
||||
|
||||
def on_stop(self):
|
||||
"""Ensure that cleanup when actor stops."""
|
||||
self.client.stop()
|
||||
self.connection.stop()
|
||||
|
||||
def parse_lines(self, new_data=None):
|
||||
"""Consume new data and yield any lines found."""
|
||||
@ -280,7 +301,7 @@ class LineProtocol(ThreadingActor):
|
||||
|
||||
Can be overridden by subclasses to change logging behaviour.
|
||||
"""
|
||||
logger.debug(u'Request from %s to %s: %s', self.client, self.actor_urn,
|
||||
logger.debug(u'Request from %s to %s: %s', self.connection, self.actor_urn,
|
||||
indent(request))
|
||||
|
||||
def log_response(self, response):
|
||||
@ -289,7 +310,7 @@ class LineProtocol(ThreadingActor):
|
||||
|
||||
Can be overridden by subclasses to change logging behaviour.
|
||||
"""
|
||||
logger.debug(u'Response to %s from %s: %s', self.client,
|
||||
logger.debug(u'Response to %s from %s: %s', self.connection,
|
||||
self.actor_urn, indent(response))
|
||||
|
||||
def log_error(self, error):
|
||||
@ -299,7 +320,7 @@ class LineProtocol(ThreadingActor):
|
||||
Can be overridden by subclasses to change logging behaviour.
|
||||
"""
|
||||
logger.warning(u'Problem with connection to %s in %s: %s',
|
||||
self.client, self.actor_urn, error)
|
||||
self.connection, self.actor_urn, error)
|
||||
|
||||
def log_timeout(self):
|
||||
"""
|
||||
@ -308,7 +329,7 @@ class LineProtocol(ThreadingActor):
|
||||
Can be overridden by subclasses to change logging behaviour.
|
||||
"""
|
||||
logger.debug(u'Closing connection to %s in %s due to timeout.',
|
||||
self.client, self.actor_urn)
|
||||
self.connection, self.actor_urn)
|
||||
|
||||
def encode(self, line):
|
||||
"""
|
||||
@ -332,7 +353,7 @@ class LineProtocol(ThreadingActor):
|
||||
|
||||
def send_lines(self, lines):
|
||||
"""
|
||||
Send array of lines to client.
|
||||
Send array of lines to client via connection.
|
||||
|
||||
Join lines using the terminator that is set for this class, encode it
|
||||
and send it to the client.
|
||||
@ -342,4 +363,4 @@ class LineProtocol(ThreadingActor):
|
||||
|
||||
data = self.terminator.join(lines)
|
||||
self.log_response(data)
|
||||
self.client.send(self.encode(data + self.terminator))
|
||||
self.connection.send(self.encode(data + self.terminator))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user