import asyncio
import atexit
import inspect
import logging
from datetime import datetime, timedelta
from functools import wraps
from time import monotonic_ns
from typing import Dict, Set, Type, Any, Callable, TypeVar, Tuple, Protocol, Union
import ray
from ray import ObjectRef
from ray.actor import ActorHandle
from ray.util.metrics import Gauge
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from stick_ray.eventbus import EventBus
from stick_ray.namespace import NAMESPACE
from stick_ray.stateful_worker import HEARTBEAT_INTERVAL, StatefulWorker
from stick_ray.utils import SerialisableBaseModel, get_or_create_event_loop, deterministic_uuid, current_utc
__all__ = [
'RoutedServiceHandle',
'routed_service',
]
logger = logging.getLogger(__name__)
class AddressNotFound(Exception):
def __init__(self, address: Union[str, None] = None):
super().__init__(f"Address {address} not found.")
class WorkerEntry(SerialisableBaseModel):
worker_actor: ActorHandle
last_add_dt: datetime
backpressure: bool
class AddresBookEntry(SerialisableBaseModel):
worker_id: str
created_dt: datetime
last_query_dt: datetime
class Router:
"""
A class representing a routed service. This is never invoked directly.
"""
def __init__(self, worker_cls: Type[StatefulWorker],
worker_actor_options: Dict[str, Any],
worker_kwargs: Dict[str, Any],
max_concurrent_sessions: int,
expiry_period: timedelta,
min_num_workers: int):
"""
Initialises a router.
Args:
worker_cls: the worker class
worker_actor_options: the worker actor options (see ray.remote)
worker_kwargs: the worker kwargs to pass in a worker creation
max_concurrent_sessions: maximum number of concurrent sessions per worker, before spinning up a new worker
expiry_period: how long a session last since last interaction before being closed down
min_num_workers: the minimum number of persistent workers. The number is maintained, but not the specific
workers. I.e. the first worker may not be the one that lasts forever, if min_num_workers=1.
"""
self.lock = asyncio.Lock() # for any manipulation of address book, or workers this must be gotten.
self.workers: Dict[str, WorkerEntry] = dict()
self.address_book: Dict[str, AddresBookEntry] = dict()
self.max_concurrent_sessions = max_concurrent_sessions
self.expiry_period = expiry_period
self.min_num_workers = min_num_workers
self.worker_cls = worker_cls
self.worker_actor_options = worker_actor_options
self.worker_kwargs = worker_kwargs
self.active_sessions = Gauge(
'active_sessions',
description="Measures how many sessions are live currently across all workers.",
tag_keys=('routed_service_name',)
)
self.active_sessions.set_default_tags({"routed_service_name": worker_cls.__name__})
self._event_bus = EventBus(name='routed_services')
loop = get_or_create_event_loop()
task = loop.create_task(self._run_control_loop())
def clean_up(task, loop):
task.cancel()
loop.run_until_complete(asyncio.gather(task, return_exceptions=True))
atexit.register(clean_up, task, loop)
async def health_check(self):
"""
Announce health check.
"""
logger.info(f"{self.__class__.__name__} is up!")
return
def _count_worker_sessions(self, worker_id):
return sum(map(lambda w: 1, filter(lambda item: item.worker_id == worker_id, self.address_book.values())))
async def _run_control_loop(self):
logger.info(f"Starting router control loop for {self.worker_cls.__name__}!")
while True:
# Run all individual loops
try:
await asyncio.gather(
self._check_worker_health(),
self._reconcile_address_book(),
self._prune_expired_sessions(),
self._shutdown_empty_workers(),
self._log_info(),
self._update_metrics(),
return_exceptions=False
)
except Exception as e:
logger.exception(str(e))
logger.info("Restarting control loop")
async def _log_info(self):
def info() -> str:
"""
Get information string of the current routing table.
"""
info = f"# Workers: {len(self.workers):04d}\t# Sessions: {len(self.address_book):04d}"
for worker_id in self.workers:
count = self._count_worker_sessions(worker_id=worker_id)
if self.workers[worker_id].backpressure:
info += f"\n* Worker {worker_id}: {count} sessions"
else:
info += f"\nWorker {worker_id}: {count} sessions"
return info
log_interval = timedelta(minutes=1)
t0 = current_utc()
while True:
async with self.lock:
if (current_utc() > t0 + log_interval) and (len(self.address_book) > 0):
logger.info(info())
t0 = current_utc()
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
async def _check_worker_health(self):
while True:
if len(self.workers) == 0:
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
continue
for worker_id in list(self.workers.keys()):
await asyncio.sleep(0)
non_responsive = True
t0 = current_utc()
while current_utc() < t0 + timedelta(minutes=1): # TODO: could make + 5 sigma of latency stats
try:
backpressure: bool = await self._event_bus.peek(
key=f"{worker_id}_backpressure",
timeout=HEARTBEAT_INTERVAL.total_seconds()
)
non_responsive = False
if backpressure: # do something with backpressure
if not isinstance(backpressure, bool):
raise TypeError(f"Expected bool backpressure, got {type(backpressure)}")
async with self.lock:
if worker_id not in self.workers: # Check in case it disappeared before getting lock, weird
continue
self.workers[worker_id].backpressure = backpressure
break
except asyncio.TimeoutError:
continue
except Exception as e:
logger.exception(f"Problem with event bus: {str(e)}")
continue
if non_responsive: # unhealthy
async with self.lock:
logger.warning(f"Worker {worker_id} of {self.worker_cls.__name__} is not healthy!")
count = 0
for session_id in list(self.address_book):
if self.address_book[session_id].worker_id == worker_id:
del self.address_book[session_id]
logger.info(f"Removed {session_id} from {worker_id}.")
count += 1
logger.info(f"Removed {count} sessions for {worker_id}.")
if worker_id in self.workers: # Check in case it disappeared before getting lock, weird
del self.workers[worker_id]
logger.info(f"Removed worker {worker_id}")
continue
async def _reconcile_address_book(self):
while True:
if len(self.workers) == 0:
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
continue
for worker_id in list(self.workers.keys()):
await asyncio.sleep(0)
try:
update: Set[str] = await self._event_bus.peek(
key=f"{worker_id}_update",
timeout=HEARTBEAT_INTERVAL.total_seconds() * 2
)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.exception(f"Problem getting update from worker {worker_id}:{str(e)}")
continue
# Update address book
async with self.lock:
if worker_id not in self.workers: # pruned before getting lock
continue
for session_id in update: # Ensure consistency
if session_id not in self.address_book: # not found locally, add new entry
self.address_book[session_id] = AddresBookEntry(
worker_id=worker_id,
created_dt=current_utc(),
last_query_dt=current_utc() - timedelta(days=10000)
)
continue
if self.address_book[session_id].worker_id != worker_id: # found, but inconsistent, update
logger.info(
f"Moving session {session_id} from {self.address_book[session_id].worker_id} to {worker_id}"
)
try: # Close other other one if possible
await self.workers[
self.address_book[session_id].worker_id].worker_actor.close_session.remote(
session_id=session_id
)
except Exception as e:
logger.exception(e)
finally:
self.address_book[session_id] = AddresBookEntry(
worker_id=worker_id,
created_dt=current_utc(),
last_query_dt=current_utc() - timedelta(days=10000)
)
continue
async def _prune_expired_sessions(self):
# Should not be able to prune while ferrying, so we need to acquire the lock.
while True:
if len(self.address_book) == 0:
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
continue
for session_id in list(self.address_book.keys()):
await asyncio.sleep(0)
async with self.lock:
if session_id not in self.address_book: # pruned before getting lock
continue
if current_utc() > self.address_book[session_id].last_query_dt + self.expiry_period: # expired
logger.info(f"Session {session_id} expired. Closing down.")
try:
worker_entry = self.workers[self.address_book[session_id].worker_id]
await worker_entry.worker_actor.close_session.remote(session_id=session_id)
except Exception as e:
logger.exception(str(e))
finally:
del self.address_book[session_id]
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
async def _shutdown_empty_workers(self):
while True:
if len(self.workers) == 0:
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
continue
for worker_id in list(self.workers.keys()):
await asyncio.sleep(0)
async with self.lock:
if worker_id not in self.workers: # shutdown before getting lock
continue
if len(self.workers) > self.min_num_workers:
count = self._count_worker_sessions(worker_id=worker_id)
if count == 0:
logger.info(f"Worker {worker_id} empty. Closing down.")
try:
await self.workers[worker_id].worker_actor.shutdown.remote()
except Exception as e:
logger.exception(str(e))
finally:
# Note, any existing references will prevent shutdown of actor until released.
del self.workers[worker_id]
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
async def _update_metrics(self):
while True:
# Adjust max_concurrent
self.active_sessions.set(len(self.address_book))
await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds())
async def _create_new_worker(self) -> str:
# Make sure calling method has lock
worker_id = str(deterministic_uuid(str(monotonic_ns())))
worker_name = f"{self.worker_cls.__name__}.worker.{worker_id[:6]}"
worker_actor_options = self.worker_actor_options.copy()
worker_actor_options['name'] = worker_name
worker_actor_options.pop('lifetime', None) # make sure we don't accidentally make a detached one
new_worker = ray.remote(self.worker_cls).options(**worker_actor_options).remote(
worker_id=worker_id,
**self.worker_kwargs
)
await new_worker.start.remote()
logger.info(f"Started {worker_name}")
self.workers[worker_id] = WorkerEntry(
worker_actor=new_worker,
last_add_dt=current_utc() - timedelta(days=10000),
backpressure=False
)
return worker_id
async def _get_assigned_worker(self, session_id: str) -> WorkerEntry:
async with self.lock:
# Get assigned worker id, creating worker if necessary.
if session_id in self.address_book:
assigned_worker_id = self.address_book[session_id].worker_id
self.address_book[session_id].last_query_dt = current_utc()
else:
# get LRU worker, or new worker if all full
assigned_worker_id: Union[str, None] = None
sorted_workers = reversed(
sorted(self.workers, key=lambda worker_id: self.workers[worker_id].last_add_dt)
)
for worker_id in sorted_workers:
count = self._count_worker_sessions(worker_id=worker_id)
if (count < self.max_concurrent_sessions) and not self.workers[worker_id].backpressure:
assigned_worker_id = worker_id
break
if assigned_worker_id is None: # make worker
assigned_worker_id = await self._create_new_worker()
self.workers[assigned_worker_id].last_add_dt = current_utc()
await self.workers[assigned_worker_id].worker_actor.create_session.remote(session_id=session_id)
self.address_book[session_id] = AddresBookEntry(worker_id=assigned_worker_id,
created_dt=current_utc(),
last_query_dt=current_utc())
return self.workers[assigned_worker_id]
async def ferry(self, method: str, session_id: str, data_ref_tuple: Tuple[ObjectRef]) -> Tuple[ObjectRef]:
"""
Forwards a request to the correct worker with the current session. If no session can be found, creates a new one.
Args:
method: the name of the method to be ferried. Only methods not starting with '_' are permitted.
session_id: the id of the session. No constraints on the shape of this string. Typically, a UUID.
data_ref_tuple: a tuple of an object ref containing the args to the method. These are resolved on the other
side, so that the ferry itself never touches the object.
Returns:
a tuple of object ref representing the results.
"""
worker_entry = await self._get_assigned_worker(session_id=session_id)
obj_ref = worker_entry.worker_actor.ferry.remote(
method=method,
data_ref_tuple=data_ref_tuple,
session_id=session_id
)
return (obj_ref,)
[docs]
class RoutedServiceHandle:
"""
A class representing a handle that can be used to transparently dish out requests to a routed service.
"""
def __init__(self, router: ActorHandle, sync: bool):
"""
Initialised routed service handle.
Args:
router: an actor handle for the router
sync: whether to produce a sync handle, or async handle.
"""
self._router = router
self._sync = sync
[docs]
def __getattr__(self, item):
if item.startswith('_'):
raise AttributeError(f"Only public methods are reachable. {item} invalid public method name.")
if self._sync:
def ferry(*args, **kwargs) -> ObjectRef:
session_id = kwargs.pop('session_id', None)
data = dict(
args=args,
kwargs=kwargs
)
data_ref_tuple = (ray.put(data),)
if session_id is None:
raise ValueError(f"Missing session_id")
(obj_ref,) = ray.get(self._router.ferry.remote(method=item,
session_id=session_id,
data_ref_tuple=data_ref_tuple))
return obj_ref
else:
async def ferry(*args, **kwargs) -> ObjectRef:
session_id = kwargs.pop('session_id', None)
data = dict(
args=args,
kwargs=kwargs
)
data_ref_tuple = (ray.put(data),)
if session_id is None:
raise ValueError(f"Missing session_id")
(obj_ref,) = await self._router.ferry.remote(method=item,
session_id=session_id,
data_ref_tuple=data_ref_tuple)
return obj_ref
return ferry
class NoRoutedServiceFound(Exception):
pass
class BaseRoutedService:
def __init__(self, router: ActorHandle):
self._router = router
def __hash__(self):
# required for using in sets, etc.
return hash(repr(self._router))
def __eq__(self, other: 'BaseRoutedService'):
# required for comparisons
if not isinstance(other, BaseRoutedService):
raise ValueError('Can only compare with `BaseRoutedService`')
return repr(self._router) == repr(other._router)
@staticmethod
def dynamic_cls(name: str) -> Type:
# a dynamic class that will be parsed properly by ray dashboard, so that it has a nice class name.
return type(
f"Router:{name}",
(Router,),
dict(Router.__dict__),
)
@staticmethod
def router_name(name: str) -> str:
router_name = f"{name}.router"
return router_name
def get_handle(self, sync: bool = True) -> RoutedServiceHandle:
"""
Gets a handle for the routed service.
Args:
sync: whether it will be a sync or async handle.
Returns:
a RoutedServiceHandle
Examples:
# sync handles are used like
handle = service.get_handle()
ray.get(handle.some_func(..., session_id='abc'))
# an async handle are used like
handle = service.get_handle(sync=False)
await (await handle.some_func(..., session_id='abc'))
# the first await creates a task (so the work starts being done).
# the second await gets the result
"""
return RoutedServiceHandle(router=self._router, sync=sync)
class FetchedRoutedService(BaseRoutedService):
def __init__(self, name: str):
router_name = self.router_name(name)
try:
router = ray.get_actor(router_name, namespace=NAMESPACE)
logger.info(f"Connected to existing {router_name}")
except ValueError:
raise NoRoutedServiceFound(f"Tried to fetch {router_name}, but it was not found.")
super().__init__(router=router)
class RoutedService(BaseRoutedService):
def __init__(self, name: str,
worker_cls: Type[StatefulWorker],
worker_actor_options: Union[Dict[str, Any], None],
worker_kwargs: Union[Dict[str, Any], None],
max_concurrent_sessions: int,
expiry_period: timedelta,
min_num_workers: int):
if max_concurrent_sessions < 1:
raise ValueError(f"max_concurrent_sessions must be >= 1, got {max_concurrent_sessions}")
if expiry_period.total_seconds() < 0:
raise ValueError(f"expiry_period must be strictly positive, got {expiry_period}")
worker_actor_options = worker_actor_options or dict()
worker_kwargs = worker_kwargs or dict()
head_node_id = ray.get_runtime_context().get_node_id()
router_name = self.router_name(name)
try:
router = ray.get_actor(router_name, namespace=NAMESPACE)
logger.info(f"Connected to existing {router_name}")
except ValueError:
router_actor_options = {
"num_cpus": 0,
"name": router_name,
"lifetime": "detached",
"max_restarts": -1,
"max_task_retries": -1,
# Schedule the controller on the head node with a soft constraint. This
# prefers it to run on the head node in most cases, but allows it to be
# restarted on other nodes in an HA cluster.
"scheduling_strategy": NodeAffinitySchedulingStrategy(head_node_id, soft=True),
"namespace": NAMESPACE,
"max_concurrency": 15000 # Needs to be large, as there should be no limit.
}
router_kwargs = dict(
worker_cls=worker_cls,
worker_actor_options=worker_actor_options,
worker_kwargs=worker_kwargs,
max_concurrent_sessions=max_concurrent_sessions,
expiry_period=expiry_period,
min_num_workers=min_num_workers
)
dynamic_cls = self.dynamic_cls(name)
router = ray.remote(dynamic_cls).options(**router_actor_options).remote(**router_kwargs)
ray.get(router.health_check.remote())
logger.info(f"Created new {router_name}")
super().__init__(router=router)
V = TypeVar('V')
class FProtocol(Protocol):
__call__: Callable[..., RoutedService] # The main function signature, change None to the actual return type of f
fetch: Callable[[], FetchedRoutedService] # The g method signature, change None to the actual return type of g
[docs]
def routed_service(
expiry_period: timedelta,
name: Union[str, None] = None,
worker_actor_options: Union[Dict[str, Any], None] = None,
max_concurrent_sessions: int = 10,
min_num_workers: int = 0
) -> Callable[[Type[V]], FProtocol]:
def decorator(worker_cls: Type[V]) -> FProtocol:
if not issubclass(worker_cls, StatefulWorker):
raise ValueError(f"Only StatefulWorker subclasses can be made into routed services. Got {worker_cls}.")
if name is None:
_name = worker_cls.__name__
else:
_name = name
@wraps(worker_cls)
def wrapped(**kwargs):
return RoutedService(name=_name,
worker_cls=worker_cls,
worker_actor_options=worker_actor_options,
worker_kwargs=kwargs,
max_concurrent_sessions=max_concurrent_sessions,
expiry_period=expiry_period,
min_num_workers=min_num_workers
)
def fetch() -> FetchedRoutedService:
return FetchedRoutedService(name=_name)
wrapped.fetch = fetch
wrapped.__signature__ = inspect.signature(worker_cls)
return wrapped
return decorator