mirror of https://github.com/xemu-project/xemu.git
python/aqmp: add AsyncProtocol.accept() method
It's a little messier than connect, because it wasn't designed to accept *precisely one* connection. Such is life. Signed-off-by: John Snow <jsnow@redhat.com> Reviewed-by: Eric Blake <eblake@redhat.com> Message-id: 20210915162955.333025-10-jsnow@redhat.com Signed-off-by: John Snow <jsnow@redhat.com>
This commit is contained in:
parent
50e533061f
commit
774c64a58d
|
@ -243,6 +243,24 @@ class AsyncProtocol(Generic[T]):
|
||||||
await self._runstate_event.wait()
|
await self._runstate_event.wait()
|
||||||
return self.runstate
|
return self.runstate
|
||||||
|
|
||||||
|
@upper_half
|
||||||
|
@require(Runstate.IDLE)
|
||||||
|
async def accept(self, address: Union[str, Tuple[str, int]],
|
||||||
|
ssl: Optional[SSLContext] = None) -> None:
|
||||||
|
"""
|
||||||
|
Accept a connection and begin processing message queues.
|
||||||
|
|
||||||
|
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
|
||||||
|
|
||||||
|
:param address:
|
||||||
|
Address to listen to; UNIX socket path or TCP address/port.
|
||||||
|
:param ssl: SSL context to use, if any.
|
||||||
|
|
||||||
|
:raise StateError: When the `Runstate` is not `IDLE`.
|
||||||
|
:raise ConnectError: If a connection could not be accepted.
|
||||||
|
"""
|
||||||
|
await self._new_session(address, ssl, accept=True)
|
||||||
|
|
||||||
@upper_half
|
@upper_half
|
||||||
@require(Runstate.IDLE)
|
@require(Runstate.IDLE)
|
||||||
async def connect(self, address: Union[str, Tuple[str, int]],
|
async def connect(self, address: Union[str, Tuple[str, int]],
|
||||||
|
@ -308,7 +326,8 @@ class AsyncProtocol(Generic[T]):
|
||||||
@upper_half
|
@upper_half
|
||||||
async def _new_session(self,
|
async def _new_session(self,
|
||||||
address: Union[str, Tuple[str, int]],
|
address: Union[str, Tuple[str, int]],
|
||||||
ssl: Optional[SSLContext] = None) -> None:
|
ssl: Optional[SSLContext] = None,
|
||||||
|
accept: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Establish a new connection and initialize the session.
|
Establish a new connection and initialize the session.
|
||||||
|
|
||||||
|
@ -317,9 +336,10 @@ class AsyncProtocol(Generic[T]):
|
||||||
to be set back to `IDLE`.
|
to be set back to `IDLE`.
|
||||||
|
|
||||||
:param address:
|
:param address:
|
||||||
Address to connect to;
|
Address to connect to/listen on;
|
||||||
UNIX socket path or TCP address/port.
|
UNIX socket path or TCP address/port.
|
||||||
:param ssl: SSL context to use, if any.
|
:param ssl: SSL context to use, if any.
|
||||||
|
:param accept: Accept a connection instead of connecting when `True`.
|
||||||
|
|
||||||
:raise ConnectError:
|
:raise ConnectError:
|
||||||
When a connection or session cannot be established.
|
When a connection or session cannot be established.
|
||||||
|
@ -333,7 +353,7 @@ class AsyncProtocol(Generic[T]):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
phase = "connection"
|
phase = "connection"
|
||||||
await self._establish_connection(address, ssl)
|
await self._establish_connection(address, ssl, accept)
|
||||||
|
|
||||||
phase = "session"
|
phase = "session"
|
||||||
await self._establish_session()
|
await self._establish_session()
|
||||||
|
@ -367,6 +387,7 @@ class AsyncProtocol(Generic[T]):
|
||||||
self,
|
self,
|
||||||
address: Union[str, Tuple[str, int]],
|
address: Union[str, Tuple[str, int]],
|
||||||
ssl: Optional[SSLContext] = None,
|
ssl: Optional[SSLContext] = None,
|
||||||
|
accept: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Establish a new connection.
|
Establish a new connection.
|
||||||
|
@ -375,6 +396,7 @@ class AsyncProtocol(Generic[T]):
|
||||||
Address to connect to/listen on;
|
Address to connect to/listen on;
|
||||||
UNIX socket path or TCP address/port.
|
UNIX socket path or TCP address/port.
|
||||||
:param ssl: SSL context to use, if any.
|
:param ssl: SSL context to use, if any.
|
||||||
|
:param accept: Accept a connection instead of connecting when `True`.
|
||||||
"""
|
"""
|
||||||
assert self.runstate == Runstate.IDLE
|
assert self.runstate == Runstate.IDLE
|
||||||
self._set_state(Runstate.CONNECTING)
|
self._set_state(Runstate.CONNECTING)
|
||||||
|
@ -384,7 +406,66 @@ class AsyncProtocol(Generic[T]):
|
||||||
# otherwise yield.
|
# otherwise yield.
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
await self._do_connect(address, ssl)
|
if accept:
|
||||||
|
await self._do_accept(address, ssl)
|
||||||
|
else:
|
||||||
|
await self._do_connect(address, ssl)
|
||||||
|
|
||||||
|
@upper_half
|
||||||
|
async def _do_accept(self, address: Union[str, Tuple[str, int]],
|
||||||
|
ssl: Optional[SSLContext] = None) -> None:
|
||||||
|
"""
|
||||||
|
Acting as the transport server, accept a single connection.
|
||||||
|
|
||||||
|
:param address:
|
||||||
|
Address to listen on; UNIX socket path or TCP address/port.
|
||||||
|
:param ssl: SSL context to use, if any.
|
||||||
|
|
||||||
|
:raise OSError: For stream-related errors.
|
||||||
|
"""
|
||||||
|
self.logger.debug("Awaiting connection on %s ...", address)
|
||||||
|
connected = asyncio.Event()
|
||||||
|
server: Optional[asyncio.AbstractServer] = None
|
||||||
|
|
||||||
|
async def _client_connected_cb(reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter) -> None:
|
||||||
|
"""Used to accept a single incoming connection, see below."""
|
||||||
|
nonlocal server
|
||||||
|
nonlocal connected
|
||||||
|
|
||||||
|
# A connection has been accepted; stop listening for new ones.
|
||||||
|
assert server is not None
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
server = None
|
||||||
|
|
||||||
|
# Register this client as being connected
|
||||||
|
self._reader, self._writer = (reader, writer)
|
||||||
|
|
||||||
|
# Signal back: We've accepted a client!
|
||||||
|
connected.set()
|
||||||
|
|
||||||
|
if isinstance(address, tuple):
|
||||||
|
coro = asyncio.start_server(
|
||||||
|
_client_connected_cb,
|
||||||
|
host=address[0],
|
||||||
|
port=address[1],
|
||||||
|
ssl=ssl,
|
||||||
|
backlog=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
coro = asyncio.start_unix_server(
|
||||||
|
_client_connected_cb,
|
||||||
|
path=address,
|
||||||
|
ssl=ssl,
|
||||||
|
backlog=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
server = await coro # Starts listening
|
||||||
|
await connected.wait() # Waits for the callback to fire (and finish)
|
||||||
|
assert server is None
|
||||||
|
|
||||||
|
self.logger.debug("Connection accepted.")
|
||||||
|
|
||||||
@upper_half
|
@upper_half
|
||||||
async def _do_connect(self, address: Union[str, Tuple[str, int]],
|
async def _do_connect(self, address: Union[str, Tuple[str, int]],
|
||||||
|
|
Loading…
Reference in New Issue