synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +291 -56
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/RECORD +106 -13
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/top_level.txt +1 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OpenAIClient:
|
|
13
|
+
"""Async HTTP client for OpenAI-compatible inference servers (vLLM)."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
base_url: str,
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
timeout_s: float = 120.0,
|
|
20
|
+
) -> None:
|
|
21
|
+
self.base_url = base_url.rstrip("/")
|
|
22
|
+
self.api_key = api_key
|
|
23
|
+
self.timeout_s = timeout_s
|
|
24
|
+
self.headers = {}
|
|
25
|
+
|
|
26
|
+
if api_key:
|
|
27
|
+
self.headers["Authorization"] = f"Bearer {api_key}"
|
|
28
|
+
|
|
29
|
+
def _fix_model_parameters(self, request: Dict[str, Any], target_url: Optional[str] = None) -> Dict[str, Any]:
|
|
30
|
+
"""
|
|
31
|
+
Fix parameter compatibility for newer OpenAI models.
|
|
32
|
+
|
|
33
|
+
Newer models like gpt-5-nano use 'max_completion_tokens' instead of 'max_tokens'.
|
|
34
|
+
"""
|
|
35
|
+
if not request:
|
|
36
|
+
return request
|
|
37
|
+
|
|
38
|
+
# Make a copy to avoid modifying the original
|
|
39
|
+
fixed_request = request.copy()
|
|
40
|
+
|
|
41
|
+
# Determine if target is OpenAI-compatible (OpenAI, Azure OpenAI, Groq);
|
|
42
|
+
# strip fields those endpoints don't accept
|
|
43
|
+
is_openai = False
|
|
44
|
+
try:
|
|
45
|
+
if isinstance(target_url, str):
|
|
46
|
+
low = target_url.lower()
|
|
47
|
+
is_openai = (
|
|
48
|
+
("openai.com" in low)
|
|
49
|
+
or ("azure" in low and ".openai." in low)
|
|
50
|
+
or ("groq.com" in low)
|
|
51
|
+
or ("/openai" in low)
|
|
52
|
+
)
|
|
53
|
+
except Exception:
|
|
54
|
+
is_openai = False
|
|
55
|
+
|
|
56
|
+
model = fixed_request.get("model", "")
|
|
57
|
+
|
|
58
|
+
if is_openai:
|
|
59
|
+
# Remove fields OpenAI/Groq don't accept
|
|
60
|
+
for k in (
|
|
61
|
+
"stop_after_tool_calls",
|
|
62
|
+
"thinking_mode",
|
|
63
|
+
"thinking_budget",
|
|
64
|
+
"reasoning",
|
|
65
|
+
"extra_body",
|
|
66
|
+
"parallel_tool_calls",
|
|
67
|
+
"function_call",
|
|
68
|
+
):
|
|
69
|
+
if k in fixed_request:
|
|
70
|
+
fixed_request.pop(k, None)
|
|
71
|
+
|
|
72
|
+
# GPT-5 family specifics
|
|
73
|
+
if "gpt-5" in model or "gpt-4.1" in model:
|
|
74
|
+
# Convert max_tokens to max_completion_tokens for newer models
|
|
75
|
+
if "max_tokens" in fixed_request:
|
|
76
|
+
if "max_completion_tokens" not in fixed_request:
|
|
77
|
+
fixed_request["max_completion_tokens"] = fixed_request.pop("max_tokens")
|
|
78
|
+
logger.info(f"Converted max_tokens to max_completion_tokens for model {model}")
|
|
79
|
+
else:
|
|
80
|
+
fixed_request.pop("max_tokens")
|
|
81
|
+
logger.info(f"Removed conflicting max_tokens parameter for model {model}")
|
|
82
|
+
# Some OpenAI endpoints ignore/deny sampling fields for reasoning models
|
|
83
|
+
for k in ("temperature", "top_p"):
|
|
84
|
+
if k in fixed_request:
|
|
85
|
+
fixed_request.pop(k, None)
|
|
86
|
+
# If tools are present, force single tool choice to our function
|
|
87
|
+
try:
|
|
88
|
+
tools = fixed_request.get("tools")
|
|
89
|
+
if isinstance(tools, list) and tools:
|
|
90
|
+
fixed_request["tool_choice"] = {
|
|
91
|
+
"type": "function",
|
|
92
|
+
"function": {"name": "interact_many"},
|
|
93
|
+
}
|
|
94
|
+
fixed_request["parallel_tool_calls"] = False
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
return fixed_request
|
|
99
|
+
|
|
100
|
+
async def generate(
|
|
101
|
+
self,
|
|
102
|
+
request: Dict[str, Any],
|
|
103
|
+
base_url: Optional[str] = None,
|
|
104
|
+
timeout_s: Optional[float] = None,
|
|
105
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
106
|
+
) -> Dict[str, Any]:
|
|
107
|
+
"""
|
|
108
|
+
Send a chat completion request to the inference server.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
request: OpenAI-compatible chat completion request
|
|
112
|
+
base_url: Override base URL for this request
|
|
113
|
+
timeout_s: Override timeout for this request
|
|
114
|
+
extra_headers: Additional headers to include (e.g., X-Policy-Name)
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
OpenAI-compatible chat completion response
|
|
118
|
+
"""
|
|
119
|
+
url = (base_url or self.base_url).rstrip("/") + "/v1/chat/completions"
|
|
120
|
+
timeout = timeout_s or self.timeout_s
|
|
121
|
+
|
|
122
|
+
# Merge headers
|
|
123
|
+
headers = self.headers.copy()
|
|
124
|
+
if extra_headers:
|
|
125
|
+
headers.update(extra_headers)
|
|
126
|
+
|
|
127
|
+
# Fix parameter compatibility for newer models
|
|
128
|
+
processed_request = self._fix_model_parameters(request, target_url=url)
|
|
129
|
+
|
|
130
|
+
# Log request (redact messages in production)
|
|
131
|
+
logger.info(f"Inference POST target: {url}")
|
|
132
|
+
if extra_headers:
|
|
133
|
+
logger.info(f"Extra headers: {extra_headers}")
|
|
134
|
+
try:
|
|
135
|
+
keys_preview = sorted(list(processed_request.keys()))
|
|
136
|
+
logger.info(f"Request keys: {keys_preview}")
|
|
137
|
+
except Exception:
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
# Final hard-guard for OpenAI: ensure unsupported field is not present
|
|
141
|
+
try:
|
|
142
|
+
if "openai" in url.lower():
|
|
143
|
+
if "stop_after_tool_calls" in processed_request:
|
|
144
|
+
processed_request.pop("stop_after_tool_calls", None)
|
|
145
|
+
logger.info("Removed stop_after_tool_calls for OpenAI request")
|
|
146
|
+
# Groq-specific requirement: when using JSON mode, one of the messages must contain the word 'json'
|
|
147
|
+
low_url = url.lower()
|
|
148
|
+
if ("groq.com" in low_url or "/openai" in low_url) and isinstance(processed_request, dict):
|
|
149
|
+
rf = processed_request.get("response_format")
|
|
150
|
+
rf_type = None
|
|
151
|
+
if isinstance(rf, dict):
|
|
152
|
+
rf_type = str(rf.get("type") or "").lower()
|
|
153
|
+
if rf_type in {"json_object", "json_schema"}:
|
|
154
|
+
msgs = processed_request.get("messages")
|
|
155
|
+
has_json_word = False
|
|
156
|
+
if isinstance(msgs, list):
|
|
157
|
+
for m in msgs:
|
|
158
|
+
try:
|
|
159
|
+
content = m.get("content") if isinstance(m, dict) else None
|
|
160
|
+
text = None
|
|
161
|
+
if isinstance(content, str):
|
|
162
|
+
text = content
|
|
163
|
+
elif isinstance(content, list):
|
|
164
|
+
# Join any text segments
|
|
165
|
+
parts = []
|
|
166
|
+
for seg in content:
|
|
167
|
+
if isinstance(seg, dict) and isinstance(seg.get("text"), str):
|
|
168
|
+
parts.append(seg["text"])
|
|
169
|
+
text = "\n".join(parts)
|
|
170
|
+
if isinstance(text, str) and ("json" in text.lower()):
|
|
171
|
+
has_json_word = True
|
|
172
|
+
break
|
|
173
|
+
except Exception:
|
|
174
|
+
continue
|
|
175
|
+
if not has_json_word:
|
|
176
|
+
try:
|
|
177
|
+
instruction = "Respond in strict JSON only. Output a single valid JSON object."
|
|
178
|
+
if not isinstance(msgs, list):
|
|
179
|
+
msgs = []
|
|
180
|
+
# Prepend a system message to satisfy Groq requirement without changing user intent
|
|
181
|
+
prepend = {"role": "system", "content": instruction}
|
|
182
|
+
processed_request["messages"] = [prepend] + list(msgs)
|
|
183
|
+
logger.info("Injected JSON-mode system instruction for Groq response_format compliance")
|
|
184
|
+
except Exception:
|
|
185
|
+
pass
|
|
186
|
+
except Exception:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
190
|
+
try:
|
|
191
|
+
response = await client.post(
|
|
192
|
+
url,
|
|
193
|
+
json=processed_request,
|
|
194
|
+
headers=headers,
|
|
195
|
+
)
|
|
196
|
+
response.raise_for_status()
|
|
197
|
+
|
|
198
|
+
# Rich response diagnostics
|
|
199
|
+
content_type = response.headers.get("content-type")
|
|
200
|
+
body_text = response.text
|
|
201
|
+
logger.info(
|
|
202
|
+
f"Inference response status=200, content-type={content_type}, bytes={len(body_text)}"
|
|
203
|
+
)
|
|
204
|
+
if body_text:
|
|
205
|
+
preview_len = min(800, len(body_text))
|
|
206
|
+
logger.info(f"Inference response preview ({preview_len} bytes): {body_text[:preview_len]}")
|
|
207
|
+
|
|
208
|
+
result = response.json()
|
|
209
|
+
logger.info(f"Inference response parsed_type={type(result).__name__}")
|
|
210
|
+
return result
|
|
211
|
+
|
|
212
|
+
except httpx.TimeoutException:
|
|
213
|
+
logger.error(f"Request to {url} timed out after {timeout}s")
|
|
214
|
+
raise
|
|
215
|
+
except httpx.HTTPStatusError as e:
|
|
216
|
+
status = e.response.status_code if e.response is not None else None
|
|
217
|
+
text = e.response.text if e.response is not None else str(e)
|
|
218
|
+
# Log full body for debugging remote failures
|
|
219
|
+
try:
|
|
220
|
+
logger.error({
|
|
221
|
+
"openai_http_error": True,
|
|
222
|
+
"status": status,
|
|
223
|
+
"url": url,
|
|
224
|
+
"body": text,
|
|
225
|
+
})
|
|
226
|
+
except Exception:
|
|
227
|
+
logger.error(f"HTTP error from {url}: {status} - {text}")
|
|
228
|
+
# For 4xx/5xx, print full sanitized request to aid debugging (especially Groq 400s)
|
|
229
|
+
try:
|
|
230
|
+
redacted_headers = dict(headers)
|
|
231
|
+
if "Authorization" in redacted_headers:
|
|
232
|
+
redacted_headers["Authorization"] = "***REDACTED***"
|
|
233
|
+
logger.error({
|
|
234
|
+
"request_debug": True,
|
|
235
|
+
"status": status,
|
|
236
|
+
"target": url,
|
|
237
|
+
"headers": redacted_headers,
|
|
238
|
+
"payload": processed_request,
|
|
239
|
+
})
|
|
240
|
+
except Exception:
|
|
241
|
+
pass
|
|
242
|
+
# Special case: token budget exceeded (OpenAI-compatible error schema)
|
|
243
|
+
try:
|
|
244
|
+
if status == 400 and e.response is not None:
|
|
245
|
+
data = e.response.json()
|
|
246
|
+
detail = data.get("detail") if isinstance(data, dict) else None
|
|
247
|
+
err_code = (detail or {}).get("error") if isinstance(detail, dict) else None
|
|
248
|
+
if err_code == "token_budget_exceeded":
|
|
249
|
+
info = (detail or {}).get("details") or {}
|
|
250
|
+
messages_tokens = int(info.get("messages_tokens") or 0)
|
|
251
|
+
model_limit = int(info.get("model_limit") or 0)
|
|
252
|
+
safety = 64
|
|
253
|
+
# Compute a conservative new max_tokens
|
|
254
|
+
new_max = max(16, model_limit - messages_tokens - safety)
|
|
255
|
+
try:
|
|
256
|
+
# Update request and retry once immediately with smaller budget
|
|
257
|
+
if isinstance(processed_request, dict):
|
|
258
|
+
processed_request = dict(processed_request)
|
|
259
|
+
if "max_completion_tokens" in processed_request:
|
|
260
|
+
processed_request["max_completion_tokens"] = new_max
|
|
261
|
+
processed_request.pop("max_tokens", None)
|
|
262
|
+
else:
|
|
263
|
+
processed_request["max_tokens"] = new_max
|
|
264
|
+
# Remove optional fields that some servers reject
|
|
265
|
+
for k in ("thinking_mode", "thinking_budget", "reasoning"):
|
|
266
|
+
processed_request.pop(k, None)
|
|
267
|
+
# Force structured tool choice
|
|
268
|
+
if processed_request.get("tool_choice") == "required":
|
|
269
|
+
func_name = "interact_many"
|
|
270
|
+
try:
|
|
271
|
+
tools_arr = processed_request.get("tools") or []
|
|
272
|
+
if isinstance(tools_arr, list) and tools_arr:
|
|
273
|
+
f = tools_arr[0].get("function") if isinstance(tools_arr[0], dict) else None
|
|
274
|
+
cand = (f or {}).get("name") if isinstance(f, dict) else None
|
|
275
|
+
if isinstance(cand, str) and cand:
|
|
276
|
+
func_name = cand
|
|
277
|
+
except Exception:
|
|
278
|
+
pass
|
|
279
|
+
processed_request["tool_choice"] = {"type": "function", "function": {"name": func_name}}
|
|
280
|
+
processed_request["parallel_tool_calls"] = False
|
|
281
|
+
logger.warning({
|
|
282
|
+
"token_budget_recovery": True,
|
|
283
|
+
"messages_tokens": messages_tokens,
|
|
284
|
+
"model_limit": model_limit,
|
|
285
|
+
"retry_max_tokens": new_max,
|
|
286
|
+
})
|
|
287
|
+
# Retry once with reduced budget
|
|
288
|
+
async with httpx.AsyncClient(timeout=timeout) as client2:
|
|
289
|
+
r2 = await client2.post(url, json=processed_request, headers=headers)
|
|
290
|
+
r2.raise_for_status()
|
|
291
|
+
return r2.json()
|
|
292
|
+
except Exception:
|
|
293
|
+
pass
|
|
294
|
+
except Exception:
|
|
295
|
+
pass
|
|
296
|
+
# Gracefully degrade on 422 so rollouts can still produce a trajectory
|
|
297
|
+
if status == 422:
|
|
298
|
+
try:
|
|
299
|
+
# Best-effort parse of error for diagnostics
|
|
300
|
+
err = None
|
|
301
|
+
try:
|
|
302
|
+
err = e.response.json()
|
|
303
|
+
except Exception:
|
|
304
|
+
err = {"error": "unprocessable", "detail": (text or "")[:200]}
|
|
305
|
+
logger.warning({
|
|
306
|
+
"inference_422_recovered": True,
|
|
307
|
+
"detail": err,
|
|
308
|
+
})
|
|
309
|
+
except Exception:
|
|
310
|
+
pass
|
|
311
|
+
# Return a minimal OpenAI-compatible response with no tool_calls/content
|
|
312
|
+
import time as _t
|
|
313
|
+
return {
|
|
314
|
+
"id": f"cmpl-{int(_t.time())}",
|
|
315
|
+
"object": "chat.completion",
|
|
316
|
+
"created": int(_t.time()),
|
|
317
|
+
"model": processed_request.get("model") or "unknown",
|
|
318
|
+
"choices": [
|
|
319
|
+
{
|
|
320
|
+
"index": 0,
|
|
321
|
+
"message": {"role": "assistant", "content": "", "tool_calls": []},
|
|
322
|
+
"finish_reason": "stop",
|
|
323
|
+
}
|
|
324
|
+
],
|
|
325
|
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
326
|
+
}
|
|
327
|
+
raise
|
|
328
|
+
except Exception as e:
|
|
329
|
+
logger.error(f"Unexpected error calling {url}: {e}")
|
|
330
|
+
raise
|
|
331
|
+
|
|
332
|
+
async def check_health(
|
|
333
|
+
self,
|
|
334
|
+
base_url: Optional[str] = None,
|
|
335
|
+
timeout_s: Optional[float] = None,
|
|
336
|
+
) -> Dict[str, Any]:
|
|
337
|
+
"""
|
|
338
|
+
Check if the inference service is healthy.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
base_url: Override base URL for this request
|
|
342
|
+
timeout_s: Override timeout for this request
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
Health status dict with 'status' field
|
|
346
|
+
"""
|
|
347
|
+
url = (base_url or self.base_url).rstrip("/") + "/health"
|
|
348
|
+
timeout = timeout_s or 10.0
|
|
349
|
+
|
|
350
|
+
try:
|
|
351
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
352
|
+
response = await client.get(url, headers=self.headers)
|
|
353
|
+
response.raise_for_status()
|
|
354
|
+
return response.json()
|
|
355
|
+
except httpx.HTTPStatusError as e:
|
|
356
|
+
if e.response.status_code == 400:
|
|
357
|
+
# Service is overloaded but still responding
|
|
358
|
+
try:
|
|
359
|
+
data = e.response.json()
|
|
360
|
+
if data.get("status") == "overloaded":
|
|
361
|
+
return {"status": "overloaded", "retry_after": data.get("retry_after", 1)}
|
|
362
|
+
except Exception:
|
|
363
|
+
pass
|
|
364
|
+
return {"status": "unhealthy", "error": str(e)}
|
|
365
|
+
except Exception as e:
|
|
366
|
+
return {"status": "unhealthy", "error": str(e)}
|
|
367
|
+
|
|
368
|
+
async def generate_with_retries(
|
|
369
|
+
self,
|
|
370
|
+
request: Dict[str, Any],
|
|
371
|
+
base_url: Optional[str] = None,
|
|
372
|
+
timeout_s: Optional[float] = None,
|
|
373
|
+
max_retries: int = 4,
|
|
374
|
+
backoff_factor: float = 2.0,
|
|
375
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
376
|
+
) -> Dict[str, Any]:
|
|
377
|
+
"""
|
|
378
|
+
Generate with exponential backoff retries for transient errors.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
request: OpenAI-compatible chat completion request
|
|
382
|
+
base_url: Override base URL
|
|
383
|
+
timeout_s: Override timeout
|
|
384
|
+
max_retries: Maximum number of retry attempts
|
|
385
|
+
backoff_factor: Exponential backoff multiplier
|
|
386
|
+
extra_headers: Additional headers to include (e.g., X-Policy-Name)
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
OpenAI-compatible chat completion response
|
|
390
|
+
"""
|
|
391
|
+
last_error = None
|
|
392
|
+
wait_time = 1.0
|
|
393
|
+
|
|
394
|
+
for attempt in range(max_retries + 1):
|
|
395
|
+
try:
|
|
396
|
+
# Apply parameter fixes to the request
|
|
397
|
+
processed_request = self._fix_model_parameters(
|
|
398
|
+
request,
|
|
399
|
+
target_url=(base_url or self.base_url).rstrip("/") + "/v1/chat/completions",
|
|
400
|
+
)
|
|
401
|
+
return await self.generate(
|
|
402
|
+
request=processed_request,
|
|
403
|
+
base_url=base_url,
|
|
404
|
+
timeout_s=timeout_s,
|
|
405
|
+
extra_headers=extra_headers,
|
|
406
|
+
)
|
|
407
|
+
except httpx.HTTPStatusError as e:
|
|
408
|
+
# Retry on 400 (overloaded), 429 (rate limit), 500 (internal error), 503 (service unavailable)
|
|
409
|
+
if e.response.status_code not in [400, 429, 500, 503]:
|
|
410
|
+
raise
|
|
411
|
+
last_error = e
|
|
412
|
+
if e.response.status_code == 400:
|
|
413
|
+
# Check if this is an overload error by looking at response content
|
|
414
|
+
try:
|
|
415
|
+
response_data = e.response.json()
|
|
416
|
+
if response_data.get("status") == "overloaded":
|
|
417
|
+
retry_after = response_data.get("retry_after", 1)
|
|
418
|
+
# Use the suggested retry_after time instead of exponential backoff for overload
|
|
419
|
+
wait_time = max(wait_time, float(retry_after))
|
|
420
|
+
logger.warning(f"Inference service overloaded (400). {response_data} Retrying after {wait_time}s...")
|
|
421
|
+
else:
|
|
422
|
+
# This is a different type of 400 error, don't retry
|
|
423
|
+
try:
|
|
424
|
+
redacted_headers = {}
|
|
425
|
+
try:
|
|
426
|
+
redacted_headers = dict(self.headers)
|
|
427
|
+
if "Authorization" in redacted_headers:
|
|
428
|
+
redacted_headers["Authorization"] = "***REDACTED***"
|
|
429
|
+
except Exception:
|
|
430
|
+
redacted_headers = {}
|
|
431
|
+
logger.error({
|
|
432
|
+
"non_overload_400": True,
|
|
433
|
+
"target": (base_url or self.base_url),
|
|
434
|
+
"payload": processed_request,
|
|
435
|
+
"headers": redacted_headers,
|
|
436
|
+
"body": e.response.text if e.response is not None else None,
|
|
437
|
+
})
|
|
438
|
+
except Exception:
|
|
439
|
+
pass
|
|
440
|
+
raise RuntimeError(
|
|
441
|
+
f"Inference 400 response: {e.response.text if e.response is not None else 'Bad Request'}"
|
|
442
|
+
) from e
|
|
443
|
+
except Exception:
|
|
444
|
+
# If we can't parse the response, don't retry 400 errors
|
|
445
|
+
try:
|
|
446
|
+
logger.error({
|
|
447
|
+
"non_overload_400_unparsed": True,
|
|
448
|
+
"target": (base_url or self.base_url),
|
|
449
|
+
"payload": processed_request,
|
|
450
|
+
})
|
|
451
|
+
except Exception:
|
|
452
|
+
pass
|
|
453
|
+
raise RuntimeError(
|
|
454
|
+
f"Inference 400 response (unparsed): {e.response.text if e.response is not None else 'Bad Request'}"
|
|
455
|
+
) from e
|
|
456
|
+
elif e.response.status_code == 503:
|
|
457
|
+
# Avoid referencing undefined response_data
|
|
458
|
+
try:
|
|
459
|
+
preview = (e.response.text or "")[:200]
|
|
460
|
+
except Exception:
|
|
461
|
+
preview = ""
|
|
462
|
+
logger.warning(
|
|
463
|
+
f"Flash returned 503; container may be cold starting. Retrying... body={preview}"
|
|
464
|
+
)
|
|
465
|
+
elif e.response.status_code == 500:
|
|
466
|
+
try:
|
|
467
|
+
preview = (e.response.text or "")[:200]
|
|
468
|
+
except Exception:
|
|
469
|
+
preview = ""
|
|
470
|
+
logger.warning(
|
|
471
|
+
f"Flash returned 500; inference service error. Retrying... body={preview}"
|
|
472
|
+
)
|
|
473
|
+
except httpx.TimeoutException as e:
|
|
474
|
+
last_error = e
|
|
475
|
+
|
|
476
|
+
if attempt < max_retries:
|
|
477
|
+
logger.warning(
|
|
478
|
+
f"Inference request failed (attempt {attempt + 1}/{max_retries + 1}), "
|
|
479
|
+
f"retrying in {wait_time}s..."
|
|
480
|
+
)
|
|
481
|
+
await asyncio.sleep(wait_time)
|
|
482
|
+
wait_time *= backoff_factor
|
|
483
|
+
|
|
484
|
+
raise last_error
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def create_inference_client(
|
|
488
|
+
task_app: Any,
|
|
489
|
+
api_key: Optional[str] = None,
|
|
490
|
+
) -> OpenAIClient:
|
|
491
|
+
"""
|
|
492
|
+
Create an inference client using TaskApp configuration.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
task_app: TaskApp instance with vllm_base_url
|
|
496
|
+
api_key: Optional API key for authentication
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
Configured OpenAIClient instance
|
|
500
|
+
"""
|
|
501
|
+
# Fallback to environment if caller didn't provide an API key
|
|
502
|
+
if api_key is None:
|
|
503
|
+
try:
|
|
504
|
+
import os as _os # local import to avoid module-level side effects
|
|
505
|
+
api_key = _os.getenv("OPENAI_API_KEY") or getattr(task_app, "openai_api_key", None)
|
|
506
|
+
except Exception:
|
|
507
|
+
api_key = None
|
|
508
|
+
|
|
509
|
+
return OpenAIClient(
|
|
510
|
+
base_url=task_app.vllm_base_url,
|
|
511
|
+
api_key=api_key,
|
|
512
|
+
)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Main entry point for the GRPO Synth Envs Hosted Service.
|
|
4
|
+
|
|
5
|
+
For local development:
|
|
6
|
+
uvicorn main:app --reload --port 8000
|
|
7
|
+
|
|
8
|
+
For Modal deployment:
|
|
9
|
+
modal deploy main.py
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import modal
|
|
18
|
+
|
|
19
|
+
# Try to import Modal-specific features
|
|
20
|
+
try:
|
|
21
|
+
from modal import App, Image, Volume, asgi_app
|
|
22
|
+
|
|
23
|
+
MODAL_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
MODAL_AVAILABLE = False
|
|
26
|
+
|
|
27
|
+
from synth_envs_hosted.hosted_app import create_app
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Local development mode
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
import uvicorn
|
|
33
|
+
|
|
34
|
+
# Create the FastAPI app
|
|
35
|
+
app = create_app()
|
|
36
|
+
|
|
37
|
+
# Run with uvicorn
|
|
38
|
+
uvicorn.run(
|
|
39
|
+
app,
|
|
40
|
+
host="0.0.0.0",
|
|
41
|
+
port=int(os.getenv("PORT", "8000")),
|
|
42
|
+
reload=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Modal deployment mode
|
|
46
|
+
elif MODAL_AVAILABLE:
|
|
47
|
+
# Define Modal app
|
|
48
|
+
modal_app = App("grpo-synth-envs-hosted")
|
|
49
|
+
|
|
50
|
+
# Define the container image
|
|
51
|
+
image = Image.debian_slim().pip_install(
|
|
52
|
+
"fastapi",
|
|
53
|
+
"uvicorn[standard]",
|
|
54
|
+
"httpx",
|
|
55
|
+
"pydantic",
|
|
56
|
+
"synth-ai",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Create or get the volume for state storage
|
|
60
|
+
state_volume = Volume.from_name("synth-env-state", create_if_missing=True)
|
|
61
|
+
|
|
62
|
+
# Define the ASGI app function
|
|
63
|
+
@modal_app.function(
|
|
64
|
+
image=image,
|
|
65
|
+
min_containers=1,
|
|
66
|
+
volumes={"/data/state": state_volume},
|
|
67
|
+
secrets=[
|
|
68
|
+
modal.Secret.from_name("vllm-config"),
|
|
69
|
+
],
|
|
70
|
+
)
|
|
71
|
+
@asgi_app()
|
|
72
|
+
def fastapi_app():
|
|
73
|
+
"""Modal ASGI app factory."""
|
|
74
|
+
return create_app()
|
|
75
|
+
|
|
76
|
+
# Optional: Add a scheduled cleanup job
|
|
77
|
+
@modal_app.function(
|
|
78
|
+
schedule=modal.Period(hours=24),
|
|
79
|
+
volumes={"/data/state": state_volume},
|
|
80
|
+
)
|
|
81
|
+
def cleanup_old_snapshots(max_age_hours: int = 48):
|
|
82
|
+
"""Periodic cleanup of old snapshots."""
|
|
83
|
+
import shutil
|
|
84
|
+
from datetime import datetime, timedelta
|
|
85
|
+
from pathlib import Path
|
|
86
|
+
|
|
87
|
+
base_path = Path("/data/state/runs")
|
|
88
|
+
if not base_path.exists():
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
cutoff_time = datetime.utcnow() - timedelta(hours=max_age_hours)
|
|
92
|
+
|
|
93
|
+
for run_dir in base_path.iterdir():
|
|
94
|
+
if run_dir.is_dir():
|
|
95
|
+
# Check modification time
|
|
96
|
+
mtime = datetime.fromtimestamp(run_dir.stat().st_mtime)
|
|
97
|
+
if mtime < cutoff_time:
|
|
98
|
+
print(f"Removing old run directory: {run_dir}")
|
|
99
|
+
shutil.rmtree(run_dir)
|
|
100
|
+
|
|
101
|
+
# Export for Modal
|
|
102
|
+
app = fastapi_app
|