synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (48) hide show
  1. examples/agora_ex/README_MoE.md +224 -0
  2. examples/agora_ex/__init__.py +7 -0
  3. examples/agora_ex/agora_ex.py +65 -0
  4. examples/agora_ex/agora_ex_task_app.py +590 -0
  5. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
  6. examples/agora_ex/reward_fn_grpo-human.py +129 -0
  7. examples/agora_ex/system_prompt_CURRENT.md +63 -0
  8. examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
  9. examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
  10. examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
  11. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  12. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
  13. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  14. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  15. examples/multi_step/crafter_rl_lora.md +51 -10
  16. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  17. examples/multi_step/task_app_config_notes.md +7 -1
  18. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
  19. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
  20. examples/warming_up_to_rl/run_eval.py +127 -18
  21. examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
  22. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  23. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
  24. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +232 -193
  25. synth_ai/__init__.py +41 -1
  26. synth_ai/api/train/builders.py +49 -19
  27. synth_ai/api/train/configs/__init__.py +44 -0
  28. synth_ai/api/train/configs/rl.py +133 -0
  29. synth_ai/api/train/configs/sft.py +94 -0
  30. synth_ai/api/train/configs/shared.py +24 -0
  31. synth_ai/cli/demo.py +38 -39
  32. synth_ai/cli/rl_demo.py +81 -102
  33. synth_ai/cli/task_apps.py +3 -0
  34. synth_ai/demos/core/cli.py +121 -159
  35. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  36. synth_ai/evals/__init__.py +15 -0
  37. synth_ai/evals/client.py +85 -0
  38. synth_ai/evals/types.py +42 -0
  39. synth_ai/judge_schemas.py +127 -0
  40. synth_ai/rubrics/__init__.py +22 -0
  41. synth_ai/rubrics/validators.py +126 -0
  42. synth_ai/tracing_v3/serialization.py +130 -0
  43. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +1 -1
  44. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +48 -22
  45. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
  46. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
  47. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
  48. {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # SSE metrics/event streaming design (RL + FT)
2
+
3
+ ## Goals
4
+ - Near real-time push of job status, metrics, and logs during RL training, evaluation, and fine-tuning (FT)
5
+ - Single streaming endpoint per job, resumable (Last-Event-ID), low overhead, widely compatible (HTTP/1.1)
6
+ - Minimal client friction (CLI + Python helper), production-ready (auth, backpressure, rate limit)
7
+
8
+ ## Non-goals
9
+ - Binary/frame multiplexing (use WebSocket if needed later)
10
+ - Arbitrary high-frequency payloads (we will coalesce/limit ~2–4 Hz for metrics)
11
+
12
+ ---
13
+
14
+ ## Endpoint
15
+ - Method: GET `/rl/jobs/{job_id}/stream`
16
+ - Headers:
17
+ - Auth: `X-API-Key: <env key>` or `Authorization: Bearer <token>`
18
+ - Cache: `Cache-Control: no-cache`
19
+ - Response:
20
+ - Content-Type: `text/event-stream; charset=utf-8`
21
+ - Transfer-Encoding: `chunked`
22
+ - Connection: `keep-alive`
23
+ - Query params (optional):
24
+ - `since_id`: int; resume from a specific event id (inclusive)
25
+ - `types`: comma list `metric,status,log,artifact` (default: all)
26
+ - `heartbeat`: seconds between heartbeats (default 20)
27
+ - `split`: `train|eval` (filters metrics only)
28
+
29
+ ## Event framing (SSE)
30
+ - Fields per message:
31
+ - `id: <int>` monotonically increasing per job
32
+ - `event: <status|metric|log|artifact|heartbeat>`
33
+ - `data: <JSON>` single-line JSON (compact)
34
+ - Heartbeats: comment lines `: keep-alive` at configured interval
35
+ - Flush: after each event write + heartbeat
36
+ - Backpressure: if producer > consumer, coalesce metrics, keep status/logs, never buffer unbounded
37
+
38
+ ## Payload schemas
39
+ - `status`
40
+ - `{ "state": "queued|running|succeeded|failed|canceled", "step": 123, "epoch": 3, "phase": "train|eval|ft", "message": "...", "ts": 173.12 }`
41
+ - `metric`
42
+ - `{ "name": "avg_reward|loss|accuracy|success_rate|return", "value": 0.123, "step": 123, "epoch": 3, "split": "train|eval", "window": 100, "mean": 0.42, "std": 0.08, "ts": 173.12 }`
43
+ - Optional extras: `{ "tags": {"env": "crafter", "policy": "react"} }`
44
+ - `log`
45
+ - `{ "level": "INFO|WARN|ERROR", "message": "...", "ts": 173.12 }`
46
+ - `artifact`
47
+ - `{ "kind": "checkpoint|trace|plot|jsonl", "url": "/rl/jobs/{id}/artifacts/ckpt_0003.pt", "step": 123, "ts": 173.12 }`
48
+ - `heartbeat`
49
+ - `{ "alive": true, "ts": 173.12 }`
50
+
51
+ ### Example stream (illustrative)
52
+ ```
53
+ id: 101
54
+ event: status
55
+ data: {"state":"running","phase":"train","step":820,"epoch":4,"ts":173.12}
56
+
57
+ id: 102
58
+ event: metric
59
+ data: {"name":"avg_reward","value":0.62,"step":820,"epoch":4,"split":"train","ts":173.13}
60
+
61
+ id: 103
62
+ event: metric
63
+ data: {"name":"loss","value":1.84,"step":820,"epoch":4,"split":"train","window":100,"mean":1.90,"std":0.15,"ts":173.13}
64
+
65
+ id: 104
66
+ event: log
67
+ data: {"level":"INFO","message":"checkpoint saved","ts":173.16}
68
+
69
+ id: 105
70
+ event: artifact
71
+ data: {"kind":"checkpoint","url":"/rl/jobs/j_abc/artifacts/ckpt_0004.pt","step":820,"ts":173.16}
72
+
73
+ : keep-alive
74
+ ```
75
+
76
+ ---
77
+
78
+ ## Server architecture
79
+
80
+ ### Components
81
+ - Event bus per `job_id` (async queue) where producers (RL, FT, evaluator) `emit(Event)`
82
+ - Ring buffer per job for replay (configurable: last N events OR last T minutes)
83
+ - SSE handler:
84
+ 1) Authenticate, pick job, determine resume cursor (`since_id` or `Last-Event-ID`)
85
+ 2) Replay from ring buffer >= cursor
86
+ 3) Attach to live queue; stream new events
87
+ 4) Emit heartbeats; close after terminal `status` + grace
88
+
89
+ ### Concurrency & ordering
90
+ - Single writer increments `event_id`
91
+ - Replay preserves original order; live continues from last id
92
+ - If consumer slow: drop/coalesce metrics (preserve last per metric name), always deliver status/log/artifact
93
+
94
+ ### Rate limiting & coalescing
95
+ - Default target 2–4 Hz for metrics per split
96
+ - Coalesce by metric name within a small interval (e.g., 250–500 ms)
97
+ - Status events limited to phase changes or every 5–10s
98
+
99
+ ### Auth & security
100
+ - Accept `X-API-Key` or `Authorization: Bearer`
101
+ - Validate job ownership/visibility
102
+ - CORS: allow EventSource; set `Access-Control-Allow-Origin` appropriately
103
+ - Timeouts: server idle timeout > heartbeat * 2; client reconnect on drop
104
+
105
+ ### Config knobs (env)
106
+ - `SSE_HEARTBEAT_SECS` (default 20)
107
+ - `SSE_RING_BUFFER_EVENTS` (e.g., 2000) OR `SSE_RING_BUFFER_WINDOW_SECS` (e.g., 600)
108
+ - `SSE_MAX_METRIC_HZ` (e.g., 4)
109
+ - `SSE_MAX_CLIENTS_PER_JOB` (protect from fan-out)
110
+
111
+ ---
112
+
113
+ ## Emit points
114
+
115
+ ### RL training/eval
116
+ - On train step end: `metric` avg_reward/return/success_rate; `status` every N steps
117
+ - On eval step end: `metric` eval_return/success_rate; `artifact` eval JSONL optional
118
+ - On checkpoint: `artifact` + `log`
119
+ - On phase transitions: `status` (train→eval, etc.)
120
+
121
+ ### Fine-tuning (FT)
122
+ - On optimizer step: `metric` loss (and optional lr)
123
+ - On validation: `metric` val_loss/accuracy; optional `artifact` (curves)
124
+ - On checkpoint: `artifact` + `log`
125
+
126
+ ---
127
+
128
+ ## Client (synth-ai)
129
+
130
+ ### CLI
131
+ - `synth-ai jobs stream <job_id> [--jsonl out.jsonl] [--types metric,status]`
132
+ - Prints compact lines: `t=18:22:40 step=820 avg_reward=0.62 loss=1.84`
133
+ - Writes raw events to JSONL if specified
134
+
135
+ ### Python helper
136
+ ```python
137
+ from synth_ai.client import TaskAppClient
138
+
139
+ with TaskAppClient(base_url, api_key) as c:
140
+ for ev in c.stream_job(job_id, types=["metric","status"], since_id=None):
141
+ handle(ev)
142
+ ```
143
+ - Handles reconnect with `Last-Event-ID`
144
+ - Dedupes by event id; optional local aggregation windows
145
+
146
+ ---
147
+
148
+ ## Error handling & edge cases
149
+ - 404 (job not found): immediate error, no stream
150
+ - 401/403: immediate error, no stream
151
+ - Producer stalls: heartbeats continue; client shows "connected: no new data"
152
+ - Ring buffer miss (client too far behind): send a summary `status` snapshot, then live only
153
+ - Buffer overflow: set `dropped_events: true` field in a `status` update
154
+
155
+ ---
156
+
157
+ ## Testing plan
158
+ - Unit: SSE formatting, ring buffer indexing, resume semantics, coalescing
159
+ - Integration: run toy RL + FT loops; assert ordered sequence and types
160
+ - Load: simulate 100 Hz metric emissions; verify rate limiting and non-blocking behavior
161
+ - Auth: invalid key, missing key, cross-job access rejection
162
+
163
+ ---
164
+
165
+ ## Rollout plan
166
+ 1) Implement event bus + ring buffer per job
167
+ 2) Add SSE route with auth, resume, heartbeat
168
+ 3) Emiters in RL/FT loops (train/eval steps, loss/reward, checkpoints)
169
+ 4) synth-ai client: CLI + Python helper; optional JSONL sink
170
+ 5) Docs and example snippets (curl, Python, browser EventSource)
171
+
172
+ ---
173
+
174
+ ## Open questions / decisions
175
+ - Do we persist events beyond ring buffer? (optional artifact JSONL)
176
+ - Multi-tenant limits per org/user? (SSE_MAX_CLIENTS_PER_ORG?)
177
+ - Artifact URLs: signed vs. authenticated direct
178
+ - Backward compatibility with current polling endpoints
179
+ I'll scan the backend to find where RL jobs are created, status is served, and whether there's any existing streaming/event infrastructure to extend. Then I'll identify exact files/functions to change for adding an SSE endpoint and metric emitters.
180
+
181
+ [2 tools called]
182
+
183
+ Here’s the exact backend surface you’ll need to touch to add a single SSE stream for status + metrics, with where to hook metric emitters.
184
+
185
+ ### Where jobs and events live
186
+ - Job and event tables (legacy RL path):
187
+ ```1:20:monorepo/backend/app/db/models/jobs.py
188
+ """Job orchestration models (RL/SFT shared).
189
+ """
190
+ ```
191
+ ```31:58:monorepo/backend/app/db/models/jobs.py
192
+ class Job(Base):
193
+ __tablename__ = "jobs"
194
+ job_id = Column(String, primary_key=True)
195
+ ...
196
+ status = Column(String, nullable=False, default="QUEUED")
197
+ ...
198
+ last_seq = Column(BigInteger, nullable=True, default=0)
199
+ events = relationship("JobEvent", back_populates="job", cascade="all, delete-orphan", passive_deletes=True)
200
+ ```
201
+ ```66:88:monorepo/backend/app/db/models/jobs.py
202
+ class JobEvent(Base):
203
+ """Append-only job event log."""
204
+ __tablename__ = "job_events"
205
+ job_id = Column(String, ForeignKey("jobs.job_id", ondelete="CASCADE"), primary_key=True)
206
+ seq = Column(BigInteger, primary_key=True)
207
+ ts = Column(DateTime(timezone=True), server_default=func.now())
208
+ type = Column(String, nullable=False)
209
+ level = Column(String, nullable=False, default="info")
210
+ message = Column(Text, nullable=False)
211
+ data = Column(JSONB, nullable=True)
212
+ ```
213
+
214
+ - DB repo (atomic seq increment + append):
215
+ ```108:146:monorepo/backend/app/orchestration/jobs/repository_db.py
216
+ async def append_event(...):
217
+ res = await self.session.execute(
218
+ update(LearningJob)
219
+ .where(LearningJob.job_id == job_id)
220
+ .values(last_seq=func.coalesce(LearningJob.last_seq, 0) + 1, updated_at=func.now())
221
+ .returning(LearningJob.last_seq)
222
+ )
223
+ seq = res.scalar_one()
224
+ ev = JobEvent(job_id=job_id, seq=seq, type=type_, level=level, message=message, data=(data or {}))
225
+ ...
226
+ ```
227
+
228
+ - RL jobs REST (create + emit first events via PostgREST emitter):
229
+ ```214:233:monorepo/backend/app/routes/clustered_training/core/routes.py
230
+ @router.post("/jobs", ...)
231
+ async def create_job(...):
232
+ ...
233
+ ```
234
+ ```784:799:monorepo/backend/app/routes/clustered_training/core/routes.py
235
+ await get_postgrest_emitter().append_event(job_id=job_id, type_="rl.job.created", message="RL job created", data={"work": payload.work})
236
+ ```
237
+
238
+ - Shared storage path showing how “update” and “append event” are broadcast via the PostgREST emitter:
239
+ ```242:271:monorepo/backend/app/routes/simple_training/services/job_service.py
240
+ async def update_job(...): ...
241
+ async def append_job_event(...): ...
242
+ ```
243
+
244
+ What to add/change (SSE + emitters)
245
+
246
+ 1) Add SSE endpoint under RL jobs
247
+ - File: `monorepo/backend/app/routes/clustered_training/core/routes.py`
248
+ - New route: `GET /api/rl/jobs/{job_id}/stream`
249
+ - Auth: `ValidatedAPIKey`
250
+ - Headers: honor `Last-Event-ID`; Query: `since_id`, `types`, `heartbeat`
251
+ - Flow:
252
+ - Validate job exists (use the DB-backed repo in production mode)
253
+ - Determine start seq (from Last-Event-ID or since_id, else 0)
254
+ - Replay: call repo `list_events(job_id, after=seq)` and stream as SSE (`id`, `event`, `data`)
255
+ - Live tail: loop with short sleep (e.g., 0.5–1.0s) fetching new events by `after=last_seq`
256
+ - Heartbeats as comments `: keep-alive` every N seconds
257
+ - Event mapping:
258
+ - `JobEvent.type` prefixes map to SSE `event`:
259
+ - `rl.job.*`, `job.updated` → `status`
260
+ - `rl.step.metric`, `ft.step.metric`, `eval.metric` → `metric`
261
+ - `system.log.*` → `log`
262
+ - `artifact.*` → `artifact`
263
+ - `Job.status` can be snapshotted once at connect (send a `status`)
264
+
265
+ 2) Ensure a consistent event source for reads
266
+ - Prefer the DB repo (`JobsRepositoryDB`) in prod mode. If the current code path uses the PostgREST emitter for appends, verify that the repo’s `list_events` reads from the same canonical table (it does for `JobEvent`). If your RL path uses the “learning_shared” models instead, use the associated repository there (same pattern: list by job_id + seq).
267
+ - If you must keep PostgREST for append-only, that’s fine; SSE can still read the DB rows inserted alongside (your outbox/emitter already supports both).
268
+
269
+ 3) Emitters in training/FT loops
270
+ - File(s): `monorepo/backend/app/orchestration/hatchet/workflows.py` (RL workflow nodes), any FT job loops
271
+ - After each meaningful step:
272
+ - Train: append `type="rl.step.metric"`, `data={"avg_reward":..., "return":..., "success_rate":..., "step":..., "epoch":..., "split":"train"}`.
273
+ - Eval: `type="eval.metric"` with eval metrics and split.
274
+ - FT: `type="ft.step.metric"`, `data={"loss":..., "lr":..., "step":..., "epoch":..., "split":"train"}`; validation as `split="eval"`.
275
+ - On phase changes/checkpoints: `type="job.updated"` or `artifact.checkpoint` with URLs.
276
+ - Use the same helper used elsewhere:
277
+ ```236:276:monorepo/backend/app/routes/simple_training/services/job_service.py
278
+ async def append_job_event(...): return await get_postgrest_emitter().append_event(...)
279
+ ```
280
+
281
+ 4) Optional shared service abstraction
282
+ - File: `monorepo/backend/app/routes/simple_training/services/storage_shared.py`
283
+ - Add a small `stream_job_events(job_id, after)` helper that wraps `repo.list_events(...)` and normalizes schemas (legacy vs learning_shared). The SSE route can call this.
284
+
285
+ 5) Wire the router
286
+ - File: `monorepo/backend/app/routes/main.py`
287
+ - Include the new GET route (under the RL router you already mount).
288
+ - CORS: ensure EventSource allowed if frontend will use browser SSE.
289
+
290
+ Indexes/Perf you already have
291
+ - Jobs/events tables include per-job indexes and a `last_seq`; reads by `(job_id, seq)` are efficient:
292
+ ```84:88:monorepo/backend/app/db/models/jobs.py
293
+ Index("idx_job_events_job_ts", "job_id", "ts"),
294
+ Index("idx_job_events_data_gin", "data", postgresql_using="gin"),
295
+ ```
296
+ - For learning_shared events (if used), there are sequence indexes too:
297
+ ```159:165:monorepo/backend/app/db/models/learning_shared.py
298
+ sa.Index("idx_learning_job_events_job_seq_idx", "job_id", "seq"),
299
+ ```
300
+
301
+ Summary of minimal backend edits
302
+ - Add SSE route:
303
+ - `monorepo/backend/app/routes/clustered_training/core/routes.py` (GET `/api/rl/jobs/{job_id}/stream`)
304
+ - Consume events via repo:
305
+ - `monorepo/backend/app/orchestration/jobs/repository_db.py` (use `list_events`)
306
+ - Emit metrics from loops:
307
+ - `monorepo/backend/app/orchestration/hatchet/workflows.py` (append_event at train/eval steps)
308
+ - Optional shared helper:
309
+ - `monorepo/backend/app/routes/simple_training/services/storage_shared.py` (normalize event reads)
310
+ - Wireup:
311
+ - `monorepo/backend/app/routes/main.py` (include SSE route)
312
+ - No schema changes required; you already have `JobEvent`/indexes and event append plumbing.
313
+
314
+ ### SDK additions for great terminal polling (with SSE fallback)
315
+
316
+ - RlJobsApi extensions (synth_ai/jobs/client.py)
317
+ - stream(job_id, since_id=None, types=None, heartbeat=None) -> async iterator of events (uses SSE; falls back to polling)
318
+ - events(job_id, after=None, limit=500) -> list[JobEvent] (poll)
319
+ - status(job_id) -> JobSummary (single snapshot)
320
+
321
+ - Event models (synth_ai/jobs/types.py)
322
+ - JobEvent base: {id, type, level, message, data, ts}
323
+ - StatusEvent, MetricEvent, LogEvent, ArtifactEvent (typed helpers)
324
+
325
+ - JobsWatcher helper (synth_ai/jobs/watcher.py)
326
+ - constructor(client, job_id, interval=2.0, prefer_sse=True, jsonl_path=None, types=None)
327
+ - run(on_event, stop_when=None) → handles SSE connect/reconnect, polling fallback, Last-Event-ID cursor, dedupe
328
+ - metrics_tracker: rolling windows per metric name (mean/std/min/max, last_value, last_step)
329
+ - backoff policy: jittered reconnect; rate limiter for render
330
+
331
+ - Terminal renderer (synth_ai/jobs/render.py)
332
+ - RichRenderer (or minimal TTY): compact line updates: t=HH:MM:SS | step/E | key metrics (avg_reward, loss, val_loss, success_rate)
333
+ - modes: one-line ticker vs. per-event lines; quiet mode; color by level/state
334
+ - JSONL sink: raw event writes without printing prompts/payloads
335
+
336
+ - CLI command (synth_ai/api/train/cli.py)
337
+ - synth-ai jobs watch <job_id> [--types metric,status] [--interval 2] [--jsonl out.jsonl] [--since-id N] [--no-sse]
338
+ - exit codes: 0 on succeeded, 1 on failed/canceled, 2 on timeout
339
+
340
+ - Utilities (synth_ai/jobs/utils.py)
341
+ - BackoffPolicy(retry, max) with jitter
342
+ - EventCursor(last_id, update)
343
+ - MetricsFormatter(map by job_type: RL vs FT metric labels)
344
+ - Coalescer: compress frequent metrics to ≤4 Hz
345
+
346
+ - Defaults/behavior
347
+ - Prefer SSE; if 404/405/close → fallback to polling events() every interval
348
+ - Heartbeat support; show “connected/no data” when only heartbeats
349
+ - Resume: honor --since-id or Last-Event-ID; persist cursor optionally
350
+
351
+ - Minimal backend assumptions
352
+ - GET /api/rl/jobs/{job_id}/stream (SSE) or /api/rl/jobs/{job_id}/events?after=… (poll)
353
+ - Events include metric/status/log/artifact with seq ids and ts
354
+
355
+ - Extensibility
356
+ - Plugin renderers per job_type (rl, sft/ft)
357
+ - Hooks: on_status_change, on_metric(name, value), on_artifact(url)
@@ -485,4 +485,10 @@ payload = {
485
485
  }
486
486
  ```
487
487
 
488
- Status: Notes committed on branch `friday-cleanup` and pushed.
488
+ Status: Notes committed on branch `friday-cleanup` and pushed.
489
+
490
+ ### Operational guardrails
491
+
492
+ - Treat avg_turns == 0 (or a high fraction of episodes with turns == 0) as a failure condition; exit non‑zero.
493
+ - Fail fast when the first policy step returns a 4xx/5xx from the inference target; include the HTTP status and URL in the error message.
494
+ - CI hint: a tiny smoke run (2 seeds × 1 rollout) should see turns > 0 in healthy setups.
@@ -4,9 +4,11 @@
4
4
  # task_app_url = "https://YOUR-TASK-APP.modal.run"
5
5
 
6
6
  model = "qwen/qwen3-32b"
7
- num_episodes = 50
7
+ # Route inference to local task app Groq proxy
8
+ inference_url = "http://localhost:8001/proxy/groq"
9
+ num_episodes = 10
8
10
  max_turns = 10
9
- concurrency = 5
11
+ concurrency = 10
10
12
  # difficulty = "easy" # optional
11
13
 
12
14
  [rollout]
@@ -4,9 +4,11 @@
4
4
  # task_app_url = "https://YOUR-TASK-APP.modal.run"
5
5
 
6
6
  model = "qwen/qwen3-32b"
7
- num_episodes = 50
7
+ # Route inference to local task app Groq proxy
8
+ inference_url = "http://localhost:8001/proxy/groq"
9
+ num_episodes = 10
8
10
  max_turns = 10
9
- concurrency = 5
11
+ concurrency = 10
10
12
  # difficulty = "easy" # optional
11
13
 
12
14
  [rollout]
@@ -14,12 +14,14 @@ import contextlib
14
14
  import json
15
15
  import os
16
16
  import re
17
- import tomllib
18
- from copy import deepcopy
17
+ import sys
19
18
  from collections import Counter
19
+ from copy import deepcopy
20
20
  from pathlib import Path
21
21
  from typing import Any
22
22
 
23
+ import tomllib
24
+
23
25
  import httpx
24
26
 
25
27
 
@@ -332,6 +334,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
332
334
  observation = created.get("observation") if isinstance(created, dict) else None
333
335
  if not isinstance(observation, dict):
334
336
  observation = {}
337
+ try:
338
+ ach_map_initial = observation.get("achievements_status")
339
+ if isinstance(ach_map_initial, dict):
340
+ achievements.update(k for k, v in ach_map_initial.items() if v)
341
+ except Exception:
342
+ pass
335
343
 
336
344
  try:
337
345
  while turns < MAX_TURNS and not done:
@@ -351,6 +359,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
351
359
  nxt = step.get("observation")
352
360
  if isinstance(nxt, dict):
353
361
  observation = nxt
362
+ try:
363
+ ach_map = observation.get("achievements_status")
364
+ if isinstance(ach_map, dict):
365
+ achievements.update(k for k, v in ach_map.items() if v)
366
+ except Exception:
367
+ pass
354
368
  finally:
355
369
  with contextlib.suppress(Exception):
356
370
  await client.terminate(env_name, env_id)
@@ -358,21 +372,45 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
358
372
  return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
359
373
 
360
374
 
361
- async def main() -> None:
362
- # Best-effort load local .env if present (ensures ENVIRONMENT_API_KEY for rollout)
375
+ def _load_dotenv_defaults() -> None:
376
+ """Load .env-style key/value pairs without clobbering explicit exports."""
363
377
  try:
364
- env_path = Path(__file__).resolve().parent / ".env"
365
- if env_path.exists():
366
- for line in env_path.read_text(encoding="utf-8").splitlines():
367
- line = line.strip()
368
- if not line or line.startswith("#") or "=" not in line:
369
- continue
370
- k, v = line.split("=", 1)
371
- k = k.strip()
372
- v = v.strip().strip('"').strip("'")
373
- os.environ.setdefault(k, v)
378
+ script_path = Path(__file__).resolve()
374
379
  except Exception:
375
- pass
380
+ return
381
+ candidates: list[Path] = []
382
+ # Prefer the repo root .env, then allow per-directory overrides.
383
+ for base in [Path.cwd(), script_path.parent, *script_path.parents]:
384
+ env_path = base / ".env"
385
+ if env_path not in candidates and env_path.is_file():
386
+ candidates.append(env_path)
387
+ seen: set[str] = set()
388
+ try:
389
+ for env_path in candidates:
390
+ try:
391
+ for raw in env_path.read_text(encoding="utf-8").splitlines():
392
+ line = raw.strip()
393
+ if not line or line.startswith("#") or "=" not in line:
394
+ continue
395
+ key, value = line.split("=", 1)
396
+ key = key.strip()
397
+ if not key or key in seen:
398
+ continue
399
+ seen.add(key)
400
+ val = value.strip().strip('"').strip("'")
401
+ os.environ.setdefault(key, val)
402
+ except Exception:
403
+ continue
404
+ except Exception:
405
+ return
406
+
407
+
408
+ async def main() -> None:
409
+ _load_dotenv_defaults()
410
+ if not (os.getenv("ENVIRONMENT_API_KEY") or os.getenv("DEV_ENVIRONMENT_API_KEY")):
411
+ raise RuntimeError(
412
+ "ENVIRONMENT_API_KEY is required. Export it or add it to your project .env."
413
+ )
376
414
 
377
415
  parser = argparse.ArgumentParser(
378
416
  description="Baseline eval against task app with optional TOML config"
@@ -497,7 +535,7 @@ async def main() -> None:
497
535
  if isinstance(step_block, dict):
498
536
  stepwise_details = step_block
499
537
  # Extract achievements count if present
500
- ach = []
538
+ achieved: set[str] = set()
501
539
  try:
502
540
  trajs = r.get("trajectories") or []
503
541
  final_obs = (
@@ -511,9 +549,29 @@ async def main() -> None:
511
549
  else None
512
550
  )
513
551
  if isinstance(ach_map, dict):
514
- ach = sorted([k for k, v in ach_map.items() if v])
552
+ achieved.update(k for k, v in ach_map.items() if v)
553
+ except Exception:
554
+ pass
555
+ try:
556
+ step_seen = stepwise_details.get("unique_achievements")
557
+ except Exception:
558
+ step_seen = None
559
+ if isinstance(step_seen, (list, tuple, set)):
560
+ achieved.update(str(a) for a in step_seen)
561
+ else:
562
+ try:
563
+ alt_seen = stepwise_details.get("achievements_seen")
564
+ if isinstance(alt_seen, (list, tuple, set)):
565
+ achieved.update(str(a) for a in alt_seen)
566
+ except Exception:
567
+ pass
568
+ try:
569
+ summary_final = stepwise_details.get("final_achievements")
570
+ if isinstance(summary_final, (list, tuple, set)):
571
+ achieved.update(str(a) for a in summary_final)
515
572
  except Exception:
516
573
  pass
574
+ ach = sorted(achieved)
517
575
  length = 0
518
576
  try:
519
577
  trajs = r.get("trajectories") or []
@@ -556,7 +614,10 @@ async def main() -> None:
556
614
  stepwise_reward_sums: list[float] = []
557
615
  stepwise_indicator_sums: list[float] = []
558
616
  stepwise_new_ach_totals: list[float] = []
617
+ stepwise_resource_rewards: list[float] = []
559
618
  strategies_seen = Counter()
619
+ unique_union: set[str] = set()
620
+ final_union: set[str] = set()
560
621
  for r in results:
561
622
  if not isinstance(r, dict):
562
623
  continue
@@ -577,6 +638,19 @@ async def main() -> None:
577
638
  stepwise_new_ach_totals.append(
578
639
  float(stepwise_block.get("new_achievements_total"))
579
640
  )
641
+ with contextlib.suppress(Exception):
642
+ if stepwise_block.get("resource_reward") is not None:
643
+ stepwise_resource_rewards.append(
644
+ float(stepwise_block.get("resource_reward"))
645
+ )
646
+ with contextlib.suppress(Exception):
647
+ uniq = stepwise_block.get("unique_achievements") or []
648
+ if isinstance(uniq, (list, tuple, set)):
649
+ unique_union.update(str(v) for v in uniq)
650
+ with contextlib.suppress(Exception):
651
+ final = stepwise_block.get("final_achievements") or []
652
+ if isinstance(final, (list, tuple, set)):
653
+ final_union.update(str(v) for v in final)
580
654
  strategy_name = stepwise_block.get("strategy")
581
655
  if isinstance(strategy_name, str) and strategy_name:
582
656
  strategies_seen[strategy_name] += 1
@@ -603,14 +677,49 @@ async def main() -> None:
603
677
  aggregate["avg_stepwise_new_achievements"] = sum(stepwise_new_ach_totals) / len(
604
678
  stepwise_new_ach_totals
605
679
  )
680
+ if stepwise_resource_rewards:
681
+ aggregate["avg_stepwise_resource_reward"] = (
682
+ sum(stepwise_resource_rewards) / len(stepwise_resource_rewards)
683
+ )
606
684
  if strategies_seen:
607
685
  aggregate["stepwise_strategies"] = dict(strategies_seen)
608
- aggregate["stepwise_samples"] = len(stepwise_reward_sums)
686
+ aggregate["stepwise_samples"] = max(
687
+ len(stepwise_reward_sums),
688
+ len(stepwise_indicator_sums),
689
+ len(stepwise_new_ach_totals),
690
+ len(stepwise_resource_rewards),
691
+ ) if any(
692
+ (
693
+ stepwise_reward_sums,
694
+ stepwise_indicator_sums,
695
+ stepwise_new_ach_totals,
696
+ stepwise_resource_rewards,
697
+ )
698
+ ) else 0
699
+ if not unique_union:
700
+ for r in results:
701
+ try:
702
+ for a in r.get("achievements") or []:
703
+ unique_union.add(str(a))
704
+ except Exception:
705
+ continue
706
+ if not final_union:
707
+ final_union.update(unique_union)
708
+ if unique_union:
709
+ aggregate["unique_achievements_union"] = sorted(unique_union)
710
+ if final_union:
711
+ aggregate["final_achievements_union"] = sorted(final_union)
609
712
  summary = {
610
713
  "episodes": results,
611
714
  "aggregate": aggregate,
612
715
  }
613
716
  print(json.dumps(summary, indent=2))
717
+ # Failure guardrails: any error or zero-turn episodes across the board
718
+ any_errors = any(isinstance(r, dict) and r.get("error") for r in results)
719
+ all_zero_turns = all((int(r.get("turns") or 0) == 0) for r in results if isinstance(r, dict))
720
+ if any_errors or all_zero_turns:
721
+ # Exit non-zero so automation/CI treats this as a failure
722
+ sys.exit(2)
614
723
  else:
615
724
 
616
725
  async def _run(seed: int):
@@ -93,6 +93,7 @@ TASK_APP_ROOT = _resolve_task_app_root(REPO_ROOT)
93
93
  SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
94
94
 
95
95
  EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
96
+ RUBRICS_ROOT = (EXAMPLES_ROOT / "multi_step" / "rubrics").resolve()
96
97
 
97
98
  for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
98
99
  try:
@@ -344,40 +345,9 @@ def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
344
345
  )
345
346
 
346
347
 
347
- OUTCOME_RUBRIC = load_rubric(
348
- {
349
- "version": "1",
350
- "goal_text": "Reward unlocking Crafter achievements and survival.",
351
- "aggregation": "weighted_sum",
352
- "criteria": [
353
- {
354
- "id": "achievements",
355
- "description": "Unlock achievements or crafting milestones.",
356
- "weight": 1.0,
357
- },
358
- {
359
- "id": "survival",
360
- "description": "Maintain health, food, and drink levels.",
361
- "weight": 1.0,
362
- },
363
- ],
364
- }
365
- )
348
+ OUTCOME_RUBRIC = load_rubric(str(RUBRICS_ROOT / "crafter_outcome_rubric.json"))
366
349
 
367
- EVENTS_RUBRIC = load_rubric(
368
- {
369
- "version": "1",
370
- "goal_text": "Encourage purposeful step-wise exploration and crafting.",
371
- "aggregation": "weighted_sum",
372
- "criteria": [
373
- {
374
- "id": "progress_steps",
375
- "description": "Actions progress quests, crafting, or exploration.",
376
- "weight": 1.0,
377
- }
378
- ],
379
- }
380
- )
350
+ EVENTS_RUBRIC = load_rubric(str(RUBRICS_ROOT / "crafter_events_rubric.json"))
381
351
 
382
352
 
383
353
  def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]: