synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (48) hide show
  1. examples/agora_ex/README_MoE.md +224 -0
  2. examples/agora_ex/__init__.py +7 -0
  3. examples/agora_ex/agora_ex.py +65 -0
  4. examples/agora_ex/agora_ex_task_app.py +590 -0
  5. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
  6. examples/agora_ex/reward_fn_grpo-human.py +129 -0
  7. examples/agora_ex/system_prompt_CURRENT.md +63 -0
  8. examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
  9. examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
  10. examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
  11. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  12. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
  13. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  14. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  15. examples/multi_step/crafter_rl_lora.md +51 -10
  16. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  17. examples/multi_step/task_app_config_notes.md +7 -1
  18. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
  19. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
  20. examples/warming_up_to_rl/run_eval.py +127 -18
  21. examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
  22. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  23. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
  24. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +232 -193
  25. synth_ai/__init__.py +41 -1
  26. synth_ai/api/train/builders.py +49 -19
  27. synth_ai/api/train/configs/__init__.py +44 -0
  28. synth_ai/api/train/configs/rl.py +133 -0
  29. synth_ai/api/train/configs/sft.py +94 -0
  30. synth_ai/api/train/configs/shared.py +24 -0
  31. synth_ai/cli/demo.py +38 -39
  32. synth_ai/cli/rl_demo.py +81 -102
  33. synth_ai/cli/task_apps.py +3 -0
  34. synth_ai/demos/core/cli.py +121 -159
  35. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  36. synth_ai/evals/__init__.py +15 -0
  37. synth_ai/evals/client.py +85 -0
  38. synth_ai/evals/types.py +42 -0
  39. synth_ai/judge_schemas.py +127 -0
  40. synth_ai/rubrics/__init__.py +22 -0
  41. synth_ai/rubrics/validators.py +126 -0
  42. synth_ai/tracing_v3/serialization.py +130 -0
  43. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +1 -1
  44. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +48 -22
  45. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
  46. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
  47. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
  48. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,590 @@
1
+ """Task app for the Agora EX landing page generation environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import importlib.util
7
+ import json
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
12
+
13
+ import httpx
14
+ from fastapi import FastAPI, HTTPException
15
+ from fastapi.exceptions import RequestValidationError
16
+ from fastapi.responses import JSONResponse
17
+ from starlette.requests import Request
18
+
19
+ from synth_ai.task.auth import (
20
+ is_api_key_header_authorized,
21
+ normalize_environment_api_key,
22
+ )
23
+ from synth_ai.task.contracts import (
24
+ RolloutMetrics,
25
+ RolloutRequest,
26
+ RolloutResponse,
27
+ RolloutStep,
28
+ RolloutTrajectory,
29
+ TaskInfo,
30
+ )
31
+ from synth_ai.task.server import TaskAppConfig, create_task_app, run_task_app
32
+
33
+ try: # Optional registry integration
34
+ from synth_ai.task.apps import TaskAppEntry, register_task_app
35
+ except ImportError: # pragma: no cover - registry not available in some test harnesses
36
+ TaskAppEntry = None # type: ignore[assignment]
37
+ register_task_app = None # type: ignore[assignment]
38
+
39
+ LOGGER = logging.getLogger("agora_ex.task_app")
40
+
41
+ APP_ID = "agora-ex-landing-page"
42
+ APP_NAME = "Agora EX Landing Page Task App"
43
+ APP_DESCRIPTION = (
44
+ "Single-turn Next.js landing page generation task evaluated by the Eames human judge."
45
+ )
46
+ DATASET_ID = "agora_ex_prompts_v1"
47
+ PROMPTS_FILENAME = "user_prompts_CURRENT.jsonl"
48
+ SYSTEM_PROMPT_FILENAME = "system_prompt_CURRENT.md"
49
+ DEFAULT_MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
50
+ DEFAULT_TEMPERATURE = 0.15
51
+ DEFAULT_MAX_TOKENS = 3072
52
+ INFERENCE_TIMEOUT_SECONDS = float(os.getenv("AGORA_EX_INFERENCE_TIMEOUT", "120"))
53
+
54
+
55
+ class AgoraPromptDataset:
56
+ """JSONL-backed prompt dataset for Agora EX."""
57
+
58
+ def __init__(self, path: Path) -> None:
59
+ self._path = path
60
+ self._prompts = self._load_prompts(path)
61
+
62
+ @staticmethod
63
+ def _load_prompts(path: Path) -> List[str]:
64
+ if not path.exists():
65
+ raise FileNotFoundError(f"Prompt file not found: {path}")
66
+ prompts: List[str] = []
67
+ with path.open("r", encoding="utf-8") as handle:
68
+ for line in handle:
69
+ stripped = line.strip()
70
+ if not stripped:
71
+ continue
72
+ try:
73
+ payload = json.loads(stripped)
74
+ except json.JSONDecodeError as exc:
75
+ raise ValueError(f"Invalid JSONL line in {path}: {stripped}") from exc
76
+ prompt = payload.get("user_prompt")
77
+ if isinstance(prompt, str) and prompt.strip():
78
+ prompts.append(prompt.strip())
79
+ if not prompts:
80
+ raise ValueError(f"No prompts loaded from {path}")
81
+ return prompts
82
+
83
+ def __len__(self) -> int:
84
+ return len(self._prompts)
85
+
86
+ def resolve(self, raw_index: int) -> Tuple[int, str]:
87
+ total = len(self._prompts)
88
+ if total == 0:
89
+ raise RuntimeError("Prompt dataset is empty")
90
+ index = raw_index % total if raw_index >= 0 else (total + (raw_index % total)) % total
91
+ return index, self._prompts[index]
92
+
93
+ def describe(self) -> Dict[str, Any]:
94
+ return {
95
+ "dataset_id": DATASET_ID,
96
+ "num_prompts": len(self._prompts),
97
+ "source": str(self._path),
98
+ "splits": {"train": len(self._prompts)},
99
+ }
100
+
101
+
102
+ def _load_reward_module() -> Any:
103
+ module_path = Path(__file__).with_name("reward_fn_grpo-human.py")
104
+ if not module_path.exists():
105
+ raise FileNotFoundError(f"Missing reward module: {module_path}")
106
+ spec = importlib.util.spec_from_file_location("agora_ex_reward_module", module_path)
107
+ if spec is None or spec.loader is None:
108
+ raise ImportError(f"Unable to load reward module from {module_path}")
109
+ module = importlib.util.module_from_spec(spec)
110
+ spec.loader.exec_module(module)
111
+ return module
112
+
113
+
114
+ REWARD_MODULE = _load_reward_module()
115
+ REWARD_FN = getattr(REWARD_MODULE, "reward_fn")
116
+ REWARD_RUN_TYPE = getattr(REWARD_MODULE, "RUN_TYPE", "rl_training")
117
+ REWARD_RUN_VERSION = getattr(REWARD_MODULE, "RUN_VERSION", 1.0)
118
+ REWARD_EXPERIMENT = getattr(REWARD_MODULE, "EXPERIMENT_NAME", APP_ID)
119
+ REWARD_USER_PROMPT_VERSION = getattr(REWARD_MODULE, "USER_PROMPT_VERSION", "current")
120
+ REWARD_SYSTEM_PROMPT_VERSION = getattr(REWARD_MODULE, "SYSTEM_PROMPT_VERSION", "current")
121
+
122
+
123
+ def _read_system_prompt(path: Path) -> str:
124
+ if not path.exists():
125
+ raise FileNotFoundError(f"System prompt file missing: {path}")
126
+ return path.read_text(encoding="utf-8").strip()
127
+
128
+
129
+ def _coerce_int(value: Any, default: int) -> int:
130
+ try:
131
+ return int(value)
132
+ except (TypeError, ValueError):
133
+ return default
134
+
135
+
136
+ def _coerce_float(value: Any, default: float) -> float:
137
+ try:
138
+ return float(value)
139
+ except (TypeError, ValueError):
140
+ return default
141
+
142
+
143
+ def _resolve_inference_url(policy_config: Dict[str, Any]) -> Optional[str]:
144
+ candidate = policy_config.get("inference_url")
145
+ if isinstance(candidate, str) and candidate.strip():
146
+ return candidate.strip()
147
+ env_fallback = os.getenv("AGORA_EX_INFERENCE_URL")
148
+ return env_fallback.strip() if env_fallback else None
149
+
150
+
151
+ def _normalize_chat_url(base_url: str) -> str:
152
+ base = base_url.rstrip("/")
153
+ if base.endswith("/v1/chat/completions"):
154
+ return base
155
+ return f"{base}/v1/chat/completions"
156
+
157
+
158
+ def _base_task_info() -> TaskInfo:
159
+ return TaskInfo(
160
+ task={
161
+ "id": APP_ID,
162
+ "name": "Agora EX Landing Page Generation",
163
+ "description": (
164
+ "Generate a production-ready Next.js landing page that satisfies the Agora EX brief."
165
+ ),
166
+ },
167
+ environments=["default"],
168
+ observation={
169
+ "type": "text",
170
+ "description": "System prompt plus product brief describing the required landing page.",
171
+ },
172
+ action_space={
173
+ "type": "free_text",
174
+ "description": "Return one TSX file wrapped in a single ```tsx code fence.",
175
+ },
176
+ dataset={
177
+ "id": DATASET_ID,
178
+ "default_split": "train",
179
+ "user_prompt_version": REWARD_USER_PROMPT_VERSION,
180
+ },
181
+ rubric={
182
+ "outcome": {"name": "Human Preference Score", "criteria": []},
183
+ "events": {"name": "Design Compliance", "criteria": []},
184
+ },
185
+ inference={"providers": ["vllm", "local"]},
186
+ capabilities={"tools": []},
187
+ limits={
188
+ "max_turns": 1,
189
+ "max_response_tokens": DEFAULT_MAX_TOKENS,
190
+ },
191
+ )
192
+
193
+
194
+ def describe_taskset(dataset: AgoraPromptDataset) -> Dict[str, Any]:
195
+ return {
196
+ "task": APP_NAME,
197
+ "dataset": dataset.describe(),
198
+ "system_prompt_version": REWARD_SYSTEM_PROMPT_VERSION,
199
+ "user_prompt_version": REWARD_USER_PROMPT_VERSION,
200
+ }
201
+
202
+
203
+ def provide_task_instances(dataset: AgoraPromptDataset, seeds: Sequence[int]) -> Iterable[TaskInfo]:
204
+ base = _base_task_info()
205
+ for seed in seeds:
206
+ index, _ = dataset.resolve(seed)
207
+ yield TaskInfo(
208
+ task=base.task,
209
+ environments=base.environments,
210
+ observation=base.observation,
211
+ action_space=base.action_space,
212
+ dataset={**base.dataset, "selected_index": index},
213
+ rubric=base.rubric,
214
+ inference=base.inference,
215
+ capabilities=base.capabilities,
216
+ limits=base.limits,
217
+ )
218
+
219
+
220
+ def _invoke_inference(
221
+ chat_url: str,
222
+ messages: List[Dict[str, Any]],
223
+ model: Optional[str],
224
+ temperature: float,
225
+ max_tokens: int,
226
+ ) -> Tuple[Optional[str], Dict[str, Any]]:
227
+ payload = {
228
+ "model": model or DEFAULT_MODEL,
229
+ "messages": messages,
230
+ "temperature": temperature,
231
+ "max_tokens": max_tokens,
232
+ }
233
+ LOGGER.info(
234
+ "[AGORA] request inference_url=%s model=%s temperature=%.3f max_tokens=%s",
235
+ chat_url,
236
+ payload["model"],
237
+ temperature,
238
+ max_tokens,
239
+ )
240
+ response = httpx.post(chat_url, json=payload, timeout=INFERENCE_TIMEOUT_SECONDS)
241
+ info: Dict[str, Any] = {"status_code": response.status_code}
242
+ if response.status_code != 200:
243
+ info["error_text"] = response.text[:2000]
244
+ LOGGER.error(
245
+ "[AGORA] inference failed status=%s body_preview=%s",
246
+ response.status_code,
247
+ info["error_text"],
248
+ )
249
+ return None, info
250
+
251
+ data = response.json()
252
+ info["raw_response"] = data
253
+ choices = data.get("choices") or []
254
+ primary = choices[0] if choices else {}
255
+ message = primary.get("message") or {}
256
+ completion = message.get("content")
257
+ if isinstance(completion, str):
258
+ return completion.strip(), info
259
+ return None, info
260
+
261
+
262
+ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
263
+ app_state = fastapi_request.app.state
264
+ dataset: AgoraPromptDataset = app_state.agora_dataset
265
+ system_prompt: str = app_state.system_prompt
266
+ system_prompt_version: str = getattr(app_state, "system_prompt_version", REWARD_SYSTEM_PROMPT_VERSION)
267
+ user_prompt_version: str = getattr(app_state, "user_prompt_version", REWARD_USER_PROMPT_VERSION)
268
+
269
+ env_cfg = getattr(request.env, "config", {}) if hasattr(request, "env") else {}
270
+ env_name = getattr(request.env, "env_name", APP_ID) if hasattr(request, "env") else APP_ID
271
+ env_index = _coerce_int(env_cfg.get("index"), _coerce_int(getattr(request.env, "seed", 0), 0))
272
+ resolved_index, user_prompt = dataset.resolve(env_index)
273
+
274
+ policy_config = getattr(request.policy, "config", {}) if request.policy else {}
275
+ policy_config = policy_config or {}
276
+ policy_model = policy_config.get("model") if isinstance(policy_config, dict) else None
277
+ policy_name = None
278
+ if request.policy:
279
+ policy_name = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
280
+ policy_id = policy_name or "policy"
281
+
282
+ inference_url = _resolve_inference_url(policy_config if isinstance(policy_config, dict) else {})
283
+ if not inference_url:
284
+ raise HTTPException(
285
+ status_code=502,
286
+ detail="No inference_url provided in policy config",
287
+ )
288
+ chat_url = _normalize_chat_url(inference_url)
289
+ temperature = _coerce_float(policy_config.get("temperature"), DEFAULT_TEMPERATURE)
290
+ max_tokens = _coerce_int(policy_config.get("max_tokens"), DEFAULT_MAX_TOKENS)
291
+
292
+ messages = [
293
+ {"role": "system", "content": system_prompt},
294
+ {"role": "user", "content": user_prompt},
295
+ ]
296
+
297
+ completion, inference_info = _invoke_inference(
298
+ chat_url=chat_url,
299
+ messages=messages,
300
+ model=policy_model if isinstance(policy_model, str) else None,
301
+ temperature=temperature,
302
+ max_tokens=max_tokens,
303
+ )
304
+
305
+ reward_score = 0.0
306
+ reward_error: Optional[str] = None
307
+ reward_metadata = {
308
+ "prompt_index": resolved_index,
309
+ "raw_seed": env_index,
310
+ "policy_id": policy_id,
311
+ "policy_model": policy_model,
312
+ "system_prompt_version": system_prompt_version,
313
+ "user_prompt_version": user_prompt_version,
314
+ "inference_status": inference_info.get("status_code"),
315
+ "inference_error": inference_info.get("error_text"),
316
+ }
317
+ reward_kwargs = {
318
+ "run_type": REWARD_RUN_TYPE,
319
+ "run_version": REWARD_RUN_VERSION,
320
+ "experiment": REWARD_EXPERIMENT,
321
+ "model": request.model or policy_model or DEFAULT_MODEL,
322
+ "user_prompt": user_prompt_version,
323
+ "system_prompt": system_prompt_version,
324
+ "metadata": reward_metadata,
325
+ }
326
+ if completion:
327
+ try:
328
+ reward_score = float(REWARD_FN(completion, **reward_kwargs))
329
+ except Exception as exc: # pragma: no cover - reward service failure
330
+ reward_error = str(exc)
331
+ LOGGER.exception("Reward evaluation failed: %s", exc)
332
+ reward_score = 0.0
333
+ else:
334
+ reward_error = "empty_completion"
335
+
336
+ obs_payload = {
337
+ "prompt": user_prompt,
338
+ "prompt_index": resolved_index,
339
+ "system_prompt_version": system_prompt_version,
340
+ "user_prompt_version": user_prompt_version,
341
+ "env_name": env_name,
342
+ }
343
+
344
+ info_payload: Dict[str, Any] = {
345
+ "completion_preview": completion[:400] if completion else "",
346
+ "inference": inference_info,
347
+ "reward_score": reward_score,
348
+ }
349
+ if reward_error:
350
+ info_payload["reward_error"] = reward_error
351
+
352
+ step = RolloutStep(
353
+ obs=obs_payload,
354
+ tool_calls=[],
355
+ reward=reward_score,
356
+ done=True,
357
+ truncated=False,
358
+ info=info_payload,
359
+ )
360
+
361
+ final_info = {
362
+ "score": reward_score,
363
+ "reward_error": reward_error,
364
+ "prompt_index": resolved_index,
365
+ "policy_id": policy_id,
366
+ "model": request.model or policy_model or DEFAULT_MODEL,
367
+ }
368
+ final_observation = {
369
+ "completion": completion,
370
+ "prompt": user_prompt,
371
+ "system_prompt_version": system_prompt_version,
372
+ "prompt_index": resolved_index,
373
+ "env_name": env_name,
374
+ }
375
+
376
+ metrics = RolloutMetrics(
377
+ episode_returns=[reward_score],
378
+ mean_return=reward_score,
379
+ num_steps=1,
380
+ num_episodes=1,
381
+ outcome_score=reward_score,
382
+ events_score=None,
383
+ details={
384
+ "prompt_index": resolved_index,
385
+ "policy_id": policy_id,
386
+ "inference_status": inference_info.get("status_code"),
387
+ "env_name": env_name,
388
+ },
389
+ )
390
+
391
+ trajectory = RolloutTrajectory(
392
+ env_id=str(env_name),
393
+ policy_id=str(policy_id),
394
+ steps=[step],
395
+ final={
396
+ "observation": final_observation,
397
+ "reward": reward_score,
398
+ "done": True,
399
+ "truncated": False,
400
+ "info": final_info,
401
+ },
402
+ length=1,
403
+ )
404
+
405
+ trace_payload = {
406
+ "messages": messages,
407
+ "completion": completion,
408
+ "reward_score": reward_score,
409
+ "prompt_index": resolved_index,
410
+ "policy_id": policy_id,
411
+ "inference": inference_info,
412
+ "env_name": env_name,
413
+ }
414
+
415
+ return RolloutResponse(
416
+ run_id=str(getattr(request, "run_id", "run")),
417
+ trajectories=[trajectory],
418
+ metrics=metrics,
419
+ branches={},
420
+ aborted=False,
421
+ ops_executed=0,
422
+ trace=trace_payload,
423
+ )
424
+
425
+
426
+ def build_config() -> TaskAppConfig:
427
+ module_dir = Path(__file__).resolve().parent
428
+ dataset = AgoraPromptDataset(module_dir / PROMPTS_FILENAME)
429
+ system_prompt = _read_system_prompt(module_dir / SYSTEM_PROMPT_FILENAME)
430
+
431
+ base_info = _base_task_info()
432
+ app_state: Dict[str, Any] = {
433
+ "agora_dataset": dataset,
434
+ "system_prompt": system_prompt,
435
+ "system_prompt_version": REWARD_SYSTEM_PROMPT_VERSION,
436
+ "user_prompt_version": REWARD_USER_PROMPT_VERSION,
437
+ }
438
+
439
+ return TaskAppConfig(
440
+ app_id=APP_ID,
441
+ name=APP_NAME,
442
+ description=APP_DESCRIPTION,
443
+ base_task_info=base_info,
444
+ describe_taskset=lambda: describe_taskset(dataset),
445
+ provide_task_instances=lambda seeds: list(provide_task_instances(dataset, seeds)),
446
+ rollout=rollout_executor,
447
+ dataset_registry=None,
448
+ rubrics=None,
449
+ proxy=None,
450
+ routers=(),
451
+ app_state=app_state,
452
+ cors_origins=["*"],
453
+ )
454
+
455
+
456
+ def fastapi_app() -> FastAPI:
457
+ app = create_task_app(build_config())
458
+
459
+ filtered_routes = []
460
+ for route in app.router.routes:
461
+ path = getattr(route, "path", None)
462
+ methods = getattr(route, "methods", set()) or set()
463
+ if path in {"/health", "/health/rollout"} and "GET" in methods:
464
+ continue
465
+ filtered_routes.append(route)
466
+ app.router.routes = filtered_routes
467
+
468
+ def _log_env_key_prefix(source: str, env_key: Optional[str]) -> Optional[str]:
469
+ if not env_key:
470
+ return None
471
+ prefix = env_key[: max(1, len(env_key) // 2)]
472
+ LOGGER.info("[%s] expected ENVIRONMENT_API_KEY prefix: %s", source, prefix)
473
+ return prefix
474
+
475
+ @app.get("/health")
476
+ async def health(request: Request) -> JSONResponse | Dict[str, Any]:
477
+ env_key = normalize_environment_api_key()
478
+ if not env_key:
479
+ return JSONResponse(
480
+ status_code=503,
481
+ content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
482
+ )
483
+ if not is_api_key_header_authorized(request):
484
+ prefix = _log_env_key_prefix("health", env_key)
485
+ content: Dict[str, Any] = {"status": "healthy", "authorized": False}
486
+ if prefix:
487
+ content["expected_api_key_prefix"] = prefix
488
+ return JSONResponse(status_code=200, content=content)
489
+ return {"status": "healthy", "authorized": True}
490
+
491
+ @app.get("/health/rollout")
492
+ async def health_rollout(request: Request) -> JSONResponse | Dict[str, Any]:
493
+ env_key = normalize_environment_api_key()
494
+ if not env_key:
495
+ return JSONResponse(
496
+ status_code=503,
497
+ content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
498
+ )
499
+ if not is_api_key_header_authorized(request):
500
+ prefix = _log_env_key_prefix("health/rollout", env_key)
501
+ content: Dict[str, Any] = {"status": "healthy", "authorized": False}
502
+ if prefix:
503
+ content["expected_api_key_prefix"] = prefix
504
+ return JSONResponse(status_code=200, content=content)
505
+ return {"ok": True, "authorized": True}
506
+
507
+ @app.exception_handler(RequestValidationError)
508
+ async def _on_validation_error(request: Request, exc: RequestValidationError) -> JSONResponse:
509
+ snapshot = {
510
+ "path": str(getattr(request, "url").path),
511
+ "have_x_api_key": bool(request.headers.get("x-api-key")),
512
+ "have_authorization": bool(request.headers.get("authorization")),
513
+ "errors": exc.errors()[:5],
514
+ }
515
+ LOGGER.warning("[422] validation error %s", snapshot)
516
+ return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
517
+
518
+ return app
519
+
520
+
521
+ def main() -> None:
522
+ parser = argparse.ArgumentParser(description="Run the Agora EX task app locally")
523
+ parser.add_argument("--host", default="0.0.0.0")
524
+ parser.add_argument("--port", type=int, default=8101)
525
+ parser.add_argument("--reload", action="store_true")
526
+ parser.add_argument("--env-file", action="append", default=[])
527
+ args = parser.parse_args()
528
+
529
+ module_dir = Path(__file__).resolve().parent
530
+ default_env = module_dir / ".env"
531
+ env_files = [str(default_env)] if default_env.exists() else []
532
+ env_files.extend(args.env_file or [])
533
+
534
+ run_task_app(
535
+ build_config,
536
+ host=args.host,
537
+ port=args.port,
538
+ reload=args.reload,
539
+ env_files=env_files,
540
+ )
541
+
542
+
543
+ if register_task_app and TaskAppEntry:
544
+ try:
545
+ # Import ModalDeploymentConfig
546
+ from synth_ai.task.apps import ModalDeploymentConfig
547
+
548
+ # Resolve repo root for Modal mounts
549
+ _HERE = Path(__file__).resolve().parent
550
+ _REPO_ROOT = _HERE.parent.parent # examples/agora_ex -> synth-ai
551
+
552
+ register_task_app(
553
+ entry=TaskAppEntry(
554
+ app_id="agora-ex", # Use string literal for AST discovery
555
+ description=APP_DESCRIPTION,
556
+ config_factory=build_config,
557
+ aliases=("agora-ex", "agora-ex-landing-page", APP_ID),
558
+ modal=ModalDeploymentConfig(
559
+ app_name="agora-ex-task-app",
560
+ python_version="3.11",
561
+ pip_packages=(
562
+ "fastapi>=0.100.0",
563
+ "uvicorn>=0.23.0",
564
+ "pydantic>=2.0.0",
565
+ "httpx>=0.24.0",
566
+ "python-dotenv>=1.0.1",
567
+ # Tracing/DB runtime deps
568
+ "sqlalchemy>=2.0.42",
569
+ "aiosqlite>=0.21.0",
570
+ "greenlet>=3.2.3",
571
+ ),
572
+ extra_local_dirs=(
573
+ # Mount repo root so local modules resolve when deployed on Modal
574
+ (str(_REPO_ROOT), "/opt/synth_ai_repo"),
575
+ (str(_REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
576
+ (str(_HERE), "/opt/synth_ai_repo/examples/agora_ex"),
577
+ ),
578
+ secret_names=("groq-api-key", "openai-api-key"),
579
+ memory=8192, # 8GB memory
580
+ cpu=2.0, # 2 CPUs
581
+ max_containers=10,
582
+ ),
583
+ )
584
+ )
585
+ except Exception as exc: # pragma: no cover - registry optional
586
+ LOGGER.warning("Failed to register Agora EX task app: %s", exc)
587
+
588
+
589
+ if __name__ == "__main__":
590
+ main()
@@ -0,0 +1,121 @@
1
+ # Agora EX RL Training - Qwen3 MoE (30B-A3B) on 2xH200
2
+ # Small MoE model: 30B total parameters, 3B activated per token
3
+ # Task app provides rewards via human judge (no external judge service needed)
4
+
5
+ [algorithm]
6
+ type = "online"
7
+ method = "policy_gradient"
8
+ variety = "gspo"
9
+
10
+ [services]
11
+ # Task app URL - set via environment or CLI flag
12
+ # The task app includes the reward function, so no separate judge service needed
13
+ task_url = "http://localhost:8101" # Local task app or Modal deployment
14
+
15
+ [model]
16
+ base = "Qwen/Qwen3-30B-A3B"
17
+ trainer_mode = "lora"
18
+ label = "agora-ex-qwen3-moe-rl"
19
+
20
+ [lora]
21
+ r = 16
22
+ alpha = 32
23
+ dropout = 0.05
24
+ target_modules = ["all-linear"] # MoE benefits from wider LoRA coverage
25
+
26
+ [policy]
27
+ # inference_url is auto-configured by the RL orchestrator
28
+ max_tokens = 3072
29
+ temperature = 0.15
30
+ # system_hint is included in the task app's system prompt
31
+
32
+ [data]
33
+ split = "train"
34
+ dataset_id = "agora_ex_prompts_v1"
35
+ seed_start = 0
36
+ episodes_per_iteration = 32 # 16 episodes × 2 batches
37
+ evaluation_split = "train"
38
+ evaluation_episodes = 16
39
+
40
+ [training]
41
+ num_epochs = 3
42
+ gradient_accumulation_steps = 8
43
+ max_accumulated_minibatch = 16
44
+ max_turns = 1
45
+ batch_size = 2
46
+ group_size = 4 # GSPO group size for advantage computation
47
+ learning_rate = 3e-5
48
+ log_interval = 1
49
+ weight_sync_interval = 1
50
+ iterations_per_epoch = 4
51
+ weight_sync_verify_checksums = false
52
+ warmup_steps = 10
53
+
54
+ [training.weight_sync]
55
+ enable = true
56
+ targets = ["policy"]
57
+ mode = "full" # Full weight sync for LoRA
58
+ direct = true
59
+ verify_every_k = 0
60
+ chunk_bytes = 0
61
+
62
+ [compute]
63
+ gpu_type = "H200"
64
+ gpu_count = 2
65
+
66
+ [topology]
67
+ type = "single_node_split"
68
+ gpus_for_vllm = 1 # Inference on GPU 0
69
+ gpus_for_training = 1 # Training on GPU 1
70
+ gpus_for_ref = 0 # No reference model
71
+ tensor_parallel = 1
72
+
73
+ [vllm]
74
+ tensor_parallel_size = 1
75
+ max_model_len = 4096
76
+
77
+ [reference]
78
+ placement = "none"
79
+
80
+ [rollout]
81
+ env_name = "agora-ex-landing-page"
82
+ policy_name = "agora-ex-moe-policy"
83
+ max_turns = 1
84
+ episodes_per_batch = 16 # 16 episodes per batch
85
+ max_concurrent_rollouts = 4 # Conservative: human judge takes 5-30 min
86
+ batches_per_step = 2 # 32 episodes per training step
87
+ ops = ["agent", "env"]
88
+
89
+ [rollout.env_config]
90
+ # No special env config needed for single-turn generation
91
+
92
+ [rollout.policy_config]
93
+ temperature = 0.15
94
+ max_tokens = 3072
95
+
96
+ [evaluation]
97
+ instances = 16
98
+ every_n_iters = 2 # More frequent due to slow judge
99
+ seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
100
+
101
+ [tags]
102
+ experiment = "agora_ex_qwen3_moe_lora_v1"
103
+ owner = "agora"
104
+ model_type = "moe"
105
+
106
+ [checkpoint]
107
+ interval = 25
108
+ directory = "/checkpoints"
109
+ keep_last_n = 3
110
+ save_optimizer = true
111
+ save_scheduler = true
112
+ enabled = true
113
+
114
+ [telemetry]
115
+ supabase = true
116
+ device_snapshots = false
117
+ perf_metrics = true
118
+ weight_sync = false
119
+ train_step_interval = 3
120
+ clickhouse_batch_mode = "rollup"
121
+