dataforge-07 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataforge/__init__.py +204 -0
- dataforge/__main__.py +5 -0
- dataforge/agent/__init__.py +16 -0
- dataforge/agent/providers.py +259 -0
- dataforge/agent/scratchpad.py +183 -0
- dataforge/agent/tool_actions.py +343 -0
- dataforge/bench/__init__.py +31 -0
- dataforge/bench/core.py +426 -0
- dataforge/bench/groq_client.py +386 -0
- dataforge/bench/methods.py +443 -0
- dataforge/bench/report.py +309 -0
- dataforge/bench/runner.py +247 -0
- dataforge/causal/__init__.py +21 -0
- dataforge/causal/dag.py +174 -0
- dataforge/causal/pc.py +232 -0
- dataforge/causal/root_cause.py +193 -0
- dataforge/cli/__init__.py +50 -0
- dataforge/cli/audit.py +70 -0
- dataforge/cli/bench.py +154 -0
- dataforge/cli/common.py +267 -0
- dataforge/cli/constraints.py +407 -0
- dataforge/cli/profile.py +147 -0
- dataforge/cli/release.py +166 -0
- dataforge/cli/repair.py +407 -0
- dataforge/cli/revert.py +139 -0
- dataforge/cli/watch.py +144 -0
- dataforge/datasets/__init__.py +25 -0
- dataforge/datasets/embedded/hospital/clean.csv +11 -0
- dataforge/datasets/embedded/hospital/dirty.csv +11 -0
- dataforge/datasets/real_world.py +290 -0
- dataforge/datasets/registry.py +103 -0
- dataforge/detectors/__init__.py +80 -0
- dataforge/detectors/base.py +145 -0
- dataforge/detectors/decimal_shift.py +166 -0
- dataforge/detectors/fd_violation.py +157 -0
- dataforge/detectors/type_mismatch.py +173 -0
- dataforge/engine/__init__.py +39 -0
- dataforge/engine/repair.py +905 -0
- dataforge/env/__init__.py +22 -0
- dataforge/env/environment.py +883 -0
- dataforge/env/observation.py +61 -0
- dataforge/env/openenv_core.py +161 -0
- dataforge/env/reward.py +128 -0
- dataforge/env/server.py +176 -0
- dataforge/evaluation_contract.py +76 -0
- dataforge/fixtures/hospital_10rows.csv +11 -0
- dataforge/fixtures/hospital_schema.yaml +17 -0
- dataforge/http/__init__.py +1 -0
- dataforge/http/problem.py +103 -0
- dataforge/integrations/__init__.py +1 -0
- dataforge/integrations/dbt.py +164 -0
- dataforge/observability.py +76 -0
- dataforge/py.typed +1 -0
- dataforge/release/__init__.py +1 -0
- dataforge/release/doctor.py +367 -0
- dataforge/release/full_vision.py +702 -0
- dataforge/release/gate.py +861 -0
- dataforge/release/playground_check.py +411 -0
- dataforge/repair_contract.py +468 -0
- dataforge/repairers/__init__.py +88 -0
- dataforge/repairers/base.py +77 -0
- dataforge/repairers/decimal_shift.py +43 -0
- dataforge/repairers/fd_violation.py +225 -0
- dataforge/repairers/type_mismatch.py +73 -0
- dataforge/safety/__init__.py +5 -0
- dataforge/safety/adversarial/attack_01_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_02_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_03_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_04_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_05_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_06_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_07_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_08_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_09_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_10_phone_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_11_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_12_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_13_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_14_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_15_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_16_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_17_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_18_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_19_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_20_ssn_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_21_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_22_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_23_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_24_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_25_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_26_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_27_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_28_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_29_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_30_email_pii.yaml +8 -0
- dataforge/safety/adversarial/attack_31_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_32_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_33_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_34_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_35_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_36_row_delete.yaml +11 -0
- dataforge/safety/adversarial/attack_37_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_38_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_39_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_40_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_41_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_42_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_43_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_44_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_45_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_46_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_47_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_48_row_delete.yaml +7 -0
- dataforge/safety/adversarial/attack_49_row_delete.yaml +8 -0
- dataforge/safety/adversarial/attack_50_row_delete.yaml +7 -0
- dataforge/safety/constitution.py +307 -0
- dataforge/safety/constitutions/default.yaml +40 -0
- dataforge/safety/filter.py +134 -0
- dataforge/schema_inference.py +620 -0
- dataforge/stores/__init__.py +46 -0
- dataforge/stores/base.py +73 -0
- dataforge/stores/cloud.py +78 -0
- dataforge/stores/csv.py +94 -0
- dataforge/stores/duckdb.py +313 -0
- dataforge/stores/patch_plan.py +178 -0
- dataforge/stores/registry.py +82 -0
- dataforge/stores/repair.py +121 -0
- dataforge/stores/revert.py +22 -0
- dataforge/stores/sql.py +27 -0
- dataforge/table.py +228 -0
- dataforge/transactions/__init__.py +34 -0
- dataforge/transactions/files.py +96 -0
- dataforge/transactions/log.py +613 -0
- dataforge/transactions/revert.py +102 -0
- dataforge/transactions/txn.py +104 -0
- dataforge/ui/__init__.py +1 -0
- dataforge/ui/profile_view.py +136 -0
- dataforge/ui/repair_diff.py +91 -0
- dataforge/verifier/__init__.py +55 -0
- dataforge/verifier/constraint_ir.py +155 -0
- dataforge/verifier/explain.py +47 -0
- dataforge/verifier/gate.py +5 -0
- dataforge/verifier/schema.py +111 -0
- dataforge/verifier/smt.py +433 -0
- dataforge_07-0.1.0.dist-info/METADATA +436 -0
- dataforge_07-0.1.0.dist-info/RECORD +150 -0
- dataforge_07-0.1.0.dist-info/WHEEL +5 -0
- dataforge_07-0.1.0.dist-info/entry_points.txt +3 -0
- dataforge_07-0.1.0.dist-info/licenses/LICENSE +176 -0
- dataforge_07-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
"""Minimal OpenAI-compatible clients for benchmark-only LLM baselines."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import cast
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ProviderRequestError(RuntimeError):
|
|
15
|
+
"""Raised when a provider rejects a benchmark request payload."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProviderRateLimitError(ProviderRequestError):
|
|
19
|
+
"""Raised when a provider asks us to wait longer than the configured cap."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _is_rate_limit_error(exc: BaseException) -> bool:
|
|
23
|
+
"""Return whether an exception is an HTTP 429 response."""
|
|
24
|
+
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _is_retryable_provider_error(exc: BaseException) -> bool:
|
|
28
|
+
"""Return whether an HTTP error is worth retrying for teacher collection."""
|
|
29
|
+
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float:
|
|
33
|
+
"""Return provider retry-after delay when present."""
|
|
34
|
+
raw_retry_after = exc.response.headers.get("retry-after")
|
|
35
|
+
if raw_retry_after is None:
|
|
36
|
+
return fallback_s
|
|
37
|
+
try:
|
|
38
|
+
return max(float(raw_retry_after), fallback_s)
|
|
39
|
+
except ValueError:
|
|
40
|
+
return fallback_s
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True, kw_only=True)
|
|
44
|
+
class GroqCompletion:
|
|
45
|
+
"""Completion payload plus conservative usage accounting."""
|
|
46
|
+
|
|
47
|
+
text: str
|
|
48
|
+
prompt_tokens: int
|
|
49
|
+
completion_tokens: int
|
|
50
|
+
warnings: tuple[str, ...]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class OpenAICompatBenchClient:
|
|
54
|
+
"""Sequential OpenAI-compatible client with fixed 429 retry and spacing."""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
*,
|
|
59
|
+
api_key: str,
|
|
60
|
+
model: str,
|
|
61
|
+
endpoint: str,
|
|
62
|
+
provider: str,
|
|
63
|
+
min_interval_s: float = 2.0,
|
|
64
|
+
max_tokens: int = 512,
|
|
65
|
+
max_retries: int = 5,
|
|
66
|
+
max_retry_after_s: float = 120.0,
|
|
67
|
+
timeout_s: float = 60.0,
|
|
68
|
+
) -> None:
|
|
69
|
+
self._api_key = api_key
|
|
70
|
+
self._model = model
|
|
71
|
+
self._endpoint = endpoint
|
|
72
|
+
self._provider = provider
|
|
73
|
+
self._min_interval_s = min_interval_s
|
|
74
|
+
self._max_tokens = max_tokens
|
|
75
|
+
self._max_retries = max_retries
|
|
76
|
+
self._max_retry_after_s = max_retry_after_s
|
|
77
|
+
self._timeout_s = timeout_s
|
|
78
|
+
self._last_success_at: float | None = None
|
|
79
|
+
self._client = httpx.Client(
|
|
80
|
+
timeout=self._timeout_s,
|
|
81
|
+
headers={
|
|
82
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
83
|
+
"Content-Type": "application/json",
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def model(self) -> str:
|
|
89
|
+
"""Return the configured provider model name."""
|
|
90
|
+
return self._model
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def provider(self) -> str:
|
|
94
|
+
"""Return the configured provider identifier."""
|
|
95
|
+
return self._provider
|
|
96
|
+
|
|
97
|
+
def _respect_spacing(self) -> None:
|
|
98
|
+
"""Sleep long enough to keep requests sequential with a fixed gap."""
|
|
99
|
+
if self._last_success_at is None:
|
|
100
|
+
return
|
|
101
|
+
elapsed = time.monotonic() - self._last_success_at
|
|
102
|
+
remaining = self._min_interval_s - elapsed
|
|
103
|
+
if remaining > 0:
|
|
104
|
+
time.sleep(remaining)
|
|
105
|
+
|
|
106
|
+
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
|
107
|
+
"""Issue the underlying chat-completions request."""
|
|
108
|
+
payload = {
|
|
109
|
+
"model": self._model,
|
|
110
|
+
"messages": messages,
|
|
111
|
+
"temperature": 0.0,
|
|
112
|
+
"max_tokens": self._max_tokens,
|
|
113
|
+
}
|
|
114
|
+
last_rate_limit_error: httpx.HTTPStatusError | None = None
|
|
115
|
+
for attempt in range(self._max_retries):
|
|
116
|
+
response: httpx.Response | None = None
|
|
117
|
+
try:
|
|
118
|
+
response = self._client.post(
|
|
119
|
+
self._endpoint,
|
|
120
|
+
json=payload,
|
|
121
|
+
)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
except httpx.HTTPStatusError as exc:
|
|
124
|
+
if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
|
|
125
|
+
body = exc.response.text[:500].replace("\n", " ")
|
|
126
|
+
raise ProviderRequestError(
|
|
127
|
+
f"{self._provider} request rejected with HTTP "
|
|
128
|
+
f"{exc.response.status_code}: {body}"
|
|
129
|
+
) from exc
|
|
130
|
+
last_rate_limit_error = exc
|
|
131
|
+
retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
|
|
132
|
+
if retry_s > self._max_retry_after_s:
|
|
133
|
+
body = exc.response.text[:500].replace("\n", " ")
|
|
134
|
+
raise ProviderRateLimitError(
|
|
135
|
+
f"{self._provider} rate limit retry-after {retry_s:.2f}s "
|
|
136
|
+
f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
|
|
137
|
+
) from exc
|
|
138
|
+
logging.getLogger("dataforge.bench.groq_client").warning(
|
|
139
|
+
"%s_rate_limit attempt=%d retry_after_s=%.2f",
|
|
140
|
+
self._provider,
|
|
141
|
+
attempt + 1,
|
|
142
|
+
retry_s,
|
|
143
|
+
)
|
|
144
|
+
time.sleep(retry_s)
|
|
145
|
+
continue
|
|
146
|
+
except httpx.TimeoutException as exc:
|
|
147
|
+
raise TimeoutError(
|
|
148
|
+
f"{self._provider} request timed out after {self._timeout_s:.1f} seconds."
|
|
149
|
+
) from exc
|
|
150
|
+
return dict(response.json())
|
|
151
|
+
if last_rate_limit_error is not None:
|
|
152
|
+
raise last_rate_limit_error
|
|
153
|
+
raise RuntimeError(f"{self._provider} request failed without a response.")
|
|
154
|
+
|
|
155
|
+
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
|
|
156
|
+
"""Send one benchmark completion request to the configured provider."""
|
|
157
|
+
self._respect_spacing()
|
|
158
|
+
payload = self._post(messages)
|
|
159
|
+
self._last_success_at = time.monotonic()
|
|
160
|
+
|
|
161
|
+
warnings: list[str] = []
|
|
162
|
+
usage = payload.get("usage", {})
|
|
163
|
+
prompt_tokens = int(usage.get("prompt_tokens", 0)) if isinstance(usage, dict) else 0
|
|
164
|
+
completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0
|
|
165
|
+
if not usage:
|
|
166
|
+
warnings.append("missing_usage_payload")
|
|
167
|
+
logging.getLogger("dataforge.bench.groq_client").warning(
|
|
168
|
+
"%s_missing_usage_payload", self._provider
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
choices = cast(list[dict[str, object]], payload["choices"])
|
|
173
|
+
message = cast(dict[str, object], choices[0]["message"])
|
|
174
|
+
content = str(message["content"])
|
|
175
|
+
except (KeyError, IndexError, TypeError) as exc:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f"Unexpected {self._provider} response payload: {json.dumps(payload)}"
|
|
178
|
+
) from exc
|
|
179
|
+
return GroqCompletion(
|
|
180
|
+
text=content,
|
|
181
|
+
prompt_tokens=prompt_tokens,
|
|
182
|
+
completion_tokens=completion_tokens,
|
|
183
|
+
warnings=tuple(warnings),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class GroqBenchClient(OpenAICompatBenchClient):
|
|
188
|
+
"""Sequential Groq client with fixed 429 retry and spacing."""
|
|
189
|
+
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
*,
|
|
193
|
+
api_key: str,
|
|
194
|
+
model: str = "llama-3.3-70b-versatile",
|
|
195
|
+
min_interval_s: float = 2.0,
|
|
196
|
+
max_tokens: int = 512,
|
|
197
|
+
max_retries: int = 5,
|
|
198
|
+
max_retry_after_s: float = 120.0,
|
|
199
|
+
timeout_s: float = 60.0,
|
|
200
|
+
) -> None:
|
|
201
|
+
super().__init__(
|
|
202
|
+
api_key=api_key,
|
|
203
|
+
model=model,
|
|
204
|
+
endpoint="https://api.groq.com/openai/v1/chat/completions",
|
|
205
|
+
provider="groq",
|
|
206
|
+
min_interval_s=min_interval_s,
|
|
207
|
+
max_tokens=max_tokens,
|
|
208
|
+
max_retries=max_retries,
|
|
209
|
+
max_retry_after_s=max_retry_after_s,
|
|
210
|
+
timeout_s=timeout_s,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class CerebrasBenchClient(OpenAICompatBenchClient):
|
|
215
|
+
"""Sequential Cerebras client with fixed 429 retry and spacing."""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
*,
|
|
220
|
+
api_key: str,
|
|
221
|
+
model: str = "qwen-3-235b-a22b-instruct-2507",
|
|
222
|
+
min_interval_s: float = 0.5,
|
|
223
|
+
max_tokens: int = 512,
|
|
224
|
+
max_retries: int = 5,
|
|
225
|
+
max_retry_after_s: float = 120.0,
|
|
226
|
+
timeout_s: float = 60.0,
|
|
227
|
+
) -> None:
|
|
228
|
+
super().__init__(
|
|
229
|
+
api_key=api_key,
|
|
230
|
+
model=model,
|
|
231
|
+
endpoint="https://api.cerebras.ai/v1/chat/completions",
|
|
232
|
+
provider="cerebras",
|
|
233
|
+
min_interval_s=min_interval_s,
|
|
234
|
+
max_tokens=max_tokens,
|
|
235
|
+
max_retries=max_retries,
|
|
236
|
+
max_retry_after_s=max_retry_after_s,
|
|
237
|
+
timeout_s=timeout_s,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GeminiBenchClient:
|
|
242
|
+
"""Sequential Gemini client adapted to the benchmark completion interface."""
|
|
243
|
+
|
|
244
|
+
def __init__(
|
|
245
|
+
self,
|
|
246
|
+
*,
|
|
247
|
+
api_key: str,
|
|
248
|
+
model: str = "gemini-3.1-pro-preview",
|
|
249
|
+
min_interval_s: float = 2.0,
|
|
250
|
+
max_tokens: int = 512,
|
|
251
|
+
max_retries: int = 5,
|
|
252
|
+
max_retry_after_s: float = 120.0,
|
|
253
|
+
timeout_s: float = 60.0,
|
|
254
|
+
) -> None:
|
|
255
|
+
self._api_key = api_key
|
|
256
|
+
self._model = model.removeprefix("models/")
|
|
257
|
+
self._min_interval_s = min_interval_s
|
|
258
|
+
self._max_tokens = max_tokens
|
|
259
|
+
self._max_retries = max_retries
|
|
260
|
+
self._max_retry_after_s = max_retry_after_s
|
|
261
|
+
self._timeout_s = timeout_s
|
|
262
|
+
self._last_success_at: float | None = None
|
|
263
|
+
self._client = httpx.Client(
|
|
264
|
+
timeout=self._timeout_s,
|
|
265
|
+
headers={"Content-Type": "application/json"},
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def model(self) -> str:
|
|
270
|
+
"""Return the configured Gemini model name."""
|
|
271
|
+
return self._model
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def provider(self) -> str:
|
|
275
|
+
"""Return the provider identifier."""
|
|
276
|
+
return "gemini"
|
|
277
|
+
|
|
278
|
+
def _respect_spacing(self) -> None:
|
|
279
|
+
"""Sleep long enough to keep requests sequential with a fixed gap."""
|
|
280
|
+
if self._last_success_at is None:
|
|
281
|
+
return
|
|
282
|
+
elapsed = time.monotonic() - self._last_success_at
|
|
283
|
+
remaining = self._min_interval_s - elapsed
|
|
284
|
+
if remaining > 0:
|
|
285
|
+
time.sleep(remaining)
|
|
286
|
+
|
|
287
|
+
def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
|
288
|
+
"""Convert OpenAI-style chat messages to Gemini generateContent payload."""
|
|
289
|
+
system_texts: list[str] = []
|
|
290
|
+
contents: list[dict[str, object]] = []
|
|
291
|
+
for message in messages:
|
|
292
|
+
role = message.get("role", "user")
|
|
293
|
+
content = message.get("content", "")
|
|
294
|
+
if role == "system":
|
|
295
|
+
system_texts.append(content)
|
|
296
|
+
continue
|
|
297
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
298
|
+
contents.append({"role": gemini_role, "parts": [{"text": content}]})
|
|
299
|
+
|
|
300
|
+
payload: dict[str, object] = {
|
|
301
|
+
"contents": contents,
|
|
302
|
+
"generationConfig": {
|
|
303
|
+
"temperature": 0.0,
|
|
304
|
+
"maxOutputTokens": self._max_tokens,
|
|
305
|
+
},
|
|
306
|
+
}
|
|
307
|
+
if system_texts:
|
|
308
|
+
payload["systemInstruction"] = {
|
|
309
|
+
"parts": [{"text": "\n\n".join(system_texts)}],
|
|
310
|
+
}
|
|
311
|
+
return payload
|
|
312
|
+
|
|
313
|
+
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
|
314
|
+
"""Issue the underlying Gemini generateContent request."""
|
|
315
|
+
endpoint = (
|
|
316
|
+
f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent"
|
|
317
|
+
)
|
|
318
|
+
last_rate_limit_error: httpx.HTTPStatusError | None = None
|
|
319
|
+
for attempt in range(self._max_retries):
|
|
320
|
+
response: httpx.Response | None = None
|
|
321
|
+
try:
|
|
322
|
+
response = self._client.post(
|
|
323
|
+
endpoint,
|
|
324
|
+
params={"key": self._api_key},
|
|
325
|
+
json=self._payload(messages),
|
|
326
|
+
)
|
|
327
|
+
response.raise_for_status()
|
|
328
|
+
except httpx.HTTPStatusError as exc:
|
|
329
|
+
if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
|
|
330
|
+
body = exc.response.text[:500].replace("\n", " ")
|
|
331
|
+
raise ProviderRequestError(
|
|
332
|
+
f"gemini request rejected with HTTP {exc.response.status_code}: {body}"
|
|
333
|
+
) from exc
|
|
334
|
+
last_rate_limit_error = exc
|
|
335
|
+
retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
|
|
336
|
+
if retry_s > self._max_retry_after_s:
|
|
337
|
+
body = exc.response.text[:500].replace("\n", " ")
|
|
338
|
+
raise ProviderRateLimitError(
|
|
339
|
+
f"gemini rate limit retry-after {retry_s:.2f}s "
|
|
340
|
+
f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
|
|
341
|
+
) from exc
|
|
342
|
+
logging.getLogger("dataforge.bench.groq_client").warning(
|
|
343
|
+
"gemini_rate_limit attempt=%d retry_after_s=%.2f",
|
|
344
|
+
attempt + 1,
|
|
345
|
+
retry_s,
|
|
346
|
+
)
|
|
347
|
+
time.sleep(retry_s)
|
|
348
|
+
continue
|
|
349
|
+
except httpx.TimeoutException as exc:
|
|
350
|
+
raise TimeoutError(
|
|
351
|
+
f"gemini request timed out after {self._timeout_s:.1f} seconds."
|
|
352
|
+
) from exc
|
|
353
|
+
return dict(response.json())
|
|
354
|
+
if last_rate_limit_error is not None:
|
|
355
|
+
raise last_rate_limit_error
|
|
356
|
+
raise RuntimeError("gemini request failed without a response.")
|
|
357
|
+
|
|
358
|
+
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
|
|
359
|
+
"""Send one benchmark completion request to Gemini."""
|
|
360
|
+
self._respect_spacing()
|
|
361
|
+
payload = self._post(messages)
|
|
362
|
+
self._last_success_at = time.monotonic()
|
|
363
|
+
|
|
364
|
+
warnings: list[str] = []
|
|
365
|
+
usage = payload.get("usageMetadata", {})
|
|
366
|
+
prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0
|
|
367
|
+
completion_tokens = (
|
|
368
|
+
int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0
|
|
369
|
+
)
|
|
370
|
+
if not usage:
|
|
371
|
+
warnings.append("missing_usage_payload")
|
|
372
|
+
logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload")
|
|
373
|
+
|
|
374
|
+
try:
|
|
375
|
+
candidates = cast(list[dict[str, object]], payload["candidates"])
|
|
376
|
+
content = cast(dict[str, object], candidates[0]["content"])
|
|
377
|
+
parts = cast(list[dict[str, object]], content["parts"])
|
|
378
|
+
text = "".join(str(part.get("text", "")) for part in parts)
|
|
379
|
+
except (KeyError, IndexError, TypeError) as exc:
|
|
380
|
+
raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc
|
|
381
|
+
return GroqCompletion(
|
|
382
|
+
text=text,
|
|
383
|
+
prompt_tokens=prompt_tokens,
|
|
384
|
+
completion_tokens=completion_tokens,
|
|
385
|
+
warnings=tuple(warnings),
|
|
386
|
+
)
|