diff --git a/nbd/client.c b/nbd/client.c index 29ffc609a4..c89c750467 100644 --- a/nbd/client.c +++ b/nbd/client.c @@ -596,13 +596,31 @@ static int nbd_request_simple_option(QIOChannel *ioc, int opt, bool strict, return 1; } +/* Callback to learn when QIO TLS upgrade is complete */ +struct NBDTLSClientHandshakeData { + bool complete; + Error *error; + GMainLoop *loop; +}; + +static void nbd_client_tls_handshake(QIOTask *task, void *opaque) +{ + struct NBDTLSClientHandshakeData *data = opaque; + + qio_task_propagate_error(task, &data->error); + data->complete = true; + if (data->loop) { + g_main_loop_quit(data->loop); + } +} + static QIOChannel *nbd_receive_starttls(QIOChannel *ioc, QCryptoTLSCreds *tlscreds, const char *hostname, Error **errp) { int ret; QIOChannelTLS *tioc; - struct NBDTLSHandshakeData data = { 0 }; + struct NBDTLSClientHandshakeData data = { 0 }; ret = nbd_request_simple_option(ioc, NBD_OPT_STARTTLS, true, errp); if (ret <= 0) { @@ -619,18 +637,20 @@ static QIOChannel *nbd_receive_starttls(QIOChannel *ioc, return NULL; } qio_channel_set_name(QIO_CHANNEL(tioc), "nbd-client-tls"); - data.loop = g_main_loop_new(g_main_context_default(), FALSE); trace_nbd_receive_starttls_tls_handshake(); qio_channel_tls_handshake(tioc, - nbd_tls_handshake, + nbd_client_tls_handshake, &data, NULL, NULL); if (!data.complete) { + data.loop = g_main_loop_new(g_main_context_default(), FALSE); g_main_loop_run(data.loop); + assert(data.complete); + g_main_loop_unref(data.loop); } - g_main_loop_unref(data.loop); + if (data.error) { error_propagate(errp, data.error); object_unref(OBJECT(tioc)); diff --git a/nbd/common.c b/nbd/common.c index 3247c1d618..589a748cfe 100644 --- a/nbd/common.c +++ b/nbd/common.c @@ -47,17 +47,6 @@ int nbd_drop(QIOChannel *ioc, size_t size, Error **errp) } -void nbd_tls_handshake(QIOTask *task, - void *opaque) -{ - struct NBDTLSHandshakeData *data = opaque; - - qio_task_propagate_error(task, &data->error); - data->complete = true; - g_main_loop_quit(data->loop); -} - - const char *nbd_opt_lookup(uint32_t opt) { switch (opt) { diff --git a/nbd/nbd-internal.h b/nbd/nbd-internal.h index dfa02f77ee..91895106a9 100644 --- a/nbd/nbd-internal.h +++ b/nbd/nbd-internal.h @@ -72,16 +72,6 @@ static inline int nbd_write(QIOChannel *ioc, const void *buffer, size_t size, return qio_channel_write_all(ioc, buffer, size, errp) < 0 ? -EIO : 0; } -struct NBDTLSHandshakeData { - GMainLoop *loop; - bool complete; - Error *error; -}; - - -void nbd_tls_handshake(QIOTask *task, - void *opaque); - int nbd_drop(QIOChannel *ioc, size_t size, Error **errp); #endif diff --git a/nbd/server.c b/nbd/server.c index c3484cc1eb..98ae0e1632 100644 --- a/nbd/server.c +++ b/nbd/server.c @@ -748,6 +748,23 @@ static int nbd_negotiate_handle_info(NBDClient *client, Error **errp) return rc; } +/* Callback to learn when QIO TLS upgrade is complete */ +struct NBDTLSServerHandshakeData { + bool complete; + Error *error; + Coroutine *co; +}; + +static void nbd_server_tls_handshake(QIOTask *task, void *opaque) +{ + struct NBDTLSServerHandshakeData *data = opaque; + + qio_task_propagate_error(task, &data->error); + data->complete = true; + if (!qemu_coroutine_entered(data->co)) { + aio_co_wake(data->co); + } +} /* Handle NBD_OPT_STARTTLS. Return NULL to drop connection, or else the * new channel for all further (now-encrypted) communication. */ @@ -756,7 +773,7 @@ static QIOChannel *nbd_negotiate_handle_starttls(NBDClient *client, { QIOChannel *ioc; QIOChannelTLS *tioc; - struct NBDTLSHandshakeData data = { 0 }; + struct NBDTLSServerHandshakeData data = { 0 }; assert(client->opt == NBD_OPT_STARTTLS); @@ -777,17 +794,18 @@ static QIOChannel *nbd_negotiate_handle_starttls(NBDClient *client, qio_channel_set_name(QIO_CHANNEL(tioc), "nbd-server-tls"); trace_nbd_negotiate_handle_starttls_handshake(); - data.loop = g_main_loop_new(g_main_context_default(), FALSE); + data.co = qemu_coroutine_self(); qio_channel_tls_handshake(tioc, - nbd_tls_handshake, + nbd_server_tls_handshake, &data, NULL, NULL); if (!data.complete) { - g_main_loop_run(data.loop); + qemu_coroutine_yield(); + assert(data.complete); } - g_main_loop_unref(data.loop); + if (data.error) { object_unref(OBJECT(tioc)); error_propagate(errp, data.error);