synth-ai 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +5 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +66 -3
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +17 -49
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +13 -5
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +15 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +134 -3
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +4 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +6 -3
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +125 -10
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +12 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +58 -1487
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -11
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/validators.py +2 -2
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/utils/env.py +25 -18
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/modal.py +2 -2
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/METADATA +8 -3
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/RECORD +184 -109
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from collections import deque
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Callable
|
|
11
|
+
|
|
12
|
+
import click
|
|
13
|
+
|
|
14
|
+
from .types import StreamMessage, StreamType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class StreamHandler(ABC):
|
|
18
|
+
"""Base class for log handlers that consume ``StreamMessage`` objects."""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def handle(self, message: StreamMessage) -> None:
|
|
22
|
+
"""Process a message produced by the streamer."""
|
|
23
|
+
|
|
24
|
+
def should_handle(self, message: StreamMessage) -> bool: # pragma: no cover - trivial
|
|
25
|
+
"""Predicate allowing handlers to filter messages before processing."""
|
|
26
|
+
return True
|
|
27
|
+
|
|
28
|
+
def flush(self) -> None: # pragma: no cover - optional
|
|
29
|
+
"""Flush buffered output."""
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CLIHandler(StreamHandler):
|
|
34
|
+
"""Simple CLI output mirroring current poller behaviour."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
hidden_event_types: set[str] | None = None,
|
|
40
|
+
hidden_event_substrings: set[str] | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
self._hidden_event_types = set(hidden_event_types or set())
|
|
43
|
+
self._hidden_event_substrings = {s.lower() for s in (hidden_event_substrings or set())}
|
|
44
|
+
|
|
45
|
+
def handle(self, message: StreamMessage) -> None:
|
|
46
|
+
if not self.should_handle(message):
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
50
|
+
if message.stream_type is StreamType.STATUS:
|
|
51
|
+
status = str(message.data.get("status") or message.data.get("state") or "unknown")
|
|
52
|
+
click.echo(f"[{timestamp}] status={status}")
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
if message.stream_type is StreamType.EVENTS:
|
|
56
|
+
event_type = message.data.get("type", "event")
|
|
57
|
+
if event_type in self._hidden_event_types:
|
|
58
|
+
return
|
|
59
|
+
level = message.data.get("level")
|
|
60
|
+
msg = message.data.get("message") or ""
|
|
61
|
+
# Evaluate substring filters against lower-cased concatenated text
|
|
62
|
+
if self._hidden_event_substrings:
|
|
63
|
+
blob = " ".join(
|
|
64
|
+
[
|
|
65
|
+
event_type or "",
|
|
66
|
+
str(msg),
|
|
67
|
+
json.dumps(message.data.get("data", "")),
|
|
68
|
+
]
|
|
69
|
+
).lower()
|
|
70
|
+
if any(sub in blob for sub in self._hidden_event_substrings):
|
|
71
|
+
return
|
|
72
|
+
prefix = f"[{timestamp}] [{message.seq}] {event_type}"
|
|
73
|
+
if level:
|
|
74
|
+
prefix += f" ({level})"
|
|
75
|
+
click.echo(f"{prefix}: {msg}".rstrip(": "))
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if message.stream_type is StreamType.METRICS:
|
|
79
|
+
name = message.data.get("name", "metric")
|
|
80
|
+
value = message.data.get("value")
|
|
81
|
+
step = message.data.get("step")
|
|
82
|
+
click.echo(f"[{timestamp}] {name}={value} (step={step})")
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
if message.stream_type is StreamType.TIMELINE:
|
|
86
|
+
phase = message.data.get("phase", "phase")
|
|
87
|
+
click.echo(f"[{timestamp}] timeline={phase}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class JSONHandler(StreamHandler):
|
|
91
|
+
"""Emit messages as JSON lines suitable for machine parsing."""
|
|
92
|
+
|
|
93
|
+
def __init__(self, output_file: str | None = None, *, indent: int | None = None) -> None:
|
|
94
|
+
self.output_file = Path(output_file).expanduser() if output_file else None
|
|
95
|
+
self._indent = indent
|
|
96
|
+
|
|
97
|
+
def handle(self, message: StreamMessage) -> None:
|
|
98
|
+
if not self.should_handle(message):
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
payload: dict[str, Any] = {
|
|
102
|
+
"stream_type": message.stream_type.name,
|
|
103
|
+
"timestamp": message.timestamp,
|
|
104
|
+
"job_id": message.job_id,
|
|
105
|
+
"data": message.data,
|
|
106
|
+
}
|
|
107
|
+
if message.seq is not None:
|
|
108
|
+
payload["seq"] = message.seq
|
|
109
|
+
if message.step is not None:
|
|
110
|
+
payload["step"] = message.step
|
|
111
|
+
if message.phase is not None:
|
|
112
|
+
payload["phase"] = message.phase
|
|
113
|
+
|
|
114
|
+
line = json.dumps(payload, indent=self._indent)
|
|
115
|
+
if self.output_file:
|
|
116
|
+
with self.output_file.open("a", encoding="utf-8") as fh:
|
|
117
|
+
fh.write(line)
|
|
118
|
+
if self._indent is None:
|
|
119
|
+
fh.write("\n")
|
|
120
|
+
else:
|
|
121
|
+
click.echo(line)
|
|
122
|
+
|
|
123
|
+
def flush(self) -> None:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class CallbackHandler(StreamHandler):
|
|
128
|
+
"""Invoke user-provided callbacks for specific stream types."""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
*,
|
|
133
|
+
on_status: Callable[[dict[str, Any]], None] | None = None,
|
|
134
|
+
on_event: Callable[[dict[str, Any]], None] | None = None,
|
|
135
|
+
on_metric: Callable[[dict[str, Any]], None] | None = None,
|
|
136
|
+
on_timeline: Callable[[dict[str, Any]], None] | None = None,
|
|
137
|
+
) -> None:
|
|
138
|
+
self._on_status = on_status
|
|
139
|
+
self._on_event = on_event
|
|
140
|
+
self._on_metric = on_metric
|
|
141
|
+
self._on_timeline = on_timeline
|
|
142
|
+
|
|
143
|
+
def handle(self, message: StreamMessage) -> None:
|
|
144
|
+
if not self.should_handle(message):
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
if message.stream_type is StreamType.STATUS and self._on_status:
|
|
148
|
+
self._on_status(message.data)
|
|
149
|
+
elif message.stream_type is StreamType.EVENTS and self._on_event:
|
|
150
|
+
self._on_event(message.data)
|
|
151
|
+
elif message.stream_type is StreamType.METRICS and self._on_metric:
|
|
152
|
+
self._on_metric(message.data)
|
|
153
|
+
elif message.stream_type is StreamType.TIMELINE and self._on_timeline:
|
|
154
|
+
self._on_timeline(message.data)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class BufferedHandler(StreamHandler):
|
|
158
|
+
"""Collect messages and emit them in batches."""
|
|
159
|
+
|
|
160
|
+
def __init__(self, *, flush_interval: float = 5.0, max_buffer_size: int = 100) -> None:
|
|
161
|
+
self.flush_interval = flush_interval
|
|
162
|
+
self.max_buffer_size = max_buffer_size
|
|
163
|
+
self._buffer: list[StreamMessage] = []
|
|
164
|
+
self._last_flush = time.time()
|
|
165
|
+
|
|
166
|
+
def handle(self, message: StreamMessage) -> None:
|
|
167
|
+
if not self.should_handle(message):
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
self._buffer.append(message)
|
|
171
|
+
now = time.time()
|
|
172
|
+
if len(self._buffer) >= self.max_buffer_size or now - self._last_flush >= self.flush_interval:
|
|
173
|
+
self.flush()
|
|
174
|
+
|
|
175
|
+
def flush(self) -> None:
|
|
176
|
+
if not self._buffer:
|
|
177
|
+
return
|
|
178
|
+
self.process_batch(self._buffer)
|
|
179
|
+
self._buffer.clear()
|
|
180
|
+
self._last_flush = time.time()
|
|
181
|
+
|
|
182
|
+
def process_batch(self, messages: list[StreamMessage]) -> None: # pragma: no cover - abstract
|
|
183
|
+
"""Override to define how buffered messages should be processed."""
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class IntegrationTestHandler(StreamHandler):
|
|
187
|
+
"""Collect messages for integration tests or programmatic assertions."""
|
|
188
|
+
|
|
189
|
+
def __init__(self) -> None:
|
|
190
|
+
self.messages: list[StreamMessage] = []
|
|
191
|
+
|
|
192
|
+
def handle(self, message: StreamMessage) -> None:
|
|
193
|
+
self.messages.append(message)
|
|
194
|
+
|
|
195
|
+
def clear(self) -> None:
|
|
196
|
+
self.messages.clear()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class LossCurveHandler(StreamHandler):
|
|
200
|
+
"""Render a live-updating loss chart inside a fixed Rich panel."""
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
*,
|
|
205
|
+
metric_name: str = "train.loss",
|
|
206
|
+
max_points: int = 200,
|
|
207
|
+
width: int = 60,
|
|
208
|
+
console: Any | None = None,
|
|
209
|
+
live: Any | None = None,
|
|
210
|
+
) -> None:
|
|
211
|
+
try:
|
|
212
|
+
from rich.console import Console
|
|
213
|
+
from rich.live import Live
|
|
214
|
+
from rich.panel import Panel
|
|
215
|
+
from rich.text import Text
|
|
216
|
+
except ImportError as exc: # pragma: no cover - optional dependency guard
|
|
217
|
+
raise RuntimeError(
|
|
218
|
+
"LossCurveHandler requires the 'rich' package. Install synth-ai[analytics] or rich>=13."
|
|
219
|
+
) from exc
|
|
220
|
+
|
|
221
|
+
self.metric_name = metric_name
|
|
222
|
+
self.max_points = max_points
|
|
223
|
+
self.width = width
|
|
224
|
+
|
|
225
|
+
self._console_class = Console
|
|
226
|
+
self._panel_class = Panel
|
|
227
|
+
self._text_class = Text
|
|
228
|
+
|
|
229
|
+
self._console = console or Console()
|
|
230
|
+
self._live = live or Live(console=self._console, transient=False, refresh_per_second=8)
|
|
231
|
+
self._started = False
|
|
232
|
+
|
|
233
|
+
self._steps: list[int] = []
|
|
234
|
+
self._values: list[float] = []
|
|
235
|
+
self._status = "waiting"
|
|
236
|
+
self._last_event: str | None = None
|
|
237
|
+
|
|
238
|
+
def handle(self, message: StreamMessage) -> None:
|
|
239
|
+
updated = False
|
|
240
|
+
|
|
241
|
+
if message.stream_type is StreamType.STATUS:
|
|
242
|
+
status = str(message.data.get("status") or message.data.get("state") or "unknown")
|
|
243
|
+
if status != self._status:
|
|
244
|
+
self._status = status
|
|
245
|
+
updated = True
|
|
246
|
+
|
|
247
|
+
elif message.stream_type is StreamType.EVENTS:
|
|
248
|
+
event_type = message.data.get("type", "")
|
|
249
|
+
msg = message.data.get("message") or ""
|
|
250
|
+
level = message.data.get("level")
|
|
251
|
+
summary = f"{event_type}".strip()
|
|
252
|
+
if level:
|
|
253
|
+
summary += f" ({level})"
|
|
254
|
+
if msg:
|
|
255
|
+
summary += f": {msg}"
|
|
256
|
+
if summary != self._last_event:
|
|
257
|
+
self._last_event = summary
|
|
258
|
+
updated = True
|
|
259
|
+
|
|
260
|
+
elif message.stream_type is StreamType.METRICS:
|
|
261
|
+
if message.data.get("name") != self.metric_name:
|
|
262
|
+
return
|
|
263
|
+
value = message.data.get("value")
|
|
264
|
+
step = message.data.get("step")
|
|
265
|
+
if not isinstance(value, (int, float)) or not isinstance(step, int):
|
|
266
|
+
return
|
|
267
|
+
self._values.append(float(value))
|
|
268
|
+
self._steps.append(step)
|
|
269
|
+
if len(self._values) > self.max_points:
|
|
270
|
+
self._values = self._values[-self.max_points :]
|
|
271
|
+
self._steps = self._steps[-self.max_points :]
|
|
272
|
+
updated = True
|
|
273
|
+
|
|
274
|
+
elif message.stream_type is StreamType.TIMELINE:
|
|
275
|
+
phase = message.data.get("phase")
|
|
276
|
+
if phase:
|
|
277
|
+
self._status = str(phase)
|
|
278
|
+
updated = True
|
|
279
|
+
|
|
280
|
+
if updated:
|
|
281
|
+
self._refresh()
|
|
282
|
+
|
|
283
|
+
def flush(self) -> None:
|
|
284
|
+
if self._started:
|
|
285
|
+
with contextlib.suppress(Exception):
|
|
286
|
+
self._live.stop()
|
|
287
|
+
self._started = False
|
|
288
|
+
|
|
289
|
+
def _ensure_live(self) -> None:
|
|
290
|
+
if not self._started:
|
|
291
|
+
with contextlib.suppress(Exception):
|
|
292
|
+
self._live.start()
|
|
293
|
+
self._started = True
|
|
294
|
+
|
|
295
|
+
def _refresh(self) -> None:
|
|
296
|
+
self._ensure_live()
|
|
297
|
+
body = self._build_body()
|
|
298
|
+
title = f"{self.metric_name} | status={self._status}"
|
|
299
|
+
self._live.update(self._panel_class(body, title=title, border_style="cyan"))
|
|
300
|
+
|
|
301
|
+
def _build_body(self) -> Any:
|
|
302
|
+
if not self._values:
|
|
303
|
+
return self._text_class("Waiting for metrics…", style="yellow")
|
|
304
|
+
|
|
305
|
+
chart = self._render_sparkline()
|
|
306
|
+
last_value = self._values[-1]
|
|
307
|
+
lines = [
|
|
308
|
+
chart,
|
|
309
|
+
f"latest: {last_value:.4f} (step {self._steps[-1]})",
|
|
310
|
+
]
|
|
311
|
+
if self._last_event:
|
|
312
|
+
lines.append(f"event: {self._last_event}")
|
|
313
|
+
return "\n".join(lines)
|
|
314
|
+
|
|
315
|
+
def _render_sparkline(self) -> str:
|
|
316
|
+
blocks = "▁▂▃▄▅▆▇█"
|
|
317
|
+
tail_len = min(self.width, len(self._values))
|
|
318
|
+
tail = self._values[-tail_len:]
|
|
319
|
+
minimum = min(tail)
|
|
320
|
+
maximum = max(tail)
|
|
321
|
+
if maximum == minimum:
|
|
322
|
+
level = blocks[0]
|
|
323
|
+
return f"{minimum:.2f} {level * tail_len} {maximum:.2f}"
|
|
324
|
+
scale = (len(blocks) - 1) / (maximum - minimum)
|
|
325
|
+
chars = "".join(blocks[int((v - minimum) * scale + 0.5)] for v in tail)
|
|
326
|
+
return f"{minimum:.2f} {chars} {maximum:.2f}"
|
|
327
|
+
|
|
328
|
+
def __del__(self) -> None: # pragma: no cover - defensive cleanup
|
|
329
|
+
with contextlib.suppress(Exception):
|
|
330
|
+
self.flush()
|
|
331
|
+
|
|
332
|
+
class RichHandler(StreamHandler):
|
|
333
|
+
"""Rich powered handler with live progress and metrics table."""
|
|
334
|
+
|
|
335
|
+
def __init__(
|
|
336
|
+
self,
|
|
337
|
+
*,
|
|
338
|
+
event_log_size: int = 20,
|
|
339
|
+
console: Any | None = None,
|
|
340
|
+
) -> None:
|
|
341
|
+
try:
|
|
342
|
+
from rich.console import Console
|
|
343
|
+
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
|
|
344
|
+
from rich.table import Table
|
|
345
|
+
except ImportError as exc: # pragma: no cover - requires optional dependency
|
|
346
|
+
raise RuntimeError(
|
|
347
|
+
"RichHandler requires the 'rich' package. Install synth-ai[analytics] or rich>=13."
|
|
348
|
+
) from exc
|
|
349
|
+
|
|
350
|
+
self._console_class = Console
|
|
351
|
+
self._progress_class = Progress
|
|
352
|
+
self._spinner_column = SpinnerColumn
|
|
353
|
+
self._text_column = TextColumn
|
|
354
|
+
self._bar_column = BarColumn
|
|
355
|
+
self._table_class = Table
|
|
356
|
+
|
|
357
|
+
self._console = console or Console()
|
|
358
|
+
self._progress = Progress(
|
|
359
|
+
SpinnerColumn(),
|
|
360
|
+
TextColumn("[progress.description]{task.description}"),
|
|
361
|
+
BarColumn(),
|
|
362
|
+
TextColumn("{task.completed}/{task.total}" if console else ""),
|
|
363
|
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
364
|
+
transient=False,
|
|
365
|
+
console=self._console,
|
|
366
|
+
)
|
|
367
|
+
self._task_id: int | None = None
|
|
368
|
+
self._current_status = "unknown"
|
|
369
|
+
self._latest_metrics: dict[str, Any] = {}
|
|
370
|
+
self._event_log: deque[str] = deque(maxlen=event_log_size)
|
|
371
|
+
self._progress_started = False
|
|
372
|
+
|
|
373
|
+
def handle(self, message: StreamMessage) -> None:
|
|
374
|
+
if not self.should_handle(message):
|
|
375
|
+
return
|
|
376
|
+
|
|
377
|
+
if message.stream_type is StreamType.STATUS:
|
|
378
|
+
self._current_status = str(message.data.get("status") or message.data.get("state"))
|
|
379
|
+
self._ensure_progress_started()
|
|
380
|
+
if self._task_id is not None:
|
|
381
|
+
description = f"Status: {self._current_status}"
|
|
382
|
+
self._progress.update(self._task_id, description=description)
|
|
383
|
+
self._render_summary()
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
if message.stream_type is StreamType.EVENTS:
|
|
387
|
+
event_type = message.data.get("type", "event")
|
|
388
|
+
summary = message.data.get("message") or ""
|
|
389
|
+
level = message.data.get("level")
|
|
390
|
+
formatted = f"[{event_type}] {summary}".strip()
|
|
391
|
+
if level:
|
|
392
|
+
formatted = f"{formatted} ({level})"
|
|
393
|
+
self._event_log.append(formatted)
|
|
394
|
+
data = message.data.get("data") or {}
|
|
395
|
+
step = data.get("step") or data.get("current_step")
|
|
396
|
+
total_steps = data.get("total_steps") or data.get("max_steps")
|
|
397
|
+
if step and total_steps:
|
|
398
|
+
self._ensure_progress_started(total_steps)
|
|
399
|
+
if self._task_id is not None:
|
|
400
|
+
self._progress.update(self._task_id, completed=int(step), total=int(total_steps))
|
|
401
|
+
self._render_summary()
|
|
402
|
+
return
|
|
403
|
+
|
|
404
|
+
if message.stream_type is StreamType.METRICS:
|
|
405
|
+
name = message.data.get("name", "")
|
|
406
|
+
value = message.data.get("value")
|
|
407
|
+
if name:
|
|
408
|
+
self._latest_metrics[name] = value
|
|
409
|
+
self._render_summary()
|
|
410
|
+
return
|
|
411
|
+
|
|
412
|
+
if message.stream_type is StreamType.TIMELINE:
|
|
413
|
+
phase = message.data.get("phase", "")
|
|
414
|
+
if phase and phase.lower() not in {"training", "running"}:
|
|
415
|
+
self._event_log.append(f"[timeline] {phase}")
|
|
416
|
+
self._render_summary()
|
|
417
|
+
|
|
418
|
+
def flush(self) -> None:
|
|
419
|
+
if self._progress_started:
|
|
420
|
+
self._progress.stop()
|
|
421
|
+
self._progress_started = False
|
|
422
|
+
self._render_summary(force=True)
|
|
423
|
+
|
|
424
|
+
def _ensure_progress_started(self, total: int | float | None = None) -> None:
|
|
425
|
+
if not self._progress_started:
|
|
426
|
+
self._progress.start()
|
|
427
|
+
self._progress_started = True
|
|
428
|
+
if self._task_id is None:
|
|
429
|
+
self._task_id = self._progress.add_task(
|
|
430
|
+
f"Status: {self._current_status}", total=total or 100
|
|
431
|
+
)
|
|
432
|
+
elif total is not None and self._task_id is not None:
|
|
433
|
+
self._progress.update(self._task_id, total=total)
|
|
434
|
+
|
|
435
|
+
def _render_summary(self, force: bool = False) -> None:
|
|
436
|
+
if force and self._progress_started:
|
|
437
|
+
self._progress.refresh()
|
|
438
|
+
|
|
439
|
+
table = self._table_class(title="Latest Metrics")
|
|
440
|
+
table.add_column("Metric")
|
|
441
|
+
table.add_column("Value")
|
|
442
|
+
|
|
443
|
+
if not self._latest_metrics:
|
|
444
|
+
table.add_row("—", "—")
|
|
445
|
+
else:
|
|
446
|
+
for name, value in sorted(self._latest_metrics.items()):
|
|
447
|
+
table.add_row(str(name), str(value))
|
|
448
|
+
|
|
449
|
+
if self._progress_started:
|
|
450
|
+
self._progress.console.print(table)
|
|
451
|
+
else:
|
|
452
|
+
self._console.print(table)
|
|
453
|
+
|
|
454
|
+
if self._event_log:
|
|
455
|
+
self._console.print("\nRecent events:")
|
|
456
|
+
for entry in list(self._event_log):
|
|
457
|
+
self._console.print(f" • {entry}")
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
__all__ = [
|
|
461
|
+
"BufferedHandler",
|
|
462
|
+
"CallbackHandler",
|
|
463
|
+
"CLIHandler",
|
|
464
|
+
"JSONHandler",
|
|
465
|
+
"IntegrationTestHandler",
|
|
466
|
+
"LossCurveHandler",
|
|
467
|
+
"RichHandler",
|
|
468
|
+
"StreamHandler",
|
|
469
|
+
]
|