mirror of https://github.com/xemu-project/xemu.git
python/aqmp: add runstate state machine to AsyncProtocol
This serves a few purposes: 1. Protect interfaces when it's not safe to call them (via @require) 2. Add an interface by which an async client can determine if the state has changed, for the purposes of connection management. Signed-off-by: John Snow <jsnow@redhat.com> Reviewed-by: Eric Blake <eblake@redhat.com> Message-id: 20210915162955.333025-7-jsnow@redhat.com Signed-off-by: John Snow <jsnow@redhat.com>
This commit is contained in:
parent
4ccaab0377
commit
c58b42e095
|
@ -22,12 +22,16 @@ managing QMP events.
|
|||
# the COPYING file in the top-level directory.
|
||||
|
||||
from .error import AQMPError
|
||||
from .protocol import ConnectError
|
||||
from .protocol import ConnectError, Runstate, StateError
|
||||
|
||||
|
||||
# The order of these fields impact the Sphinx documentation order.
|
||||
__all__ = (
|
||||
# Classes
|
||||
'Runstate',
|
||||
|
||||
# Exceptions, most generic to most explicit
|
||||
'AQMPError',
|
||||
'StateError',
|
||||
'ConnectError',
|
||||
)
|
||||
|
|
|
@ -12,11 +12,10 @@ class.
|
|||
|
||||
import asyncio
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from ssl import SSLContext
|
||||
# import exceptions will be removed in a forthcoming commit.
|
||||
# The problem stems from pylint/flake8 believing that 'Any'
|
||||
# is unused because of its only use in a string-quoted type.
|
||||
from typing import ( # pylint: disable=unused-import # noqa
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
|
@ -26,6 +25,7 @@ from typing import ( # pylint: disable=unused-import # noqa
|
|||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from .error import AQMPError
|
||||
|
@ -44,6 +44,20 @@ _TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None``
|
|||
_FutureT = TypeVar('_FutureT', bound=Optional['asyncio.Future[Any]'])
|
||||
|
||||
|
||||
class Runstate(Enum):
|
||||
"""Protocol session runstate."""
|
||||
|
||||
#: Fully quiesced and disconnected.
|
||||
IDLE = 0
|
||||
#: In the process of connecting or establishing a session.
|
||||
CONNECTING = 1
|
||||
#: Fully connected and active session.
|
||||
RUNNING = 2
|
||||
#: In the process of disconnecting.
|
||||
#: Runstate may be returned to `IDLE` by calling `disconnect()`.
|
||||
DISCONNECTING = 3
|
||||
|
||||
|
||||
class ConnectError(AQMPError):
|
||||
"""
|
||||
Raised when the initial connection process has failed.
|
||||
|
@ -65,6 +79,76 @@ class ConnectError(AQMPError):
|
|||
return f"{self.error_message}: {self.exc!s}"
|
||||
|
||||
|
||||
class StateError(AQMPError):
|
||||
"""
|
||||
An API command (connect, execute, etc) was issued at an inappropriate time.
|
||||
|
||||
This error is raised when a command like
|
||||
:py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
|
||||
time.
|
||||
|
||||
:param error_message: Human-readable string describing the state violation.
|
||||
:param state: The actual `Runstate` seen at the time of the violation.
|
||||
:param required: The `Runstate` required to process this command.
|
||||
"""
|
||||
def __init__(self, error_message: str,
|
||||
state: Runstate, required: Runstate):
|
||||
super().__init__(error_message)
|
||||
self.error_message = error_message
|
||||
self.state = state
|
||||
self.required = required
|
||||
|
||||
|
||||
F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Don't Panic.
|
||||
def require(required_state: Runstate) -> Callable[[F], F]:
|
||||
"""
|
||||
Decorator: protect a method so it can only be run in a certain `Runstate`.
|
||||
|
||||
:param required_state: The `Runstate` required to invoke this method.
|
||||
:raise StateError: When the required `Runstate` is not met.
|
||||
"""
|
||||
def _decorator(func: F) -> F:
|
||||
# _decorator is the decorator that is built by calling the
|
||||
# require() decorator factory; e.g.:
|
||||
#
|
||||
# @require(Runstate.IDLE) def foo(): ...
|
||||
# will replace 'foo' with the result of '_decorator(foo)'.
|
||||
|
||||
@wraps(func)
|
||||
def _wrapper(proto: 'AsyncProtocol[Any]',
|
||||
*args: Any, **kwargs: Any) -> Any:
|
||||
# _wrapper is the function that gets executed prior to the
|
||||
# decorated method.
|
||||
|
||||
name = type(proto).__name__
|
||||
|
||||
if proto.runstate != required_state:
|
||||
if proto.runstate == Runstate.CONNECTING:
|
||||
emsg = f"{name} is currently connecting."
|
||||
elif proto.runstate == Runstate.DISCONNECTING:
|
||||
emsg = (f"{name} is disconnecting."
|
||||
" Call disconnect() to return to IDLE state.")
|
||||
elif proto.runstate == Runstate.RUNNING:
|
||||
emsg = f"{name} is already connected and running."
|
||||
elif proto.runstate == Runstate.IDLE:
|
||||
emsg = f"{name} is disconnected and idle."
|
||||
else:
|
||||
assert False
|
||||
raise StateError(emsg, proto.runstate, required_state)
|
||||
# No StateError, so call the wrapped method.
|
||||
return func(proto, *args, **kwargs)
|
||||
|
||||
# Return the decorated method;
|
||||
# Transforming Func to Decorated[Func].
|
||||
return cast(F, _wrapper)
|
||||
|
||||
# Return the decorator instance from the decorator factory. Phew!
|
||||
return _decorator
|
||||
|
||||
|
||||
class AsyncProtocol(Generic[T]):
|
||||
"""
|
||||
AsyncProtocol implements a generic async message-based protocol.
|
||||
|
@ -118,7 +202,24 @@ class AsyncProtocol(Generic[T]):
|
|||
#: exit.
|
||||
self._dc_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
self._runstate = Runstate.IDLE
|
||||
self._runstate_changed: Optional[asyncio.Event] = None
|
||||
|
||||
@property # @upper_half
|
||||
def runstate(self) -> Runstate:
|
||||
"""The current `Runstate` of the connection."""
|
||||
return self._runstate
|
||||
|
||||
@upper_half
|
||||
async def runstate_changed(self) -> Runstate:
|
||||
"""
|
||||
Wait for the `runstate` to change, then return that runstate.
|
||||
"""
|
||||
await self._runstate_event.wait()
|
||||
return self.runstate
|
||||
|
||||
@upper_half
|
||||
@require(Runstate.IDLE)
|
||||
async def connect(self, address: Union[str, Tuple[str, int]],
|
||||
ssl: Optional[SSLContext] = None) -> None:
|
||||
"""
|
||||
|
@ -152,6 +253,30 @@ class AsyncProtocol(Generic[T]):
|
|||
# Section: Session machinery
|
||||
# --------------------------
|
||||
|
||||
@property
|
||||
def _runstate_event(self) -> asyncio.Event:
|
||||
# asyncio.Event() objects should not be created prior to entrance into
|
||||
# an event loop, so we can ensure we create it in the correct context.
|
||||
# Create it on-demand *only* at the behest of an 'async def' method.
|
||||
if not self._runstate_changed:
|
||||
self._runstate_changed = asyncio.Event()
|
||||
return self._runstate_changed
|
||||
|
||||
@upper_half
|
||||
@bottom_half
|
||||
def _set_state(self, state: Runstate) -> None:
|
||||
"""
|
||||
Change the `Runstate` of the protocol connection.
|
||||
|
||||
Signals the `runstate_changed` event.
|
||||
"""
|
||||
if state == self._runstate:
|
||||
return
|
||||
|
||||
self._runstate = state
|
||||
self._runstate_event.set()
|
||||
self._runstate_event.clear()
|
||||
|
||||
@upper_half
|
||||
async def _new_session(self,
|
||||
address: Union[str, Tuple[str, int]],
|
||||
|
@ -176,6 +301,8 @@ class AsyncProtocol(Generic[T]):
|
|||
protocol-level failure occurs while establishing a new
|
||||
session, the wrapped error may also be an `AQMPError`.
|
||||
"""
|
||||
assert self.runstate == Runstate.IDLE
|
||||
|
||||
try:
|
||||
phase = "connection"
|
||||
await self._establish_connection(address, ssl)
|
||||
|
@ -185,6 +312,7 @@ class AsyncProtocol(Generic[T]):
|
|||
|
||||
except BaseException as err:
|
||||
emsg = f"Failed to establish {phase}"
|
||||
# Reset from CONNECTING back to IDLE.
|
||||
await self.disconnect()
|
||||
|
||||
# NB: CancelledError is not a BaseException before Python 3.8
|
||||
|
@ -197,6 +325,8 @@ class AsyncProtocol(Generic[T]):
|
|||
# Raise BaseExceptions un-wrapped, they're more important.
|
||||
raise
|
||||
|
||||
assert self.runstate == Runstate.RUNNING
|
||||
|
||||
@upper_half
|
||||
async def _establish_connection(
|
||||
self,
|
||||
|
@ -211,6 +341,14 @@ class AsyncProtocol(Generic[T]):
|
|||
UNIX socket path or TCP address/port.
|
||||
:param ssl: SSL context to use, if any.
|
||||
"""
|
||||
assert self.runstate == Runstate.IDLE
|
||||
self._set_state(Runstate.CONNECTING)
|
||||
|
||||
# Allow runstate watchers to witness 'CONNECTING' state; some
|
||||
# failures in the streaming layer are synchronous and will not
|
||||
# otherwise yield.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await self._do_connect(address, ssl)
|
||||
|
||||
@upper_half
|
||||
|
@ -240,6 +378,8 @@ class AsyncProtocol(Generic[T]):
|
|||
own negotiations here. The Runstate will be RUNNING upon
|
||||
successful conclusion.
|
||||
"""
|
||||
assert self.runstate == Runstate.CONNECTING
|
||||
|
||||
self._outgoing = asyncio.Queue()
|
||||
|
||||
reader_coro = self._bh_loop_forever(self._bh_recv_message)
|
||||
|
@ -253,6 +393,9 @@ class AsyncProtocol(Generic[T]):
|
|||
self._writer_task,
|
||||
)
|
||||
|
||||
self._set_state(Runstate.RUNNING)
|
||||
await asyncio.sleep(0) # Allow runstate_event to process
|
||||
|
||||
@upper_half
|
||||
@bottom_half
|
||||
def _schedule_disconnect(self) -> None:
|
||||
|
@ -266,6 +409,7 @@ class AsyncProtocol(Generic[T]):
|
|||
It can be invoked no matter what the `runstate` is.
|
||||
"""
|
||||
if not self._dc_task:
|
||||
self._set_state(Runstate.DISCONNECTING)
|
||||
self._dc_task = create_task(self._bh_disconnect())
|
||||
|
||||
@upper_half
|
||||
|
@ -281,6 +425,7 @@ class AsyncProtocol(Generic[T]):
|
|||
:raise Exception:
|
||||
Arbitrary exception re-raised on behalf of the reader/writer.
|
||||
"""
|
||||
assert self.runstate == Runstate.DISCONNECTING
|
||||
assert self._dc_task
|
||||
|
||||
aws: List[Awaitable[object]] = [self._dc_task]
|
||||
|
@ -295,6 +440,7 @@ class AsyncProtocol(Generic[T]):
|
|||
await all_defined_tasks # Raise Exceptions from the bottom half.
|
||||
finally:
|
||||
self._cleanup()
|
||||
self._set_state(Runstate.IDLE)
|
||||
|
||||
@upper_half
|
||||
def _cleanup(self) -> None:
|
||||
|
@ -306,6 +452,7 @@ class AsyncProtocol(Generic[T]):
|
|||
assert (task is None) or task.done()
|
||||
return None if (task and task.done()) else task
|
||||
|
||||
assert self.runstate == Runstate.DISCONNECTING
|
||||
self._dc_task = _paranoid_task_erase(self._dc_task)
|
||||
self._reader_task = _paranoid_task_erase(self._reader_task)
|
||||
self._writer_task = _paranoid_task_erase(self._writer_task)
|
||||
|
@ -314,6 +461,9 @@ class AsyncProtocol(Generic[T]):
|
|||
self._reader = None
|
||||
self._writer = None
|
||||
|
||||
# NB: _runstate_changed cannot be cleared because we still need it to
|
||||
# send the final runstate changed event ...!
|
||||
|
||||
# ----------------------------
|
||||
# Section: Bottom Half methods
|
||||
# ----------------------------
|
||||
|
@ -328,6 +478,7 @@ class AsyncProtocol(Generic[T]):
|
|||
it is free to wait on any pending actions that may still need to
|
||||
occur in either the reader or writer tasks.
|
||||
"""
|
||||
assert self.runstate == Runstate.DISCONNECTING
|
||||
|
||||
def _done(task: Optional['asyncio.Future[Any]']) -> bool:
|
||||
return task is not None and task.done()
|
||||
|
|
Loading…
Reference in New Issue