"""
This module is adapted from https://github.com/Mause/rpc.
The task process can only get access to a wrapper over the real device object. The wrapper will
redirect all the method calls to the real device object via RabbitMQ. The real device object is in
DeviceManager class, which will handle all the request to run certain methods on the real device.
"""
import time
from collections.abc import Callable
from concurrent.futures import Future
from enum import Enum, auto
from functools import partial
from threading import Thread
from typing import Any, NoReturn, cast
from uuid import uuid4
import dill
import pika
from bson import ObjectId
from pika import BasicProperties
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic
from .config import AlabOSConfig
from .device_view.device_view import DeviceTaskStatus, DeviceView
from .utils.data_objects import get_rabbitmq_connection
from .utils.module_ops import load_definition
DEFAULT_SERVER_QUEUE_SUFFIX = ".device_rpc"
DEFAULT_CLIENT_QUEUE_SUFFIX = DEFAULT_SERVER_QUEUE_SUFFIX + ".reply_to"
[docs]
class MethodCallStatus(Enum):
"""The status of a method call."""
PENDING = auto()
IN_PROGRESS = auto()
SUCCESS = auto()
FAILURE = auto()
[docs]
class DeviceMethodCallState:
"""holds the status of a pending method call to a device."""
status: MethodCallStatus
future: Future
last_updated: float
[docs]
class DeviceWrapper:
"""A wrapper over the device."""
[docs]
class DeviceMethodWrapper:
"""A wrapper over a device method."""
def __init__(self, device_name: str, method: str, method_handler: Callable):
self._device_name = device_name
self._method: str = method
self._method_handler = method_handler
@property
def method(self) -> str:
"""The name of the method."""
return self._method
def __call__(self, *args, **kwargs):
"""Call the method."""
return self._method_handler(*args, **kwargs)
def __repr__(self) -> str:
"""Return the representation of the method."""
return f"<method {self._device_name}.{self._method}>"
def _raise(self, *args, **kwargs) -> NoReturn: # pylint: disable=no-self-use
"""Raise an error."""
raise AttributeError(
"This is a class method, you cannot use it as an attribute."
)
__str__ = __repr__
__len__ = __getattr__ = __getitem__ = __add__ = __sub__ = __eq__ = __lt__ = (
__gt__
) = _raise
def __init__(self, name: str, devices_client: "DevicesClient"):
self._name = name
self._devices_client = devices_client
@property
def name(self) -> str:
"""The name of the device."""
return self._name
def __getattr__(self, method: str):
"""Get the method."""
return self.DeviceMethodWrapper(
device_name=self.name,
method=method,
method_handler=partial(self._devices_client.call, self.name, method),
)
[docs]
class DeviceManager:
"""
Device manager is basically a rabbitmq-backed RPC server, which receives and
executes commands on the device drivers, as requested by the task process.
"""
def __init__(self, _check_status: bool = True):
"""
Args:
_check_status: Check if the task currently occupied this device when
running commands. (disable it only for test purpose).
"""
load_definition()
self.sim_mode_flag = AlabOSConfig().is_sim_mode()
if self.sim_mode_flag:
self._rpc_queue_name = (
AlabOSConfig()["general"]["name"] + "_sim" + DEFAULT_SERVER_QUEUE_SUFFIX
)
else:
self._rpc_queue_name = (
AlabOSConfig()["general"]["name"] + DEFAULT_SERVER_QUEUE_SUFFIX
)
self._device_view = DeviceView(connect_to_devices=True)
self._check_status = _check_status
self.threads = []
[docs]
def run(self):
"""Start to listen on the device_rpc queue and conduct the command one by one."""
self.connection = get_rabbitmq_connection()
with self.connection.channel() as channel:
channel.queue_declare(
queue=self._rpc_queue_name,
auto_delete=True,
exclusive=False,
)
channel.basic_consume(
queue=self._rpc_queue_name,
on_message_callback=self.on_message,
auto_ack=False,
consumer_tag=self._rpc_queue_name,
)
channel.start_consuming()
def _execute_command_wrapper(
self,
channel,
delivery_tag,
props,
device,
method,
task_id,
*args,
**kwargs,
):
"""Execute a command on the device. Acknowledges completion on rabbitmq channel."""
def callback_publish(channel, delivery_tag, props, response):
if props.reply_to is not None:
channel.basic_publish(
exchange="",
routing_key=props.reply_to,
properties=pika.BasicProperties(
correlation_id=props.correlation_id,
content_type="application/python-dill",
),
body=dill.dumps(response),
)
channel.basic_ack(delivery_tag=cast(int, delivery_tag))
try:
device_entry: dict[str, Any] | None = self._device_view.get_device(device)
# check if the device is currently occupied by this task
if self._check_status and (
device_entry is None
or device_entry["status"] != DeviceTaskStatus.OCCUPIED.name
or device_entry["task_id"] != ObjectId(task_id)
):
if device_entry is None:
raise PermissionError("There is no such device in the device view.")
if device_entry["status"] != DeviceTaskStatus.OCCUPIED.name:
# Wait a few seconds for the device to be OCCUPIED.
for _ in range(5):
time.sleep(1)
device_entry: dict[str, Any] | None = (
self._device_view.get_device(device)
)
if device_entry["status"] == DeviceTaskStatus.OCCUPIED.name:
break
if device_entry["status"] != DeviceTaskStatus.OCCUPIED.name:
raise PermissionError(
f"Currently the device ({device}) is NOT OCCUPIED, it is currently in status {device_entry['status']}"
)
if device_entry["task_id"] != ObjectId(task_id):
device_task_id = str(device_entry["task_id"])
raise PermissionError(
f"Currently the task ({task_id}) "
f"does not occupy this device: {device}, which is currently occupied by task {device_task_id}"
)
result = self._device_view.execute_command(device, method, *args, **kwargs)
response = {"status": "success", "result": result}
except Exception as e:
response = {"status": "failure", "result": e}
cb = partial(callback_publish, channel, delivery_tag, props, response)
self.connection.add_callback_threadsafe(cb)
[docs]
def on_message(
self,
channel: BlockingChannel,
method: Basic.Deliver,
props: BasicProperties,
_body: bytes,
):
"""
Function that handle the command message.
The structure of ``_body``:
.. code-block::
{
"task_id": str,
"device": str,
"method": str,
"args": List,
"kwargs": Dict,
}
"""
body: dict[str, Any] = dill.loads(_body)
thread = Thread(
target=self._execute_command_wrapper,
args=(
channel,
method.delivery_tag,
props,
body["device"],
body["method"],
body["task_id"],
*body["args"],
),
kwargs=body["kwargs"],
)
self.threads.append(thread)
thread.start()
[docs]
class DevicesClient: # pylint: disable=too-many-instance-attributes
"""
A rabbitmq-backed RPC client for sending device requests to the Device Manager (server).
Use ``create_device_wrapper`` to create Device Wrapper instance.
"""
def __init__(self, task_id: ObjectId, timeout: int = None):
"""
Args:
task_id: the task id of current task process
timeout: the max time to wait for the server to respond, if
the time exceed the max time, a :py:class:`TimeoutError <concurrent.futures._base.TimeoutError>`
shall be raised.
"""
assert task_id is not None, "task_id cannot be None!"
self.sim_mode_flag = AlabOSConfig().is_sim_mode()
if self.sim_mode_flag:
self._rpc_queue_name = (
AlabOSConfig()["general"]["name"] + "_sim" + DEFAULT_SERVER_QUEUE_SUFFIX
)
else:
self._rpc_queue_name = (
AlabOSConfig()["general"]["name"] + DEFAULT_SERVER_QUEUE_SUFFIX
)
# self._rpc_reply_queue_name = ( str(task_id) + DEFAULT_CLIENT_QUEUE_SUFFIX ) # TODO does this have to be
# taskid, or can be random? I think this dies with the resourcerequest context manager anyways?
self._rpc_reply_queue_name = str(uuid4()) + DEFAULT_CLIENT_QUEUE_SUFFIX
self._task_id = task_id
self._waiting: dict[ObjectId, Future] = {}
self._conn = get_rabbitmq_connection()
self._channel = self._conn.channel()
self._channel.queue_declare(
self._rpc_reply_queue_name, exclusive=False, auto_delete=True
)
self._thread: Thread | None = None
self._channel.basic_consume(
queue=self._rpc_reply_queue_name,
on_message_callback=self.on_message,
auto_ack=True,
)
self._thread = Thread(target=self._channel.start_consuming)
self._thread.daemon = True
self._thread.start()
self._timeout = timeout
def __getitem__(self, device_name: str):
"""Get the device wrapper."""
return self.create_device_wrapper(device_name=device_name)
[docs]
def create_device_wrapper(
self, device_name: str
) -> object: # pylint: disable=no-self-use
"""
Create a wrapper over a device with ``device_name``.
Args:
device_name: the name of device to be wrapped
Returns
-------
A device wrapper that will send every call to class method to remote server.
"""
return DeviceWrapper(name=device_name, devices_client=self)
[docs]
def call(self, device_name: str, method: str, *args, **kwargs) -> Any:
"""
Call a method inside the device with name ``device_name``. args, kwargs will be feeded into
the method directly.
Args:
device_name: the name of device, which is defined by administer.
method: the class method to call
args: positional arguments to feed into the method function
kwargs: keyword arguments to feed into the method function
Returns
-------
the result of function
"""
assert self._conn and self._channel
f: Future = Future()
correlation_id = ObjectId()
self._waiting[correlation_id] = f
self._conn.add_callback_threadsafe(
lambda: self._channel.basic_publish(
exchange="",
routing_key=self._rpc_queue_name,
body=dill.dumps(
{
"device": device_name,
"method": method,
"args": args,
"kwargs": kwargs,
"task_id": str(self._task_id),
}
),
properties=BasicProperties(
reply_to=self._rpc_reply_queue_name,
content_type="application/python-dill",
correlation_id=str(correlation_id),
),
)
)
return f.result()
[docs]
def on_message(
self,
channel: BlockingChannel,
method_frame: Basic.Deliver, # pylint: disable=unused-argument
properties: BasicProperties,
_body: bytes,
):
"""Callback function to handle a returned message from Device Manager."""
f = self._waiting.pop(ObjectId(properties.correlation_id))
try:
body = dill.loads(_body)
if body["status"] == "success":
f.set_result(body["result"])
else:
f.set_exception(body["result"])
except Exception as e:
f.set_exception(e)