Source code for alab_management.device_view.device

"""Define the base class of devices."""

import datetime
import functools
import threading
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from queue import Empty, PriorityQueue
from traceback import format_exc
from typing import Any
from unittest.mock import Mock

from alab_management.logger import DBLogger
from alab_management.sample_view.sample import SamplePosition
from alab_management.user_input import request_maintenance_input

from .dbattributes import DictInDatabase, ListInDatabase


def _UNSPECIFIED(_):
    return None


[docs] def mock( return_constant: Any = _UNSPECIFIED, object_type: list[Any] | Any = _UNSPECIFIED, ): """ A decorator used for mocking functions during simulation. Args: return_constant (Any, optional): The constant value to be returned by the mocked function. It can be a value (str, int, float, bool), list of values, or a dictionary specifying return values for keys. Default is None. object_type (Union[List[Any], Any], optional): The type or list of types to mock if the function returns an object. Default is None. Returns ------- Decorator function used to mock other functions during simulation. Raises ------ ValueError: If both `return_constant` and `object_type` are specified. ValueError: If `return_constant` is not of types: str, int, float, bool, list, or dict. ValueError: If `object_type` is specified and not a list or a class type. .. note :: The decorator mocks the function during simulation based on specified constant values or object types. Examples -------- 1. Mocking a function with a constant return value: .. code-block:: python @mock(return_constant=42) def get_data() -> int: ... a = some_integer return a 2. Mocking a function that returns multiple values in a dictionary: .. code-block:: python @mock(return_constant={"twotheta": [0.1, 0.2, 0.3], "counts": [100, 200, 300]}) def run_simulation() -> dict: ... a = {twotheta: [0.1, 0.2, 0.3]} b = {counts: [100, 200, 300]} return {**a, **b} 3. Mocking a function that returns a specific object type: .. code-block:: python @mock(object_type=str) def create_mock_string() -> str: return "Mocked String" 4. Mocking a function that returns a single object type: .. code-block:: python from alab_control.ohaus_scale import OhausScale as ScaleDriver @mock(object_type=ScaleDriver) def get_driver(self): self.driver = ScaleDriver(ip=self.ip_address, timeout=self.TIMEOUT) self.driver.set_unit_to_mg() return self.driver 5. Mocking a function that returns a list of object types: .. code-block:: python from alab_control.furnace_2416 import FurnaceController from alab_control.door_controller import DoorController @mock(object_type=[FurnaceController, DoorController]) def get_driver(self): self.driver = FurnaceController(port=self.com_port) return self.driver, self.door_controller """ def decorator(f: Callable[..., Any]): @functools.wraps(f) def wrapper(*args, **kwargs): from alab_management.config import AlabOSConfig if AlabOSConfig().is_sim_mode(): if ( return_constant is not _UNSPECIFIED and object_type is not _UNSPECIFIED ): raise ValueError( "Cannot specify both return_constant and return_mock_call!" ) elif isinstance(return_constant, dict): return_dict = {key: return_constant[key] for key in return_constant} return return_dict elif isinstance(return_constant, list): return list(range(len(return_constant))) elif return_constant is not _UNSPECIFIED: return return_constant elif object_type is not _UNSPECIFIED: if isinstance(object_type, list): return [Mock(spec=cls) for cls in object_type] else: return Mock(spec=object_type) else: raise ValueError( "Must specify either return_constant or object_type! return_constant should " "be of the type str, int, float, bool, list or dict. " "object_type should be any other class that you want to mock." ) else: return f(*args, **kwargs) return wrapper return decorator
# Base Device class #
[docs] class BaseDevice(ABC): """ The abstract class of device. All the devices should be inherited from this class Attributes ---------- name (str): the name of device, which is the unique identifier of this device description (Optional[str]): description of this kind of device, which can include the device type, how to set up and so on. args: arguments that will be passed to the device class kwargs: keyword arguments that will be passed to the device class """ def __init__(self, name: str, description: str | None = None, *args, **kwargs): """ Initialize a device object, you can set up connection to the device in this method. The device will only be initialized once in the system. So if your connection to driver need to be renewed from time to time, you can write a custom function to connect to the device when needed. Args: name: the name of device, which is the unique identifier of this device description: description of this kind of device Here is an example of how to write a new device .. code-block:: python def __init__(self, address: str, port: int = 502, *args, **kwargs): super(Furnace, self).__init__(*args, **kwargs) self.address = address self.port = port self.driver = FurnaceController(address=address, port=port) """ # override default class description if provided during device instantiation. if description: self.description = description # TODO: change this self.name = name if not isinstance(self.description, str): raise TypeError("description must be a string") from alab_management.device_view import DeviceView self._device_view = DeviceView() self._signalemitter = DeviceSignalEmitter( device=self ) # this will periodically log any device methods that are decorated with @log_signal! @property @abstractmethod def description(self) -> str: """ A short description of the device. This will be stored in the database + displayed in the dashboard. This must be declared in subclasses of BaseDevice!. """ return self._description @description.setter def description(self, value: str) -> None: if not isinstance(value, str): raise TypeError("description must be a string") self._description = value
[docs] def set_message(self, message: str): """Sets the device message to be displayed on the dashboard. Note: this method is used instead of python getter/setters because the DeviceWrapper can currently only access methods, not properties. """ self._device_view.set_message(device_name=self.name, message=message) self.__message = message
[docs] def get_message(self) -> str: """Returns the device message to be displayed on the dashboard. Note: this method is used instead of python getter/setters because the DeviceWrapper can currently only access methods, not properties. """ self.__message = self._device_view.get_message(device_name=self.name) return self.__message
def _connect_wrapper(self): """ Connect to the device and execute any backend actions that are only possible when alabos is running + the device is connected. Note that device's only connect within `alabos launch`, so ExperimentManager, DeviceManager, and the API are guaranteed to be running when this method is called. """ self.connect() self.get_message() # retrieve the most recent message from the database. self._signalemitter.start() # start the signal emitter thread
[docs] @abstractmethod def connect(self): """ Connect to any devices here. This will be called by alabos to make connections to devices at the appropriate time. This method must be defined even if no device connections are required! Just return in this case. """ raise NotImplementedError()
def _disconnect_wrapper(self): """Disconnect from the device and execute any backend actions that are only possible when alabos is running.""" self.disconnect() self._signalemitter.stop()
[docs] @abstractmethod def disconnect(self): """ Disconnect from devices here. This will be called by alabos to release connections to devices at the appropriate time. This method must be defined even if no device connections are required! Just return in this case. """ raise NotImplementedError()
@property @abstractmethod def sample_positions(self) -> list[SamplePosition]: """ The sample positions describe the position that can hold a sample. The name of sample position will be the unique identifier of this sample position. It does not store any coordinates information about where the position is in the lab. Users need to map the sample positions to real lab coordinates manually. .. note:: It doesn't matter in which device class a sample position is defined. We use ``name`` attribute to identify them. Here is an example of how to define some sample positions .. code-block:: python @property def sample_positions(self): return [ SamplePosition( "inside", description="The position inside the furnace, where the samples are heated", number=8, ), SamplePosition( "furnace_table", description="Temporary position to transfer samples", number=16, ), ] """ raise NotImplementedError()
[docs] @abstractmethod def is_running(self) -> bool: """Check whether this device is running.""" raise NotImplementedError()
# methods to store Device values inside the database. Lists and dictionaries are supported.
[docs] def list_in_database( self, name: str, default_value: list | None | None = None ) -> ListInDatabase: """ Create a list attribute that is stored in the database. Note: nested dicts/lists are not supported!. Args: name: The name of the attribute default_value: The default value of the attribute. if None (default), will default to an empty list. Returns ------- Class instance to access the attribute. Acts like a normal List, but is stored in the database. """ return ListInDatabase( device_collection=self._device_view._device_collection, device_name=self.name, attribute_name=name, default_value=default_value, )
[docs] def dict_in_database( self, name: str, default_value: dict | None | None = None ) -> DictInDatabase: """ Create a dict attribute that is stored in the database. Note: nested dicts/lists are not supported!. Args: name: The name of the attribute default_value: The default value of the attribute. if None (default), will default to an empty dict. Returns ------- Class instance to access the attribute. Acts like a normal Dict, but is stored in the database. """ return DictInDatabase( device_collection=self._device_view._device_collection, device_name=self.name, attribute_name=name, default_value=default_value, )
def _apply_default_db_values(self): """ Apply default values to attributes that are stored in the database. This is called when the device is first added to the database, typically only when alabos is setting up a new lab. """ for attribute_name in dir(self): attribute = getattr(self, attribute_name) if any(isinstance(attribute, t) for t in [ListInDatabase, DictInDatabase]): attribute.apply_default_value()
[docs] def request_maintenance(self, prompt: str, options: list[Any]): """ Request maintenance input from the user. This will display a prompt to the user and wait for them to select an option. The selected option will be returned. Args: prompt: the text to display to the user options: the options to display to the user. This should be a list of strings. """ return request_maintenance_input(prompt=prompt, options=options)
[docs] def retrieve_signal( self, signal_name: str, within: datetime.timedelta | None = None ): """ Retrieve a signal from the database. Args: signal_name (str): device signal name. This should match the signal_name passed to the ``@log_device_signal`` decorator within (Optional[datetime.timedelta], optional): timedelta defining how far back to pull logs from (relative to current time). Defaults to None. Returns ------- Dict: Dictionary of signal result. Single value vs lists depends on whether ``within`` was None or not, respectively. Form is: { "device_name": "device_name", "signal_name": "signal_name", "value": "signal_value" or ["signal_value_1", "signal_value_2", ...]], "timestamp": "timestamp" or [ "timestamp_1", "timestamp_2", ...] } """ return self._signalemitter.retrieve_signal(signal_name, within)
# DeviceSignalEmitter and related decorator #
[docs] def log_signal(signal_name: str, interval_seconds: int): """This is a decorator for methods within a `BaseDevice`. Methods decorated with this will be called at the specified interval and the result will be logged to the database under the `signal_name` provided. The intended use is to track process variables (like a furnace temperature, a pressure sensor, etc.) whenever the device is connected to alabos. Args: signal_name (str): Name to attribute to this signal interval_seconds (int): Interval at which to log this signal to the database. """ def wrapper(func): def wrapper_func(self, *args, **kwargs): value = func(self, *args, **kwargs) return value wrapper_func.logging_interval_seconds = interval_seconds wrapper_func.signal_name = signal_name return wrapper_func return wrapper
[docs] class DeviceSignalEmitter: """ This class is responsible for periodically logging device signals to the database. It is intended to be used as a singleton, and should be instantiated once per device. """ def __init__(self, device: BaseDevice): from alab_management.device_view import DeviceView self._device_view: DeviceView = DeviceView() self.dblogger = DBLogger(task_id=None) self.device = device self.is_logging = False self.queue: PriorityQueue = PriorityQueue() self._logging_thread: threading.Thread | None = None self._start_time: datetime.datetime | None = None
[docs] def get_methods_to_log(self): """ Log the data from all methods decorated with ``@log_signal`` to the database. Collected all the methods that are decorated with ``@log_signal`` and return a dictionary of the form: .. code-block:: { <method_name>: { "interval": <interval_seconds>, "signal_name": <signal_name> } } """ methods_to_log = {} for method_name, method in self.device.__class__.__dict__.items(): if hasattr(method, "logging_interval_seconds") and callable(method): methods_to_log[method_name] = { "interval": method.logging_interval_seconds, "signal_name": method.signal_name, } return methods_to_log
def _worker(self): """This is the worker thread that will periodically log device signals to the database.""" def wait_with_option_to_kill(time_to_wait: float): """ Waits until the next log is due, or until the logging is stopped. When stopping the logging worker, the worker will complete its current logging task before stopping the worker thread. This may result in long blockages if the logging interval is long. This function will allow the worker to periodically check if it should stop mid-wait to avoid this issue. Args: time_to_wait (float): The total time to wait, assuming the logging worker is not stopped. """ if time_to_wait < 0: # we are behind schedule, don't wait at all! return total_time_to_wait = time_to_wait while total_time_to_wait > 0: if not self.is_logging: return time_to_wait = min( total_time_to_wait, 0.2 ) # we will wait 0.2 second at a time total_time_to_wait -= time_to_wait time.sleep(time_to_wait) if len(self.get_methods_to_log()) == 0: return # no need to run if we aren't logging any methods while True: # we check if logging is active within the loop. This is to allow the logging worker # to be stopped mid-wait if necessary. Prevents us blocking a `.stop()` call if stuck # waiting to log method on a long interval. try: log_at, method_name, signal_name, interval, count = self.queue.get( block=False ) except Empty: # wait for queue to refill. We shouldn't reach this under normal circumstances time.sleep(1) continue time_until_this_log = (log_at - datetime.datetime.now()).total_seconds() wait_with_option_to_kill(time_until_this_log) if not self.is_logging: break self.log_method_to_db(method_name=method_name, signal_name=signal_name) count += 1 next_log_at = self._start_time + datetime.timedelta( seconds=interval * count ) self.queue.put((next_log_at, method_name, signal_name, interval, count))
[docs] def log_method_to_db(self, method_name: str, signal_name: str): """ Logs a method to the database. This is called by the worker thread. Args: method_name: the name of the method to call on the device signal_name: the name of the signal to log to the database Exceptions: Any exceptions raised by the method call will be caught and raised directly. """ method = getattr(self.device, method_name) try: value = method() except Exception: value = ( f"Error reading {method_name} from device {self.device.name}." f"The error message is: " f"{format_exc()}" ) self.dblogger.log_device_signal( device_name=self.device.name, signal_name=signal_name, signal_value=value, )
[docs] def start(self): """ Start the logging worker thread. This will start logging all methods decorated with `@log_signal` to the database. """ self.queue = PriorityQueue() for method_name, logging_properties in self.get_methods_to_log().items(): # queue items are tuples of the form: # tuple( # 0. timestamp of next log, # 1. method name (to get value from device), # 2. signal name (what to call this in the database), # 3. interval (seconds), # 4. count (how many times this has been logged since starting) # ) self.queue.put( ( datetime.datetime.now() + datetime.timedelta(seconds=logging_properties["interval"]), method_name, logging_properties["signal_name"], logging_properties["interval"], 1, ) ) self.is_logging = True self._logging_thread = threading.Thread(target=self._worker) self._start_time = datetime.datetime.now() self._logging_thread.start()
[docs] def stop(self): """Stop the logging worker thread. This will stop logging all.""" self.is_logging = False self._logging_thread.join()
[docs] def retrieve_signal(self, signal_name, within: datetime.timedelta | None = None): """Retrieve a signal from the database. Args: signal_name (str): device signal name. This should match the signal_name passed to the `@log_device_signal` decorator within (Optional[datetime.timedelta]): timedelta defining how far back to pull logs from (relative to current time). Defaults to None. Returns ------- Dict: Dictionary of signal result. Single value vs lists depends on whether ``within`` was None or not, respectively. Form is: .. code-block:: { "device_name": "device_name", "signal_name": "signal_name", "value": "signal_value" or ["signal_value_1", "signal_value_2", ...]], "timestamp": "timestamp" or ["timestamp_1", "timestamp_2", ...] } """ if within is None: return self.dblogger.get_latest_device_signal( device_name=self.device.name, signal_name=signal_name ) else: return self.dblogger.filter_device_signal( device_name=self.device.name, signal_name=signal_name, within=within )
_device_registry: dict[str, BaseDevice] = {}
[docs] def add_device(device: BaseDevice): """Register a device instance. It is stored in a global dictionary.""" if device.name in _device_registry: raise KeyError(f"Duplicated device name {device.name}") _device_registry[device.name] = device
[docs] def get_all_devices() -> dict[str, BaseDevice]: """ Get all the device names in the device registry. This is a shallow copy of the registry. Returns ------- A dictionary of all the devices in the registry. The keys are the device names, and the values are the device instances. """ return _device_registry.copy()