tokenrail 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tokenrail/__init__.py +31 -0
- tokenrail/catalog.py +159 -0
- tokenrail/client.py +55 -0
- tokenrail/executor.py +246 -0
- tokenrail/monitor.py +203 -0
- tokenrail/providers/__init__.py +3 -0
- tokenrail/providers/base.py +20 -0
- tokenrail/providers/openai.py +218 -0
- tokenrail/py.typed +0 -0
- tokenrail/sinks.py +104 -0
- tokenrail/types.py +219 -0
- tokenrail-1.0.0.dist-info/METADATA +143 -0
- tokenrail-1.0.0.dist-info/RECORD +15 -0
- tokenrail-1.0.0.dist-info/WHEEL +4 -0
- tokenrail-1.0.0.dist-info/licenses/LICENSE +21 -0
tokenrail/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Thin client and batch execution helpers for OpenAI Responses API workloads.
|
|
2
|
+
|
|
3
|
+
tokenrail wraps the OpenAI Responses API with a ``client.responses.create(...)``-style
|
|
4
|
+
surface and adds thread-based batch execution, client-side RPM/TPM submit throttling,
|
|
5
|
+
per-model token/cost monitoring, and resumable JSONL / per-request result writing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .client import RailClient
|
|
9
|
+
from .executor import BatchExecutor, batch_items_from_queries
|
|
10
|
+
from .monitor import RollingMetricsMonitor
|
|
11
|
+
from .providers import OpenAIProvider
|
|
12
|
+
from .sinks import PerRequestJsonSink, ResultsJsonlSink
|
|
13
|
+
from .types import BatchItem, CostBreakdown, NormalizedResponse, StatsSnapshot, UsageBreakdown
|
|
14
|
+
|
|
15
|
+
__version__ = "1.0.0"
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"__version__",
|
|
19
|
+
"BatchExecutor",
|
|
20
|
+
"BatchItem",
|
|
21
|
+
"CostBreakdown",
|
|
22
|
+
"NormalizedResponse",
|
|
23
|
+
"OpenAIProvider",
|
|
24
|
+
"PerRequestJsonSink",
|
|
25
|
+
"RailClient",
|
|
26
|
+
"ResultsJsonlSink",
|
|
27
|
+
"RollingMetricsMonitor",
|
|
28
|
+
"StatsSnapshot",
|
|
29
|
+
"UsageBreakdown",
|
|
30
|
+
"batch_items_from_queries",
|
|
31
|
+
]
|
tokenrail/catalog.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from typing import TypeVar
|
|
7
|
+
|
|
8
|
+
from .types import CostBreakdown, UsageBreakdown
|
|
9
|
+
|
|
10
|
+
_T = TypeVar("_T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True, slots=True)
|
|
14
|
+
class ModelCapabilities:
|
|
15
|
+
reasoning_effort: bool
|
|
16
|
+
verbosity: bool
|
|
17
|
+
temperature: bool
|
|
18
|
+
top_p: bool
|
|
19
|
+
max_output_tokens: bool
|
|
20
|
+
response_format: bool
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True, slots=True)
|
|
24
|
+
class ModelPricing:
|
|
25
|
+
input_per_million: Decimal
|
|
26
|
+
cached_input_per_million: Decimal | None
|
|
27
|
+
output_per_million: Decimal
|
|
28
|
+
service_tier: str = "default"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_CAPABILITY_RULES: list[tuple[tuple[str, ...], ModelCapabilities]] = [
|
|
32
|
+
(
|
|
33
|
+
("gpt-5", "o1", "o3", "o4"),
|
|
34
|
+
ModelCapabilities(
|
|
35
|
+
reasoning_effort=True,
|
|
36
|
+
verbosity=True,
|
|
37
|
+
temperature=True,
|
|
38
|
+
top_p=True,
|
|
39
|
+
max_output_tokens=True,
|
|
40
|
+
response_format=True,
|
|
41
|
+
),
|
|
42
|
+
),
|
|
43
|
+
(
|
|
44
|
+
("gpt-4.1", "gpt-4o"),
|
|
45
|
+
ModelCapabilities(
|
|
46
|
+
reasoning_effort=False,
|
|
47
|
+
verbosity=True,
|
|
48
|
+
temperature=True,
|
|
49
|
+
top_p=True,
|
|
50
|
+
max_output_tokens=True,
|
|
51
|
+
response_format=True,
|
|
52
|
+
),
|
|
53
|
+
),
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
_DEFAULT_CAPABILITIES = ModelCapabilities(
|
|
57
|
+
reasoning_effort=False,
|
|
58
|
+
verbosity=False,
|
|
59
|
+
temperature=True,
|
|
60
|
+
top_p=True,
|
|
61
|
+
max_output_tokens=True,
|
|
62
|
+
response_format=True,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
_PRICING_RULES: list[tuple[tuple[str, ...], ModelPricing]] = [
|
|
66
|
+
(("gpt-5.5",), ModelPricing(Decimal("5.00"), Decimal("0.50"), Decimal("30.00"))),
|
|
67
|
+
(("gpt-5.4-mini",), ModelPricing(Decimal("0.750"), Decimal("0.075"), Decimal("4.500"))),
|
|
68
|
+
(("gpt-5.4-nano",), ModelPricing(Decimal("0.20"), Decimal("0.02"), Decimal("1.25"))),
|
|
69
|
+
(("gpt-5.4",), ModelPricing(Decimal("2.50"), Decimal("0.25"), Decimal("15.00"))),
|
|
70
|
+
(("gpt-5.2",), ModelPricing(Decimal("1.75"), Decimal("0.175"), Decimal("14.00"))),
|
|
71
|
+
(("gpt-5-mini",), ModelPricing(Decimal("0.25"), Decimal("0.025"), Decimal("2.00"))),
|
|
72
|
+
(("gpt-5-nano",), ModelPricing(Decimal("0.05"), Decimal("0.005"), Decimal("0.40"))),
|
|
73
|
+
(("gpt-5",), ModelPricing(Decimal("1.25"), Decimal("0.125"), Decimal("10.00"))),
|
|
74
|
+
(("gpt-4.1-mini",), ModelPricing(Decimal("0.40"), Decimal("0.10"), Decimal("1.60"))),
|
|
75
|
+
(("gpt-4.1-nano",), ModelPricing(Decimal("0.10"), Decimal("0.025"), Decimal("0.40"))),
|
|
76
|
+
(("gpt-4.1",), ModelPricing(Decimal("2.00"), Decimal("0.50"), Decimal("8.00"))),
|
|
77
|
+
(("gpt-4o-mini",), ModelPricing(Decimal("0.15"), Decimal("0.075"), Decimal("0.60"))),
|
|
78
|
+
(("gpt-4o",), ModelPricing(Decimal("2.50"), Decimal("1.25"), Decimal("10.00"))),
|
|
79
|
+
(("o4-mini",), ModelPricing(Decimal("1.10"), Decimal("0.275"), Decimal("4.40"))),
|
|
80
|
+
(("o3",), ModelPricing(Decimal("2.00"), Decimal("0.50"), Decimal("8.00"))),
|
|
81
|
+
(("o1",), ModelPricing(Decimal("15.00"), Decimal("7.50"), Decimal("60.00"))),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
_MODEL_NAME_DELIMITERS = {"-", "_", "/", ":", " ", "."}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _is_delimited_match(model: str, candidate: str, start: int, end: int) -> bool:
|
|
89
|
+
before = model[start - 1] if start > 0 else None
|
|
90
|
+
after = model[end] if end < len(model) else None
|
|
91
|
+
before_ok = before is None or before in _MODEL_NAME_DELIMITERS
|
|
92
|
+
after_ok = after is None or after in _MODEL_NAME_DELIMITERS
|
|
93
|
+
return before_ok and after_ok
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _match_rule(model: str, rules: Iterable[tuple[tuple[str, ...], _T]]) -> _T | None:
|
|
97
|
+
matches: list[tuple[int, int, _T]] = []
|
|
98
|
+
for rule_index, (prefixes, payload) in enumerate(rules):
|
|
99
|
+
for prefix in prefixes:
|
|
100
|
+
start = model.find(prefix)
|
|
101
|
+
while start != -1:
|
|
102
|
+
end = start + len(prefix)
|
|
103
|
+
if _is_delimited_match(model, prefix, start, end):
|
|
104
|
+
matches.append((len(prefix), rule_index, payload))
|
|
105
|
+
break
|
|
106
|
+
start = model.find(prefix, start + 1)
|
|
107
|
+
|
|
108
|
+
if not matches:
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
matches.sort(key=lambda item: (-item[0], item[1]))
|
|
112
|
+
return matches[0][2]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_model_capabilities(model: str) -> ModelCapabilities:
|
|
116
|
+
"""Return the request-parameter capabilities for ``model``.
|
|
117
|
+
|
|
118
|
+
Matching is delimiter-aware substring matching against the checked-in
|
|
119
|
+
capability registry; unknown models fall back to a conservative default.
|
|
120
|
+
"""
|
|
121
|
+
return _match_rule(model, _CAPABILITY_RULES) or _DEFAULT_CAPABILITIES
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_model_pricing(model: str, service_tier: str = "default") -> ModelPricing | None:
|
|
125
|
+
"""Return per-million-token pricing for ``model``, or ``None`` if unknown.
|
|
126
|
+
|
|
127
|
+
The checked-in registry only carries default-tier prices; for other service
|
|
128
|
+
tiers the default-tier price is returned as an approximation.
|
|
129
|
+
"""
|
|
130
|
+
return _match_rule(model, _PRICING_RULES)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def calculate_cost(
|
|
134
|
+
model: str,
|
|
135
|
+
usage: UsageBreakdown,
|
|
136
|
+
payer: str | None,
|
|
137
|
+
service_tier: str = "default",
|
|
138
|
+
) -> CostBreakdown | None:
|
|
139
|
+
"""Compute the nominal USD cost of ``usage`` and attribute it to a payer.
|
|
140
|
+
|
|
141
|
+
Returns ``None`` when the model has no pricing entry. When ``payer`` is
|
|
142
|
+
``"openai"`` the cost is attributed to OpenAI instead of the developer.
|
|
143
|
+
"""
|
|
144
|
+
pricing = get_model_pricing(model, service_tier=service_tier)
|
|
145
|
+
if pricing is None:
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
uncached_input = max(usage.input_tokens - usage.cached_tokens, 0)
|
|
149
|
+
cached_rate = pricing.cached_input_per_million or Decimal("0")
|
|
150
|
+
total = (
|
|
151
|
+
(Decimal(uncached_input) * pricing.input_per_million)
|
|
152
|
+
+ (Decimal(usage.cached_tokens) * cached_rate)
|
|
153
|
+
+ (Decimal(usage.output_tokens) * pricing.output_per_million)
|
|
154
|
+
) / Decimal("1000000")
|
|
155
|
+
|
|
156
|
+
nominal = float(total)
|
|
157
|
+
if payer == "openai":
|
|
158
|
+
return CostBreakdown(nominal_usd=nominal, developer_usd=0.0, openai_usd=nominal, payer=payer)
|
|
159
|
+
return CostBreakdown(nominal_usd=nominal, developer_usd=nominal, openai_usd=0.0, payer=payer)
|
tokenrail/client.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .providers.base import BaseProvider
|
|
6
|
+
from .providers.openai import OpenAIProvider
|
|
7
|
+
from .types import NormalizedResponse
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _ResponsesNamespace:
|
|
11
|
+
def __init__(self, provider: BaseProvider) -> None:
|
|
12
|
+
self._provider = provider
|
|
13
|
+
|
|
14
|
+
def create(self, **kwargs: Any) -> NormalizedResponse:
|
|
15
|
+
return self._provider.create(**kwargs)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RailClient:
|
|
19
|
+
"""Provider-agnostic client with a ``client.responses.create(...)`` surface.
|
|
20
|
+
|
|
21
|
+
Wraps a :class:`~tokenrail.providers.base.BaseProvider` and exposes it through
|
|
22
|
+
a ``responses`` namespace that mirrors the OpenAI SDK call shape while
|
|
23
|
+
returning :class:`~tokenrail.types.NormalizedResponse` objects.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, provider: BaseProvider) -> None:
|
|
27
|
+
self.provider = provider
|
|
28
|
+
self.responses = _ResponsesNamespace(provider)
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def openai(
|
|
32
|
+
cls,
|
|
33
|
+
*,
|
|
34
|
+
api_key: str | None = None,
|
|
35
|
+
organization: str | None = None,
|
|
36
|
+
timeout: float | None = None,
|
|
37
|
+
base_url: str | None = None,
|
|
38
|
+
max_retries: int = 2,
|
|
39
|
+
client: Any | None = None,
|
|
40
|
+
) -> RailClient:
|
|
41
|
+
"""Build a :class:`RailClient` backed by the OpenAI Python SDK.
|
|
42
|
+
|
|
43
|
+
``max_retries`` configures the SDK's built-in retry behavior; tokenrail
|
|
44
|
+
does not add its own retry loop. Pass ``client`` to inject a pre-built
|
|
45
|
+
(or fake) OpenAI client instead of constructing one.
|
|
46
|
+
"""
|
|
47
|
+
provider = OpenAIProvider(
|
|
48
|
+
client=client,
|
|
49
|
+
api_key=api_key,
|
|
50
|
+
organization=organization,
|
|
51
|
+
timeout=timeout,
|
|
52
|
+
base_url=base_url,
|
|
53
|
+
max_retries=max_retries,
|
|
54
|
+
)
|
|
55
|
+
return cls(provider=provider)
|
tokenrail/executor.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from collections import deque
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .monitor import RollingMetricsMonitor
|
|
10
|
+
from .sinks import ResultSink
|
|
11
|
+
from .types import BatchItem, NormalizedResponse, StatsSnapshot, TimingBreakdown, UsageBreakdown
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def batch_items_from_queries(queries: dict[str, Any], **shared_request_kwargs: Any) -> list[BatchItem]:
|
|
15
|
+
"""Build :class:`BatchItem` objects from an ``{id: input}`` mapping.
|
|
16
|
+
|
|
17
|
+
Each value becomes the request ``input``; ``shared_request_kwargs`` (e.g.
|
|
18
|
+
``model``, ``reasoning_effort``) are applied to every item.
|
|
19
|
+
"""
|
|
20
|
+
return [
|
|
21
|
+
BatchItem(id=str(item_id), request_kwargs={"input": messages, **shared_request_kwargs})
|
|
22
|
+
for item_id, messages in queries.items()
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _error_response(item_id: str, model: str, provider: str, error: Exception) -> NormalizedResponse:
|
|
27
|
+
return NormalizedResponse(
|
|
28
|
+
id=item_id,
|
|
29
|
+
model=model,
|
|
30
|
+
provider=provider,
|
|
31
|
+
output_text=None,
|
|
32
|
+
raw_response=None,
|
|
33
|
+
usage=UsageBreakdown.empty(),
|
|
34
|
+
billing=None,
|
|
35
|
+
cost=None,
|
|
36
|
+
timing=TimingBreakdown(started_at=0.0, completed_at=0.0, latency_seconds=0.0),
|
|
37
|
+
error=f"{type(error).__name__}: {error}",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _SubmitRateLimiter:
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
max_rpm: int | None,
|
|
46
|
+
max_tpm: int | None,
|
|
47
|
+
window_seconds: float = 60.0,
|
|
48
|
+
time_fn: Callable[[], float] = time.time,
|
|
49
|
+
sleep_fn: Callable[[float], None] = time.sleep,
|
|
50
|
+
) -> None:
|
|
51
|
+
if max_rpm is not None and max_rpm < 1:
|
|
52
|
+
raise ValueError("max_rpm must be at least 1")
|
|
53
|
+
if max_tpm is not None and max_tpm < 1:
|
|
54
|
+
raise ValueError("max_tpm must be at least 1")
|
|
55
|
+
self.max_rpm = max_rpm
|
|
56
|
+
self.max_tpm = max_tpm
|
|
57
|
+
self.window_seconds = window_seconds
|
|
58
|
+
self.time_fn = time_fn
|
|
59
|
+
self.sleep_fn = sleep_fn
|
|
60
|
+
self._submitted_at: deque[float] = deque()
|
|
61
|
+
self._completed_events: deque[tuple[float, int]] = deque()
|
|
62
|
+
self._inflight_estimates: deque[int] = deque()
|
|
63
|
+
self._inflight_estimated_tokens = 0
|
|
64
|
+
self._completed_requests = 0
|
|
65
|
+
self._completed_tokens = 0
|
|
66
|
+
|
|
67
|
+
def _prune(self, now: float) -> None:
|
|
68
|
+
cutoff = now - self.window_seconds
|
|
69
|
+
while self._submitted_at and self._submitted_at[0] <= cutoff:
|
|
70
|
+
self._submitted_at.popleft()
|
|
71
|
+
while self._completed_events and self._completed_events[0][0] <= cutoff:
|
|
72
|
+
self._completed_events.popleft()
|
|
73
|
+
|
|
74
|
+
def _estimated_next_tokens(self) -> int:
|
|
75
|
+
if self._completed_requests == 0:
|
|
76
|
+
return 0
|
|
77
|
+
return (self._completed_tokens + self._completed_requests - 1) // self._completed_requests
|
|
78
|
+
|
|
79
|
+
def _rolling_completed_tokens(self) -> int:
|
|
80
|
+
return sum(tokens for _, tokens in self._completed_events)
|
|
81
|
+
|
|
82
|
+
def can_submit(self) -> bool:
|
|
83
|
+
now = self.time_fn()
|
|
84
|
+
self._prune(now)
|
|
85
|
+
if self.max_rpm is not None and len(self._submitted_at) >= self.max_rpm:
|
|
86
|
+
return False
|
|
87
|
+
if self.max_tpm is not None:
|
|
88
|
+
if self._completed_requests == 0 and self._submitted_at:
|
|
89
|
+
return False
|
|
90
|
+
estimated_next = self._estimated_next_tokens()
|
|
91
|
+
if not self._completed_events and self._inflight_estimated_tokens == 0:
|
|
92
|
+
return True
|
|
93
|
+
if self._rolling_completed_tokens() + self._inflight_estimated_tokens + estimated_next > self.max_tpm:
|
|
94
|
+
return False
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
def retry_after(self) -> float | None:
|
|
98
|
+
now = self.time_fn()
|
|
99
|
+
self._prune(now)
|
|
100
|
+
waits: list[float] = []
|
|
101
|
+
if self.max_rpm is not None and len(self._submitted_at) >= self.max_rpm:
|
|
102
|
+
waits.append(self._submitted_at[0] + self.window_seconds - now)
|
|
103
|
+
if self.max_tpm is not None and self._completed_events:
|
|
104
|
+
projected = (
|
|
105
|
+
self._rolling_completed_tokens() + self._inflight_estimated_tokens + self._estimated_next_tokens()
|
|
106
|
+
)
|
|
107
|
+
if projected > self.max_tpm:
|
|
108
|
+
waits.append(self._completed_events[0][0] + self.window_seconds - now)
|
|
109
|
+
if waits:
|
|
110
|
+
return max(min(waits), 0.0)
|
|
111
|
+
if not self.can_submit():
|
|
112
|
+
return None
|
|
113
|
+
return 0.0
|
|
114
|
+
|
|
115
|
+
def wait_until_allowed(self) -> None:
|
|
116
|
+
while not self.can_submit():
|
|
117
|
+
self.sleep_fn(self.retry_after() or 0.01)
|
|
118
|
+
|
|
119
|
+
def record_submit(self) -> None:
|
|
120
|
+
now = self.time_fn()
|
|
121
|
+
self._prune(now)
|
|
122
|
+
self._submitted_at.append(now)
|
|
123
|
+
estimated_tokens = self._estimated_next_tokens() if self.max_tpm is not None else 0
|
|
124
|
+
self._inflight_estimates.append(estimated_tokens)
|
|
125
|
+
self._inflight_estimated_tokens += estimated_tokens
|
|
126
|
+
|
|
127
|
+
def record_completion(self, response: NormalizedResponse) -> None:
|
|
128
|
+
now = self.time_fn()
|
|
129
|
+
self._prune(now)
|
|
130
|
+
if self._inflight_estimates:
|
|
131
|
+
self._inflight_estimated_tokens -= self._inflight_estimates.popleft()
|
|
132
|
+
total_tokens = response.usage.total_tokens or (response.usage.input_tokens + response.usage.output_tokens)
|
|
133
|
+
self._completed_events.append((now, total_tokens))
|
|
134
|
+
self._completed_requests += 1
|
|
135
|
+
self._completed_tokens += total_tokens
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class BatchExecutor:
|
|
139
|
+
"""Thread-based batch runner for :class:`~tokenrail.client.RailClient` requests.
|
|
140
|
+
|
|
141
|
+
Submits items to a thread pool while honoring optional client-side
|
|
142
|
+
``max_rpm`` / ``max_tpm`` submit limits, writes each result to the
|
|
143
|
+
configured sinks, and records metrics on the monitor. Items whose ids are
|
|
144
|
+
already present in the first sink are skipped, which makes re-runs
|
|
145
|
+
resumable. Request errors are captured as error responses rather than
|
|
146
|
+
raised, so a single failing item does not abort the batch.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
*,
|
|
152
|
+
client: Any,
|
|
153
|
+
max_workers: int = 20,
|
|
154
|
+
max_rpm: int | None = None,
|
|
155
|
+
max_tpm: int | None = None,
|
|
156
|
+
sinks: Sequence[ResultSink] | None = None,
|
|
157
|
+
monitor: RollingMetricsMonitor | None = None,
|
|
158
|
+
) -> None:
|
|
159
|
+
self.client = client
|
|
160
|
+
self.max_workers = max_workers
|
|
161
|
+
self.max_rpm = max_rpm
|
|
162
|
+
self.max_tpm = max_tpm
|
|
163
|
+
self.sinks = list(sinks or [])
|
|
164
|
+
self.monitor = monitor or RollingMetricsMonitor()
|
|
165
|
+
self._time_fn = time.time
|
|
166
|
+
self._sleep_fn = time.sleep
|
|
167
|
+
|
|
168
|
+
def _save(self, response: NormalizedResponse) -> None:
|
|
169
|
+
for sink in self.sinks:
|
|
170
|
+
sink.save(response)
|
|
171
|
+
|
|
172
|
+
def _load_done_ids(self) -> set[str]:
|
|
173
|
+
if not self.sinks:
|
|
174
|
+
return set()
|
|
175
|
+
return self.sinks[0].load_done_ids()
|
|
176
|
+
|
|
177
|
+
def _prepare_items(self, items: Sequence[BatchItem] | dict[str, Any]) -> list[BatchItem]:
|
|
178
|
+
if isinstance(items, dict):
|
|
179
|
+
return batch_items_from_queries(items)
|
|
180
|
+
return [BatchItem(id=str(item.id), request_kwargs=dict(item.request_kwargs)) for item in items]
|
|
181
|
+
|
|
182
|
+
def _request_kwargs(self, item: BatchItem) -> dict[str, Any]:
|
|
183
|
+
request_kwargs = dict(item.request_kwargs)
|
|
184
|
+
request_kwargs.setdefault("request_id", item.id)
|
|
185
|
+
return request_kwargs
|
|
186
|
+
|
|
187
|
+
def _call_single(self, item: BatchItem) -> NormalizedResponse:
|
|
188
|
+
request_kwargs = self._request_kwargs(item)
|
|
189
|
+
try:
|
|
190
|
+
return self.client.responses.create(**request_kwargs)
|
|
191
|
+
except Exception as exc:
|
|
192
|
+
model = str(request_kwargs.get("model") or getattr(self.client.provider, "model_id", "unknown"))
|
|
193
|
+
return _error_response(item.id, model=model, provider=self.client.provider.name, error=exc)
|
|
194
|
+
|
|
195
|
+
def _run_threaded(self, items: list[BatchItem]) -> None:
|
|
196
|
+
limiter = _SubmitRateLimiter(
|
|
197
|
+
max_rpm=self.max_rpm,
|
|
198
|
+
max_tpm=self.max_tpm,
|
|
199
|
+
time_fn=self._time_fn,
|
|
200
|
+
sleep_fn=self._sleep_fn,
|
|
201
|
+
)
|
|
202
|
+
next_item = 0
|
|
203
|
+
pending: set[Future[NormalizedResponse]] = set()
|
|
204
|
+
|
|
205
|
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
206
|
+
while next_item < len(items) or pending:
|
|
207
|
+
while next_item < len(items) and len(pending) < self.max_workers and limiter.can_submit():
|
|
208
|
+
limiter.record_submit()
|
|
209
|
+
pending.add(executor.submit(self._call_single, items[next_item]))
|
|
210
|
+
next_item += 1
|
|
211
|
+
|
|
212
|
+
if not pending:
|
|
213
|
+
if next_item < len(items):
|
|
214
|
+
limiter.wait_until_allowed()
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
timeout = None
|
|
218
|
+
if next_item < len(items) and len(pending) < self.max_workers and not limiter.can_submit():
|
|
219
|
+
timeout = limiter.retry_after()
|
|
220
|
+
done, pending = wait(pending, timeout=timeout, return_when=FIRST_COMPLETED)
|
|
221
|
+
if not done:
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
for future in done:
|
|
225
|
+
response = future.result()
|
|
226
|
+
limiter.record_completion(response)
|
|
227
|
+
self._save(response)
|
|
228
|
+
self.monitor.record(response)
|
|
229
|
+
|
|
230
|
+
def run(self, items: Sequence[BatchItem] | dict[str, Any]) -> StatsSnapshot:
|
|
231
|
+
"""Execute ``items`` (a sequence of :class:`BatchItem` or an ``{id: input}``
|
|
232
|
+
dict) and return the final :class:`~tokenrail.types.StatsSnapshot`."""
|
|
233
|
+
self.monitor.reset()
|
|
234
|
+
normalized_items = self._prepare_items(items)
|
|
235
|
+
done_ids = self._load_done_ids()
|
|
236
|
+
todo = [item for item in normalized_items if item.id not in done_ids]
|
|
237
|
+
skipped = len(normalized_items) - len(todo)
|
|
238
|
+
self.monitor.start(
|
|
239
|
+
total_requests=len(normalized_items),
|
|
240
|
+
todo_requests=len(todo),
|
|
241
|
+
skipped_requests=skipped,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
self._run_threaded(todo)
|
|
245
|
+
|
|
246
|
+
return self.monitor.finalize(total_requests=len(normalized_items), skipped_requests=skipped)
|