tokenrail 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tokenrail/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ """Thin client and batch execution helpers for OpenAI Responses API workloads.
2
+
3
+ tokenrail wraps the OpenAI Responses API with a ``client.responses.create(...)``-style
4
+ surface and adds thread-based batch execution, client-side RPM/TPM submit throttling,
5
+ per-model token/cost monitoring, and resumable JSONL / per-request result writing.
6
+ """
7
+
8
+ from .client import RailClient
9
+ from .executor import BatchExecutor, batch_items_from_queries
10
+ from .monitor import RollingMetricsMonitor
11
+ from .providers import OpenAIProvider
12
+ from .sinks import PerRequestJsonSink, ResultsJsonlSink
13
+ from .types import BatchItem, CostBreakdown, NormalizedResponse, StatsSnapshot, UsageBreakdown
14
+
15
+ __version__ = "1.0.0"
16
+
17
+ __all__ = [
18
+ "__version__",
19
+ "BatchExecutor",
20
+ "BatchItem",
21
+ "CostBreakdown",
22
+ "NormalizedResponse",
23
+ "OpenAIProvider",
24
+ "PerRequestJsonSink",
25
+ "RailClient",
26
+ "ResultsJsonlSink",
27
+ "RollingMetricsMonitor",
28
+ "StatsSnapshot",
29
+ "UsageBreakdown",
30
+ "batch_items_from_queries",
31
+ ]
tokenrail/catalog.py ADDED
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass
5
+ from decimal import Decimal
6
+ from typing import TypeVar
7
+
8
+ from .types import CostBreakdown, UsageBreakdown
9
+
10
+ _T = TypeVar("_T")
11
+
12
+
13
+ @dataclass(frozen=True, slots=True)
14
+ class ModelCapabilities:
15
+ reasoning_effort: bool
16
+ verbosity: bool
17
+ temperature: bool
18
+ top_p: bool
19
+ max_output_tokens: bool
20
+ response_format: bool
21
+
22
+
23
+ @dataclass(frozen=True, slots=True)
24
+ class ModelPricing:
25
+ input_per_million: Decimal
26
+ cached_input_per_million: Decimal | None
27
+ output_per_million: Decimal
28
+ service_tier: str = "default"
29
+
30
+
31
+ _CAPABILITY_RULES: list[tuple[tuple[str, ...], ModelCapabilities]] = [
32
+ (
33
+ ("gpt-5", "o1", "o3", "o4"),
34
+ ModelCapabilities(
35
+ reasoning_effort=True,
36
+ verbosity=True,
37
+ temperature=True,
38
+ top_p=True,
39
+ max_output_tokens=True,
40
+ response_format=True,
41
+ ),
42
+ ),
43
+ (
44
+ ("gpt-4.1", "gpt-4o"),
45
+ ModelCapabilities(
46
+ reasoning_effort=False,
47
+ verbosity=True,
48
+ temperature=True,
49
+ top_p=True,
50
+ max_output_tokens=True,
51
+ response_format=True,
52
+ ),
53
+ ),
54
+ ]
55
+
56
+ _DEFAULT_CAPABILITIES = ModelCapabilities(
57
+ reasoning_effort=False,
58
+ verbosity=False,
59
+ temperature=True,
60
+ top_p=True,
61
+ max_output_tokens=True,
62
+ response_format=True,
63
+ )
64
+
65
+ _PRICING_RULES: list[tuple[tuple[str, ...], ModelPricing]] = [
66
+ (("gpt-5.5",), ModelPricing(Decimal("5.00"), Decimal("0.50"), Decimal("30.00"))),
67
+ (("gpt-5.4-mini",), ModelPricing(Decimal("0.750"), Decimal("0.075"), Decimal("4.500"))),
68
+ (("gpt-5.4-nano",), ModelPricing(Decimal("0.20"), Decimal("0.02"), Decimal("1.25"))),
69
+ (("gpt-5.4",), ModelPricing(Decimal("2.50"), Decimal("0.25"), Decimal("15.00"))),
70
+ (("gpt-5.2",), ModelPricing(Decimal("1.75"), Decimal("0.175"), Decimal("14.00"))),
71
+ (("gpt-5-mini",), ModelPricing(Decimal("0.25"), Decimal("0.025"), Decimal("2.00"))),
72
+ (("gpt-5-nano",), ModelPricing(Decimal("0.05"), Decimal("0.005"), Decimal("0.40"))),
73
+ (("gpt-5",), ModelPricing(Decimal("1.25"), Decimal("0.125"), Decimal("10.00"))),
74
+ (("gpt-4.1-mini",), ModelPricing(Decimal("0.40"), Decimal("0.10"), Decimal("1.60"))),
75
+ (("gpt-4.1-nano",), ModelPricing(Decimal("0.10"), Decimal("0.025"), Decimal("0.40"))),
76
+ (("gpt-4.1",), ModelPricing(Decimal("2.00"), Decimal("0.50"), Decimal("8.00"))),
77
+ (("gpt-4o-mini",), ModelPricing(Decimal("0.15"), Decimal("0.075"), Decimal("0.60"))),
78
+ (("gpt-4o",), ModelPricing(Decimal("2.50"), Decimal("1.25"), Decimal("10.00"))),
79
+ (("o4-mini",), ModelPricing(Decimal("1.10"), Decimal("0.275"), Decimal("4.40"))),
80
+ (("o3",), ModelPricing(Decimal("2.00"), Decimal("0.50"), Decimal("8.00"))),
81
+ (("o1",), ModelPricing(Decimal("15.00"), Decimal("7.50"), Decimal("60.00"))),
82
+ ]
83
+
84
+
85
+ _MODEL_NAME_DELIMITERS = {"-", "_", "/", ":", " ", "."}
86
+
87
+
88
+ def _is_delimited_match(model: str, candidate: str, start: int, end: int) -> bool:
89
+ before = model[start - 1] if start > 0 else None
90
+ after = model[end] if end < len(model) else None
91
+ before_ok = before is None or before in _MODEL_NAME_DELIMITERS
92
+ after_ok = after is None or after in _MODEL_NAME_DELIMITERS
93
+ return before_ok and after_ok
94
+
95
+
96
+ def _match_rule(model: str, rules: Iterable[tuple[tuple[str, ...], _T]]) -> _T | None:
97
+ matches: list[tuple[int, int, _T]] = []
98
+ for rule_index, (prefixes, payload) in enumerate(rules):
99
+ for prefix in prefixes:
100
+ start = model.find(prefix)
101
+ while start != -1:
102
+ end = start + len(prefix)
103
+ if _is_delimited_match(model, prefix, start, end):
104
+ matches.append((len(prefix), rule_index, payload))
105
+ break
106
+ start = model.find(prefix, start + 1)
107
+
108
+ if not matches:
109
+ return None
110
+
111
+ matches.sort(key=lambda item: (-item[0], item[1]))
112
+ return matches[0][2]
113
+
114
+
115
+ def get_model_capabilities(model: str) -> ModelCapabilities:
116
+ """Return the request-parameter capabilities for ``model``.
117
+
118
+ Matching is delimiter-aware substring matching against the checked-in
119
+ capability registry; unknown models fall back to a conservative default.
120
+ """
121
+ return _match_rule(model, _CAPABILITY_RULES) or _DEFAULT_CAPABILITIES
122
+
123
+
124
+ def get_model_pricing(model: str, service_tier: str = "default") -> ModelPricing | None:
125
+ """Return per-million-token pricing for ``model``, or ``None`` if unknown.
126
+
127
+ The checked-in registry only carries default-tier prices; for other service
128
+ tiers the default-tier price is returned as an approximation.
129
+ """
130
+ return _match_rule(model, _PRICING_RULES)
131
+
132
+
133
+ def calculate_cost(
134
+ model: str,
135
+ usage: UsageBreakdown,
136
+ payer: str | None,
137
+ service_tier: str = "default",
138
+ ) -> CostBreakdown | None:
139
+ """Compute the nominal USD cost of ``usage`` and attribute it to a payer.
140
+
141
+ Returns ``None`` when the model has no pricing entry. When ``payer`` is
142
+ ``"openai"`` the cost is attributed to OpenAI instead of the developer.
143
+ """
144
+ pricing = get_model_pricing(model, service_tier=service_tier)
145
+ if pricing is None:
146
+ return None
147
+
148
+ uncached_input = max(usage.input_tokens - usage.cached_tokens, 0)
149
+ cached_rate = pricing.cached_input_per_million or Decimal("0")
150
+ total = (
151
+ (Decimal(uncached_input) * pricing.input_per_million)
152
+ + (Decimal(usage.cached_tokens) * cached_rate)
153
+ + (Decimal(usage.output_tokens) * pricing.output_per_million)
154
+ ) / Decimal("1000000")
155
+
156
+ nominal = float(total)
157
+ if payer == "openai":
158
+ return CostBreakdown(nominal_usd=nominal, developer_usd=0.0, openai_usd=nominal, payer=payer)
159
+ return CostBreakdown(nominal_usd=nominal, developer_usd=nominal, openai_usd=0.0, payer=payer)
tokenrail/client.py ADDED
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from .providers.base import BaseProvider
6
+ from .providers.openai import OpenAIProvider
7
+ from .types import NormalizedResponse
8
+
9
+
10
+ class _ResponsesNamespace:
11
+ def __init__(self, provider: BaseProvider) -> None:
12
+ self._provider = provider
13
+
14
+ def create(self, **kwargs: Any) -> NormalizedResponse:
15
+ return self._provider.create(**kwargs)
16
+
17
+
18
+ class RailClient:
19
+ """Provider-agnostic client with a ``client.responses.create(...)`` surface.
20
+
21
+ Wraps a :class:`~tokenrail.providers.base.BaseProvider` and exposes it through
22
+ a ``responses`` namespace that mirrors the OpenAI SDK call shape while
23
+ returning :class:`~tokenrail.types.NormalizedResponse` objects.
24
+ """
25
+
26
+ def __init__(self, provider: BaseProvider) -> None:
27
+ self.provider = provider
28
+ self.responses = _ResponsesNamespace(provider)
29
+
30
+ @classmethod
31
+ def openai(
32
+ cls,
33
+ *,
34
+ api_key: str | None = None,
35
+ organization: str | None = None,
36
+ timeout: float | None = None,
37
+ base_url: str | None = None,
38
+ max_retries: int = 2,
39
+ client: Any | None = None,
40
+ ) -> RailClient:
41
+ """Build a :class:`RailClient` backed by the OpenAI Python SDK.
42
+
43
+ ``max_retries`` configures the SDK's built-in retry behavior; tokenrail
44
+ does not add its own retry loop. Pass ``client`` to inject a pre-built
45
+ (or fake) OpenAI client instead of constructing one.
46
+ """
47
+ provider = OpenAIProvider(
48
+ client=client,
49
+ api_key=api_key,
50
+ organization=organization,
51
+ timeout=timeout,
52
+ base_url=base_url,
53
+ max_retries=max_retries,
54
+ )
55
+ return cls(provider=provider)
tokenrail/executor.py ADDED
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from collections import deque
5
+ from collections.abc import Callable, Sequence
6
+ from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
7
+ from typing import Any
8
+
9
+ from .monitor import RollingMetricsMonitor
10
+ from .sinks import ResultSink
11
+ from .types import BatchItem, NormalizedResponse, StatsSnapshot, TimingBreakdown, UsageBreakdown
12
+
13
+
14
+ def batch_items_from_queries(queries: dict[str, Any], **shared_request_kwargs: Any) -> list[BatchItem]:
15
+ """Build :class:`BatchItem` objects from an ``{id: input}`` mapping.
16
+
17
+ Each value becomes the request ``input``; ``shared_request_kwargs`` (e.g.
18
+ ``model``, ``reasoning_effort``) are applied to every item.
19
+ """
20
+ return [
21
+ BatchItem(id=str(item_id), request_kwargs={"input": messages, **shared_request_kwargs})
22
+ for item_id, messages in queries.items()
23
+ ]
24
+
25
+
26
+ def _error_response(item_id: str, model: str, provider: str, error: Exception) -> NormalizedResponse:
27
+ return NormalizedResponse(
28
+ id=item_id,
29
+ model=model,
30
+ provider=provider,
31
+ output_text=None,
32
+ raw_response=None,
33
+ usage=UsageBreakdown.empty(),
34
+ billing=None,
35
+ cost=None,
36
+ timing=TimingBreakdown(started_at=0.0, completed_at=0.0, latency_seconds=0.0),
37
+ error=f"{type(error).__name__}: {error}",
38
+ )
39
+
40
+
41
+ class _SubmitRateLimiter:
42
+ def __init__(
43
+ self,
44
+ *,
45
+ max_rpm: int | None,
46
+ max_tpm: int | None,
47
+ window_seconds: float = 60.0,
48
+ time_fn: Callable[[], float] = time.time,
49
+ sleep_fn: Callable[[float], None] = time.sleep,
50
+ ) -> None:
51
+ if max_rpm is not None and max_rpm < 1:
52
+ raise ValueError("max_rpm must be at least 1")
53
+ if max_tpm is not None and max_tpm < 1:
54
+ raise ValueError("max_tpm must be at least 1")
55
+ self.max_rpm = max_rpm
56
+ self.max_tpm = max_tpm
57
+ self.window_seconds = window_seconds
58
+ self.time_fn = time_fn
59
+ self.sleep_fn = sleep_fn
60
+ self._submitted_at: deque[float] = deque()
61
+ self._completed_events: deque[tuple[float, int]] = deque()
62
+ self._inflight_estimates: deque[int] = deque()
63
+ self._inflight_estimated_tokens = 0
64
+ self._completed_requests = 0
65
+ self._completed_tokens = 0
66
+
67
+ def _prune(self, now: float) -> None:
68
+ cutoff = now - self.window_seconds
69
+ while self._submitted_at and self._submitted_at[0] <= cutoff:
70
+ self._submitted_at.popleft()
71
+ while self._completed_events and self._completed_events[0][0] <= cutoff:
72
+ self._completed_events.popleft()
73
+
74
+ def _estimated_next_tokens(self) -> int:
75
+ if self._completed_requests == 0:
76
+ return 0
77
+ return (self._completed_tokens + self._completed_requests - 1) // self._completed_requests
78
+
79
+ def _rolling_completed_tokens(self) -> int:
80
+ return sum(tokens for _, tokens in self._completed_events)
81
+
82
+ def can_submit(self) -> bool:
83
+ now = self.time_fn()
84
+ self._prune(now)
85
+ if self.max_rpm is not None and len(self._submitted_at) >= self.max_rpm:
86
+ return False
87
+ if self.max_tpm is not None:
88
+ if self._completed_requests == 0 and self._submitted_at:
89
+ return False
90
+ estimated_next = self._estimated_next_tokens()
91
+ if not self._completed_events and self._inflight_estimated_tokens == 0:
92
+ return True
93
+ if self._rolling_completed_tokens() + self._inflight_estimated_tokens + estimated_next > self.max_tpm:
94
+ return False
95
+ return True
96
+
97
+ def retry_after(self) -> float | None:
98
+ now = self.time_fn()
99
+ self._prune(now)
100
+ waits: list[float] = []
101
+ if self.max_rpm is not None and len(self._submitted_at) >= self.max_rpm:
102
+ waits.append(self._submitted_at[0] + self.window_seconds - now)
103
+ if self.max_tpm is not None and self._completed_events:
104
+ projected = (
105
+ self._rolling_completed_tokens() + self._inflight_estimated_tokens + self._estimated_next_tokens()
106
+ )
107
+ if projected > self.max_tpm:
108
+ waits.append(self._completed_events[0][0] + self.window_seconds - now)
109
+ if waits:
110
+ return max(min(waits), 0.0)
111
+ if not self.can_submit():
112
+ return None
113
+ return 0.0
114
+
115
+ def wait_until_allowed(self) -> None:
116
+ while not self.can_submit():
117
+ self.sleep_fn(self.retry_after() or 0.01)
118
+
119
+ def record_submit(self) -> None:
120
+ now = self.time_fn()
121
+ self._prune(now)
122
+ self._submitted_at.append(now)
123
+ estimated_tokens = self._estimated_next_tokens() if self.max_tpm is not None else 0
124
+ self._inflight_estimates.append(estimated_tokens)
125
+ self._inflight_estimated_tokens += estimated_tokens
126
+
127
+ def record_completion(self, response: NormalizedResponse) -> None:
128
+ now = self.time_fn()
129
+ self._prune(now)
130
+ if self._inflight_estimates:
131
+ self._inflight_estimated_tokens -= self._inflight_estimates.popleft()
132
+ total_tokens = response.usage.total_tokens or (response.usage.input_tokens + response.usage.output_tokens)
133
+ self._completed_events.append((now, total_tokens))
134
+ self._completed_requests += 1
135
+ self._completed_tokens += total_tokens
136
+
137
+
138
+ class BatchExecutor:
139
+ """Thread-based batch runner for :class:`~tokenrail.client.RailClient` requests.
140
+
141
+ Submits items to a thread pool while honoring optional client-side
142
+ ``max_rpm`` / ``max_tpm`` submit limits, writes each result to the
143
+ configured sinks, and records metrics on the monitor. Items whose ids are
144
+ already present in the first sink are skipped, which makes re-runs
145
+ resumable. Request errors are captured as error responses rather than
146
+ raised, so a single failing item does not abort the batch.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ *,
152
+ client: Any,
153
+ max_workers: int = 20,
154
+ max_rpm: int | None = None,
155
+ max_tpm: int | None = None,
156
+ sinks: Sequence[ResultSink] | None = None,
157
+ monitor: RollingMetricsMonitor | None = None,
158
+ ) -> None:
159
+ self.client = client
160
+ self.max_workers = max_workers
161
+ self.max_rpm = max_rpm
162
+ self.max_tpm = max_tpm
163
+ self.sinks = list(sinks or [])
164
+ self.monitor = monitor or RollingMetricsMonitor()
165
+ self._time_fn = time.time
166
+ self._sleep_fn = time.sleep
167
+
168
+ def _save(self, response: NormalizedResponse) -> None:
169
+ for sink in self.sinks:
170
+ sink.save(response)
171
+
172
+ def _load_done_ids(self) -> set[str]:
173
+ if not self.sinks:
174
+ return set()
175
+ return self.sinks[0].load_done_ids()
176
+
177
+ def _prepare_items(self, items: Sequence[BatchItem] | dict[str, Any]) -> list[BatchItem]:
178
+ if isinstance(items, dict):
179
+ return batch_items_from_queries(items)
180
+ return [BatchItem(id=str(item.id), request_kwargs=dict(item.request_kwargs)) for item in items]
181
+
182
+ def _request_kwargs(self, item: BatchItem) -> dict[str, Any]:
183
+ request_kwargs = dict(item.request_kwargs)
184
+ request_kwargs.setdefault("request_id", item.id)
185
+ return request_kwargs
186
+
187
+ def _call_single(self, item: BatchItem) -> NormalizedResponse:
188
+ request_kwargs = self._request_kwargs(item)
189
+ try:
190
+ return self.client.responses.create(**request_kwargs)
191
+ except Exception as exc:
192
+ model = str(request_kwargs.get("model") or getattr(self.client.provider, "model_id", "unknown"))
193
+ return _error_response(item.id, model=model, provider=self.client.provider.name, error=exc)
194
+
195
+ def _run_threaded(self, items: list[BatchItem]) -> None:
196
+ limiter = _SubmitRateLimiter(
197
+ max_rpm=self.max_rpm,
198
+ max_tpm=self.max_tpm,
199
+ time_fn=self._time_fn,
200
+ sleep_fn=self._sleep_fn,
201
+ )
202
+ next_item = 0
203
+ pending: set[Future[NormalizedResponse]] = set()
204
+
205
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
206
+ while next_item < len(items) or pending:
207
+ while next_item < len(items) and len(pending) < self.max_workers and limiter.can_submit():
208
+ limiter.record_submit()
209
+ pending.add(executor.submit(self._call_single, items[next_item]))
210
+ next_item += 1
211
+
212
+ if not pending:
213
+ if next_item < len(items):
214
+ limiter.wait_until_allowed()
215
+ continue
216
+
217
+ timeout = None
218
+ if next_item < len(items) and len(pending) < self.max_workers and not limiter.can_submit():
219
+ timeout = limiter.retry_after()
220
+ done, pending = wait(pending, timeout=timeout, return_when=FIRST_COMPLETED)
221
+ if not done:
222
+ continue
223
+
224
+ for future in done:
225
+ response = future.result()
226
+ limiter.record_completion(response)
227
+ self._save(response)
228
+ self.monitor.record(response)
229
+
230
+ def run(self, items: Sequence[BatchItem] | dict[str, Any]) -> StatsSnapshot:
231
+ """Execute ``items`` (a sequence of :class:`BatchItem` or an ``{id: input}``
232
+ dict) and return the final :class:`~tokenrail.types.StatsSnapshot`."""
233
+ self.monitor.reset()
234
+ normalized_items = self._prepare_items(items)
235
+ done_ids = self._load_done_ids()
236
+ todo = [item for item in normalized_items if item.id not in done_ids]
237
+ skipped = len(normalized_items) - len(todo)
238
+ self.monitor.start(
239
+ total_requests=len(normalized_items),
240
+ todo_requests=len(todo),
241
+ skipped_requests=skipped,
242
+ )
243
+
244
+ self._run_threaded(todo)
245
+
246
+ return self.monitor.finalize(total_requests=len(normalized_items), skipped_requests=skipped)