Python patches

Hopefully, fixes the race conditions witnessed through the NetBSD vm tests.
 -----BEGIN PGP SIGNATURE-----
 
 iQIzBAABCAAdFiEE+ber27ys35W+dsvQfe+BBqr8OQ4FAmImg9IACgkQfe+BBqr8
 OQ6pMxAAgilUH8OIJzJfV2C/1qWM2Hzrl/jwTUEuYxmMYacdL9kJvR3NJ4CMv5Nn
 996TyJROK+QDQoVsUuoEjkdrezbI4UDoixM9ku7KWAUMEsxXmRR5kcclSkCWX4HX
 o+My1UR+6LxPgH894JMTcnKzH9gDHkU0Aww/nu5LumJoVB12Gu1iLif/2JneQKFB
 rWaQu+8DHGH7Jv9s0ShrmkDYwtwq5XXGtefR6DEdo5xGGCjzYrYr80Frg7R1OYVU
 xlGV0MbLjTmePM5F4ZxiQGohFSOY6QsraxDMiqVOc+gBjz2J8l+7i8AA3Zirwotz
 V9BYPDRZ9pZV3ERDPqh0L3homsmk2wepkXi6YAz9/DMn0pDHizmvntPCCdhzBXyH
 cA63+QayvCYADDoHkUbMT5jc7X6ayfauj7ZkJPzfr7YtzYKs6k0bDmtgJBMyNRj1
 pHILnv5oGnnVz4kO5W98oV2jijAdqi9or3+4B2woeUmaROoQJA0ObU35ke961KNE
 n66kTOibgMj/TQmDE1veBgNvCxY0cRE+ZB7SYL7ZaqvavEwfeYQRz851sDxTdiFF
 v5b/Ls8IDKPbU8qPLDzTQrAy19CWtOkJTD4b4/6WAv9K0SAxghQEyoCUCZbk+PLt
 xGeCyxImTC7XaqFlops9WzBTK3jz/7m9EvgfJNRKj8QZ49yxCBo=
 =0ieN
 -----END PGP SIGNATURE-----

Merge remote-tracking branch 'remotes/jsnow-gitlab/tags/python-pull-request' into staging

Python patches

Hopefully, fixes the race conditions witnessed through the NetBSD vm tests.

# gpg: Signature made Mon 07 Mar 2022 22:14:42 GMT
# gpg:                using RSA key F9B7ABDBBCACDF95BE76CBD07DEF8106AAFC390E
# gpg: Good signature from "John Snow (John Huston) <jsnow@redhat.com>" [full]
# Primary key fingerprint: FAEB 9711 A12C F475 812F  18F2 88A9 064D 1835 61EB
#      Subkey fingerprint: F9B7 ABDB BCAC DF95 BE76  CBD0 7DEF 8106 AAFC 390E

* remotes/jsnow-gitlab/tags/python-pull-request:
  scripts/qmp-shell-wrap: Fix import path
  python/aqmp: drop _bind_hack()
  python/aqmp: fix race condition in legacy.py
  python/aqmp: add start_server() and accept() methods
  python/aqmp: stop the server during disconnect()
  python/aqmp: refactor _do_accept() into two distinct steps
  python/aqmp: squelch pylint warning for too many lines
  python/aqmp: split _client_connected_cb() out as _incoming()
  python/aqmp: remove _new_session and _establish_connection
  python/aqmp: rename 'accept()' to 'start_server_and_accept()'
  python/aqmp: add _session_guard()

Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
This commit is contained in:
Peter Maydell 2022-03-08 19:31:05 +00:00
commit 2ad7624900
4 changed files with 272 additions and 171 deletions

View File

@ -57,7 +57,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
self._timeout: Optional[float] = None self._timeout: Optional[float] = None
if server: if server:
self._aqmp._bind_hack(address) # pylint: disable=protected-access self._sync(self._aqmp.start_server(self._address))
_T = TypeVar('_T') _T = TypeVar('_T')
@ -90,10 +90,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
self._aqmp.await_greeting = True self._aqmp.await_greeting = True
self._aqmp.negotiate = True self._aqmp.negotiate = True
self._sync( self._sync(self._aqmp.accept(), timeout)
self._aqmp.accept(self._address),
timeout
)
ret = self._get_greeting() ret = self._get_greeting()
assert ret is not None assert ret is not None

View File

@ -10,12 +10,14 @@ In this package, it is used as the implementation for the `QMPClient`
class. class.
""" """
# It's all the docstrings ... ! It's long for a good reason ^_^;
# pylint: disable=too-many-lines
import asyncio import asyncio
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
import logging import logging
import socket
from ssl import SSLContext from ssl import SSLContext
from typing import ( from typing import (
Any, Any,
@ -239,8 +241,9 @@ class AsyncProtocol(Generic[T]):
self._runstate = Runstate.IDLE self._runstate = Runstate.IDLE
self._runstate_changed: Optional[asyncio.Event] = None self._runstate_changed: Optional[asyncio.Event] = None
# Workaround for bind() # Server state for start_server() and _incoming()
self._sock: Optional[socket.socket] = None self._server: Optional[asyncio.AbstractServer] = None
self._accepted: Optional[asyncio.Event] = None
def __repr__(self) -> str: def __repr__(self) -> str:
cls_name = type(self).__name__ cls_name = type(self).__name__
@ -265,21 +268,90 @@ class AsyncProtocol(Generic[T]):
@upper_half @upper_half
@require(Runstate.IDLE) @require(Runstate.IDLE)
async def accept(self, address: SocketAddrT, async def start_server_and_accept(
ssl: Optional[SSLContext] = None) -> None: self, address: SocketAddrT,
ssl: Optional[SSLContext] = None
) -> None:
""" """
Accept a connection and begin processing message queues. Accept a connection and begin processing message queues.
If this call fails, `runstate` is guaranteed to be set back to `IDLE`. If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
This method is precisely equivalent to calling `start_server()`
followed by `accept()`.
:param address: :param address:
Address to listen to; UNIX socket path or TCP address/port. Address to listen on; UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any. :param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`. :raise StateError: When the `Runstate` is not `IDLE`.
:raise ConnectError: If a connection could not be accepted. :raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
""" """
await self._new_session(address, ssl, accept=True) await self.start_server(address, ssl)
await self.accept()
assert self.runstate == Runstate.RUNNING
@upper_half
@require(Runstate.IDLE)
async def start_server(self, address: SocketAddrT,
ssl: Optional[SSLContext] = None) -> None:
"""
Start listening for an incoming connection, but do not wait for a peer.
This method starts listening for an incoming connection, but
does not block waiting for a peer. This call will return
immediately after binding and listening on a socket. A later
call to `accept()` must be made in order to finalize the
incoming connection.
:param address:
Address to listen on; UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`.
:raise ConnectError:
When the server could not start listening on this address.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError`.
"""
await self._session_guard(
self._do_start_server(address, ssl),
'Failed to establish connection')
assert self.runstate == Runstate.CONNECTING
@upper_half
@require(Runstate.CONNECTING)
async def accept(self) -> None:
"""
Accept an incoming connection and begin processing message queues.
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
:raise StateError: When the `Runstate` is not `CONNECTING`.
:raise QMPError: When `start_server()` was not called yet.
:raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
"""
if self._accepted is None:
raise QMPError("Cannot call accept() before start_server().")
await self._session_guard(
self._do_accept(),
'Failed to establish connection')
await self._session_guard(
self._establish_session(),
'Failed to establish session')
assert self.runstate == Runstate.RUNNING
@upper_half @upper_half
@require(Runstate.IDLE) @require(Runstate.IDLE)
@ -295,9 +367,21 @@ class AsyncProtocol(Generic[T]):
:param ssl: SSL context to use, if any. :param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`. :raise StateError: When the `Runstate` is not `IDLE`.
:raise ConnectError: If a connection cannot be made to the server. :raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
""" """
await self._new_session(address, ssl) await self._session_guard(
self._do_connect(address, ssl),
'Failed to establish connection')
await self._session_guard(
self._establish_session(),
'Failed to establish session')
assert self.runstate == Runstate.RUNNING
@upper_half @upper_half
async def disconnect(self) -> None: async def disconnect(self) -> None:
@ -317,6 +401,62 @@ class AsyncProtocol(Generic[T]):
# Section: Session machinery # Section: Session machinery
# -------------------------- # --------------------------
async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
"""
Async guard function used to roll back to `IDLE` on any error.
On any Exception, the state machine will be reset back to
`IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
`BaseException` events will be left alone (This includes
asyncio.CancelledError, even prior to Python 3.8).
:param error_message:
Human-readable string describing what connection phase failed.
:raise BaseException:
When `BaseException` occurs in the guarded block.
:raise ConnectError:
When any other error is encountered in the guarded block.
"""
# Note: After Python 3.6 support is removed, this should be an
# @asynccontextmanager instead of accepting a callback.
try:
await coro
except BaseException as err:
self.logger.error("%s: %s", emsg, exception_summary(err))
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
try:
# Reset the runstate back to IDLE.
await self.disconnect()
except:
# We don't expect any Exceptions from the disconnect function
# here, because we failed to connect in the first place.
# The disconnect() function is intended to perform
# only cannot-fail cleanup here, but you never know.
emsg = (
"Unexpected bottom half exception. "
"This is a bug in the QMP library. "
"Please report it to <qemu-devel@nongnu.org> and "
"CC: John Snow <jsnow@redhat.com>."
)
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
raise
# CancelledError is an Exception with special semantic meaning;
# We do NOT want to wrap it up under ConnectError.
# NB: CancelledError is not a BaseException before Python 3.8
if isinstance(err, asyncio.CancelledError):
raise
# Any other kind of error can be treated as some kind of connection
# failure broadly. Inspect the 'exc' field to explore the root
# cause in greater detail.
if isinstance(err, Exception):
raise ConnectError(emsg, err) from err
# Raise BaseExceptions un-wrapped, they're more important.
raise
@property @property
def _runstate_event(self) -> asyncio.Event: def _runstate_event(self) -> asyncio.Event:
# asyncio.Event() objects should not be created prior to entrance into # asyncio.Event() objects should not be created prior to entrance into
@ -343,127 +483,64 @@ class AsyncProtocol(Generic[T]):
self._runstate_event.set() self._runstate_event.set()
self._runstate_event.clear() self._runstate_event.clear()
@upper_half @bottom_half
async def _new_session(self, async def _stop_server(self) -> None:
address: SocketAddrT,
ssl: Optional[SSLContext] = None,
accept: bool = False) -> None:
""" """
Establish a new connection and initialize the session. Stop listening for / accepting new incoming connections.
Connect or accept a new connection, then begin the protocol
session machinery. If this call fails, `runstate` is guaranteed
to be set back to `IDLE`.
:param address:
Address to connect to/listen on;
UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:param accept: Accept a connection instead of connecting when `True`.
:raise ConnectError:
When a connection or session cannot be established.
This exception will wrap a more concrete one. In most cases,
the wrapped exception will be `OSError` or `EOFError`. If a
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `QMPError`.
""" """
assert self.runstate == Runstate.IDLE if self._server is None:
return
try: try:
phase = "connection" self.logger.debug("Stopping server.")
await self._establish_connection(address, ssl, accept) self._server.close()
await self._server.wait_closed()
self.logger.debug("Server stopped.")
finally:
self._server = None
phase = "session" @bottom_half # However, it does not run from the R/W tasks.
await self._establish_session() async def _incoming(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter) -> None:
"""
Accept an incoming connection and signal the upper_half.
except BaseException as err: This method does the minimum necessary to accept a single
emsg = f"Failed to establish {phase}" incoming connection. It signals back to the upper_half ASAP so
self.logger.error("%s: %s", emsg, exception_summary(err)) that any errors during session initialization can occur
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) naturally in the caller's stack.
try:
# Reset from CONNECTING back to IDLE.
await self.disconnect()
except:
emsg = "Unexpected bottom half exception"
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
raise
# NB: CancelledError is not a BaseException before Python 3.8 :param reader: Incoming `asyncio.StreamReader`
if isinstance(err, asyncio.CancelledError): :param writer: Incoming `asyncio.StreamWriter`
raise """
peer = writer.get_extra_info('peername', 'Unknown peer')
self.logger.debug("Incoming connection from %s", peer)
if isinstance(err, Exception): if self._reader or self._writer:
raise ConnectError(emsg, err) from err # Sadly, we can have more than one pending connection
# because of https://bugs.python.org/issue46715
# Close any extra connections we don't actually want.
self.logger.warning("Extraneous connection inadvertently accepted")
writer.close()
return
# Raise BaseExceptions un-wrapped, they're more important. # A connection has been accepted; stop listening for new ones.
raise assert self._accepted is not None
await self._stop_server()
assert self.runstate == Runstate.RUNNING self._reader, self._writer = (reader, writer)
self._accepted.set()
@upper_half @upper_half
async def _establish_connection( async def _do_start_server(self, address: SocketAddrT,
self, ssl: Optional[SSLContext] = None) -> None:
address: SocketAddrT,
ssl: Optional[SSLContext] = None,
accept: bool = False
) -> None:
""" """
Establish a new connection. Start listening for an incoming connection, but do not wait for a peer.
:param address: This method starts listening for an incoming connection, but does not
Address to connect to/listen on; block waiting for a peer. This call will return immediately after
UNIX socket path or TCP address/port. binding and listening to a socket. A later call to accept() must be
:param ssl: SSL context to use, if any. made in order to finalize the incoming connection.
:param accept: Accept a connection instead of connecting when `True`.
"""
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
# Allow runstate watchers to witness 'CONNECTING' state; some
# failures in the streaming layer are synchronous and will not
# otherwise yield.
await asyncio.sleep(0)
if accept:
await self._do_accept(address, ssl)
else:
await self._do_connect(address, ssl)
def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
"""
Used to create a socket in advance of accept().
This is a workaround to ensure that we can guarantee timing of
precisely when a socket exists to avoid a connection attempt
bouncing off of nothing.
Python 3.7+ adds a feature to separate the server creation and
listening phases instead, and should be used instead of this
hack.
"""
if isinstance(address, tuple):
family = socket.AF_INET
else:
family = socket.AF_UNIX
sock = socket.socket(family, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind(address)
except:
sock.close()
raise
self._sock = sock
@upper_half
async def _do_accept(self, address: SocketAddrT,
ssl: Optional[SSLContext] = None) -> None:
"""
Acting as the transport server, accept a single connection.
:param address: :param address:
Address to listen on; UNIX socket path or TCP address/port. Address to listen on; UNIX socket path or TCP address/port.
@ -471,52 +548,54 @@ class AsyncProtocol(Generic[T]):
:raise OSError: For stream-related errors. :raise OSError: For stream-related errors.
""" """
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
self.logger.debug("Awaiting connection on %s ...", address) self.logger.debug("Awaiting connection on %s ...", address)
connected = asyncio.Event() self._accepted = asyncio.Event()
server: Optional[asyncio.AbstractServer] = None
async def _client_connected_cb(reader: asyncio.StreamReader,
writer: asyncio.StreamWriter) -> None:
"""Used to accept a single incoming connection, see below."""
nonlocal server
nonlocal connected
# A connection has been accepted; stop listening for new ones.
assert server is not None
server.close()
await server.wait_closed()
server = None
# Register this client as being connected
self._reader, self._writer = (reader, writer)
# Signal back: We've accepted a client!
connected.set()
if isinstance(address, tuple): if isinstance(address, tuple):
coro = asyncio.start_server( coro = asyncio.start_server(
_client_connected_cb, self._incoming,
host=None if self._sock else address[0], host=address[0],
port=None if self._sock else address[1], port=address[1],
ssl=ssl, ssl=ssl,
backlog=1, backlog=1,
limit=self._limit, limit=self._limit,
sock=self._sock,
) )
else: else:
coro = asyncio.start_unix_server( coro = asyncio.start_unix_server(
_client_connected_cb, self._incoming,
path=None if self._sock else address, path=address,
ssl=ssl, ssl=ssl,
backlog=1, backlog=1,
limit=self._limit, limit=self._limit,
sock=self._sock,
) )
server = await coro # Starts listening # Allow runstate watchers to witness 'CONNECTING' state; some
await connected.wait() # Waits for the callback to fire (and finish) # failures in the streaming layer are synchronous and will not
assert server is None # otherwise yield.
self._sock = None await asyncio.sleep(0)
# This will start the server (bind(2), listen(2)). It will also
# call accept(2) if we yield, but we don't block on that here.
self._server = await coro
self.logger.debug("Server listening on %s", address)
@upper_half
async def _do_accept(self) -> None:
"""
Wait for and accept an incoming connection.
Requires that we have not yet accepted an incoming connection
from the upper_half, but it's OK if the server is no longer
running because the bottom_half has already accepted the
connection.
"""
assert self._accepted is not None
await self._accepted.wait()
assert self._server is None
self._accepted = None
self.logger.debug("Connection accepted.") self.logger.debug("Connection accepted.")
@ -532,6 +611,14 @@ class AsyncProtocol(Generic[T]):
:raise OSError: For stream-related errors. :raise OSError: For stream-related errors.
""" """
assert self.runstate == Runstate.IDLE
self._set_state(Runstate.CONNECTING)
# Allow runstate watchers to witness 'CONNECTING' state; some
# failures in the streaming layer are synchronous and will not
# otherwise yield.
await asyncio.sleep(0)
self.logger.debug("Connecting to %s ...", address) self.logger.debug("Connecting to %s ...", address)
if isinstance(address, tuple): if isinstance(address, tuple):
@ -644,6 +731,7 @@ class AsyncProtocol(Generic[T]):
self._reader = None self._reader = None
self._writer = None self._writer = None
self._accepted = None
# NB: _runstate_changed cannot be cleared because we still need it to # NB: _runstate_changed cannot be cleared because we still need it to
# send the final runstate changed event ...! # send the final runstate changed event ...!
@ -667,6 +755,9 @@ class AsyncProtocol(Generic[T]):
def _done(task: Optional['asyncio.Future[Any]']) -> bool: def _done(task: Optional['asyncio.Future[Any]']) -> bool:
return task is not None and task.done() return task is not None and task.done()
# If the server is running, stop it.
await self._stop_server()
# Are we already in an error pathway? If either of the tasks are # Are we already in an error pathway? If either of the tasks are
# already done, or if we have no tasks but a reader/writer; we # already done, or if we have no tasks but a reader/writer; we
# must be. # must be.

View File

@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]):
self.trigger_input = asyncio.Event() self.trigger_input = asyncio.Event()
await super()._establish_session() await super()._establish_session()
async def _do_accept(self, address, ssl=None): async def _do_start_server(self, address, ssl=None):
if not self.fake_session: if self.fake_session:
await super()._do_accept(address, ssl) self._accepted = asyncio.Event()
self._set_state(Runstate.CONNECTING)
await asyncio.sleep(0)
else:
await super()._do_start_server(address, ssl)
async def _do_accept(self):
if self.fake_session:
self._accepted = None
else:
await super()._do_accept()
async def _do_connect(self, address, ssl=None): async def _do_connect(self, address, ssl=None):
if not self.fake_session: if self.fake_session:
self._set_state(Runstate.CONNECTING)
await asyncio.sleep(0)
else:
await super()._do_connect(address, ssl) await super()._do_connect(address, ssl)
async def _do_recv(self) -> None: async def _do_recv(self) -> None:
@ -413,14 +426,14 @@ class Accept(Connect):
assert family in ('INET', 'UNIX') assert family in ('INET', 'UNIX')
if family == 'INET': if family == 'INET':
await self.proto.accept(('example.com', 1)) await self.proto.start_server_and_accept(('example.com', 1))
elif family == 'UNIX': elif family == 'UNIX':
await self.proto.accept('/dev/null') await self.proto.start_server_and_accept('/dev/null')
async def _hanging_connection(self): async def _hanging_connection(self):
with TemporaryDirectory(suffix='.aqmp') as tmpdir: with TemporaryDirectory(suffix='.aqmp') as tmpdir:
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
await self.proto.accept(sock) await self.proto.start_server_and_accept(sock)
class FakeSession(TestBase): class FakeSession(TestBase):
@ -449,13 +462,13 @@ class FakeSession(TestBase):
@TestBase.async_test @TestBase.async_test
async def testFakeAccept(self): async def testFakeAccept(self):
"""Test the full state lifecycle (via accept) with a no-op session.""" """Test the full state lifecycle (via accept) with a no-op session."""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
self.assertEqual(self.proto.runstate, Runstate.RUNNING) self.assertEqual(self.proto.runstate, Runstate.RUNNING)
@TestBase.async_test @TestBase.async_test
async def testFakeRecv(self): async def testFakeRecv(self):
"""Test receiving a fake/null message.""" """Test receiving a fake/null message."""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
logname = self.proto.logger.name logname = self.proto.logger.name
with self.assertLogs(logname, level='DEBUG') as context: with self.assertLogs(logname, level='DEBUG') as context:
@ -471,7 +484,7 @@ class FakeSession(TestBase):
@TestBase.async_test @TestBase.async_test
async def testFakeSend(self): async def testFakeSend(self):
"""Test sending a fake/null message.""" """Test sending a fake/null message."""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
logname = self.proto.logger.name logname = self.proto.logger.name
with self.assertLogs(logname, level='DEBUG') as context: with self.assertLogs(logname, level='DEBUG') as context:
@ -493,7 +506,7 @@ class FakeSession(TestBase):
): ):
with self.assertRaises(StateError) as context: with self.assertRaises(StateError) as context:
if accept: if accept:
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
else: else:
await self.proto.connect('/not/a/real/path') await self.proto.connect('/not/a/real/path')
@ -504,7 +517,7 @@ class FakeSession(TestBase):
@TestBase.async_test @TestBase.async_test
async def testAcceptRequireRunning(self): async def testAcceptRequireRunning(self):
"""Test that accept() cannot be called when Runstate=RUNNING""" """Test that accept() cannot be called when Runstate=RUNNING"""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
await self._prod_session_api( await self._prod_session_api(
Runstate.RUNNING, Runstate.RUNNING,
@ -515,7 +528,7 @@ class FakeSession(TestBase):
@TestBase.async_test @TestBase.async_test
async def testConnectRequireRunning(self): async def testConnectRequireRunning(self):
"""Test that connect() cannot be called when Runstate=RUNNING""" """Test that connect() cannot be called when Runstate=RUNNING"""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
await self._prod_session_api( await self._prod_session_api(
Runstate.RUNNING, Runstate.RUNNING,
@ -526,7 +539,7 @@ class FakeSession(TestBase):
@TestBase.async_test @TestBase.async_test
async def testAcceptRequireDisconnecting(self): async def testAcceptRequireDisconnecting(self):
"""Test that accept() cannot be called when Runstate=DISCONNECTING""" """Test that accept() cannot be called when Runstate=DISCONNECTING"""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
# Cheat: force a disconnect. # Cheat: force a disconnect.
await self.proto.simulate_disconnect() await self.proto.simulate_disconnect()
@ -541,7 +554,7 @@ class FakeSession(TestBase):
@TestBase.async_test @TestBase.async_test
async def testConnectRequireDisconnecting(self): async def testConnectRequireDisconnecting(self):
"""Test that connect() cannot be called when Runstate=DISCONNECTING""" """Test that connect() cannot be called when Runstate=DISCONNECTING"""
await self.proto.accept('/not/a/real/path') await self.proto.start_server_and_accept('/not/a/real/path')
# Cheat: force a disconnect. # Cheat: force a disconnect.
await self.proto.simulate_disconnect() await self.proto.simulate_disconnect()
@ -576,7 +589,7 @@ class SimpleSession(TestBase):
async def testSmoke(self): async def testSmoke(self):
with TemporaryDirectory(suffix='.aqmp') as tmpdir: with TemporaryDirectory(suffix='.aqmp') as tmpdir:
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
server_task = create_task(self.server.accept(sock)) server_task = create_task(self.server.start_server_and_accept(sock))
# give the server a chance to start listening [...] # give the server a chance to start listening [...]
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -4,7 +4,7 @@ import os
import sys import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python'))
from qemu.qmp import qmp_shell from qemu.aqmp import qmp_shell
if __name__ == '__main__': if __name__ == '__main__':