inferential 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inferential/__init__.py +14 -0
- inferential/_version.py +1 -0
- inferential/client.py +156 -0
- inferential/config/__init__.py +31 -0
- inferential/config/schema.py +96 -0
- inferential/config/watcher.py +84 -0
- inferential/dispatch/__init__.py +8 -0
- inferential/dispatch/dispatcher.py +166 -0
- inferential/dispatch/health.py +23 -0
- inferential/metrics/__init__.py +13 -0
- inferential/metrics/callbacks.py +20 -0
- inferential/metrics/collector.py +71 -0
- inferential/metrics/store.py +75 -0
- inferential/observation/__init__.py +10 -0
- inferential/observation/assembler.py +100 -0
- inferential/observation/slots.py +30 -0
- inferential/proto/__init__.py +124 -0
- inferential/proto/inferential.proto +57 -0
- inferential/proto/inferential_pb2.py +54 -0
- inferential/proto/inferential_pb2.pyi +116 -0
- inferential/scheduler/__init__.py +25 -0
- inferential/scheduler/base.py +90 -0
- inferential/scheduler/batch_optimized.py +126 -0
- inferential/scheduler/deadline_aware.py +107 -0
- inferential/scheduler/priority_tiered.py +93 -0
- inferential/scheduler/request.py +28 -0
- inferential/scheduler/round_robin.py +106 -0
- inferential/server.py +242 -0
- inferential/tracking/__init__.py +12 -0
- inferential/tracking/cadence.py +50 -0
- inferential/tracking/response.py +62 -0
- inferential/tracking/robots.py +41 -0
- inferential/transport/__init__.py +8 -0
- inferential/transport/messages.py +18 -0
- inferential/transport/zmq_transport.py +50 -0
- inferential-0.1.0.dist-info/METADATA +281 -0
- inferential-0.1.0.dist-info/RECORD +39 -0
- inferential-0.1.0.dist-info/WHEEL +4 -0
- inferential-0.1.0.dist-info/licenses/LICENSE +201 -0
inferential/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from inferential._version import __version__
|
|
2
|
+
from inferential.client import Connection, Model
|
|
3
|
+
from inferential.scheduler.base import register_policy
|
|
4
|
+
from inferential.scheduler.request import InferenceRequest
|
|
5
|
+
from inferential.server import Server
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"__version__",
|
|
9
|
+
"Connection",
|
|
10
|
+
"InferenceRequest",
|
|
11
|
+
"Model",
|
|
12
|
+
"Server",
|
|
13
|
+
"register_policy",
|
|
14
|
+
]
|
inferential/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
inferential/client.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import zmq
|
|
8
|
+
import zmq.asyncio
|
|
9
|
+
|
|
10
|
+
from inferential.proto import (
|
|
11
|
+
RAW,
|
|
12
|
+
Client,
|
|
13
|
+
ModelResponse,
|
|
14
|
+
Observation,
|
|
15
|
+
Tensor,
|
|
16
|
+
dtype_from_numpy,
|
|
17
|
+
dtype_to_numpy,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Connection:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
server: str = "localhost:5555",
|
|
25
|
+
client_id: str | int = "",
|
|
26
|
+
client_type: str = "",
|
|
27
|
+
reconnect_ivl_ms: int = 100,
|
|
28
|
+
reconnect_max_ms: int = 5000,
|
|
29
|
+
) -> None:
|
|
30
|
+
self._server = server
|
|
31
|
+
self._client_id = str(client_id)
|
|
32
|
+
self._client_type = client_type
|
|
33
|
+
self._ctx = zmq.Context()
|
|
34
|
+
self._socket = self._ctx.socket(zmq.DEALER)
|
|
35
|
+
self._socket.identity = str(client_id).encode()
|
|
36
|
+
self._socket.setsockopt(zmq.RECONNECT_IVL, reconnect_ivl_ms)
|
|
37
|
+
self._socket.setsockopt(zmq.RECONNECT_IVL_MAX, reconnect_max_ms)
|
|
38
|
+
self._socket.setsockopt(zmq.RCVHWM, 100)
|
|
39
|
+
self._socket.setsockopt(zmq.SNDHWM, 100)
|
|
40
|
+
self._socket.setsockopt(zmq.LINGER, 0)
|
|
41
|
+
|
|
42
|
+
if not server.startswith("tcp://"):
|
|
43
|
+
server = f"tcp://{server}"
|
|
44
|
+
self._socket.connect(server)
|
|
45
|
+
|
|
46
|
+
def model(
|
|
47
|
+
self,
|
|
48
|
+
model_id: str,
|
|
49
|
+
latency_budget_ms: float = 50.0,
|
|
50
|
+
priority: int = 1,
|
|
51
|
+
) -> Model:
|
|
52
|
+
return Model(
|
|
53
|
+
connection=self,
|
|
54
|
+
model_id=model_id,
|
|
55
|
+
latency_budget_ms=latency_budget_ms,
|
|
56
|
+
priority=priority,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def _send(self, envelope: bytes, payload: bytes) -> None:
|
|
60
|
+
self._socket.send_multipart([b"", envelope, payload])
|
|
61
|
+
|
|
62
|
+
def _recv(self, timeout_ms: int = 0) -> tuple[bytes, bytes] | None:
|
|
63
|
+
if timeout_ms > 0:
|
|
64
|
+
if self._socket.poll(timeout_ms, zmq.POLLIN) == 0:
|
|
65
|
+
return None
|
|
66
|
+
elif not self._socket.poll(0, zmq.POLLIN):
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
frames = self._socket.recv_multipart(zmq.NOBLOCK)
|
|
70
|
+
# frames = [b'', envelope, payload]
|
|
71
|
+
envelope = frames[1] if len(frames) > 1 else b""
|
|
72
|
+
payload = frames[2] if len(frames) > 2 else b""
|
|
73
|
+
return envelope, payload
|
|
74
|
+
|
|
75
|
+
def close(self) -> None:
|
|
76
|
+
self._socket.close()
|
|
77
|
+
self._ctx.term()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Model:
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
connection: Connection,
|
|
84
|
+
model_id: str,
|
|
85
|
+
latency_budget_ms: float = 50.0,
|
|
86
|
+
priority: int = 1,
|
|
87
|
+
) -> None:
|
|
88
|
+
self._conn = connection
|
|
89
|
+
self._model_id = model_id
|
|
90
|
+
self._latency_budget_ms = latency_budget_ms
|
|
91
|
+
self._priority = priority
|
|
92
|
+
|
|
93
|
+
def observe(
|
|
94
|
+
self,
|
|
95
|
+
urgency: float = 0.0,
|
|
96
|
+
steps_remaining: int | None = None,
|
|
97
|
+
**kwargs: Any,
|
|
98
|
+
) -> None:
|
|
99
|
+
obs = Observation()
|
|
100
|
+
obs.client.CopyFrom(Client(id=self._conn._client_id, type=self._conn._client_type))
|
|
101
|
+
obs.model_id = self._model_id
|
|
102
|
+
obs.timestamp_ns = int(time.time() * 1_000_000_000)
|
|
103
|
+
obs.urgency = urgency
|
|
104
|
+
if steps_remaining is not None:
|
|
105
|
+
obs.steps_remaining = steps_remaining
|
|
106
|
+
|
|
107
|
+
payload_parts: list[bytes] = []
|
|
108
|
+
offset = 0
|
|
109
|
+
|
|
110
|
+
for key, value in kwargs.items():
|
|
111
|
+
if isinstance(value, np.ndarray):
|
|
112
|
+
data = value.tobytes()
|
|
113
|
+
tensor = Tensor()
|
|
114
|
+
tensor.key = key
|
|
115
|
+
tensor.dtype = dtype_from_numpy(value.dtype)
|
|
116
|
+
tensor.shape.extend(value.shape)
|
|
117
|
+
tensor.byte_offset = offset
|
|
118
|
+
tensor.byte_length = len(data)
|
|
119
|
+
tensor.timestamp_ns = obs.timestamp_ns
|
|
120
|
+
tensor.encoding = RAW
|
|
121
|
+
obs.tensors.append(tensor)
|
|
122
|
+
payload_parts.append(data)
|
|
123
|
+
offset += len(data)
|
|
124
|
+
elif isinstance(value, str):
|
|
125
|
+
obs.metadata[key] = value
|
|
126
|
+
|
|
127
|
+
payload = b"".join(payload_parts)
|
|
128
|
+
envelope = obs.SerializeToString()
|
|
129
|
+
self._conn._send(envelope, payload)
|
|
130
|
+
|
|
131
|
+
def get_result(self, timeout_ms: int = 100) -> dict | None:
|
|
132
|
+
result = self._conn._recv(timeout_ms)
|
|
133
|
+
if result is None:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
envelope, payload = result
|
|
137
|
+
resp = ModelResponse()
|
|
138
|
+
resp.ParseFromString(envelope)
|
|
139
|
+
|
|
140
|
+
output: dict[str, Any] = {
|
|
141
|
+
"response_id": resp.response_id,
|
|
142
|
+
"model_id": resp.model_id,
|
|
143
|
+
"inference_latency_ms": resp.inference_latency_ms,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
for tensor in resp.tensors:
|
|
147
|
+
if tensor.byte_length > 0 and tensor.dtype:
|
|
148
|
+
np_dtype = dtype_to_numpy(tensor.dtype)
|
|
149
|
+
shape = tuple(tensor.shape) if tensor.shape else ()
|
|
150
|
+
data = payload[tensor.byte_offset : tensor.byte_offset + tensor.byte_length]
|
|
151
|
+
output[tensor.key] = np.frombuffer(data, dtype=np_dtype).reshape(shape)
|
|
152
|
+
|
|
153
|
+
for k, v in resp.metadata.items():
|
|
154
|
+
output[k] = v
|
|
155
|
+
|
|
156
|
+
return output
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from inferential.config.schema import (
|
|
2
|
+
BatchConfig,
|
|
3
|
+
ClientDefaults,
|
|
4
|
+
ClientEntry,
|
|
5
|
+
ClientsConfig,
|
|
6
|
+
DeadlineAwareConfig,
|
|
7
|
+
InferentialConfig,
|
|
8
|
+
MetricsConfig,
|
|
9
|
+
ObservationConfig,
|
|
10
|
+
ResponseTrackingConfig,
|
|
11
|
+
SchedulingConfig,
|
|
12
|
+
TieredConfig,
|
|
13
|
+
TransportConfig,
|
|
14
|
+
)
|
|
15
|
+
from inferential.config.watcher import ConfigWatcher
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"BatchConfig",
|
|
19
|
+
"ClientDefaults",
|
|
20
|
+
"ClientEntry",
|
|
21
|
+
"ClientsConfig",
|
|
22
|
+
"ConfigWatcher",
|
|
23
|
+
"DeadlineAwareConfig",
|
|
24
|
+
"InferentialConfig",
|
|
25
|
+
"MetricsConfig",
|
|
26
|
+
"ObservationConfig",
|
|
27
|
+
"ResponseTrackingConfig",
|
|
28
|
+
"SchedulingConfig",
|
|
29
|
+
"TieredConfig",
|
|
30
|
+
"TransportConfig",
|
|
31
|
+
]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransportConfig(BaseModel):
|
|
9
|
+
type: Literal["zmq"] = "zmq"
|
|
10
|
+
bind: str = "tcp://*:5555"
|
|
11
|
+
recv_hwm: int = Field(default=1000, ge=1)
|
|
12
|
+
send_hwm: int = Field(default=1000, ge=1)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeadlineAwareConfig(BaseModel):
|
|
16
|
+
cadence_weight: float = 45.0
|
|
17
|
+
urgency_weight: float = 25.0
|
|
18
|
+
steps_weight: float = 15.0
|
|
19
|
+
priority_weight: float = 10.0
|
|
20
|
+
age_weight: float = 15.0
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BatchConfig(BaseModel):
|
|
24
|
+
max_batch_size: int = Field(default=8, ge=1)
|
|
25
|
+
max_wait_ms: float = Field(default=10.0, gt=0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TieredConfig(BaseModel):
|
|
29
|
+
num_tiers: int = Field(default=3, ge=1)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SchedulingConfig(BaseModel):
|
|
33
|
+
strategy: str = "deadline_aware"
|
|
34
|
+
max_queue_size: int = Field(default=1000, ge=1)
|
|
35
|
+
request_ttl_ms: float = Field(default=5000.0, gt=0)
|
|
36
|
+
overflow_policy: Literal["drop_oldest", "reject_newest"] = "drop_oldest"
|
|
37
|
+
max_retries: int = Field(default=0, ge=0)
|
|
38
|
+
deadline_aware: DeadlineAwareConfig = DeadlineAwareConfig()
|
|
39
|
+
batch_optimized: BatchConfig = BatchConfig()
|
|
40
|
+
priority_tiered: TieredConfig = TieredConfig()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ClientDefaults(BaseModel):
|
|
44
|
+
latency_budget_ms: float = Field(default=50.0, gt=0)
|
|
45
|
+
priority: int = Field(default=1, ge=0)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ClientEntry(BaseModel):
|
|
49
|
+
id: str
|
|
50
|
+
model: str
|
|
51
|
+
latency_budget_ms: float | None = None
|
|
52
|
+
priority: int | None = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ClientsConfig(BaseModel):
|
|
56
|
+
defaults: ClientDefaults = ClientDefaults()
|
|
57
|
+
known: list[ClientEntry] = []
|
|
58
|
+
accept_unknown: bool = True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ResponseTrackingConfig(BaseModel):
|
|
62
|
+
overdue_multiplier: float = Field(default=1.5, gt=1.0)
|
|
63
|
+
disconnect_timeout_s: float = Field(default=10.0, gt=0)
|
|
64
|
+
cadence_alpha: float = Field(default=0.3, gt=0, lt=1.0)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ObservationConfig(BaseModel):
|
|
68
|
+
max_staleness_ms: float = Field(default=1000.0, gt=0)
|
|
69
|
+
temporal_alignment_threshold_ms: float = Field(default=20.0, gt=0)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MetricsConfig(BaseModel):
|
|
73
|
+
ring_buffer_size: int = Field(default=10000, ge=100)
|
|
74
|
+
enabled: bool = True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class InferentialConfig(BaseModel):
|
|
78
|
+
ray: dict = Field(default_factory=dict)
|
|
79
|
+
transport: TransportConfig = TransportConfig()
|
|
80
|
+
scheduling: SchedulingConfig = SchedulingConfig()
|
|
81
|
+
clients: ClientsConfig = ClientsConfig()
|
|
82
|
+
response_tracking: ResponseTrackingConfig = ResponseTrackingConfig()
|
|
83
|
+
observations: ObservationConfig = ObservationConfig()
|
|
84
|
+
metrics: MetricsConfig = MetricsConfig()
|
|
85
|
+
|
|
86
|
+
def get_client_budget(self, client_id: str) -> float:
|
|
87
|
+
for entry in self.clients.known:
|
|
88
|
+
if entry.id == client_id and entry.latency_budget_ms is not None:
|
|
89
|
+
return entry.latency_budget_ms
|
|
90
|
+
return self.clients.defaults.latency_budget_ms
|
|
91
|
+
|
|
92
|
+
def get_client_priority(self, client_id: str) -> int:
|
|
93
|
+
for entry in self.clients.known:
|
|
94
|
+
if entry.id == client_id and entry.priority is not None:
|
|
95
|
+
return entry.priority
|
|
96
|
+
return self.clients.defaults.priority
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import signal
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Callable
|
|
10
|
+
|
|
11
|
+
from inferential.config.schema import InferentialConfig
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("inferential.config")
|
|
14
|
+
|
|
15
|
+
ReloadCallback = Callable[[InferentialConfig], None]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfigWatcher:
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
config_path: str | Path | None = None,
|
|
22
|
+
poll_interval_s: float = 5.0,
|
|
23
|
+
) -> None:
|
|
24
|
+
self._config_path = Path(config_path) if config_path else None
|
|
25
|
+
self._poll_interval_s = poll_interval_s
|
|
26
|
+
self._callbacks: list[ReloadCallback] = []
|
|
27
|
+
self._last_mtime: float = 0.0
|
|
28
|
+
self._sighup_received = False
|
|
29
|
+
|
|
30
|
+
def on_reload(self, callback: ReloadCallback) -> ReloadCallback:
|
|
31
|
+
self._callbacks.append(callback)
|
|
32
|
+
return callback
|
|
33
|
+
|
|
34
|
+
def _install_sighup_handler(self) -> None:
|
|
35
|
+
try:
|
|
36
|
+
loop = asyncio.get_running_loop()
|
|
37
|
+
loop.add_signal_handler(signal.SIGHUP, self._handle_sighup)
|
|
38
|
+
except (NotImplementedError, OSError):
|
|
39
|
+
pass # Windows or restricted environment
|
|
40
|
+
|
|
41
|
+
def _handle_sighup(self) -> None:
|
|
42
|
+
self._sighup_received = True
|
|
43
|
+
logger.info("SIGHUP received, will reload config")
|
|
44
|
+
|
|
45
|
+
def _load_config(self) -> InferentialConfig | None:
|
|
46
|
+
if self._config_path is None or not self._config_path.exists():
|
|
47
|
+
return None
|
|
48
|
+
try:
|
|
49
|
+
data = json.loads(self._config_path.read_text())
|
|
50
|
+
return InferentialConfig(**data)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error("Failed to reload config from %s: %s", self._config_path, e)
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
def _fire(self, config: InferentialConfig) -> None:
|
|
56
|
+
for cb in self._callbacks:
|
|
57
|
+
try:
|
|
58
|
+
cb(config)
|
|
59
|
+
except Exception:
|
|
60
|
+
logger.exception("Error in config reload callback")
|
|
61
|
+
|
|
62
|
+
async def watch(self) -> None:
|
|
63
|
+
self._install_sighup_handler()
|
|
64
|
+
|
|
65
|
+
while True:
|
|
66
|
+
await asyncio.sleep(self._poll_interval_s)
|
|
67
|
+
|
|
68
|
+
should_reload = self._sighup_received
|
|
69
|
+
|
|
70
|
+
if self._config_path and self._config_path.exists():
|
|
71
|
+
try:
|
|
72
|
+
mtime = os.path.getmtime(self._config_path)
|
|
73
|
+
if mtime > self._last_mtime:
|
|
74
|
+
should_reload = True
|
|
75
|
+
self._last_mtime = mtime
|
|
76
|
+
except OSError:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
if should_reload:
|
|
80
|
+
self._sighup_received = False
|
|
81
|
+
config = self._load_config()
|
|
82
|
+
if config is not None:
|
|
83
|
+
logger.info("Config reloaded from %s", self._config_path)
|
|
84
|
+
self._fire(config)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from inferential.dispatch.health import EndpointHealth
|
|
13
|
+
from inferential.proto import (
|
|
14
|
+
RAW,
|
|
15
|
+
Client,
|
|
16
|
+
ModelResponse,
|
|
17
|
+
Observation,
|
|
18
|
+
Tensor,
|
|
19
|
+
dtype_from_numpy,
|
|
20
|
+
dtype_to_numpy,
|
|
21
|
+
)
|
|
22
|
+
from inferential.scheduler.request import InferenceRequest
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger("inferential.dispatch")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class DispatchResult:
|
|
29
|
+
client_id: str
|
|
30
|
+
response_id: str
|
|
31
|
+
identity: bytes
|
|
32
|
+
envelope: bytes
|
|
33
|
+
payload: bytes
|
|
34
|
+
latency_ms: float
|
|
35
|
+
success: bool
|
|
36
|
+
error: str | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RayDispatcher:
|
|
40
|
+
def __init__(self, ray_config: dict | None = None) -> None:
|
|
41
|
+
self._ray_config = ray_config or {}
|
|
42
|
+
self._handles: dict[str, Any] = {}
|
|
43
|
+
self._health: dict[str, EndpointHealth] = {}
|
|
44
|
+
|
|
45
|
+
def _get_handle(self, model_id: str) -> Any:
|
|
46
|
+
if model_id not in self._handles:
|
|
47
|
+
from ray import serve
|
|
48
|
+
|
|
49
|
+
self._handles[model_id] = serve.get_deployment_handle(model_id)
|
|
50
|
+
return self._handles[model_id]
|
|
51
|
+
|
|
52
|
+
def _reconstruct_numpy(self, req: InferenceRequest) -> dict:
|
|
53
|
+
obs = Observation()
|
|
54
|
+
obs.ParseFromString(req.envelope)
|
|
55
|
+
result: dict[str, Any] = {}
|
|
56
|
+
for tensor in obs.tensors:
|
|
57
|
+
np_dtype = dtype_to_numpy(tensor.dtype)
|
|
58
|
+
shape = tuple(tensor.shape) if tensor.shape else ()
|
|
59
|
+
data = req.payload[tensor.byte_offset : tensor.byte_offset + tensor.byte_length]
|
|
60
|
+
array = np.frombuffer(data, dtype=np_dtype).reshape(shape)
|
|
61
|
+
result[tensor.key] = array
|
|
62
|
+
for k, v in obs.metadata.items():
|
|
63
|
+
result[k] = v
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
def _ndarray_to_tensor(
|
|
67
|
+
self, key: str, arr: np.ndarray, offset: int
|
|
68
|
+
) -> tuple[Tensor, bytes]:
|
|
69
|
+
data = arr.tobytes()
|
|
70
|
+
tensor = Tensor()
|
|
71
|
+
tensor.key = key
|
|
72
|
+
tensor.dtype = dtype_from_numpy(arr.dtype)
|
|
73
|
+
tensor.shape.extend(arr.shape)
|
|
74
|
+
tensor.byte_offset = offset
|
|
75
|
+
tensor.byte_length = len(data)
|
|
76
|
+
tensor.encoding = RAW
|
|
77
|
+
return tensor, data
|
|
78
|
+
|
|
79
|
+
def _build_result(
|
|
80
|
+
self,
|
|
81
|
+
req: InferenceRequest,
|
|
82
|
+
ray_result: Any,
|
|
83
|
+
latency_ms: float,
|
|
84
|
+
) -> DispatchResult:
|
|
85
|
+
resp = ModelResponse()
|
|
86
|
+
resp.client.CopyFrom(Client(id=req.client_id))
|
|
87
|
+
resp.response_id = uuid.uuid4().hex
|
|
88
|
+
resp.timestamp_ns = int(time.time() * 1_000_000_000)
|
|
89
|
+
resp.inference_latency_ms = latency_ms
|
|
90
|
+
resp.model_id = req.model_id
|
|
91
|
+
|
|
92
|
+
payload_parts: list[bytes] = []
|
|
93
|
+
offset = 0
|
|
94
|
+
|
|
95
|
+
if isinstance(ray_result, np.ndarray):
|
|
96
|
+
tensor, data = self._ndarray_to_tensor("actions", ray_result, offset)
|
|
97
|
+
resp.tensors.append(tensor)
|
|
98
|
+
payload_parts.append(data)
|
|
99
|
+
elif isinstance(ray_result, dict):
|
|
100
|
+
for key, value in ray_result.items():
|
|
101
|
+
if isinstance(value, np.ndarray):
|
|
102
|
+
tensor, data = self._ndarray_to_tensor(key, value, offset)
|
|
103
|
+
resp.tensors.append(tensor)
|
|
104
|
+
payload_parts.append(data)
|
|
105
|
+
offset += len(data)
|
|
106
|
+
elif isinstance(value, str):
|
|
107
|
+
resp.metadata[key] = value
|
|
108
|
+
|
|
109
|
+
payload = b"".join(payload_parts)
|
|
110
|
+
|
|
111
|
+
return DispatchResult(
|
|
112
|
+
client_id=req.client_id,
|
|
113
|
+
response_id=resp.response_id,
|
|
114
|
+
identity=req.identity,
|
|
115
|
+
envelope=resp.SerializeToString(),
|
|
116
|
+
payload=payload,
|
|
117
|
+
latency_ms=latency_ms,
|
|
118
|
+
success=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def dispatch(
|
|
122
|
+
self, requests: list[InferenceRequest]
|
|
123
|
+
) -> list[DispatchResult]:
|
|
124
|
+
results: list[DispatchResult] = []
|
|
125
|
+
by_model: dict[str, list[InferenceRequest]] = defaultdict(list)
|
|
126
|
+
for req in requests:
|
|
127
|
+
by_model[req.model_id].append(req)
|
|
128
|
+
|
|
129
|
+
for model_id, model_requests in by_model.items():
|
|
130
|
+
handle = self._get_handle(model_id)
|
|
131
|
+
for req in model_requests:
|
|
132
|
+
start = time.monotonic()
|
|
133
|
+
try:
|
|
134
|
+
obs_dict = self._reconstruct_numpy(req)
|
|
135
|
+
result = await handle.infer.remote(obs_dict)
|
|
136
|
+
latency = (time.monotonic() - start) * 1000
|
|
137
|
+
self._update_health(model_id, latency, True)
|
|
138
|
+
results.append(self._build_result(req, result, latency))
|
|
139
|
+
except Exception as e:
|
|
140
|
+
latency = (time.monotonic() - start) * 1000
|
|
141
|
+
self._update_health(model_id, latency, False)
|
|
142
|
+
logger.error("Dispatch failed for client %s: %s", req.client_id, e)
|
|
143
|
+
results.append(
|
|
144
|
+
DispatchResult(
|
|
145
|
+
client_id=req.client_id,
|
|
146
|
+
response_id=uuid.uuid4().hex,
|
|
147
|
+
identity=req.identity,
|
|
148
|
+
envelope=b"",
|
|
149
|
+
payload=b"",
|
|
150
|
+
latency_ms=latency,
|
|
151
|
+
success=False,
|
|
152
|
+
error=str(e),
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
return results
|
|
156
|
+
|
|
157
|
+
def _update_health(self, model_id: str, latency_ms: float, success: bool) -> None:
|
|
158
|
+
if model_id not in self._health:
|
|
159
|
+
self._health[model_id] = EndpointHealth(model_id=model_id)
|
|
160
|
+
self._health[model_id].update(latency_ms, success)
|
|
161
|
+
|
|
162
|
+
def endpoint_health(self, model_id: str) -> EndpointHealth | None:
|
|
163
|
+
return self._health.get(model_id)
|
|
164
|
+
|
|
165
|
+
def get_handle(self, model_id: str) -> Any | None:
|
|
166
|
+
return self._handles.get(model_id)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class EndpointHealth:
|
|
8
|
+
model_id: str
|
|
9
|
+
avg_latency_ms: float = 0.0
|
|
10
|
+
error_rate: float = 0.0
|
|
11
|
+
total_requests: int = 0
|
|
12
|
+
last_error: str | None = None
|
|
13
|
+
_alpha: float = field(default=0.2, repr=False)
|
|
14
|
+
|
|
15
|
+
def update(self, latency_ms: float, success: bool) -> None:
|
|
16
|
+
self.total_requests += 1
|
|
17
|
+
self.avg_latency_ms = (
|
|
18
|
+
self._alpha * latency_ms + (1 - self._alpha) * self.avg_latency_ms
|
|
19
|
+
)
|
|
20
|
+
error_val = 0.0 if success else 1.0
|
|
21
|
+
self.error_rate = self._alpha * error_val + (1 - self._alpha) * self.error_rate
|
|
22
|
+
if not success:
|
|
23
|
+
self.last_error = f"request #{self.total_requests}"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from inferential.metrics.callbacks import CallbackRegistry, MetricCallback
|
|
2
|
+
from inferential.metrics.collector import MetricsCollector, MetricSnapshot
|
|
3
|
+
from inferential.metrics.store import MetricPoint, MetricSeries, MetricStats
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"CallbackRegistry",
|
|
7
|
+
"MetricCallback",
|
|
8
|
+
"MetricPoint",
|
|
9
|
+
"MetricSeries",
|
|
10
|
+
"MetricSnapshot",
|
|
11
|
+
"MetricStats",
|
|
12
|
+
"MetricsCollector",
|
|
13
|
+
]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
MetricCallback = Callable[[str, float, dict[str, str]], None]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CallbackRegistry:
|
|
9
|
+
def __init__(self) -> None:
|
|
10
|
+
self._callbacks: list[MetricCallback] = []
|
|
11
|
+
|
|
12
|
+
def register(self, callback: MetricCallback) -> None:
|
|
13
|
+
self._callbacks.append(callback)
|
|
14
|
+
|
|
15
|
+
def fire(self, name: str, value: float, labels: dict[str, str]) -> None:
|
|
16
|
+
for cb in self._callbacks:
|
|
17
|
+
try:
|
|
18
|
+
cb(name, value, labels)
|
|
19
|
+
except Exception:
|
|
20
|
+
pass # Callbacks must not crash the server
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from inferential.metrics.callbacks import CallbackRegistry, MetricCallback
|
|
7
|
+
from inferential.metrics.store import MetricSeries, MetricStats
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from inferential.config.schema import MetricsConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class MetricSnapshot:
|
|
15
|
+
metrics: dict[str, list[dict]]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MetricsCollector:
|
|
19
|
+
def __init__(self, config: MetricsConfig | None = None) -> None:
|
|
20
|
+
self._ring_buffer_size = config.ring_buffer_size if config else 10000
|
|
21
|
+
self._enabled = config.enabled if config else True
|
|
22
|
+
self._series: dict[str, MetricSeries] = {}
|
|
23
|
+
self._callbacks = CallbackRegistry()
|
|
24
|
+
|
|
25
|
+
def record(
|
|
26
|
+
self,
|
|
27
|
+
name: str,
|
|
28
|
+
value: float,
|
|
29
|
+
labels: dict[str, str] | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
if not self._enabled:
|
|
32
|
+
return
|
|
33
|
+
labels = labels or {}
|
|
34
|
+
if name not in self._series:
|
|
35
|
+
self._series[name] = MetricSeries(max_size=self._ring_buffer_size)
|
|
36
|
+
self._series[name].record(value, labels)
|
|
37
|
+
self._callbacks.fire(name, value, labels)
|
|
38
|
+
|
|
39
|
+
def get_latest(
|
|
40
|
+
self,
|
|
41
|
+
name: str,
|
|
42
|
+
labels: dict[str, str] | None = None,
|
|
43
|
+
) -> float | None:
|
|
44
|
+
series = self._series.get(name)
|
|
45
|
+
if series is None:
|
|
46
|
+
return None
|
|
47
|
+
return series.get_latest(labels)
|
|
48
|
+
|
|
49
|
+
def get_stats(
|
|
50
|
+
self,
|
|
51
|
+
name: str,
|
|
52
|
+
labels: dict[str, str] | None = None,
|
|
53
|
+
window_seconds: int = 300,
|
|
54
|
+
) -> MetricStats | None:
|
|
55
|
+
series = self._series.get(name)
|
|
56
|
+
if series is None:
|
|
57
|
+
return None
|
|
58
|
+
return series.get_stats(labels, window_seconds)
|
|
59
|
+
|
|
60
|
+
def snapshot(self) -> MetricSnapshot:
|
|
61
|
+
result: dict[str, list[dict]] = {}
|
|
62
|
+
for name, series in self._series.items():
|
|
63
|
+
result[name] = [
|
|
64
|
+
{"timestamp": p.timestamp, "value": p.value, "labels": p.labels}
|
|
65
|
+
for p in series._buffer
|
|
66
|
+
]
|
|
67
|
+
return MetricSnapshot(metrics=result)
|
|
68
|
+
|
|
69
|
+
def on_metric(self, callback: MetricCallback) -> MetricCallback:
|
|
70
|
+
self._callbacks.register(callback)
|
|
71
|
+
return callback
|