inferential 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. inferential/__init__.py +14 -0
  2. inferential/_version.py +1 -0
  3. inferential/client.py +156 -0
  4. inferential/config/__init__.py +31 -0
  5. inferential/config/schema.py +96 -0
  6. inferential/config/watcher.py +84 -0
  7. inferential/dispatch/__init__.py +8 -0
  8. inferential/dispatch/dispatcher.py +166 -0
  9. inferential/dispatch/health.py +23 -0
  10. inferential/metrics/__init__.py +13 -0
  11. inferential/metrics/callbacks.py +20 -0
  12. inferential/metrics/collector.py +71 -0
  13. inferential/metrics/store.py +75 -0
  14. inferential/observation/__init__.py +10 -0
  15. inferential/observation/assembler.py +100 -0
  16. inferential/observation/slots.py +30 -0
  17. inferential/proto/__init__.py +124 -0
  18. inferential/proto/inferential.proto +57 -0
  19. inferential/proto/inferential_pb2.py +54 -0
  20. inferential/proto/inferential_pb2.pyi +116 -0
  21. inferential/scheduler/__init__.py +25 -0
  22. inferential/scheduler/base.py +90 -0
  23. inferential/scheduler/batch_optimized.py +126 -0
  24. inferential/scheduler/deadline_aware.py +107 -0
  25. inferential/scheduler/priority_tiered.py +93 -0
  26. inferential/scheduler/request.py +28 -0
  27. inferential/scheduler/round_robin.py +106 -0
  28. inferential/server.py +242 -0
  29. inferential/tracking/__init__.py +12 -0
  30. inferential/tracking/cadence.py +50 -0
  31. inferential/tracking/response.py +62 -0
  32. inferential/tracking/robots.py +41 -0
  33. inferential/transport/__init__.py +8 -0
  34. inferential/transport/messages.py +18 -0
  35. inferential/transport/zmq_transport.py +50 -0
  36. inferential-0.1.0.dist-info/METADATA +281 -0
  37. inferential-0.1.0.dist-info/RECORD +39 -0
  38. inferential-0.1.0.dist-info/WHEEL +4 -0
  39. inferential-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,14 @@
1
+ from inferential._version import __version__
2
+ from inferential.client import Connection, Model
3
+ from inferential.scheduler.base import register_policy
4
+ from inferential.scheduler.request import InferenceRequest
5
+ from inferential.server import Server
6
+
7
+ __all__ = [
8
+ "__version__",
9
+ "Connection",
10
+ "InferenceRequest",
11
+ "Model",
12
+ "Server",
13
+ "register_policy",
14
+ ]
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
inferential/client.py ADDED
@@ -0,0 +1,156 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import zmq
8
+ import zmq.asyncio
9
+
10
+ from inferential.proto import (
11
+ RAW,
12
+ Client,
13
+ ModelResponse,
14
+ Observation,
15
+ Tensor,
16
+ dtype_from_numpy,
17
+ dtype_to_numpy,
18
+ )
19
+
20
+
21
+ class Connection:
22
+ def __init__(
23
+ self,
24
+ server: str = "localhost:5555",
25
+ client_id: str | int = "",
26
+ client_type: str = "",
27
+ reconnect_ivl_ms: int = 100,
28
+ reconnect_max_ms: int = 5000,
29
+ ) -> None:
30
+ self._server = server
31
+ self._client_id = str(client_id)
32
+ self._client_type = client_type
33
+ self._ctx = zmq.Context()
34
+ self._socket = self._ctx.socket(zmq.DEALER)
35
+ self._socket.identity = str(client_id).encode()
36
+ self._socket.setsockopt(zmq.RECONNECT_IVL, reconnect_ivl_ms)
37
+ self._socket.setsockopt(zmq.RECONNECT_IVL_MAX, reconnect_max_ms)
38
+ self._socket.setsockopt(zmq.RCVHWM, 100)
39
+ self._socket.setsockopt(zmq.SNDHWM, 100)
40
+ self._socket.setsockopt(zmq.LINGER, 0)
41
+
42
+ if not server.startswith("tcp://"):
43
+ server = f"tcp://{server}"
44
+ self._socket.connect(server)
45
+
46
+ def model(
47
+ self,
48
+ model_id: str,
49
+ latency_budget_ms: float = 50.0,
50
+ priority: int = 1,
51
+ ) -> Model:
52
+ return Model(
53
+ connection=self,
54
+ model_id=model_id,
55
+ latency_budget_ms=latency_budget_ms,
56
+ priority=priority,
57
+ )
58
+
59
+ def _send(self, envelope: bytes, payload: bytes) -> None:
60
+ self._socket.send_multipart([b"", envelope, payload])
61
+
62
+ def _recv(self, timeout_ms: int = 0) -> tuple[bytes, bytes] | None:
63
+ if timeout_ms > 0:
64
+ if self._socket.poll(timeout_ms, zmq.POLLIN) == 0:
65
+ return None
66
+ elif not self._socket.poll(0, zmq.POLLIN):
67
+ return None
68
+
69
+ frames = self._socket.recv_multipart(zmq.NOBLOCK)
70
+ # frames = [b'', envelope, payload]
71
+ envelope = frames[1] if len(frames) > 1 else b""
72
+ payload = frames[2] if len(frames) > 2 else b""
73
+ return envelope, payload
74
+
75
+ def close(self) -> None:
76
+ self._socket.close()
77
+ self._ctx.term()
78
+
79
+
80
+ class Model:
81
+ def __init__(
82
+ self,
83
+ connection: Connection,
84
+ model_id: str,
85
+ latency_budget_ms: float = 50.0,
86
+ priority: int = 1,
87
+ ) -> None:
88
+ self._conn = connection
89
+ self._model_id = model_id
90
+ self._latency_budget_ms = latency_budget_ms
91
+ self._priority = priority
92
+
93
+ def observe(
94
+ self,
95
+ urgency: float = 0.0,
96
+ steps_remaining: int | None = None,
97
+ **kwargs: Any,
98
+ ) -> None:
99
+ obs = Observation()
100
+ obs.client.CopyFrom(Client(id=self._conn._client_id, type=self._conn._client_type))
101
+ obs.model_id = self._model_id
102
+ obs.timestamp_ns = int(time.time() * 1_000_000_000)
103
+ obs.urgency = urgency
104
+ if steps_remaining is not None:
105
+ obs.steps_remaining = steps_remaining
106
+
107
+ payload_parts: list[bytes] = []
108
+ offset = 0
109
+
110
+ for key, value in kwargs.items():
111
+ if isinstance(value, np.ndarray):
112
+ data = value.tobytes()
113
+ tensor = Tensor()
114
+ tensor.key = key
115
+ tensor.dtype = dtype_from_numpy(value.dtype)
116
+ tensor.shape.extend(value.shape)
117
+ tensor.byte_offset = offset
118
+ tensor.byte_length = len(data)
119
+ tensor.timestamp_ns = obs.timestamp_ns
120
+ tensor.encoding = RAW
121
+ obs.tensors.append(tensor)
122
+ payload_parts.append(data)
123
+ offset += len(data)
124
+ elif isinstance(value, str):
125
+ obs.metadata[key] = value
126
+
127
+ payload = b"".join(payload_parts)
128
+ envelope = obs.SerializeToString()
129
+ self._conn._send(envelope, payload)
130
+
131
+ def get_result(self, timeout_ms: int = 100) -> dict | None:
132
+ result = self._conn._recv(timeout_ms)
133
+ if result is None:
134
+ return None
135
+
136
+ envelope, payload = result
137
+ resp = ModelResponse()
138
+ resp.ParseFromString(envelope)
139
+
140
+ output: dict[str, Any] = {
141
+ "response_id": resp.response_id,
142
+ "model_id": resp.model_id,
143
+ "inference_latency_ms": resp.inference_latency_ms,
144
+ }
145
+
146
+ for tensor in resp.tensors:
147
+ if tensor.byte_length > 0 and tensor.dtype:
148
+ np_dtype = dtype_to_numpy(tensor.dtype)
149
+ shape = tuple(tensor.shape) if tensor.shape else ()
150
+ data = payload[tensor.byte_offset : tensor.byte_offset + tensor.byte_length]
151
+ output[tensor.key] = np.frombuffer(data, dtype=np_dtype).reshape(shape)
152
+
153
+ for k, v in resp.metadata.items():
154
+ output[k] = v
155
+
156
+ return output
@@ -0,0 +1,31 @@
1
+ from inferential.config.schema import (
2
+ BatchConfig,
3
+ ClientDefaults,
4
+ ClientEntry,
5
+ ClientsConfig,
6
+ DeadlineAwareConfig,
7
+ InferentialConfig,
8
+ MetricsConfig,
9
+ ObservationConfig,
10
+ ResponseTrackingConfig,
11
+ SchedulingConfig,
12
+ TieredConfig,
13
+ TransportConfig,
14
+ )
15
+ from inferential.config.watcher import ConfigWatcher
16
+
17
+ __all__ = [
18
+ "BatchConfig",
19
+ "ClientDefaults",
20
+ "ClientEntry",
21
+ "ClientsConfig",
22
+ "ConfigWatcher",
23
+ "DeadlineAwareConfig",
24
+ "InferentialConfig",
25
+ "MetricsConfig",
26
+ "ObservationConfig",
27
+ "ResponseTrackingConfig",
28
+ "SchedulingConfig",
29
+ "TieredConfig",
30
+ "TransportConfig",
31
+ ]
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class TransportConfig(BaseModel):
9
+ type: Literal["zmq"] = "zmq"
10
+ bind: str = "tcp://*:5555"
11
+ recv_hwm: int = Field(default=1000, ge=1)
12
+ send_hwm: int = Field(default=1000, ge=1)
13
+
14
+
15
+ class DeadlineAwareConfig(BaseModel):
16
+ cadence_weight: float = 45.0
17
+ urgency_weight: float = 25.0
18
+ steps_weight: float = 15.0
19
+ priority_weight: float = 10.0
20
+ age_weight: float = 15.0
21
+
22
+
23
+ class BatchConfig(BaseModel):
24
+ max_batch_size: int = Field(default=8, ge=1)
25
+ max_wait_ms: float = Field(default=10.0, gt=0)
26
+
27
+
28
+ class TieredConfig(BaseModel):
29
+ num_tiers: int = Field(default=3, ge=1)
30
+
31
+
32
+ class SchedulingConfig(BaseModel):
33
+ strategy: str = "deadline_aware"
34
+ max_queue_size: int = Field(default=1000, ge=1)
35
+ request_ttl_ms: float = Field(default=5000.0, gt=0)
36
+ overflow_policy: Literal["drop_oldest", "reject_newest"] = "drop_oldest"
37
+ max_retries: int = Field(default=0, ge=0)
38
+ deadline_aware: DeadlineAwareConfig = DeadlineAwareConfig()
39
+ batch_optimized: BatchConfig = BatchConfig()
40
+ priority_tiered: TieredConfig = TieredConfig()
41
+
42
+
43
+ class ClientDefaults(BaseModel):
44
+ latency_budget_ms: float = Field(default=50.0, gt=0)
45
+ priority: int = Field(default=1, ge=0)
46
+
47
+
48
+ class ClientEntry(BaseModel):
49
+ id: str
50
+ model: str
51
+ latency_budget_ms: float | None = None
52
+ priority: int | None = None
53
+
54
+
55
+ class ClientsConfig(BaseModel):
56
+ defaults: ClientDefaults = ClientDefaults()
57
+ known: list[ClientEntry] = []
58
+ accept_unknown: bool = True
59
+
60
+
61
+ class ResponseTrackingConfig(BaseModel):
62
+ overdue_multiplier: float = Field(default=1.5, gt=1.0)
63
+ disconnect_timeout_s: float = Field(default=10.0, gt=0)
64
+ cadence_alpha: float = Field(default=0.3, gt=0, lt=1.0)
65
+
66
+
67
+ class ObservationConfig(BaseModel):
68
+ max_staleness_ms: float = Field(default=1000.0, gt=0)
69
+ temporal_alignment_threshold_ms: float = Field(default=20.0, gt=0)
70
+
71
+
72
+ class MetricsConfig(BaseModel):
73
+ ring_buffer_size: int = Field(default=10000, ge=100)
74
+ enabled: bool = True
75
+
76
+
77
+ class InferentialConfig(BaseModel):
78
+ ray: dict = Field(default_factory=dict)
79
+ transport: TransportConfig = TransportConfig()
80
+ scheduling: SchedulingConfig = SchedulingConfig()
81
+ clients: ClientsConfig = ClientsConfig()
82
+ response_tracking: ResponseTrackingConfig = ResponseTrackingConfig()
83
+ observations: ObservationConfig = ObservationConfig()
84
+ metrics: MetricsConfig = MetricsConfig()
85
+
86
+ def get_client_budget(self, client_id: str) -> float:
87
+ for entry in self.clients.known:
88
+ if entry.id == client_id and entry.latency_budget_ms is not None:
89
+ return entry.latency_budget_ms
90
+ return self.clients.defaults.latency_budget_ms
91
+
92
+ def get_client_priority(self, client_id: str) -> int:
93
+ for entry in self.clients.known:
94
+ if entry.id == client_id and entry.priority is not None:
95
+ return entry.priority
96
+ return self.clients.defaults.priority
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import os
7
+ import signal
8
+ from pathlib import Path
9
+ from typing import Callable
10
+
11
+ from inferential.config.schema import InferentialConfig
12
+
13
+ logger = logging.getLogger("inferential.config")
14
+
15
+ ReloadCallback = Callable[[InferentialConfig], None]
16
+
17
+
18
+ class ConfigWatcher:
19
+ def __init__(
20
+ self,
21
+ config_path: str | Path | None = None,
22
+ poll_interval_s: float = 5.0,
23
+ ) -> None:
24
+ self._config_path = Path(config_path) if config_path else None
25
+ self._poll_interval_s = poll_interval_s
26
+ self._callbacks: list[ReloadCallback] = []
27
+ self._last_mtime: float = 0.0
28
+ self._sighup_received = False
29
+
30
+ def on_reload(self, callback: ReloadCallback) -> ReloadCallback:
31
+ self._callbacks.append(callback)
32
+ return callback
33
+
34
+ def _install_sighup_handler(self) -> None:
35
+ try:
36
+ loop = asyncio.get_running_loop()
37
+ loop.add_signal_handler(signal.SIGHUP, self._handle_sighup)
38
+ except (NotImplementedError, OSError):
39
+ pass # Windows or restricted environment
40
+
41
+ def _handle_sighup(self) -> None:
42
+ self._sighup_received = True
43
+ logger.info("SIGHUP received, will reload config")
44
+
45
+ def _load_config(self) -> InferentialConfig | None:
46
+ if self._config_path is None or not self._config_path.exists():
47
+ return None
48
+ try:
49
+ data = json.loads(self._config_path.read_text())
50
+ return InferentialConfig(**data)
51
+ except Exception as e:
52
+ logger.error("Failed to reload config from %s: %s", self._config_path, e)
53
+ return None
54
+
55
+ def _fire(self, config: InferentialConfig) -> None:
56
+ for cb in self._callbacks:
57
+ try:
58
+ cb(config)
59
+ except Exception:
60
+ logger.exception("Error in config reload callback")
61
+
62
+ async def watch(self) -> None:
63
+ self._install_sighup_handler()
64
+
65
+ while True:
66
+ await asyncio.sleep(self._poll_interval_s)
67
+
68
+ should_reload = self._sighup_received
69
+
70
+ if self._config_path and self._config_path.exists():
71
+ try:
72
+ mtime = os.path.getmtime(self._config_path)
73
+ if mtime > self._last_mtime:
74
+ should_reload = True
75
+ self._last_mtime = mtime
76
+ except OSError:
77
+ pass
78
+
79
+ if should_reload:
80
+ self._sighup_received = False
81
+ config = self._load_config()
82
+ if config is not None:
83
+ logger.info("Config reloaded from %s", self._config_path)
84
+ self._fire(config)
@@ -0,0 +1,8 @@
1
+ from inferential.dispatch.dispatcher import DispatchResult, RayDispatcher
2
+ from inferential.dispatch.health import EndpointHealth
3
+
4
+ __all__ = [
5
+ "DispatchResult",
6
+ "EndpointHealth",
7
+ "RayDispatcher",
8
+ ]
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ import uuid
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+
12
+ from inferential.dispatch.health import EndpointHealth
13
+ from inferential.proto import (
14
+ RAW,
15
+ Client,
16
+ ModelResponse,
17
+ Observation,
18
+ Tensor,
19
+ dtype_from_numpy,
20
+ dtype_to_numpy,
21
+ )
22
+ from inferential.scheduler.request import InferenceRequest
23
+
24
+ logger = logging.getLogger("inferential.dispatch")
25
+
26
+
27
+ @dataclass
28
+ class DispatchResult:
29
+ client_id: str
30
+ response_id: str
31
+ identity: bytes
32
+ envelope: bytes
33
+ payload: bytes
34
+ latency_ms: float
35
+ success: bool
36
+ error: str | None = None
37
+
38
+
39
+ class RayDispatcher:
40
+ def __init__(self, ray_config: dict | None = None) -> None:
41
+ self._ray_config = ray_config or {}
42
+ self._handles: dict[str, Any] = {}
43
+ self._health: dict[str, EndpointHealth] = {}
44
+
45
+ def _get_handle(self, model_id: str) -> Any:
46
+ if model_id not in self._handles:
47
+ from ray import serve
48
+
49
+ self._handles[model_id] = serve.get_deployment_handle(model_id)
50
+ return self._handles[model_id]
51
+
52
+ def _reconstruct_numpy(self, req: InferenceRequest) -> dict:
53
+ obs = Observation()
54
+ obs.ParseFromString(req.envelope)
55
+ result: dict[str, Any] = {}
56
+ for tensor in obs.tensors:
57
+ np_dtype = dtype_to_numpy(tensor.dtype)
58
+ shape = tuple(tensor.shape) if tensor.shape else ()
59
+ data = req.payload[tensor.byte_offset : tensor.byte_offset + tensor.byte_length]
60
+ array = np.frombuffer(data, dtype=np_dtype).reshape(shape)
61
+ result[tensor.key] = array
62
+ for k, v in obs.metadata.items():
63
+ result[k] = v
64
+ return result
65
+
66
+ def _ndarray_to_tensor(
67
+ self, key: str, arr: np.ndarray, offset: int
68
+ ) -> tuple[Tensor, bytes]:
69
+ data = arr.tobytes()
70
+ tensor = Tensor()
71
+ tensor.key = key
72
+ tensor.dtype = dtype_from_numpy(arr.dtype)
73
+ tensor.shape.extend(arr.shape)
74
+ tensor.byte_offset = offset
75
+ tensor.byte_length = len(data)
76
+ tensor.encoding = RAW
77
+ return tensor, data
78
+
79
+ def _build_result(
80
+ self,
81
+ req: InferenceRequest,
82
+ ray_result: Any,
83
+ latency_ms: float,
84
+ ) -> DispatchResult:
85
+ resp = ModelResponse()
86
+ resp.client.CopyFrom(Client(id=req.client_id))
87
+ resp.response_id = uuid.uuid4().hex
88
+ resp.timestamp_ns = int(time.time() * 1_000_000_000)
89
+ resp.inference_latency_ms = latency_ms
90
+ resp.model_id = req.model_id
91
+
92
+ payload_parts: list[bytes] = []
93
+ offset = 0
94
+
95
+ if isinstance(ray_result, np.ndarray):
96
+ tensor, data = self._ndarray_to_tensor("actions", ray_result, offset)
97
+ resp.tensors.append(tensor)
98
+ payload_parts.append(data)
99
+ elif isinstance(ray_result, dict):
100
+ for key, value in ray_result.items():
101
+ if isinstance(value, np.ndarray):
102
+ tensor, data = self._ndarray_to_tensor(key, value, offset)
103
+ resp.tensors.append(tensor)
104
+ payload_parts.append(data)
105
+ offset += len(data)
106
+ elif isinstance(value, str):
107
+ resp.metadata[key] = value
108
+
109
+ payload = b"".join(payload_parts)
110
+
111
+ return DispatchResult(
112
+ client_id=req.client_id,
113
+ response_id=resp.response_id,
114
+ identity=req.identity,
115
+ envelope=resp.SerializeToString(),
116
+ payload=payload,
117
+ latency_ms=latency_ms,
118
+ success=True,
119
+ )
120
+
121
+ async def dispatch(
122
+ self, requests: list[InferenceRequest]
123
+ ) -> list[DispatchResult]:
124
+ results: list[DispatchResult] = []
125
+ by_model: dict[str, list[InferenceRequest]] = defaultdict(list)
126
+ for req in requests:
127
+ by_model[req.model_id].append(req)
128
+
129
+ for model_id, model_requests in by_model.items():
130
+ handle = self._get_handle(model_id)
131
+ for req in model_requests:
132
+ start = time.monotonic()
133
+ try:
134
+ obs_dict = self._reconstruct_numpy(req)
135
+ result = await handle.infer.remote(obs_dict)
136
+ latency = (time.monotonic() - start) * 1000
137
+ self._update_health(model_id, latency, True)
138
+ results.append(self._build_result(req, result, latency))
139
+ except Exception as e:
140
+ latency = (time.monotonic() - start) * 1000
141
+ self._update_health(model_id, latency, False)
142
+ logger.error("Dispatch failed for client %s: %s", req.client_id, e)
143
+ results.append(
144
+ DispatchResult(
145
+ client_id=req.client_id,
146
+ response_id=uuid.uuid4().hex,
147
+ identity=req.identity,
148
+ envelope=b"",
149
+ payload=b"",
150
+ latency_ms=latency,
151
+ success=False,
152
+ error=str(e),
153
+ )
154
+ )
155
+ return results
156
+
157
+ def _update_health(self, model_id: str, latency_ms: float, success: bool) -> None:
158
+ if model_id not in self._health:
159
+ self._health[model_id] = EndpointHealth(model_id=model_id)
160
+ self._health[model_id].update(latency_ms, success)
161
+
162
+ def endpoint_health(self, model_id: str) -> EndpointHealth | None:
163
+ return self._health.get(model_id)
164
+
165
+ def get_handle(self, model_id: str) -> Any | None:
166
+ return self._handles.get(model_id)
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class EndpointHealth:
8
+ model_id: str
9
+ avg_latency_ms: float = 0.0
10
+ error_rate: float = 0.0
11
+ total_requests: int = 0
12
+ last_error: str | None = None
13
+ _alpha: float = field(default=0.2, repr=False)
14
+
15
+ def update(self, latency_ms: float, success: bool) -> None:
16
+ self.total_requests += 1
17
+ self.avg_latency_ms = (
18
+ self._alpha * latency_ms + (1 - self._alpha) * self.avg_latency_ms
19
+ )
20
+ error_val = 0.0 if success else 1.0
21
+ self.error_rate = self._alpha * error_val + (1 - self._alpha) * self.error_rate
22
+ if not success:
23
+ self.last_error = f"request #{self.total_requests}"
@@ -0,0 +1,13 @@
1
+ from inferential.metrics.callbacks import CallbackRegistry, MetricCallback
2
+ from inferential.metrics.collector import MetricsCollector, MetricSnapshot
3
+ from inferential.metrics.store import MetricPoint, MetricSeries, MetricStats
4
+
5
+ __all__ = [
6
+ "CallbackRegistry",
7
+ "MetricCallback",
8
+ "MetricPoint",
9
+ "MetricSeries",
10
+ "MetricSnapshot",
11
+ "MetricStats",
12
+ "MetricsCollector",
13
+ ]
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ MetricCallback = Callable[[str, float, dict[str, str]], None]
6
+
7
+
8
+ class CallbackRegistry:
9
+ def __init__(self) -> None:
10
+ self._callbacks: list[MetricCallback] = []
11
+
12
+ def register(self, callback: MetricCallback) -> None:
13
+ self._callbacks.append(callback)
14
+
15
+ def fire(self, name: str, value: float, labels: dict[str, str]) -> None:
16
+ for cb in self._callbacks:
17
+ try:
18
+ cb(name, value, labels)
19
+ except Exception:
20
+ pass # Callbacks must not crash the server
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ from inferential.metrics.callbacks import CallbackRegistry, MetricCallback
7
+ from inferential.metrics.store import MetricSeries, MetricStats
8
+
9
+ if TYPE_CHECKING:
10
+ from inferential.config.schema import MetricsConfig
11
+
12
+
13
+ @dataclass
14
+ class MetricSnapshot:
15
+ metrics: dict[str, list[dict]]
16
+
17
+
18
+ class MetricsCollector:
19
+ def __init__(self, config: MetricsConfig | None = None) -> None:
20
+ self._ring_buffer_size = config.ring_buffer_size if config else 10000
21
+ self._enabled = config.enabled if config else True
22
+ self._series: dict[str, MetricSeries] = {}
23
+ self._callbacks = CallbackRegistry()
24
+
25
+ def record(
26
+ self,
27
+ name: str,
28
+ value: float,
29
+ labels: dict[str, str] | None = None,
30
+ ) -> None:
31
+ if not self._enabled:
32
+ return
33
+ labels = labels or {}
34
+ if name not in self._series:
35
+ self._series[name] = MetricSeries(max_size=self._ring_buffer_size)
36
+ self._series[name].record(value, labels)
37
+ self._callbacks.fire(name, value, labels)
38
+
39
+ def get_latest(
40
+ self,
41
+ name: str,
42
+ labels: dict[str, str] | None = None,
43
+ ) -> float | None:
44
+ series = self._series.get(name)
45
+ if series is None:
46
+ return None
47
+ return series.get_latest(labels)
48
+
49
+ def get_stats(
50
+ self,
51
+ name: str,
52
+ labels: dict[str, str] | None = None,
53
+ window_seconds: int = 300,
54
+ ) -> MetricStats | None:
55
+ series = self._series.get(name)
56
+ if series is None:
57
+ return None
58
+ return series.get_stats(labels, window_seconds)
59
+
60
+ def snapshot(self) -> MetricSnapshot:
61
+ result: dict[str, list[dict]] = {}
62
+ for name, series in self._series.items():
63
+ result[name] = [
64
+ {"timestamp": p.timestamp, "value": p.value, "labels": p.labels}
65
+ for p in series._buffer
66
+ ]
67
+ return MetricSnapshot(metrics=result)
68
+
69
+ def on_metric(self, callback: MetricCallback) -> MetricCallback:
70
+ self._callbacks.register(callback)
71
+ return callback