Source code for stick_ray.stateful_worker

import asyncio
import atexit
import logging
from datetime import timedelta
from typing import Any, Dict, Callable, Coroutine, Tuple, Set

from fair_async_rlock import FairAsyncRLock
from ray import ObjectRef
from ray.util.metrics import Histogram

from stick_ray.eventbus import EventBus
from stick_ray.utils import SerialisableBaseModel, get_or_create_event_loop, current_utc, is_key_after_star

__all__ = [
    'StatefulWorker',
    'StatefulSessionNotFound'
]
logger = logging.getLogger(__name__)
HEARTBEAT_INTERVAL: timedelta = timedelta(seconds=3)


[docs] class StatefulSessionNotFound(Exception): """ Exception raised when a session id is not found. """ pass
class _dummy: pass class SessionItem(SerialisableBaseModel): lock: FairAsyncRLock state: Any
[docs] class StatefulWorker: """ Represents a stateful worker. Handles sessions for sticky connections, and sends heart beats to router. """ def __init__(self, worker_id: str, **kwargs): self._lock = asyncio.Lock() self._session_items: Dict[str, SessionItem] = dict() self._stop_event = asyncio.Event() self._event_bus = EventBus(name='routed_services') self._worker_id = worker_id self._method_func_cache: Dict[str, Callable[..., Coroutine]] = dict() self.latency_ms = Histogram( 'latency_ms', description="Measures how long calls take in ms.", boundaries=list(map(lambda i: 2 ** i, range(15))), # up to 16384 ms tag_keys=('routed_service_name', 'method') ) self.latency_ms.set_default_tags({"routed_service_name": self.__class__.__name__}) loop = get_or_create_event_loop() task = loop.create_task(self._run_control_loop()) def clean_up(task, loop): task.cancel() loop.run_until_complete(asyncio.gather(task, return_exceptions=True)) atexit.register(clean_up, task, loop) async def _run_control_loop(self): logger.info(f"Starting worker control loop for {self.__class__.__name__}!") while not self._stop_event.is_set(): # Run all individual loops try: await asyncio.gather( self._send_heartbeat(), self._update(), return_exceptions=False ) except Exception as e: logger.exception(str(e)) logger.info("Restarting control loop") async def _send_heartbeat(self): while not self._stop_event.is_set(): # Heatbeat/Backpressure signal await self._event_bus.write(key=f"{self._worker_id}_backpressure", item=await self._too_busy()) await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds()) async def _update(self): last_update = set() while not self._stop_event.is_set(): # Update sessions managed update = set(self._session_items.keys()) if not (update == last_update): await self._event_bus.write(key=f"{self._worker_id}_update", item=update) last_update = update await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds()) async def _too_busy(self) -> bool: return False async def _stop(self): self._stop_event.set()
[docs] async def ferry(self, method: str, data_ref_tuple: Tuple[ObjectRef], session_id: str) -> Any: """ Ferries a method to this worker, and returns as a task, i.e. an awaitable result. Args: method: method name to ferry data_ref_tuple: an object ref of tuple (args, kwargs) session_id: session id to ferry to Returns: the result of the operation, i.e. the task is awaited. Raises: StatefulSessionNotFound if session_id not currently managed, i.e. expired session or it never existed. ValueError if session_id is found in the kwargs, as this results in an overwrite. """ start_dt = current_utc() async with self._lock: if method in self._method_func_cache: # Lookup method if used before func = self._method_func_cache[method] else: # Assess correctness of method, and cache # We allow close_session from outside ignore_methods = {'ferry', 'get_session_state', 'set_session_state', 'create_session', 'start', 'shutdown', 'check_session', 'get_session_ids', 'health_check'} available = sorted( filter(lambda x: not x.startswith('_'), set(dir(self)) - set(dir(_dummy())) - ignore_methods) ) if method.startswith('_'): raise AttributeError( f"Invalid method {method}. Available methods are {available}." ) # Get method func = getattr(self, method, None) if func is None: raise AttributeError( f"Invalid method {method}. Available methods are {available}." ) # Double ensure function spec is good. if not is_key_after_star(func, 'session_id'): raise SyntaxError( f"Method definition must have session_id as keyword-only arg, e.g. `def {method}(..., *, session_id)`") self._method_func_cache[method] = func # Ensure session id is valid if session_id not in self._session_items: # may have been pruned before getting lock raise StatefulSessionNotFound(session_id) session_lock = self._session_items[session_id].lock acquired_event = asyncio.Event() async def eval_task(): async with session_lock: # we block awaiting task until this lock is acquired. acquired_event.set() # Get the inputs locally. (data_ref,) = data_ref_tuple data = await data_ref # retrieve the input locally args = data['args'] kwargs = data['kwargs'] # Set session id in kwargs if 'session_id' in kwargs: raise ValueError(f'You have a key session_id in your kwargs, which is reserved for session id.') kwargs['session_id'] = session_id result = await func(*args, **kwargs) finish_dt = current_utc() dt = finish_dt - start_dt logger.info(f"Handled {method} in session {session_id} in {dt.total_seconds() * 1e3:0.1f} ms.") self.latency_ms.observe(dt.total_seconds() * 1e3, tags=dict(method=method)) return result # create the task for work to be done inside lock so that session id membership is locked until started. task = asyncio.create_task(eval_task()) await asyncio.sleep(0) # Wait until acquired before releasing worker lock await acquired_event.wait() return await task
[docs] async def get_session_state(self, session_id: str) -> Any: """ Get the session state for a session id. Args: session_id: session id to get state for Returns: the session state Raises: AttributeError if it's not found. """ if session_id not in self._session_items: raise StatefulSessionNotFound(session_id) return self._session_items[session_id].state
[docs] async def set_session_state(self, session_id: str, session_state: Any): """ Set the session state for the session id. Args: session_id: session id session_state: a state object. """ if session_id not in self._session_items: raise StatefulSessionNotFound(session_id) self._session_items[session_id].state = session_state
async def _close_session(self, session_id: str): raise NotImplementedError() async def _create_session(self, session_id: str): raise NotImplementedError() async def _start(self): raise NotImplementedError() async def _shutdown(self): raise NotImplementedError()
[docs] async def close_session(self, *, session_id: str): """ Closes the session for a given session id. Args: session_id: sessoin id to close down. """ async with self._lock: try: await self._close_session(session_id=session_id) finally: if session_id in self._session_items: # it could be missing del self._session_items[session_id]
[docs] async def create_session(self, session_id: str) -> bool: """ Creates a session for the given session id. Args: session_id: session id, e.g. a user UUID. Returns: True iff session what successfully created, False """ async with self._lock: if session_id in self._session_items: raise RuntimeError(f"Session {session_id} already exists.") self._session_items[session_id] = SessionItem(lock=FairAsyncRLock(), state=None) # TODO: Add backpressure here with rejection if too busy try: await self._create_session(session_id=session_id) if session_id not in self._session_items: raise RuntimeError(f"Session {session_id} disappeared.") return True except Exception as e: logger.error(f"Failed to create session {session_id}: {str(e)}") if session_id in self._session_items: del self._session_items[session_id] return False
[docs] async def start(self): """ Starts the worker. """ async with self._lock: await self._start() logger.info(f"Successful start up.")
[docs] async def shutdown(self): """ Shuts down the worker. """ async with self._lock: try: await self._shutdown() finally: await self._stop() logger.info(f"Successful shut down.")
[docs] async def check_session(self, session_id: str) -> bool: """ Checks if the given session id is managed by the worker. Args: session_id: session id to check Returns: true if managed """ async with self._lock: return session_id in self._session_items
[docs] async def get_session_ids(self) -> Set[str]: """ Gets all current managed session ids. Returns: a set of session ids. """ async with self._lock: return set(self._session_items.keys())
[docs] async def health_check(self): """ Simply returns, acting as a health check. """ return