diff --git a/blockdev-nbd.c b/blockdev-nbd.c index 24ba5382db..f73409ae49 100644 --- a/blockdev-nbd.c +++ b/blockdev-nbd.c @@ -21,12 +21,18 @@ #include "io/channel-socket.h" #include "io/net-listener.h" +typedef struct NBDConn { + QIOChannelSocket *cioc; + QLIST_ENTRY(NBDConn) next; +} NBDConn; + typedef struct NBDServerData { QIONetListener *listener; QCryptoTLSCreds *tlscreds; char *tlsauthz; uint32_t max_connections; uint32_t connections; + QLIST_HEAD(, NBDConn) conns; } NBDServerData; static NBDServerData *nbd_server; @@ -51,6 +57,14 @@ int nbd_server_max_connections(void) static void nbd_blockdev_client_closed(NBDClient *client, bool ignored) { + NBDConn *conn = nbd_client_owner(client); + + assert(qemu_in_main_thread() && nbd_server); + + object_unref(OBJECT(conn->cioc)); + QLIST_REMOVE(conn, next); + g_free(conn); + nbd_client_put(client); assert(nbd_server->connections > 0); nbd_server->connections--; @@ -60,14 +74,20 @@ static void nbd_blockdev_client_closed(NBDClient *client, bool ignored) static void nbd_accept(QIONetListener *listener, QIOChannelSocket *cioc, gpointer opaque) { + NBDConn *conn = g_new0(NBDConn, 1); + + assert(qemu_in_main_thread() && nbd_server); nbd_server->connections++; + object_ref(OBJECT(cioc)); + conn->cioc = cioc; + QLIST_INSERT_HEAD(&nbd_server->conns, conn, next); nbd_update_server_watch(nbd_server); qio_channel_set_name(QIO_CHANNEL(cioc), "nbd-server"); /* TODO - expose handshake timeout as QMP option */ nbd_client_new(cioc, NBD_DEFAULT_HANDSHAKE_MAX_SECS, nbd_server->tlscreds, nbd_server->tlsauthz, - nbd_blockdev_client_closed, NULL); + nbd_blockdev_client_closed, conn); } static void nbd_update_server_watch(NBDServerData *s) @@ -81,12 +101,25 @@ static void nbd_update_server_watch(NBDServerData *s) static void nbd_server_free(NBDServerData *server) { + NBDConn *conn, *tmp; + if (!server) { return; } + /* + * Forcefully close the listener socket, and any clients that have + * not yet disconnected on their own. + */ qio_net_listener_disconnect(server->listener); object_unref(OBJECT(server->listener)); + QLIST_FOREACH_SAFE(conn, &server->conns, next, tmp) { + qio_channel_shutdown(QIO_CHANNEL(conn->cioc), QIO_CHANNEL_SHUTDOWN_BOTH, + NULL); + } + + AIO_WAIT_WHILE_UNLOCKED(NULL, server->connections > 0); + if (server->tlscreds) { object_unref(OBJECT(server->tlscreds)); }