diff --git a/mopidy/utils/network.py b/mopidy/utils/network.py index 15ddb98e..9280e772 100644 --- a/mopidy/utils/network.py +++ b/mopidy/utils/network.py @@ -12,6 +12,9 @@ from mopidy.utils.log import indent logger = logging.getLogger('mopidy.utils.server') +class ShouldRetrySocketCall(Exception): + """Indicate that attempted socket call should be retried""" + def _try_ipv6_socket(): """Determine if system really supports IPv6""" if not socket.has_ipv6: @@ -51,61 +54,79 @@ class Server(object): self.protocol = protocol self.max_connections = max_connections self.timeout = timeout + self.server_socket = self.create_server_socket(host, port) - self.listener = create_socket() - self.listener.setblocking(False) - self.listener.bind((host, port)) - self.listener.listen(1) - - gobject.io_add_watch( - self.listener.fileno(), gobject.IO_IN, self.handle_accept) + self.register_server_socket(self.server_socket.fileno()) logger.debug(u'Listening on [%s]:%s using %s as protocol handler', - host, port, self.protocol.__name__) + host, port, self.protocol) - def handle_accept(self, fd, flags): + def create_server_socket(self, host, port): + sock = create_socket() + sock.setblocking(False) + sock.bind((host, port)) + sock.listen(1) + return sock + + def register_server_socket(self, fileno): + gobject.io_add_watch(fileno, gobject.IO_IN, self.handle_connection) + + def handle_connection(self, fd, flags): try: - sock, addr = self.listener.accept() - except socket.error as e: - if e.errno in (errno.EAGAIN, errno.EINTR): - return True # i.e. retry - raise - - num_connections = len(ActorRegistry.get_by_class(self.protocol)) - if self.max_connections and num_connections >= self.max_connections: - logger.warning(u'Rejected connection from [%s]:%s', addr[0], addr[1]) - try: - sock.close() - except socket.error: - pass + sock, addr = self.accept_connection() + except ShouldRetrySocketCall: return True - client = Client(self.protocol, sock, addr, self.timeout) - client.start() - + if self.maximum_connections_exceeded(): + self.reject_connection(sock, addr) + else: + self.init_connection(sock, addr) return True + def accept_connection(self): + try: + return self.server_socket.accept() + except socket.error as e: + if e.errno in (errno.EAGAIN, errno.EINTR): + raise ShouldRetrySocketCall + raise -class Client(object): + def maximum_connections_exceeded(self): + return (self.max_connections is not None and + self.number_of_connections() >= self.max_connections) + + def number_of_connections(self): + return len(ActorRegistry.get_by_class(self.protocol)) + + def reject_connection(self, sock, addr): + logger.warning(u'Rejected connection from [%s]:%s', addr[0], addr[1]) + try: + sock.close() + except socket.error: + pass + + def init_connection(self, sock, addr): + Connection(self.protocol, sock, addr, self.timeout) + +class Connection(object): def __init__(self, protocol, sock, addr, timeout): sock.setblocking(False) - self._sock = sock self.host, self.port = addr[:2] # IPv6 has larger addr + + self._sock = sock self._protocol = protocol self._timeout_time = timeout self._send_lock = threading.Lock() self._send_buffer = '' - self._actor_ref = None - self._recv_id = None self._send_id = None self._timeout_id = None - def start(self): self._actor_ref = self._protocol.start(self) + self._enable_recv() self.enable_timeout() @@ -223,8 +244,8 @@ class LineProtocol(ThreadingActor): #: What encoding to expect incomming data to be in, can be :class:`None`. encoding = 'utf-8' - def __init__(self, client): - self.client = client + def __init__(self, connection): + self.connection = connection self.recv_buffer = '' self.terminator_re = re.compile(self.terminator) @@ -242,7 +263,7 @@ class LineProtocol(ThreadingActor): if 'received' not in message: return - self.client.disable_timeout() + self.connection.disable_timeout() self.log_raw_data(message['received']) for line in self.parse_lines(message['received']): @@ -250,11 +271,11 @@ class LineProtocol(ThreadingActor): self.log_request(line) self.on_line_received(line) - self.client.enable_timeout() + self.connection.enable_timeout() def on_stop(self): """Ensure that cleanup when actor stops.""" - self.client.stop() + self.connection.stop() def parse_lines(self, new_data=None): """Consume new data and yield any lines found.""" @@ -280,7 +301,7 @@ class LineProtocol(ThreadingActor): Can be overridden by subclasses to change logging behaviour. """ - logger.debug(u'Request from %s to %s: %s', self.client, self.actor_urn, + logger.debug(u'Request from %s to %s: %s', self.connection, self.actor_urn, indent(request)) def log_response(self, response): @@ -289,7 +310,7 @@ class LineProtocol(ThreadingActor): Can be overridden by subclasses to change logging behaviour. """ - logger.debug(u'Response to %s from %s: %s', self.client, + logger.debug(u'Response to %s from %s: %s', self.connection, self.actor_urn, indent(response)) def log_error(self, error): @@ -299,7 +320,7 @@ class LineProtocol(ThreadingActor): Can be overridden by subclasses to change logging behaviour. """ logger.warning(u'Problem with connection to %s in %s: %s', - self.client, self.actor_urn, error) + self.connection, self.actor_urn, error) def log_timeout(self): """ @@ -308,7 +329,7 @@ class LineProtocol(ThreadingActor): Can be overridden by subclasses to change logging behaviour. """ logger.debug(u'Closing connection to %s in %s due to timeout.', - self.client, self.actor_urn) + self.connection, self.actor_urn) def encode(self, line): """ @@ -332,7 +353,7 @@ class LineProtocol(ThreadingActor): def send_lines(self, lines): """ - Send array of lines to client. + Send array of lines to client via connection. Join lines using the terminator that is set for this class, encode it and send it to the client. @@ -342,4 +363,4 @@ class LineProtocol(ThreadingActor): data = self.terminator.join(lines) self.log_response(data) - self.client.send(self.encode(data + self.terminator)) + self.connection.send(self.encode(data + self.terminator))