diff --git a/tests/utils/network_test.py b/tests/utils/network_test.py index e7767689..b686e20c 100644 --- a/tests/utils/network_test.py +++ b/tests/utils/network_test.py @@ -60,48 +60,44 @@ class CreateSocketTest(unittest.TestCase): class ServerTest(unittest.TestCase): def setUp(self): - self.protocol = network.LineProtocol - self.addr = (sentinel.host, sentinel.port) - self.host, self.port = self.addr + self.mock = Mock(spec=network.Server) - self.create_server_socket_patchter = patch.object( - network.Server, 'create_server_socket', new=Mock()) - self.register_server_socket_patcher = patch.object( - network.Server, 'register_server_socket', new=Mock()) + def test_init_calls_create_server_socket(self): + network.Server.__init__(self.mock, sentinel.host, + sentinel.port, sentinel.protocol) + self.mock.create_server_socket.assert_called_once_with( + sentinel.host, sentinel.port) - self.create_server_socket_patchter.start() - self.register_server_socket_patcher.start() + def test_init_calls_register_server(self): + sock = Mock(spec=socket.SocketType) + sock.fileno.return_value = sentinel.fileno + self.mock.create_server_socket.return_value = sock - def tearDown(self): - self.create_server_socket_patchter.stop() - self.register_server_socket_patcher.stop() + network.Server.__init__(self.mock, sentinel.host, + sentinel.port, sentinel.protocol) + self.mock.register_server_socket.assert_called_once_with(sentinel.fileno) - def create_server(self): - return network.Server(sentinel.host, sentinel.port, self.protocol) + def test_init_stores_values_in_attributes(self): + sock = Mock(spec=socket.SocketType) + self.mock.create_server_socket.return_value = sock - def test_init_creates_socket_and_registers_it(self): - server = self.create_server() - sock = server.create_server_socket.return_value - fileno = sock.fileno.return_value - - server.create_server_socket.assert_called_once_with( - self.host, self.port) - server.register_server_socket.assert_called_once_with(fileno) + network.Server.__init__(self.mock, sentinel.host, sentinel.port, + sentinel.protocol, max_connections=sentinel.max_connections, + timeout=sentinel.timeout) + self.assertEqual(sentinel.protocol, self.mock.protocol) + self.assertEqual(sentinel.max_connections, self.mock.max_connections) + self.assertEqual(sentinel.timeout, self.mock.timeout) + self.assertEqual(sock, self.mock.server_socket) @patch.object(network, 'create_socket', spec=socket.SocketType) def test_create_server_socket_sets_up_listener(self, create_socket): - self.create_server_socket_patchter.stop() + sock = create_socket.return_value - try: - server = self.create_server() - sock = create_socket.return_value - - sock.setblocking.assert_called_once_with(False) - sock.bind.assert_called_once_with(self.addr) - self.assertEqual(1, sock.listen.call_count) - self.assertEqual(sock, server.server_socket) - finally: - self.create_server_socket_patchter.start() + network.Server.create_server_socket(self.mock, + sentinel.host, sentinel.port) + sock.setblocking.assert_called_once_with(False) + sock.bind.assert_called_once_with((sentinel.host, sentinel.port)) + self.assertEqual(1, sock.listen.call_count) @SkipTest def test_create_server_socket_fails(self): @@ -111,107 +107,88 @@ class ServerTest(unittest.TestCase): @patch.object(gobject, 'io_add_watch', new=Mock()) def test_register_server_socket_sets_up_io_watch(self): - self.register_server_socket_patcher.stop() + network.Server.register_server_socket(self.mock, sentinel.fileno) + gobject.io_add_watch.assert_called_once_with(sentinel.fileno, + gobject.IO_IN, self.mock.handle_connection) - try: - server = self.create_server() - sock = server.create_server_socket.return_value - fileno = sock.fileno.return_value - - gobject.io_add_watch.assert_called_once_with( - fileno, gobject.IO_IN, server.handle_connection) - finally: - self.register_server_socket_patcher.start() - - @patch.object(network.Server, 'accept_connection', new=Mock()) - @patch.object(network.Server, 'maximum_connections_exceeded', new=Mock()) - @patch.object(network.Server, 'reject_connection', new=Mock()) - @patch.object(network.Server, 'init_connection', new=Mock()) def test_handle_connection(self): - server = self.create_server() - server.accept_connection.return_value = (sentinel.sock, self.addr) - server.maximum_connections_exceeded.return_value = False + self.mock.accept_connection.return_value = (sentinel.sock, sentinel.addr) + self.mock.maximum_connections_exceeded.return_value = False - server.handle_connection(sentinel.fileno, gobject.IO_IN) + network.Server.handle_connection(self.mock, sentinel.fileno, gobject.IO_IN) + self.mock.accept_connection.assert_called_once_with() + self.mock.maximum_connections_exceeded.assert_called_once_with() + self.mock.init_connection.assert_called_once_with(sentinel.sock, sentinel.addr) + self.assertEquals(0, self.mock.reject_connection.call_count) - server.accept_connection.assert_called_once_with() - server.maximum_connections_exceeded.assert_called_once_with() - server.init_connection.assert_called_once_with(sentinel.sock, self.addr) - self.assertEquals(0, server.reject_connection.call_count) - - @patch.object(network.Server, 'accept_connection', new=Mock()) - @patch.object(network.Server, 'maximum_connections_exceeded', new=Mock()) - @patch.object(network.Server, 'reject_connection', new=Mock()) - @patch.object(network.Server, 'init_connection', new=Mock()) def test_handle_connection_exceeded_connections(self): - server = self.create_server() - server.accept_connection.return_value = (sentinel.sock, self.addr) - server.maximum_connections_exceeded.return_value = True + self.mock.accept_connection.return_value = (sentinel.sock, sentinel.addr) + self.mock.maximum_connections_exceeded.return_value = True - server.handle_connection(sentinel.fileno, gobject.IO_IN) - - server.accept_connection.assert_called_once_with() - server.maximum_connections_exceeded.assert_called_once_with() - server.reject_connection.assert_called_once_with( - sentinel.sock, self.addr) - self.assertEquals(0, server.init_connection.call_count) + network.Server.handle_connection(self.mock, sentinel.fileno, gobject.IO_IN) + self.mock.accept_connection.assert_called_once_with() + self.mock.maximum_connections_exceeded.assert_called_once_with() + self.mock.reject_connection.assert_called_once_with(sentinel.sock, sentinel.addr) + self.assertEquals(0, self.mock.init_connection.call_count) def test_accept_connection(self): - server = self.create_server() - sock = server.create_server_socket.return_value - sock.accept.return_value = (sentinel.sock, self.addr) + sock = Mock(spec=socket.SocketType) + sock.accept.return_value = (sentinel.sock, sentinel.addr) + self.mock.server_socket = sock - self.assertEquals((sentinel.sock, self.addr), - server.accept_connection()) + sock, addr = network.Server.accept_connection(self.mock) + self.assertEquals(sentinel.sock, sock) + self.assertEquals(sentinel.addr, addr) def test_accept_connection_recoverable_error(self): - server = self.create_server() - sock = server.create_server_socket.return_value + sock = Mock(spec=socket.SocketType) + self.mock.server_socket = sock sock.accept.side_effect = socket.error(errno.EAGAIN, '') self.assertRaises(network.ShouldRetrySocketCall, - server.accept_connection) + network.Server.accept_connection, self.mock) sock.accept.side_effect = socket.error(errno.EINTR, '') self.assertRaises(network.ShouldRetrySocketCall, - server.accept_connection) + network.Server.accept_connection, self.mock) - def test_accept_connection_recoverable_error(self): - server = self.create_server() - sock = server.create_server_socket.return_value + def test_accept_connection_unrecoverable_error(self): + sock = Mock(spec=socket.SocketType) + self.mock.server_socket = sock sock.accept.side_effect = socket.error() - self.assertRaises(socket.error, server.accept_connection) + self.assertRaises(socket.error, + network.Server.accept_connection, self.mock) @patch.object(network.Server, 'number_of_connections', new=Mock()) def test_maximum_connections_exceeded(self): - server = self.create_server() - maximum_connections = server.max_connections + self.mock.max_connections = 10 - server.number_of_connections.return_value = maximum_connections + 1 - self.assertTrue(server.maximum_connections_exceeded()) + self.mock.number_of_connections.return_value = 11 + self.assertTrue(network.Server.maximum_connections_exceeded(self.mock)) - server.number_of_connections.return_value = maximum_connections - self.assertTrue(server.maximum_connections_exceeded()) + self.mock.number_of_connections.return_value = 10 + self.assertTrue(network.Server.maximum_connections_exceeded(self.mock)) - server.number_of_connections.return_value = maximum_connections - 1 - self.assertFalse(server.maximum_connections_exceeded()) + self.mock.number_of_connections.return_value = 9 + self.assertFalse(network.Server.maximum_connections_exceeded(self.mock)) @patch('pykka.registry.ActorRegistry.get_by_class') def test_number_of_connections(self, get_by_class): - server = self.create_server() + self.mock.protocol = sentinel.protocol get_by_class.return_value = [1, 2, 3] - self.assertEqual(3, server.number_of_connections()) + self.assertEqual(3, network.Server.number_of_connections(self.mock)) get_by_class.return_value = [] - self.assertEqual(0, server.number_of_connections()) + self.assertEqual(0, network.Server.number_of_connections(self.mock)) @patch.object(network, 'Connection', new=Mock()) def test_init_connection(self): - server = self.create_server() - server.init_connection(sentinel.sock, self.addr) + self.mock.protocol = sentinel.protocol + self.mock.timeout = sentinel.timeout - network.Connection.assert_called_once_with(server.protocol, - sentinel.sock, self.addr, server.timeout) + network.Server.init_connection(self.mock, sentinel.sock, sentinel.addr) + network.Connection.assert_called_once_with(sentinel.protocol, + sentinel.sock, sentinel.addr, sentinel.timeout)