synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  98. synth_ai/api/train/config_finder.py +18 -18
  99. synth_ai/api/train/env_resolver.py +28 -1
  100. synth_ai/cli/task_apps.py +264 -55
  101. synth_ai/task/apps/__init__.py +54 -13
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +107 -12
  104. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,512 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any, Dict, Optional
6
+
7
+ import httpx
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class OpenAIClient:
13
+ """Async HTTP client for OpenAI-compatible inference servers (vLLM)."""
14
+
15
+ def __init__(
16
+ self,
17
+ base_url: str,
18
+ api_key: Optional[str] = None,
19
+ timeout_s: float = 120.0,
20
+ ) -> None:
21
+ self.base_url = base_url.rstrip("/")
22
+ self.api_key = api_key
23
+ self.timeout_s = timeout_s
24
+ self.headers = {}
25
+
26
+ if api_key:
27
+ self.headers["Authorization"] = f"Bearer {api_key}"
28
+
29
+ def _fix_model_parameters(self, request: Dict[str, Any], target_url: Optional[str] = None) -> Dict[str, Any]:
30
+ """
31
+ Fix parameter compatibility for newer OpenAI models.
32
+
33
+ Newer models like gpt-5-nano use 'max_completion_tokens' instead of 'max_tokens'.
34
+ """
35
+ if not request:
36
+ return request
37
+
38
+ # Make a copy to avoid modifying the original
39
+ fixed_request = request.copy()
40
+
41
+ # Determine if target is OpenAI-compatible (OpenAI, Azure OpenAI, Groq);
42
+ # strip fields those endpoints don't accept
43
+ is_openai = False
44
+ try:
45
+ if isinstance(target_url, str):
46
+ low = target_url.lower()
47
+ is_openai = (
48
+ ("openai.com" in low)
49
+ or ("azure" in low and ".openai." in low)
50
+ or ("groq.com" in low)
51
+ or ("/openai" in low)
52
+ )
53
+ except Exception:
54
+ is_openai = False
55
+
56
+ model = fixed_request.get("model", "")
57
+
58
+ if is_openai:
59
+ # Remove fields OpenAI/Groq don't accept
60
+ for k in (
61
+ "stop_after_tool_calls",
62
+ "thinking_mode",
63
+ "thinking_budget",
64
+ "reasoning",
65
+ "extra_body",
66
+ "parallel_tool_calls",
67
+ "function_call",
68
+ ):
69
+ if k in fixed_request:
70
+ fixed_request.pop(k, None)
71
+
72
+ # GPT-5 family specifics
73
+ if "gpt-5" in model or "gpt-4.1" in model:
74
+ # Convert max_tokens to max_completion_tokens for newer models
75
+ if "max_tokens" in fixed_request:
76
+ if "max_completion_tokens" not in fixed_request:
77
+ fixed_request["max_completion_tokens"] = fixed_request.pop("max_tokens")
78
+ logger.info(f"Converted max_tokens to max_completion_tokens for model {model}")
79
+ else:
80
+ fixed_request.pop("max_tokens")
81
+ logger.info(f"Removed conflicting max_tokens parameter for model {model}")
82
+ # Some OpenAI endpoints ignore/deny sampling fields for reasoning models
83
+ for k in ("temperature", "top_p"):
84
+ if k in fixed_request:
85
+ fixed_request.pop(k, None)
86
+ # If tools are present, force single tool choice to our function
87
+ try:
88
+ tools = fixed_request.get("tools")
89
+ if isinstance(tools, list) and tools:
90
+ fixed_request["tool_choice"] = {
91
+ "type": "function",
92
+ "function": {"name": "interact_many"},
93
+ }
94
+ fixed_request["parallel_tool_calls"] = False
95
+ except Exception:
96
+ pass
97
+
98
+ return fixed_request
99
+
100
+ async def generate(
101
+ self,
102
+ request: Dict[str, Any],
103
+ base_url: Optional[str] = None,
104
+ timeout_s: Optional[float] = None,
105
+ extra_headers: Optional[Dict[str, str]] = None,
106
+ ) -> Dict[str, Any]:
107
+ """
108
+ Send a chat completion request to the inference server.
109
+
110
+ Args:
111
+ request: OpenAI-compatible chat completion request
112
+ base_url: Override base URL for this request
113
+ timeout_s: Override timeout for this request
114
+ extra_headers: Additional headers to include (e.g., X-Policy-Name)
115
+
116
+ Returns:
117
+ OpenAI-compatible chat completion response
118
+ """
119
+ url = (base_url or self.base_url).rstrip("/") + "/v1/chat/completions"
120
+ timeout = timeout_s or self.timeout_s
121
+
122
+ # Merge headers
123
+ headers = self.headers.copy()
124
+ if extra_headers:
125
+ headers.update(extra_headers)
126
+
127
+ # Fix parameter compatibility for newer models
128
+ processed_request = self._fix_model_parameters(request, target_url=url)
129
+
130
+ # Log request (redact messages in production)
131
+ logger.info(f"Inference POST target: {url}")
132
+ if extra_headers:
133
+ logger.info(f"Extra headers: {extra_headers}")
134
+ try:
135
+ keys_preview = sorted(list(processed_request.keys()))
136
+ logger.info(f"Request keys: {keys_preview}")
137
+ except Exception:
138
+ pass
139
+
140
+ # Final hard-guard for OpenAI: ensure unsupported field is not present
141
+ try:
142
+ if "openai" in url.lower():
143
+ if "stop_after_tool_calls" in processed_request:
144
+ processed_request.pop("stop_after_tool_calls", None)
145
+ logger.info("Removed stop_after_tool_calls for OpenAI request")
146
+ # Groq-specific requirement: when using JSON mode, one of the messages must contain the word 'json'
147
+ low_url = url.lower()
148
+ if ("groq.com" in low_url or "/openai" in low_url) and isinstance(processed_request, dict):
149
+ rf = processed_request.get("response_format")
150
+ rf_type = None
151
+ if isinstance(rf, dict):
152
+ rf_type = str(rf.get("type") or "").lower()
153
+ if rf_type in {"json_object", "json_schema"}:
154
+ msgs = processed_request.get("messages")
155
+ has_json_word = False
156
+ if isinstance(msgs, list):
157
+ for m in msgs:
158
+ try:
159
+ content = m.get("content") if isinstance(m, dict) else None
160
+ text = None
161
+ if isinstance(content, str):
162
+ text = content
163
+ elif isinstance(content, list):
164
+ # Join any text segments
165
+ parts = []
166
+ for seg in content:
167
+ if isinstance(seg, dict) and isinstance(seg.get("text"), str):
168
+ parts.append(seg["text"])
169
+ text = "\n".join(parts)
170
+ if isinstance(text, str) and ("json" in text.lower()):
171
+ has_json_word = True
172
+ break
173
+ except Exception:
174
+ continue
175
+ if not has_json_word:
176
+ try:
177
+ instruction = "Respond in strict JSON only. Output a single valid JSON object."
178
+ if not isinstance(msgs, list):
179
+ msgs = []
180
+ # Prepend a system message to satisfy Groq requirement without changing user intent
181
+ prepend = {"role": "system", "content": instruction}
182
+ processed_request["messages"] = [prepend] + list(msgs)
183
+ logger.info("Injected JSON-mode system instruction for Groq response_format compliance")
184
+ except Exception:
185
+ pass
186
+ except Exception:
187
+ pass
188
+
189
+ async with httpx.AsyncClient(timeout=timeout) as client:
190
+ try:
191
+ response = await client.post(
192
+ url,
193
+ json=processed_request,
194
+ headers=headers,
195
+ )
196
+ response.raise_for_status()
197
+
198
+ # Rich response diagnostics
199
+ content_type = response.headers.get("content-type")
200
+ body_text = response.text
201
+ logger.info(
202
+ f"Inference response status=200, content-type={content_type}, bytes={len(body_text)}"
203
+ )
204
+ if body_text:
205
+ preview_len = min(800, len(body_text))
206
+ logger.info(f"Inference response preview ({preview_len} bytes): {body_text[:preview_len]}")
207
+
208
+ result = response.json()
209
+ logger.info(f"Inference response parsed_type={type(result).__name__}")
210
+ return result
211
+
212
+ except httpx.TimeoutException:
213
+ logger.error(f"Request to {url} timed out after {timeout}s")
214
+ raise
215
+ except httpx.HTTPStatusError as e:
216
+ status = e.response.status_code if e.response is not None else None
217
+ text = e.response.text if e.response is not None else str(e)
218
+ # Log full body for debugging remote failures
219
+ try:
220
+ logger.error({
221
+ "openai_http_error": True,
222
+ "status": status,
223
+ "url": url,
224
+ "body": text,
225
+ })
226
+ except Exception:
227
+ logger.error(f"HTTP error from {url}: {status} - {text}")
228
+ # For 4xx/5xx, print full sanitized request to aid debugging (especially Groq 400s)
229
+ try:
230
+ redacted_headers = dict(headers)
231
+ if "Authorization" in redacted_headers:
232
+ redacted_headers["Authorization"] = "***REDACTED***"
233
+ logger.error({
234
+ "request_debug": True,
235
+ "status": status,
236
+ "target": url,
237
+ "headers": redacted_headers,
238
+ "payload": processed_request,
239
+ })
240
+ except Exception:
241
+ pass
242
+ # Special case: token budget exceeded (OpenAI-compatible error schema)
243
+ try:
244
+ if status == 400 and e.response is not None:
245
+ data = e.response.json()
246
+ detail = data.get("detail") if isinstance(data, dict) else None
247
+ err_code = (detail or {}).get("error") if isinstance(detail, dict) else None
248
+ if err_code == "token_budget_exceeded":
249
+ info = (detail or {}).get("details") or {}
250
+ messages_tokens = int(info.get("messages_tokens") or 0)
251
+ model_limit = int(info.get("model_limit") or 0)
252
+ safety = 64
253
+ # Compute a conservative new max_tokens
254
+ new_max = max(16, model_limit - messages_tokens - safety)
255
+ try:
256
+ # Update request and retry once immediately with smaller budget
257
+ if isinstance(processed_request, dict):
258
+ processed_request = dict(processed_request)
259
+ if "max_completion_tokens" in processed_request:
260
+ processed_request["max_completion_tokens"] = new_max
261
+ processed_request.pop("max_tokens", None)
262
+ else:
263
+ processed_request["max_tokens"] = new_max
264
+ # Remove optional fields that some servers reject
265
+ for k in ("thinking_mode", "thinking_budget", "reasoning"):
266
+ processed_request.pop(k, None)
267
+ # Force structured tool choice
268
+ if processed_request.get("tool_choice") == "required":
269
+ func_name = "interact_many"
270
+ try:
271
+ tools_arr = processed_request.get("tools") or []
272
+ if isinstance(tools_arr, list) and tools_arr:
273
+ f = tools_arr[0].get("function") if isinstance(tools_arr[0], dict) else None
274
+ cand = (f or {}).get("name") if isinstance(f, dict) else None
275
+ if isinstance(cand, str) and cand:
276
+ func_name = cand
277
+ except Exception:
278
+ pass
279
+ processed_request["tool_choice"] = {"type": "function", "function": {"name": func_name}}
280
+ processed_request["parallel_tool_calls"] = False
281
+ logger.warning({
282
+ "token_budget_recovery": True,
283
+ "messages_tokens": messages_tokens,
284
+ "model_limit": model_limit,
285
+ "retry_max_tokens": new_max,
286
+ })
287
+ # Retry once with reduced budget
288
+ async with httpx.AsyncClient(timeout=timeout) as client2:
289
+ r2 = await client2.post(url, json=processed_request, headers=headers)
290
+ r2.raise_for_status()
291
+ return r2.json()
292
+ except Exception:
293
+ pass
294
+ except Exception:
295
+ pass
296
+ # Gracefully degrade on 422 so rollouts can still produce a trajectory
297
+ if status == 422:
298
+ try:
299
+ # Best-effort parse of error for diagnostics
300
+ err = None
301
+ try:
302
+ err = e.response.json()
303
+ except Exception:
304
+ err = {"error": "unprocessable", "detail": (text or "")[:200]}
305
+ logger.warning({
306
+ "inference_422_recovered": True,
307
+ "detail": err,
308
+ })
309
+ except Exception:
310
+ pass
311
+ # Return a minimal OpenAI-compatible response with no tool_calls/content
312
+ import time as _t
313
+ return {
314
+ "id": f"cmpl-{int(_t.time())}",
315
+ "object": "chat.completion",
316
+ "created": int(_t.time()),
317
+ "model": processed_request.get("model") or "unknown",
318
+ "choices": [
319
+ {
320
+ "index": 0,
321
+ "message": {"role": "assistant", "content": "", "tool_calls": []},
322
+ "finish_reason": "stop",
323
+ }
324
+ ],
325
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
326
+ }
327
+ raise
328
+ except Exception as e:
329
+ logger.error(f"Unexpected error calling {url}: {e}")
330
+ raise
331
+
332
+ async def check_health(
333
+ self,
334
+ base_url: Optional[str] = None,
335
+ timeout_s: Optional[float] = None,
336
+ ) -> Dict[str, Any]:
337
+ """
338
+ Check if the inference service is healthy.
339
+
340
+ Args:
341
+ base_url: Override base URL for this request
342
+ timeout_s: Override timeout for this request
343
+
344
+ Returns:
345
+ Health status dict with 'status' field
346
+ """
347
+ url = (base_url or self.base_url).rstrip("/") + "/health"
348
+ timeout = timeout_s or 10.0
349
+
350
+ try:
351
+ async with httpx.AsyncClient(timeout=timeout) as client:
352
+ response = await client.get(url, headers=self.headers)
353
+ response.raise_for_status()
354
+ return response.json()
355
+ except httpx.HTTPStatusError as e:
356
+ if e.response.status_code == 400:
357
+ # Service is overloaded but still responding
358
+ try:
359
+ data = e.response.json()
360
+ if data.get("status") == "overloaded":
361
+ return {"status": "overloaded", "retry_after": data.get("retry_after", 1)}
362
+ except Exception:
363
+ pass
364
+ return {"status": "unhealthy", "error": str(e)}
365
+ except Exception as e:
366
+ return {"status": "unhealthy", "error": str(e)}
367
+
368
+ async def generate_with_retries(
369
+ self,
370
+ request: Dict[str, Any],
371
+ base_url: Optional[str] = None,
372
+ timeout_s: Optional[float] = None,
373
+ max_retries: int = 4,
374
+ backoff_factor: float = 2.0,
375
+ extra_headers: Optional[Dict[str, str]] = None,
376
+ ) -> Dict[str, Any]:
377
+ """
378
+ Generate with exponential backoff retries for transient errors.
379
+
380
+ Args:
381
+ request: OpenAI-compatible chat completion request
382
+ base_url: Override base URL
383
+ timeout_s: Override timeout
384
+ max_retries: Maximum number of retry attempts
385
+ backoff_factor: Exponential backoff multiplier
386
+ extra_headers: Additional headers to include (e.g., X-Policy-Name)
387
+
388
+ Returns:
389
+ OpenAI-compatible chat completion response
390
+ """
391
+ last_error = None
392
+ wait_time = 1.0
393
+
394
+ for attempt in range(max_retries + 1):
395
+ try:
396
+ # Apply parameter fixes to the request
397
+ processed_request = self._fix_model_parameters(
398
+ request,
399
+ target_url=(base_url or self.base_url).rstrip("/") + "/v1/chat/completions",
400
+ )
401
+ return await self.generate(
402
+ request=processed_request,
403
+ base_url=base_url,
404
+ timeout_s=timeout_s,
405
+ extra_headers=extra_headers,
406
+ )
407
+ except httpx.HTTPStatusError as e:
408
+ # Retry on 400 (overloaded), 429 (rate limit), 500 (internal error), 503 (service unavailable)
409
+ if e.response.status_code not in [400, 429, 500, 503]:
410
+ raise
411
+ last_error = e
412
+ if e.response.status_code == 400:
413
+ # Check if this is an overload error by looking at response content
414
+ try:
415
+ response_data = e.response.json()
416
+ if response_data.get("status") == "overloaded":
417
+ retry_after = response_data.get("retry_after", 1)
418
+ # Use the suggested retry_after time instead of exponential backoff for overload
419
+ wait_time = max(wait_time, float(retry_after))
420
+ logger.warning(f"Inference service overloaded (400). {response_data} Retrying after {wait_time}s...")
421
+ else:
422
+ # This is a different type of 400 error, don't retry
423
+ try:
424
+ redacted_headers = {}
425
+ try:
426
+ redacted_headers = dict(self.headers)
427
+ if "Authorization" in redacted_headers:
428
+ redacted_headers["Authorization"] = "***REDACTED***"
429
+ except Exception:
430
+ redacted_headers = {}
431
+ logger.error({
432
+ "non_overload_400": True,
433
+ "target": (base_url or self.base_url),
434
+ "payload": processed_request,
435
+ "headers": redacted_headers,
436
+ "body": e.response.text if e.response is not None else None,
437
+ })
438
+ except Exception:
439
+ pass
440
+ raise RuntimeError(
441
+ f"Inference 400 response: {e.response.text if e.response is not None else 'Bad Request'}"
442
+ ) from e
443
+ except Exception:
444
+ # If we can't parse the response, don't retry 400 errors
445
+ try:
446
+ logger.error({
447
+ "non_overload_400_unparsed": True,
448
+ "target": (base_url or self.base_url),
449
+ "payload": processed_request,
450
+ })
451
+ except Exception:
452
+ pass
453
+ raise RuntimeError(
454
+ f"Inference 400 response (unparsed): {e.response.text if e.response is not None else 'Bad Request'}"
455
+ ) from e
456
+ elif e.response.status_code == 503:
457
+ # Avoid referencing undefined response_data
458
+ try:
459
+ preview = (e.response.text or "")[:200]
460
+ except Exception:
461
+ preview = ""
462
+ logger.warning(
463
+ f"Flash returned 503; container may be cold starting. Retrying... body={preview}"
464
+ )
465
+ elif e.response.status_code == 500:
466
+ try:
467
+ preview = (e.response.text or "")[:200]
468
+ except Exception:
469
+ preview = ""
470
+ logger.warning(
471
+ f"Flash returned 500; inference service error. Retrying... body={preview}"
472
+ )
473
+ except httpx.TimeoutException as e:
474
+ last_error = e
475
+
476
+ if attempt < max_retries:
477
+ logger.warning(
478
+ f"Inference request failed (attempt {attempt + 1}/{max_retries + 1}), "
479
+ f"retrying in {wait_time}s..."
480
+ )
481
+ await asyncio.sleep(wait_time)
482
+ wait_time *= backoff_factor
483
+
484
+ raise last_error
485
+
486
+
487
+ def create_inference_client(
488
+ task_app: Any,
489
+ api_key: Optional[str] = None,
490
+ ) -> OpenAIClient:
491
+ """
492
+ Create an inference client using TaskApp configuration.
493
+
494
+ Args:
495
+ task_app: TaskApp instance with vllm_base_url
496
+ api_key: Optional API key for authentication
497
+
498
+ Returns:
499
+ Configured OpenAIClient instance
500
+ """
501
+ # Fallback to environment if caller didn't provide an API key
502
+ if api_key is None:
503
+ try:
504
+ import os as _os # local import to avoid module-level side effects
505
+ api_key = _os.getenv("OPENAI_API_KEY") or getattr(task_app, "openai_api_key", None)
506
+ except Exception:
507
+ api_key = None
508
+
509
+ return OpenAIClient(
510
+ base_url=task_app.vllm_base_url,
511
+ api_key=api_key,
512
+ )
@@ -0,0 +1,102 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main entry point for the GRPO Synth Envs Hosted Service.
4
+
5
+ For local development:
6
+ uvicorn main:app --reload --port 8000
7
+
8
+ For Modal deployment:
9
+ modal deploy main.py
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ from typing import Optional
16
+
17
+ import modal
18
+
19
+ # Try to import Modal-specific features
20
+ try:
21
+ from modal import App, Image, Volume, asgi_app
22
+
23
+ MODAL_AVAILABLE = True
24
+ except ImportError:
25
+ MODAL_AVAILABLE = False
26
+
27
+ from synth_envs_hosted.hosted_app import create_app
28
+
29
+
30
+ # Local development mode
31
+ if __name__ == "__main__":
32
+ import uvicorn
33
+
34
+ # Create the FastAPI app
35
+ app = create_app()
36
+
37
+ # Run with uvicorn
38
+ uvicorn.run(
39
+ app,
40
+ host="0.0.0.0",
41
+ port=int(os.getenv("PORT", "8000")),
42
+ reload=True,
43
+ )
44
+
45
+ # Modal deployment mode
46
+ elif MODAL_AVAILABLE:
47
+ # Define Modal app
48
+ modal_app = App("grpo-synth-envs-hosted")
49
+
50
+ # Define the container image
51
+ image = Image.debian_slim().pip_install(
52
+ "fastapi",
53
+ "uvicorn[standard]",
54
+ "httpx",
55
+ "pydantic",
56
+ "synth-ai",
57
+ )
58
+
59
+ # Create or get the volume for state storage
60
+ state_volume = Volume.from_name("synth-env-state", create_if_missing=True)
61
+
62
+ # Define the ASGI app function
63
+ @modal_app.function(
64
+ image=image,
65
+ min_containers=1,
66
+ volumes={"/data/state": state_volume},
67
+ secrets=[
68
+ modal.Secret.from_name("vllm-config"),
69
+ ],
70
+ )
71
+ @asgi_app()
72
+ def fastapi_app():
73
+ """Modal ASGI app factory."""
74
+ return create_app()
75
+
76
+ # Optional: Add a scheduled cleanup job
77
+ @modal_app.function(
78
+ schedule=modal.Period(hours=24),
79
+ volumes={"/data/state": state_volume},
80
+ )
81
+ def cleanup_old_snapshots(max_age_hours: int = 48):
82
+ """Periodic cleanup of old snapshots."""
83
+ import shutil
84
+ from datetime import datetime, timedelta
85
+ from pathlib import Path
86
+
87
+ base_path = Path("/data/state/runs")
88
+ if not base_path.exists():
89
+ return
90
+
91
+ cutoff_time = datetime.utcnow() - timedelta(hours=max_age_hours)
92
+
93
+ for run_dir in base_path.iterdir():
94
+ if run_dir.is_dir():
95
+ # Check modification time
96
+ mtime = datetime.fromtimestamp(run_dir.stat().st_mtime)
97
+ if mtime < cutoff_time:
98
+ print(f"Removing old run directory: {run_dir}")
99
+ shutil.rmtree(run_dir)
100
+
101
+ # Export for Modal
102
+ app = fastapi_app