Source code for alab_management.user_input
import time
from datetime import datetime
from enum import Enum
from typing import Any, cast
from bson import ObjectId
from alab_management.alarm import Alarm
from alab_management.experiment_view.experiment_view import ExperimentView
from alab_management.task_view import TaskView
from alab_management.utils.data_objects import get_collection
from .config import AlabOSConfig
[docs]
class UserRequestStatus(Enum):
"""Enum for user response."""
FULLFILLED = "fulfilled"
PENDING = "pending"
ERROR = "error"
[docs]
class UserInputView:
"""Sample view manages the samples and their positions."""
def __init__(self):
self._input_collection = get_collection("user_input")
self._task_view = TaskView()
self._experiment_view = ExperimentView()
alarm_config = AlabOSConfig().get("alarm", {})
self._alarm = Alarm(**alarm_config)
[docs]
def insert_request(
self,
prompt: str,
options: list[str],
task_id: ObjectId | None = None,
maintenance: bool = False,
category: str = "Unknown Category",
) -> ObjectId:
"""
Insert a request into the database.
Returns the request ObjectID
"""
context = {"maintenance": maintenance}
if task_id is None and not maintenance:
raise ValueError("task_id is required for non-maintenance requests!")
if task_id is not None:
self._task_view.get_task(
task_id=task_id
) # will throw error if task id does not exist
experiment_id = self._experiment_view.get_experiment_by_task_id(task_id)[
"_id"
]
context.update(
{
"experiment_id": experiment_id,
"task_id": task_id,
}
)
request_id = ObjectId()
self._input_collection.insert_one(
{
"_id": request_id,
"prompt": prompt,
"options": [str(opt) for opt in options],
"status": UserRequestStatus.PENDING.value,
"request_context": context,
"last_updated": datetime.now(),
}
)
if maintenance is True:
category = "Maintenance"
self._alarm.alert(f"User input requested: {prompt}", category)
return request_id
[docs]
def get_request(self, request_id: ObjectId) -> dict[str, Any]:
"""
Get a request.
Returns a request.
"""
request = self._input_collection.find_one({"_id": request_id})
if request is None:
raise ValueError(f"User input request id {request_id} does not exist!")
return cast(dict[str, Any], request)
[docs]
def update_request_status(self, request_id: ObjectId, response: str, note: str):
"""Update the status of a request."""
self.get_request(request_id) # will error is request does not exist
self._input_collection.update_one(
{"_id": request_id},
{
"$set": {
"response": response,
"note": note,
"status": UserRequestStatus.FULLFILLED.value,
"last_updated": datetime.now(),
}
},
)
[docs]
def retrieve_user_input(self, request_id: ObjectId) -> str:
"""
Retrive response from user for a given request. Blocks until request is marked as completed.
Returns the user response, which is one of a list of options
"""
status = UserRequestStatus.PENDING
try:
while status == UserRequestStatus.PENDING:
request = self._input_collection.find_one({"_id": request_id})
if request is None:
raise ValueError(
f"User input request id {request_id} does not exist!"
)
status = UserRequestStatus(request["status"])
time.sleep(0.5)
except: # noqa: E722
self._input_collection.update_one(
{"_id": request_id}, {"$set": {"status": UserRequestStatus.ERROR.name}}
)
raise
return request["response"]
[docs]
def clean_up_user_input_collection(self):
"""Drop the sample position collection."""
self._input_collection.drop()
[docs]
def get_all_pending_requests(self) -> list:
"""
Get all pending requests.
Returns a list of pending requests.
"""
return cast(
list[dict[str, Any]],
self._input_collection.find({"status": UserRequestStatus.PENDING.value}),
)
[docs]
def retrieve_user_input_with_note(self, request_id: ObjectId) -> tuple[str, str]:
"""
Retrive response from user for a given request. Blocks until request is marked as completed.
Returns the user response, which is one of a list of options
"""
status = UserRequestStatus.PENDING
try:
while status == UserRequestStatus.PENDING:
request = self._input_collection.find_one({"_id": request_id})
if request is None:
raise ValueError(
f"User input request id {request_id} does not exist!"
)
status = UserRequestStatus(request["status"])
time.sleep(0.5)
except: # noqa: E722
self._input_collection.update_one(
{"_id": request_id}, {"$set": {"status": UserRequestStatus.ERROR.name}}
)
raise
return request["response"], request["note"]
[docs]
def request_user_input(
task_id: ObjectId | None,
prompt: str,
options: list[str],
maintenance: bool = False,
category: str = "Unknown Category",
) -> str:
"""
Request user input through the dashboard. Blocks until response is given.
Args:
task_id (ObjectId): task id requesting user input
prompt (str): prompt to give user
options (List[str]): response options to give user
maintenance (bool): if true, mark this as a request for overall system maintenance
Returns
-------
response (str): user response as string
"""
user_input_view = UserInputView()
request_id = user_input_view.insert_request(
task_id=task_id,
prompt=prompt,
options=options,
maintenance=maintenance,
category=category,
)
return user_input_view.retrieve_user_input(request_id=request_id)
[docs]
def request_maintenance_input(prompt: str, options: list[str]):
"""
Request user input for maintenance through the dashboard. Blocks until response is given.
Args:
prompt (str): prompt to give user
options (List[str]): response options to give user
Returns
-------
response (str): user response as string
"""
return request_user_input(
task_id=None,
prompt=prompt,
options=options,
maintenance=True,
category="Maintenance",
)
[docs]
def request_user_input_with_note(
task_id: ObjectId | None,
prompt: str,
options: list[str],
maintenance: bool = False,
category: str = "Unknown Category",
) -> tuple[str, str]:
"""
Request user input through the dashboard. Blocks until response. Returns response and note.
Args:
task_id (ObjectId): task id requesting user input
prompt (str): prompt to give user
options (List[str]): response options to give user
maintenance (bool): if true, mark this as a request for overall system maintenance
Returns
-------
response (str): user response as string
note (str): note from the user
"""
user_input_view = UserInputView()
request_id = user_input_view.insert_request(
task_id=task_id,
prompt=prompt,
options=options,
maintenance=maintenance,
category=category,
)
return user_input_view.retrieve_user_input_with_note(request_id=request_id)