synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/agora_ex/README_MoE.md +224 -0
- examples/agora_ex/__init__.py +7 -0
- examples/agora_ex/agora_ex.py +65 -0
- examples/agora_ex/agora_ex_task_app.py +590 -0
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
- examples/agora_ex/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/system_prompt_CURRENT.md +63 -0
- examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
- examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +7 -1
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
- examples/warming_up_to_rl/run_eval.py +127 -18
- examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +232 -193
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +49 -19
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +133 -0
- synth_ai/api/train/configs/sft.py +94 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/cli/demo.py +38 -39
- synth_ai/cli/rl_demo.py +81 -102
- synth_ai/cli/task_apps.py +3 -0
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +85 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/judge_schemas.py +127 -0
- synth_ai/rubrics/__init__.py +22 -0
- synth_ai/rubrics/validators.py +126 -0
- synth_ai/tracing_v3/serialization.py +130 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +1 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +48 -22
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# SSE metrics/event streaming design (RL + FT)
|
|
2
|
+
|
|
3
|
+
## Goals
|
|
4
|
+
- Near real-time push of job status, metrics, and logs during RL training, evaluation, and fine-tuning (FT)
|
|
5
|
+
- Single streaming endpoint per job, resumable (Last-Event-ID), low overhead, widely compatible (HTTP/1.1)
|
|
6
|
+
- Minimal client friction (CLI + Python helper), production-ready (auth, backpressure, rate limit)
|
|
7
|
+
|
|
8
|
+
## Non-goals
|
|
9
|
+
- Binary/frame multiplexing (use WebSocket if needed later)
|
|
10
|
+
- Arbitrary high-frequency payloads (we will coalesce/limit ~2–4 Hz for metrics)
|
|
11
|
+
|
|
12
|
+
---
|
|
13
|
+
|
|
14
|
+
## Endpoint
|
|
15
|
+
- Method: GET `/rl/jobs/{job_id}/stream`
|
|
16
|
+
- Headers:
|
|
17
|
+
- Auth: `X-API-Key: <env key>` or `Authorization: Bearer <token>`
|
|
18
|
+
- Cache: `Cache-Control: no-cache`
|
|
19
|
+
- Response:
|
|
20
|
+
- Content-Type: `text/event-stream; charset=utf-8`
|
|
21
|
+
- Transfer-Encoding: `chunked`
|
|
22
|
+
- Connection: `keep-alive`
|
|
23
|
+
- Query params (optional):
|
|
24
|
+
- `since_id`: int; resume from a specific event id (inclusive)
|
|
25
|
+
- `types`: comma list `metric,status,log,artifact` (default: all)
|
|
26
|
+
- `heartbeat`: seconds between heartbeats (default 20)
|
|
27
|
+
- `split`: `train|eval` (filters metrics only)
|
|
28
|
+
|
|
29
|
+
## Event framing (SSE)
|
|
30
|
+
- Fields per message:
|
|
31
|
+
- `id: <int>` monotonically increasing per job
|
|
32
|
+
- `event: <status|metric|log|artifact|heartbeat>`
|
|
33
|
+
- `data: <JSON>` single-line JSON (compact)
|
|
34
|
+
- Heartbeats: comment lines `: keep-alive` at configured interval
|
|
35
|
+
- Flush: after each event write + heartbeat
|
|
36
|
+
- Backpressure: if producer > consumer, coalesce metrics, keep status/logs, never buffer unbounded
|
|
37
|
+
|
|
38
|
+
## Payload schemas
|
|
39
|
+
- `status`
|
|
40
|
+
- `{ "state": "queued|running|succeeded|failed|canceled", "step": 123, "epoch": 3, "phase": "train|eval|ft", "message": "...", "ts": 173.12 }`
|
|
41
|
+
- `metric`
|
|
42
|
+
- `{ "name": "avg_reward|loss|accuracy|success_rate|return", "value": 0.123, "step": 123, "epoch": 3, "split": "train|eval", "window": 100, "mean": 0.42, "std": 0.08, "ts": 173.12 }`
|
|
43
|
+
- Optional extras: `{ "tags": {"env": "crafter", "policy": "react"} }`
|
|
44
|
+
- `log`
|
|
45
|
+
- `{ "level": "INFO|WARN|ERROR", "message": "...", "ts": 173.12 }`
|
|
46
|
+
- `artifact`
|
|
47
|
+
- `{ "kind": "checkpoint|trace|plot|jsonl", "url": "/rl/jobs/{id}/artifacts/ckpt_0003.pt", "step": 123, "ts": 173.12 }`
|
|
48
|
+
- `heartbeat`
|
|
49
|
+
- `{ "alive": true, "ts": 173.12 }`
|
|
50
|
+
|
|
51
|
+
### Example stream (illustrative)
|
|
52
|
+
```
|
|
53
|
+
id: 101
|
|
54
|
+
event: status
|
|
55
|
+
data: {"state":"running","phase":"train","step":820,"epoch":4,"ts":173.12}
|
|
56
|
+
|
|
57
|
+
id: 102
|
|
58
|
+
event: metric
|
|
59
|
+
data: {"name":"avg_reward","value":0.62,"step":820,"epoch":4,"split":"train","ts":173.13}
|
|
60
|
+
|
|
61
|
+
id: 103
|
|
62
|
+
event: metric
|
|
63
|
+
data: {"name":"loss","value":1.84,"step":820,"epoch":4,"split":"train","window":100,"mean":1.90,"std":0.15,"ts":173.13}
|
|
64
|
+
|
|
65
|
+
id: 104
|
|
66
|
+
event: log
|
|
67
|
+
data: {"level":"INFO","message":"checkpoint saved","ts":173.16}
|
|
68
|
+
|
|
69
|
+
id: 105
|
|
70
|
+
event: artifact
|
|
71
|
+
data: {"kind":"checkpoint","url":"/rl/jobs/j_abc/artifacts/ckpt_0004.pt","step":820,"ts":173.16}
|
|
72
|
+
|
|
73
|
+
: keep-alive
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
---
|
|
77
|
+
|
|
78
|
+
## Server architecture
|
|
79
|
+
|
|
80
|
+
### Components
|
|
81
|
+
- Event bus per `job_id` (async queue) where producers (RL, FT, evaluator) `emit(Event)`
|
|
82
|
+
- Ring buffer per job for replay (configurable: last N events OR last T minutes)
|
|
83
|
+
- SSE handler:
|
|
84
|
+
1) Authenticate, pick job, determine resume cursor (`since_id` or `Last-Event-ID`)
|
|
85
|
+
2) Replay from ring buffer >= cursor
|
|
86
|
+
3) Attach to live queue; stream new events
|
|
87
|
+
4) Emit heartbeats; close after terminal `status` + grace
|
|
88
|
+
|
|
89
|
+
### Concurrency & ordering
|
|
90
|
+
- Single writer increments `event_id`
|
|
91
|
+
- Replay preserves original order; live continues from last id
|
|
92
|
+
- If consumer slow: drop/coalesce metrics (preserve last per metric name), always deliver status/log/artifact
|
|
93
|
+
|
|
94
|
+
### Rate limiting & coalescing
|
|
95
|
+
- Default target 2–4 Hz for metrics per split
|
|
96
|
+
- Coalesce by metric name within a small interval (e.g., 250–500 ms)
|
|
97
|
+
- Status events limited to phase changes or every 5–10s
|
|
98
|
+
|
|
99
|
+
### Auth & security
|
|
100
|
+
- Accept `X-API-Key` or `Authorization: Bearer`
|
|
101
|
+
- Validate job ownership/visibility
|
|
102
|
+
- CORS: allow EventSource; set `Access-Control-Allow-Origin` appropriately
|
|
103
|
+
- Timeouts: server idle timeout > heartbeat * 2; client reconnect on drop
|
|
104
|
+
|
|
105
|
+
### Config knobs (env)
|
|
106
|
+
- `SSE_HEARTBEAT_SECS` (default 20)
|
|
107
|
+
- `SSE_RING_BUFFER_EVENTS` (e.g., 2000) OR `SSE_RING_BUFFER_WINDOW_SECS` (e.g., 600)
|
|
108
|
+
- `SSE_MAX_METRIC_HZ` (e.g., 4)
|
|
109
|
+
- `SSE_MAX_CLIENTS_PER_JOB` (protect from fan-out)
|
|
110
|
+
|
|
111
|
+
---
|
|
112
|
+
|
|
113
|
+
## Emit points
|
|
114
|
+
|
|
115
|
+
### RL training/eval
|
|
116
|
+
- On train step end: `metric` avg_reward/return/success_rate; `status` every N steps
|
|
117
|
+
- On eval step end: `metric` eval_return/success_rate; `artifact` eval JSONL optional
|
|
118
|
+
- On checkpoint: `artifact` + `log`
|
|
119
|
+
- On phase transitions: `status` (train→eval, etc.)
|
|
120
|
+
|
|
121
|
+
### Fine-tuning (FT)
|
|
122
|
+
- On optimizer step: `metric` loss (and optional lr)
|
|
123
|
+
- On validation: `metric` val_loss/accuracy; optional `artifact` (curves)
|
|
124
|
+
- On checkpoint: `artifact` + `log`
|
|
125
|
+
|
|
126
|
+
---
|
|
127
|
+
|
|
128
|
+
## Client (synth-ai)
|
|
129
|
+
|
|
130
|
+
### CLI
|
|
131
|
+
- `synth-ai jobs stream <job_id> [--jsonl out.jsonl] [--types metric,status]`
|
|
132
|
+
- Prints compact lines: `t=18:22:40 step=820 avg_reward=0.62 loss=1.84`
|
|
133
|
+
- Writes raw events to JSONL if specified
|
|
134
|
+
|
|
135
|
+
### Python helper
|
|
136
|
+
```python
|
|
137
|
+
from synth_ai.client import TaskAppClient
|
|
138
|
+
|
|
139
|
+
with TaskAppClient(base_url, api_key) as c:
|
|
140
|
+
for ev in c.stream_job(job_id, types=["metric","status"], since_id=None):
|
|
141
|
+
handle(ev)
|
|
142
|
+
```
|
|
143
|
+
- Handles reconnect with `Last-Event-ID`
|
|
144
|
+
- Dedupes by event id; optional local aggregation windows
|
|
145
|
+
|
|
146
|
+
---
|
|
147
|
+
|
|
148
|
+
## Error handling & edge cases
|
|
149
|
+
- 404 (job not found): immediate error, no stream
|
|
150
|
+
- 401/403: immediate error, no stream
|
|
151
|
+
- Producer stalls: heartbeats continue; client shows "connected: no new data"
|
|
152
|
+
- Ring buffer miss (client too far behind): send a summary `status` snapshot, then live only
|
|
153
|
+
- Buffer overflow: set `dropped_events: true` field in a `status` update
|
|
154
|
+
|
|
155
|
+
---
|
|
156
|
+
|
|
157
|
+
## Testing plan
|
|
158
|
+
- Unit: SSE formatting, ring buffer indexing, resume semantics, coalescing
|
|
159
|
+
- Integration: run toy RL + FT loops; assert ordered sequence and types
|
|
160
|
+
- Load: simulate 100 Hz metric emissions; verify rate limiting and non-blocking behavior
|
|
161
|
+
- Auth: invalid key, missing key, cross-job access rejection
|
|
162
|
+
|
|
163
|
+
---
|
|
164
|
+
|
|
165
|
+
## Rollout plan
|
|
166
|
+
1) Implement event bus + ring buffer per job
|
|
167
|
+
2) Add SSE route with auth, resume, heartbeat
|
|
168
|
+
3) Emiters in RL/FT loops (train/eval steps, loss/reward, checkpoints)
|
|
169
|
+
4) synth-ai client: CLI + Python helper; optional JSONL sink
|
|
170
|
+
5) Docs and example snippets (curl, Python, browser EventSource)
|
|
171
|
+
|
|
172
|
+
---
|
|
173
|
+
|
|
174
|
+
## Open questions / decisions
|
|
175
|
+
- Do we persist events beyond ring buffer? (optional artifact JSONL)
|
|
176
|
+
- Multi-tenant limits per org/user? (SSE_MAX_CLIENTS_PER_ORG?)
|
|
177
|
+
- Artifact URLs: signed vs. authenticated direct
|
|
178
|
+
- Backward compatibility with current polling endpoints
|
|
179
|
+
I'll scan the backend to find where RL jobs are created, status is served, and whether there's any existing streaming/event infrastructure to extend. Then I'll identify exact files/functions to change for adding an SSE endpoint and metric emitters.
|
|
180
|
+
|
|
181
|
+
[2 tools called]
|
|
182
|
+
|
|
183
|
+
Here’s the exact backend surface you’ll need to touch to add a single SSE stream for status + metrics, with where to hook metric emitters.
|
|
184
|
+
|
|
185
|
+
### Where jobs and events live
|
|
186
|
+
- Job and event tables (legacy RL path):
|
|
187
|
+
```1:20:monorepo/backend/app/db/models/jobs.py
|
|
188
|
+
"""Job orchestration models (RL/SFT shared).
|
|
189
|
+
"""
|
|
190
|
+
```
|
|
191
|
+
```31:58:monorepo/backend/app/db/models/jobs.py
|
|
192
|
+
class Job(Base):
|
|
193
|
+
__tablename__ = "jobs"
|
|
194
|
+
job_id = Column(String, primary_key=True)
|
|
195
|
+
...
|
|
196
|
+
status = Column(String, nullable=False, default="QUEUED")
|
|
197
|
+
...
|
|
198
|
+
last_seq = Column(BigInteger, nullable=True, default=0)
|
|
199
|
+
events = relationship("JobEvent", back_populates="job", cascade="all, delete-orphan", passive_deletes=True)
|
|
200
|
+
```
|
|
201
|
+
```66:88:monorepo/backend/app/db/models/jobs.py
|
|
202
|
+
class JobEvent(Base):
|
|
203
|
+
"""Append-only job event log."""
|
|
204
|
+
__tablename__ = "job_events"
|
|
205
|
+
job_id = Column(String, ForeignKey("jobs.job_id", ondelete="CASCADE"), primary_key=True)
|
|
206
|
+
seq = Column(BigInteger, primary_key=True)
|
|
207
|
+
ts = Column(DateTime(timezone=True), server_default=func.now())
|
|
208
|
+
type = Column(String, nullable=False)
|
|
209
|
+
level = Column(String, nullable=False, default="info")
|
|
210
|
+
message = Column(Text, nullable=False)
|
|
211
|
+
data = Column(JSONB, nullable=True)
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
- DB repo (atomic seq increment + append):
|
|
215
|
+
```108:146:monorepo/backend/app/orchestration/jobs/repository_db.py
|
|
216
|
+
async def append_event(...):
|
|
217
|
+
res = await self.session.execute(
|
|
218
|
+
update(LearningJob)
|
|
219
|
+
.where(LearningJob.job_id == job_id)
|
|
220
|
+
.values(last_seq=func.coalesce(LearningJob.last_seq, 0) + 1, updated_at=func.now())
|
|
221
|
+
.returning(LearningJob.last_seq)
|
|
222
|
+
)
|
|
223
|
+
seq = res.scalar_one()
|
|
224
|
+
ev = JobEvent(job_id=job_id, seq=seq, type=type_, level=level, message=message, data=(data or {}))
|
|
225
|
+
...
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
- RL jobs REST (create + emit first events via PostgREST emitter):
|
|
229
|
+
```214:233:monorepo/backend/app/routes/clustered_training/core/routes.py
|
|
230
|
+
@router.post("/jobs", ...)
|
|
231
|
+
async def create_job(...):
|
|
232
|
+
...
|
|
233
|
+
```
|
|
234
|
+
```784:799:monorepo/backend/app/routes/clustered_training/core/routes.py
|
|
235
|
+
await get_postgrest_emitter().append_event(job_id=job_id, type_="rl.job.created", message="RL job created", data={"work": payload.work})
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
- Shared storage path showing how “update” and “append event” are broadcast via the PostgREST emitter:
|
|
239
|
+
```242:271:monorepo/backend/app/routes/simple_training/services/job_service.py
|
|
240
|
+
async def update_job(...): ...
|
|
241
|
+
async def append_job_event(...): ...
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
What to add/change (SSE + emitters)
|
|
245
|
+
|
|
246
|
+
1) Add SSE endpoint under RL jobs
|
|
247
|
+
- File: `monorepo/backend/app/routes/clustered_training/core/routes.py`
|
|
248
|
+
- New route: `GET /api/rl/jobs/{job_id}/stream`
|
|
249
|
+
- Auth: `ValidatedAPIKey`
|
|
250
|
+
- Headers: honor `Last-Event-ID`; Query: `since_id`, `types`, `heartbeat`
|
|
251
|
+
- Flow:
|
|
252
|
+
- Validate job exists (use the DB-backed repo in production mode)
|
|
253
|
+
- Determine start seq (from Last-Event-ID or since_id, else 0)
|
|
254
|
+
- Replay: call repo `list_events(job_id, after=seq)` and stream as SSE (`id`, `event`, `data`)
|
|
255
|
+
- Live tail: loop with short sleep (e.g., 0.5–1.0s) fetching new events by `after=last_seq`
|
|
256
|
+
- Heartbeats as comments `: keep-alive` every N seconds
|
|
257
|
+
- Event mapping:
|
|
258
|
+
- `JobEvent.type` prefixes map to SSE `event`:
|
|
259
|
+
- `rl.job.*`, `job.updated` → `status`
|
|
260
|
+
- `rl.step.metric`, `ft.step.metric`, `eval.metric` → `metric`
|
|
261
|
+
- `system.log.*` → `log`
|
|
262
|
+
- `artifact.*` → `artifact`
|
|
263
|
+
- `Job.status` can be snapshotted once at connect (send a `status`)
|
|
264
|
+
|
|
265
|
+
2) Ensure a consistent event source for reads
|
|
266
|
+
- Prefer the DB repo (`JobsRepositoryDB`) in prod mode. If the current code path uses the PostgREST emitter for appends, verify that the repo’s `list_events` reads from the same canonical table (it does for `JobEvent`). If your RL path uses the “learning_shared” models instead, use the associated repository there (same pattern: list by job_id + seq).
|
|
267
|
+
- If you must keep PostgREST for append-only, that’s fine; SSE can still read the DB rows inserted alongside (your outbox/emitter already supports both).
|
|
268
|
+
|
|
269
|
+
3) Emitters in training/FT loops
|
|
270
|
+
- File(s): `monorepo/backend/app/orchestration/hatchet/workflows.py` (RL workflow nodes), any FT job loops
|
|
271
|
+
- After each meaningful step:
|
|
272
|
+
- Train: append `type="rl.step.metric"`, `data={"avg_reward":..., "return":..., "success_rate":..., "step":..., "epoch":..., "split":"train"}`.
|
|
273
|
+
- Eval: `type="eval.metric"` with eval metrics and split.
|
|
274
|
+
- FT: `type="ft.step.metric"`, `data={"loss":..., "lr":..., "step":..., "epoch":..., "split":"train"}`; validation as `split="eval"`.
|
|
275
|
+
- On phase changes/checkpoints: `type="job.updated"` or `artifact.checkpoint` with URLs.
|
|
276
|
+
- Use the same helper used elsewhere:
|
|
277
|
+
```236:276:monorepo/backend/app/routes/simple_training/services/job_service.py
|
|
278
|
+
async def append_job_event(...): return await get_postgrest_emitter().append_event(...)
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
4) Optional shared service abstraction
|
|
282
|
+
- File: `monorepo/backend/app/routes/simple_training/services/storage_shared.py`
|
|
283
|
+
- Add a small `stream_job_events(job_id, after)` helper that wraps `repo.list_events(...)` and normalizes schemas (legacy vs learning_shared). The SSE route can call this.
|
|
284
|
+
|
|
285
|
+
5) Wire the router
|
|
286
|
+
- File: `monorepo/backend/app/routes/main.py`
|
|
287
|
+
- Include the new GET route (under the RL router you already mount).
|
|
288
|
+
- CORS: ensure EventSource allowed if frontend will use browser SSE.
|
|
289
|
+
|
|
290
|
+
Indexes/Perf you already have
|
|
291
|
+
- Jobs/events tables include per-job indexes and a `last_seq`; reads by `(job_id, seq)` are efficient:
|
|
292
|
+
```84:88:monorepo/backend/app/db/models/jobs.py
|
|
293
|
+
Index("idx_job_events_job_ts", "job_id", "ts"),
|
|
294
|
+
Index("idx_job_events_data_gin", "data", postgresql_using="gin"),
|
|
295
|
+
```
|
|
296
|
+
- For learning_shared events (if used), there are sequence indexes too:
|
|
297
|
+
```159:165:monorepo/backend/app/db/models/learning_shared.py
|
|
298
|
+
sa.Index("idx_learning_job_events_job_seq_idx", "job_id", "seq"),
|
|
299
|
+
```
|
|
300
|
+
|
|
301
|
+
Summary of minimal backend edits
|
|
302
|
+
- Add SSE route:
|
|
303
|
+
- `monorepo/backend/app/routes/clustered_training/core/routes.py` (GET `/api/rl/jobs/{job_id}/stream`)
|
|
304
|
+
- Consume events via repo:
|
|
305
|
+
- `monorepo/backend/app/orchestration/jobs/repository_db.py` (use `list_events`)
|
|
306
|
+
- Emit metrics from loops:
|
|
307
|
+
- `monorepo/backend/app/orchestration/hatchet/workflows.py` (append_event at train/eval steps)
|
|
308
|
+
- Optional shared helper:
|
|
309
|
+
- `monorepo/backend/app/routes/simple_training/services/storage_shared.py` (normalize event reads)
|
|
310
|
+
- Wireup:
|
|
311
|
+
- `monorepo/backend/app/routes/main.py` (include SSE route)
|
|
312
|
+
- No schema changes required; you already have `JobEvent`/indexes and event append plumbing.
|
|
313
|
+
|
|
314
|
+
### SDK additions for great terminal polling (with SSE fallback)
|
|
315
|
+
|
|
316
|
+
- RlJobsApi extensions (synth_ai/jobs/client.py)
|
|
317
|
+
- stream(job_id, since_id=None, types=None, heartbeat=None) -> async iterator of events (uses SSE; falls back to polling)
|
|
318
|
+
- events(job_id, after=None, limit=500) -> list[JobEvent] (poll)
|
|
319
|
+
- status(job_id) -> JobSummary (single snapshot)
|
|
320
|
+
|
|
321
|
+
- Event models (synth_ai/jobs/types.py)
|
|
322
|
+
- JobEvent base: {id, type, level, message, data, ts}
|
|
323
|
+
- StatusEvent, MetricEvent, LogEvent, ArtifactEvent (typed helpers)
|
|
324
|
+
|
|
325
|
+
- JobsWatcher helper (synth_ai/jobs/watcher.py)
|
|
326
|
+
- constructor(client, job_id, interval=2.0, prefer_sse=True, jsonl_path=None, types=None)
|
|
327
|
+
- run(on_event, stop_when=None) → handles SSE connect/reconnect, polling fallback, Last-Event-ID cursor, dedupe
|
|
328
|
+
- metrics_tracker: rolling windows per metric name (mean/std/min/max, last_value, last_step)
|
|
329
|
+
- backoff policy: jittered reconnect; rate limiter for render
|
|
330
|
+
|
|
331
|
+
- Terminal renderer (synth_ai/jobs/render.py)
|
|
332
|
+
- RichRenderer (or minimal TTY): compact line updates: t=HH:MM:SS | step/E | key metrics (avg_reward, loss, val_loss, success_rate)
|
|
333
|
+
- modes: one-line ticker vs. per-event lines; quiet mode; color by level/state
|
|
334
|
+
- JSONL sink: raw event writes without printing prompts/payloads
|
|
335
|
+
|
|
336
|
+
- CLI command (synth_ai/api/train/cli.py)
|
|
337
|
+
- synth-ai jobs watch <job_id> [--types metric,status] [--interval 2] [--jsonl out.jsonl] [--since-id N] [--no-sse]
|
|
338
|
+
- exit codes: 0 on succeeded, 1 on failed/canceled, 2 on timeout
|
|
339
|
+
|
|
340
|
+
- Utilities (synth_ai/jobs/utils.py)
|
|
341
|
+
- BackoffPolicy(retry, max) with jitter
|
|
342
|
+
- EventCursor(last_id, update)
|
|
343
|
+
- MetricsFormatter(map by job_type: RL vs FT metric labels)
|
|
344
|
+
- Coalescer: compress frequent metrics to ≤4 Hz
|
|
345
|
+
|
|
346
|
+
- Defaults/behavior
|
|
347
|
+
- Prefer SSE; if 404/405/close → fallback to polling events() every interval
|
|
348
|
+
- Heartbeat support; show “connected/no data” when only heartbeats
|
|
349
|
+
- Resume: honor --since-id or Last-Event-ID; persist cursor optionally
|
|
350
|
+
|
|
351
|
+
- Minimal backend assumptions
|
|
352
|
+
- GET /api/rl/jobs/{job_id}/stream (SSE) or /api/rl/jobs/{job_id}/events?after=… (poll)
|
|
353
|
+
- Events include metric/status/log/artifact with seq ids and ts
|
|
354
|
+
|
|
355
|
+
- Extensibility
|
|
356
|
+
- Plugin renderers per job_type (rl, sft/ft)
|
|
357
|
+
- Hooks: on_status_change, on_metric(name, value), on_artifact(url)
|
|
@@ -485,4 +485,10 @@ payload = {
|
|
|
485
485
|
}
|
|
486
486
|
```
|
|
487
487
|
|
|
488
|
-
Status: Notes committed on branch `friday-cleanup` and pushed.
|
|
488
|
+
Status: Notes committed on branch `friday-cleanup` and pushed.
|
|
489
|
+
|
|
490
|
+
### Operational guardrails
|
|
491
|
+
|
|
492
|
+
- Treat avg_turns == 0 (or a high fraction of episodes with turns == 0) as a failure condition; exit non‑zero.
|
|
493
|
+
- Fail fast when the first policy step returns a 4xx/5xx from the inference target; include the HTTP status and URL in the error message.
|
|
494
|
+
- CI hint: a tiny smoke run (2 seeds × 1 rollout) should see turns > 0 in healthy setups.
|
|
@@ -4,9 +4,11 @@
|
|
|
4
4
|
# task_app_url = "https://YOUR-TASK-APP.modal.run"
|
|
5
5
|
|
|
6
6
|
model = "qwen/qwen3-32b"
|
|
7
|
-
|
|
7
|
+
# Route inference to local task app Groq proxy
|
|
8
|
+
inference_url = "http://localhost:8001/proxy/groq"
|
|
9
|
+
num_episodes = 10
|
|
8
10
|
max_turns = 10
|
|
9
|
-
concurrency =
|
|
11
|
+
concurrency = 10
|
|
10
12
|
# difficulty = "easy" # optional
|
|
11
13
|
|
|
12
14
|
[rollout]
|
|
@@ -4,9 +4,11 @@
|
|
|
4
4
|
# task_app_url = "https://YOUR-TASK-APP.modal.run"
|
|
5
5
|
|
|
6
6
|
model = "qwen/qwen3-32b"
|
|
7
|
-
|
|
7
|
+
# Route inference to local task app Groq proxy
|
|
8
|
+
inference_url = "http://localhost:8001/proxy/groq"
|
|
9
|
+
num_episodes = 10
|
|
8
10
|
max_turns = 10
|
|
9
|
-
concurrency =
|
|
11
|
+
concurrency = 10
|
|
10
12
|
# difficulty = "easy" # optional
|
|
11
13
|
|
|
12
14
|
[rollout]
|
|
@@ -14,12 +14,14 @@ import contextlib
|
|
|
14
14
|
import json
|
|
15
15
|
import os
|
|
16
16
|
import re
|
|
17
|
-
import
|
|
18
|
-
from copy import deepcopy
|
|
17
|
+
import sys
|
|
19
18
|
from collections import Counter
|
|
19
|
+
from copy import deepcopy
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Any
|
|
22
22
|
|
|
23
|
+
import tomllib
|
|
24
|
+
|
|
23
25
|
import httpx
|
|
24
26
|
|
|
25
27
|
|
|
@@ -332,6 +334,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
332
334
|
observation = created.get("observation") if isinstance(created, dict) else None
|
|
333
335
|
if not isinstance(observation, dict):
|
|
334
336
|
observation = {}
|
|
337
|
+
try:
|
|
338
|
+
ach_map_initial = observation.get("achievements_status")
|
|
339
|
+
if isinstance(ach_map_initial, dict):
|
|
340
|
+
achievements.update(k for k, v in ach_map_initial.items() if v)
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
335
343
|
|
|
336
344
|
try:
|
|
337
345
|
while turns < MAX_TURNS and not done:
|
|
@@ -351,6 +359,12 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
351
359
|
nxt = step.get("observation")
|
|
352
360
|
if isinstance(nxt, dict):
|
|
353
361
|
observation = nxt
|
|
362
|
+
try:
|
|
363
|
+
ach_map = observation.get("achievements_status")
|
|
364
|
+
if isinstance(ach_map, dict):
|
|
365
|
+
achievements.update(k for k, v in ach_map.items() if v)
|
|
366
|
+
except Exception:
|
|
367
|
+
pass
|
|
354
368
|
finally:
|
|
355
369
|
with contextlib.suppress(Exception):
|
|
356
370
|
await client.terminate(env_name, env_id)
|
|
@@ -358,21 +372,45 @@ async def eval_episode(client: TaskAppClient, seed: int) -> dict[str, Any]:
|
|
|
358
372
|
return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
|
|
359
373
|
|
|
360
374
|
|
|
361
|
-
|
|
362
|
-
|
|
375
|
+
def _load_dotenv_defaults() -> None:
|
|
376
|
+
"""Load .env-style key/value pairs without clobbering explicit exports."""
|
|
363
377
|
try:
|
|
364
|
-
|
|
365
|
-
if env_path.exists():
|
|
366
|
-
for line in env_path.read_text(encoding="utf-8").splitlines():
|
|
367
|
-
line = line.strip()
|
|
368
|
-
if not line or line.startswith("#") or "=" not in line:
|
|
369
|
-
continue
|
|
370
|
-
k, v = line.split("=", 1)
|
|
371
|
-
k = k.strip()
|
|
372
|
-
v = v.strip().strip('"').strip("'")
|
|
373
|
-
os.environ.setdefault(k, v)
|
|
378
|
+
script_path = Path(__file__).resolve()
|
|
374
379
|
except Exception:
|
|
375
|
-
|
|
380
|
+
return
|
|
381
|
+
candidates: list[Path] = []
|
|
382
|
+
# Prefer the repo root .env, then allow per-directory overrides.
|
|
383
|
+
for base in [Path.cwd(), script_path.parent, *script_path.parents]:
|
|
384
|
+
env_path = base / ".env"
|
|
385
|
+
if env_path not in candidates and env_path.is_file():
|
|
386
|
+
candidates.append(env_path)
|
|
387
|
+
seen: set[str] = set()
|
|
388
|
+
try:
|
|
389
|
+
for env_path in candidates:
|
|
390
|
+
try:
|
|
391
|
+
for raw in env_path.read_text(encoding="utf-8").splitlines():
|
|
392
|
+
line = raw.strip()
|
|
393
|
+
if not line or line.startswith("#") or "=" not in line:
|
|
394
|
+
continue
|
|
395
|
+
key, value = line.split("=", 1)
|
|
396
|
+
key = key.strip()
|
|
397
|
+
if not key or key in seen:
|
|
398
|
+
continue
|
|
399
|
+
seen.add(key)
|
|
400
|
+
val = value.strip().strip('"').strip("'")
|
|
401
|
+
os.environ.setdefault(key, val)
|
|
402
|
+
except Exception:
|
|
403
|
+
continue
|
|
404
|
+
except Exception:
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
async def main() -> None:
|
|
409
|
+
_load_dotenv_defaults()
|
|
410
|
+
if not (os.getenv("ENVIRONMENT_API_KEY") or os.getenv("DEV_ENVIRONMENT_API_KEY")):
|
|
411
|
+
raise RuntimeError(
|
|
412
|
+
"ENVIRONMENT_API_KEY is required. Export it or add it to your project .env."
|
|
413
|
+
)
|
|
376
414
|
|
|
377
415
|
parser = argparse.ArgumentParser(
|
|
378
416
|
description="Baseline eval against task app with optional TOML config"
|
|
@@ -497,7 +535,7 @@ async def main() -> None:
|
|
|
497
535
|
if isinstance(step_block, dict):
|
|
498
536
|
stepwise_details = step_block
|
|
499
537
|
# Extract achievements count if present
|
|
500
|
-
|
|
538
|
+
achieved: set[str] = set()
|
|
501
539
|
try:
|
|
502
540
|
trajs = r.get("trajectories") or []
|
|
503
541
|
final_obs = (
|
|
@@ -511,9 +549,29 @@ async def main() -> None:
|
|
|
511
549
|
else None
|
|
512
550
|
)
|
|
513
551
|
if isinstance(ach_map, dict):
|
|
514
|
-
|
|
552
|
+
achieved.update(k for k, v in ach_map.items() if v)
|
|
553
|
+
except Exception:
|
|
554
|
+
pass
|
|
555
|
+
try:
|
|
556
|
+
step_seen = stepwise_details.get("unique_achievements")
|
|
557
|
+
except Exception:
|
|
558
|
+
step_seen = None
|
|
559
|
+
if isinstance(step_seen, (list, tuple, set)):
|
|
560
|
+
achieved.update(str(a) for a in step_seen)
|
|
561
|
+
else:
|
|
562
|
+
try:
|
|
563
|
+
alt_seen = stepwise_details.get("achievements_seen")
|
|
564
|
+
if isinstance(alt_seen, (list, tuple, set)):
|
|
565
|
+
achieved.update(str(a) for a in alt_seen)
|
|
566
|
+
except Exception:
|
|
567
|
+
pass
|
|
568
|
+
try:
|
|
569
|
+
summary_final = stepwise_details.get("final_achievements")
|
|
570
|
+
if isinstance(summary_final, (list, tuple, set)):
|
|
571
|
+
achieved.update(str(a) for a in summary_final)
|
|
515
572
|
except Exception:
|
|
516
573
|
pass
|
|
574
|
+
ach = sorted(achieved)
|
|
517
575
|
length = 0
|
|
518
576
|
try:
|
|
519
577
|
trajs = r.get("trajectories") or []
|
|
@@ -556,7 +614,10 @@ async def main() -> None:
|
|
|
556
614
|
stepwise_reward_sums: list[float] = []
|
|
557
615
|
stepwise_indicator_sums: list[float] = []
|
|
558
616
|
stepwise_new_ach_totals: list[float] = []
|
|
617
|
+
stepwise_resource_rewards: list[float] = []
|
|
559
618
|
strategies_seen = Counter()
|
|
619
|
+
unique_union: set[str] = set()
|
|
620
|
+
final_union: set[str] = set()
|
|
560
621
|
for r in results:
|
|
561
622
|
if not isinstance(r, dict):
|
|
562
623
|
continue
|
|
@@ -577,6 +638,19 @@ async def main() -> None:
|
|
|
577
638
|
stepwise_new_ach_totals.append(
|
|
578
639
|
float(stepwise_block.get("new_achievements_total"))
|
|
579
640
|
)
|
|
641
|
+
with contextlib.suppress(Exception):
|
|
642
|
+
if stepwise_block.get("resource_reward") is not None:
|
|
643
|
+
stepwise_resource_rewards.append(
|
|
644
|
+
float(stepwise_block.get("resource_reward"))
|
|
645
|
+
)
|
|
646
|
+
with contextlib.suppress(Exception):
|
|
647
|
+
uniq = stepwise_block.get("unique_achievements") or []
|
|
648
|
+
if isinstance(uniq, (list, tuple, set)):
|
|
649
|
+
unique_union.update(str(v) for v in uniq)
|
|
650
|
+
with contextlib.suppress(Exception):
|
|
651
|
+
final = stepwise_block.get("final_achievements") or []
|
|
652
|
+
if isinstance(final, (list, tuple, set)):
|
|
653
|
+
final_union.update(str(v) for v in final)
|
|
580
654
|
strategy_name = stepwise_block.get("strategy")
|
|
581
655
|
if isinstance(strategy_name, str) and strategy_name:
|
|
582
656
|
strategies_seen[strategy_name] += 1
|
|
@@ -603,14 +677,49 @@ async def main() -> None:
|
|
|
603
677
|
aggregate["avg_stepwise_new_achievements"] = sum(stepwise_new_ach_totals) / len(
|
|
604
678
|
stepwise_new_ach_totals
|
|
605
679
|
)
|
|
680
|
+
if stepwise_resource_rewards:
|
|
681
|
+
aggregate["avg_stepwise_resource_reward"] = (
|
|
682
|
+
sum(stepwise_resource_rewards) / len(stepwise_resource_rewards)
|
|
683
|
+
)
|
|
606
684
|
if strategies_seen:
|
|
607
685
|
aggregate["stepwise_strategies"] = dict(strategies_seen)
|
|
608
|
-
aggregate["stepwise_samples"] =
|
|
686
|
+
aggregate["stepwise_samples"] = max(
|
|
687
|
+
len(stepwise_reward_sums),
|
|
688
|
+
len(stepwise_indicator_sums),
|
|
689
|
+
len(stepwise_new_ach_totals),
|
|
690
|
+
len(stepwise_resource_rewards),
|
|
691
|
+
) if any(
|
|
692
|
+
(
|
|
693
|
+
stepwise_reward_sums,
|
|
694
|
+
stepwise_indicator_sums,
|
|
695
|
+
stepwise_new_ach_totals,
|
|
696
|
+
stepwise_resource_rewards,
|
|
697
|
+
)
|
|
698
|
+
) else 0
|
|
699
|
+
if not unique_union:
|
|
700
|
+
for r in results:
|
|
701
|
+
try:
|
|
702
|
+
for a in r.get("achievements") or []:
|
|
703
|
+
unique_union.add(str(a))
|
|
704
|
+
except Exception:
|
|
705
|
+
continue
|
|
706
|
+
if not final_union:
|
|
707
|
+
final_union.update(unique_union)
|
|
708
|
+
if unique_union:
|
|
709
|
+
aggregate["unique_achievements_union"] = sorted(unique_union)
|
|
710
|
+
if final_union:
|
|
711
|
+
aggregate["final_achievements_union"] = sorted(final_union)
|
|
609
712
|
summary = {
|
|
610
713
|
"episodes": results,
|
|
611
714
|
"aggregate": aggregate,
|
|
612
715
|
}
|
|
613
716
|
print(json.dumps(summary, indent=2))
|
|
717
|
+
# Failure guardrails: any error or zero-turn episodes across the board
|
|
718
|
+
any_errors = any(isinstance(r, dict) and r.get("error") for r in results)
|
|
719
|
+
all_zero_turns = all((int(r.get("turns") or 0) == 0) for r in results if isinstance(r, dict))
|
|
720
|
+
if any_errors or all_zero_turns:
|
|
721
|
+
# Exit non-zero so automation/CI treats this as a failure
|
|
722
|
+
sys.exit(2)
|
|
614
723
|
else:
|
|
615
724
|
|
|
616
725
|
async def _run(seed: int):
|
|
@@ -93,6 +93,7 @@ TASK_APP_ROOT = _resolve_task_app_root(REPO_ROOT)
|
|
|
93
93
|
SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
|
|
94
94
|
|
|
95
95
|
EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
|
|
96
|
+
RUBRICS_ROOT = (EXAMPLES_ROOT / "multi_step" / "rubrics").resolve()
|
|
96
97
|
|
|
97
98
|
for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
|
|
98
99
|
try:
|
|
@@ -344,40 +345,9 @@ def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
|
344
345
|
)
|
|
345
346
|
|
|
346
347
|
|
|
347
|
-
OUTCOME_RUBRIC = load_rubric(
|
|
348
|
-
{
|
|
349
|
-
"version": "1",
|
|
350
|
-
"goal_text": "Reward unlocking Crafter achievements and survival.",
|
|
351
|
-
"aggregation": "weighted_sum",
|
|
352
|
-
"criteria": [
|
|
353
|
-
{
|
|
354
|
-
"id": "achievements",
|
|
355
|
-
"description": "Unlock achievements or crafting milestones.",
|
|
356
|
-
"weight": 1.0,
|
|
357
|
-
},
|
|
358
|
-
{
|
|
359
|
-
"id": "survival",
|
|
360
|
-
"description": "Maintain health, food, and drink levels.",
|
|
361
|
-
"weight": 1.0,
|
|
362
|
-
},
|
|
363
|
-
],
|
|
364
|
-
}
|
|
365
|
-
)
|
|
348
|
+
OUTCOME_RUBRIC = load_rubric(str(RUBRICS_ROOT / "crafter_outcome_rubric.json"))
|
|
366
349
|
|
|
367
|
-
EVENTS_RUBRIC = load_rubric(
|
|
368
|
-
{
|
|
369
|
-
"version": "1",
|
|
370
|
-
"goal_text": "Encourage purposeful step-wise exploration and crafting.",
|
|
371
|
-
"aggregation": "weighted_sum",
|
|
372
|
-
"criteria": [
|
|
373
|
-
{
|
|
374
|
-
"id": "progress_steps",
|
|
375
|
-
"description": "Actions progress quests, crafting, or exploration.",
|
|
376
|
-
"weight": 1.0,
|
|
377
|
-
}
|
|
378
|
-
],
|
|
379
|
-
}
|
|
380
|
-
)
|
|
350
|
+
EVENTS_RUBRIC = load_rubric(str(RUBRICS_ROOT / "crafter_events_rubric.json"))
|
|
381
351
|
|
|
382
352
|
|
|
383
353
|
def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]:
|