synth-ai 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +5 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +66 -3
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +17 -49
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +13 -5
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +15 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +134 -3
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +4 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +6 -3
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +125 -10
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +12 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +58 -1487
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -11
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/validators.py +2 -2
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/utils/env.py +25 -18
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/modal.py +2 -2
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/METADATA +8 -3
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/RECORD +184 -109
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, Union
|
|
4
4
|
import base64
|
|
5
|
+
import time
|
|
5
6
|
from io import BytesIO
|
|
6
7
|
|
|
7
8
|
from pydantic import BaseModel, Field
|
|
@@ -19,6 +20,8 @@ from synth_ai.environments.environment.tools import (
|
|
|
19
20
|
)
|
|
20
21
|
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
|
21
22
|
from synth_ai.environments.stateful.core import StatefulEnvironment
|
|
23
|
+
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, TimeRecord
|
|
24
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
22
25
|
try: # optional for image encoding
|
|
23
26
|
import numpy as _np # type: ignore
|
|
24
27
|
from PIL import Image as _PILImage # type: ignore
|
|
@@ -121,6 +124,7 @@ class PokemonRedEnvironment(StatefulEnvironment, ReproducibleEnvironment[Pokemon
|
|
|
121
124
|
task_instance: Optional[PokemonRedTaskInstance] = None,
|
|
122
125
|
custom_step_obs: Optional[GetObservationCallable] = None,
|
|
123
126
|
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
|
127
|
+
tracer: Optional[SessionTracer] = None,
|
|
124
128
|
):
|
|
125
129
|
self.name = "PokemonRed"
|
|
126
130
|
self.task_instance = task_instance or DEFAULT_TASK_INSTANCE
|
|
@@ -129,6 +133,7 @@ class PokemonRedEnvironment(StatefulEnvironment, ReproducibleEnvironment[Pokemon
|
|
|
129
133
|
custom_ckpt_obs or PokemonRedObservationCallable()
|
|
130
134
|
)
|
|
131
135
|
self.engine = PokemonRedEngine(self.task_instance)
|
|
136
|
+
self.tracer = tracer
|
|
132
137
|
|
|
133
138
|
# Register tools
|
|
134
139
|
self._press_button_tool = PressButtonTool(self.engine)
|
|
@@ -203,6 +208,27 @@ class PokemonRedEnvironment(StatefulEnvironment, ReproducibleEnvironment[Pokemon
|
|
|
203
208
|
if tool_result.error and hasattr(pub_state, "error_info"):
|
|
204
209
|
pub_state.error_info = tool_result.error
|
|
205
210
|
|
|
211
|
+
# Record EnvironmentEvent for tracing if tracer is available
|
|
212
|
+
if self.tracer and hasattr(priv_state, 'reward_last_step'):
|
|
213
|
+
# Get state information for the event
|
|
214
|
+
prev_state = getattr(self.engine, '_previous_state', None)
|
|
215
|
+
terminated = getattr(priv_state, 'terminated', False)
|
|
216
|
+
truncated = getattr(priv_state, 'truncated', False)
|
|
217
|
+
|
|
218
|
+
# Convert states to dict for serialization
|
|
219
|
+
pub_state_dict = pub_state.__dict__ if hasattr(pub_state, '__dict__') else pub_state
|
|
220
|
+
|
|
221
|
+
env_event = EnvironmentEvent(
|
|
222
|
+
system_instance_id="pokemon_red_env",
|
|
223
|
+
time_record=TimeRecord(event_time=time.time()),
|
|
224
|
+
reward=float(priv_state.reward_last_step),
|
|
225
|
+
terminated=terminated,
|
|
226
|
+
truncated=truncated,
|
|
227
|
+
system_state_before=prev_state if prev_state else None,
|
|
228
|
+
system_state_after=pub_state_dict,
|
|
229
|
+
)
|
|
230
|
+
await self.tracer.record_event(env_event)
|
|
231
|
+
|
|
206
232
|
return await self._to_observation(
|
|
207
233
|
priv_state, pub_state, self.custom_step_observation_callable
|
|
208
234
|
)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trace hooks for Pokemon Red environment - v3 version.
|
|
3
|
+
Captures reward information and saves to Turso database.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
from synth_ai.tracing_v3.abstractions import BaseEvent, EnvironmentEvent
|
|
10
|
+
from synth_ai.tracing_v3.hooks import HookManager
|
|
11
|
+
|
|
12
|
+
# Pokemon Red achievement categories by reward value
|
|
13
|
+
EXPLORATION_ACHIEVEMENTS = {
|
|
14
|
+
0.02: "explore_new_area",
|
|
15
|
+
0.04: "explore_multiple_areas",
|
|
16
|
+
1.0: "leave_starting_area",
|
|
17
|
+
1.5: "enter_new_city",
|
|
18
|
+
2.0: "explore_new_route",
|
|
19
|
+
5.0: "enter_gym_building",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
TRAINING_ACHIEVEMENTS = {
|
|
23
|
+
0.2: "pokemon_level_up",
|
|
24
|
+
0.3: "reach_power_level",
|
|
25
|
+
3.0: "pokemon_ready_for_battle",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
BATTLE_ACHIEVEMENTS = {
|
|
29
|
+
0.1: "encounter_wild_pokemon",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
RESOURCE_ACHIEVEMENTS = {
|
|
33
|
+
0.05: "keep_pokemon_healthy",
|
|
34
|
+
0.5: "find_valuable_item",
|
|
35
|
+
0.8: "visit_pokemon_center",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
MAJOR_ACHIEVEMENTS = {
|
|
39
|
+
50.0: "defeat_brock_win_badge",
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
async def track_pokemon_rewards(event_obj: BaseEvent, **kwargs) -> Optional[Dict[str, Any]]:
|
|
44
|
+
"""Hook that captures detailed Pokemon Red reward information."""
|
|
45
|
+
# Only process EnvironmentEvents
|
|
46
|
+
if not isinstance(event_obj, EnvironmentEvent):
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
reward = event_obj.reward
|
|
50
|
+
if reward is None or reward == 0.0:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
# Determine achievement type based on reward value
|
|
54
|
+
achievement_type = "unknown"
|
|
55
|
+
achievement_category = "other"
|
|
56
|
+
|
|
57
|
+
# Check each category
|
|
58
|
+
if reward in EXPLORATION_ACHIEVEMENTS:
|
|
59
|
+
achievement_type = EXPLORATION_ACHIEVEMENTS[reward]
|
|
60
|
+
achievement_category = "exploration"
|
|
61
|
+
elif reward in TRAINING_ACHIEVEMENTS:
|
|
62
|
+
achievement_type = TRAINING_ACHIEVEMENTS[reward]
|
|
63
|
+
achievement_category = "training"
|
|
64
|
+
elif reward in BATTLE_ACHIEVEMENTS:
|
|
65
|
+
achievement_type = BATTLE_ACHIEVEMENTS[reward]
|
|
66
|
+
achievement_category = "battle"
|
|
67
|
+
elif reward in RESOURCE_ACHIEVEMENTS:
|
|
68
|
+
achievement_type = RESOURCE_ACHIEVEMENTS[reward]
|
|
69
|
+
achievement_category = "resource"
|
|
70
|
+
elif reward in MAJOR_ACHIEVEMENTS:
|
|
71
|
+
achievement_type = MAJOR_ACHIEVEMENTS[reward]
|
|
72
|
+
achievement_category = "major"
|
|
73
|
+
|
|
74
|
+
return {
|
|
75
|
+
"reward_value": reward,
|
|
76
|
+
"achievement_type": achievement_type,
|
|
77
|
+
"achievement_category": achievement_category,
|
|
78
|
+
"timestamp": datetime.now().isoformat(),
|
|
79
|
+
"system_state_before": event_obj.system_state_before,
|
|
80
|
+
"system_state_after": event_obj.system_state_after,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def track_pokemon_milestones(event_obj: BaseEvent, **kwargs) -> Optional[Dict[str, Any]]:
|
|
85
|
+
"""Hook that tracks significant Pokemon Red milestones."""
|
|
86
|
+
# Only process EnvironmentEvents
|
|
87
|
+
if not isinstance(event_obj, EnvironmentEvent):
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
reward = event_obj.reward
|
|
91
|
+
if reward is None:
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
# Track major milestones
|
|
95
|
+
if reward >= 1.0: # Significant progress rewards
|
|
96
|
+
return {
|
|
97
|
+
"milestone": "major_progress",
|
|
98
|
+
"reward": reward,
|
|
99
|
+
"timestamp": datetime.now().isoformat(),
|
|
100
|
+
}
|
|
101
|
+
elif reward >= 0.5: # Moderate rewards
|
|
102
|
+
return {
|
|
103
|
+
"milestone": "moderate_progress",
|
|
104
|
+
"reward": reward,
|
|
105
|
+
"timestamp": datetime.now().isoformat(),
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
async def track_pokemon_outcomes(event_obj: BaseEvent, **kwargs) -> Optional[Dict[str, Any]]:
|
|
112
|
+
"""Hook that tracks episode outcomes for Pokemon Red."""
|
|
113
|
+
# Only process EnvironmentEvents
|
|
114
|
+
if not isinstance(event_obj, EnvironmentEvent):
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
# Check for termination conditions
|
|
118
|
+
if event_obj.terminated or event_obj.truncated:
|
|
119
|
+
total_reward = getattr(event_obj, 'total_reward', 0.0)
|
|
120
|
+
steps_taken = getattr(event_obj, 'step_count', 0)
|
|
121
|
+
|
|
122
|
+
# Extract achievement information from system state
|
|
123
|
+
achievements_count = 0
|
|
124
|
+
if event_obj.system_state_after:
|
|
125
|
+
# Count positive rewards as achievements
|
|
126
|
+
# This is a simplified count - in practice you'd track actual achievements
|
|
127
|
+
achievements_count = max(1, int(total_reward / 0.1)) # Rough estimate
|
|
128
|
+
|
|
129
|
+
return {
|
|
130
|
+
"outcome_type": "episode_end",
|
|
131
|
+
"total_reward": total_reward,
|
|
132
|
+
"steps_taken": steps_taken,
|
|
133
|
+
"achievements_count": achievements_count,
|
|
134
|
+
"terminated": event_obj.terminated,
|
|
135
|
+
"truncated": event_obj.truncated,
|
|
136
|
+
"timestamp": datetime.now().isoformat(),
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# Create the global POKEMON_RED_HOOKS instance
|
|
143
|
+
POKEMON_RED_HOOKS = HookManager()
|
|
144
|
+
|
|
145
|
+
# Register all hooks
|
|
146
|
+
POKEMON_RED_HOOKS.register(
|
|
147
|
+
"event_recorded",
|
|
148
|
+
track_pokemon_rewards,
|
|
149
|
+
name="pokemon_rewards",
|
|
150
|
+
priority=10,
|
|
151
|
+
event_types=["environment"],
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
POKEMON_RED_HOOKS.register(
|
|
155
|
+
"event_recorded",
|
|
156
|
+
track_pokemon_milestones,
|
|
157
|
+
name="pokemon_milestones",
|
|
158
|
+
priority=5,
|
|
159
|
+
event_types=["environment"],
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
POKEMON_RED_HOOKS.register(
|
|
163
|
+
"event_recorded",
|
|
164
|
+
track_pokemon_outcomes,
|
|
165
|
+
name="pokemon_outcomes",
|
|
166
|
+
priority=5,
|
|
167
|
+
event_types=["environment"],
|
|
168
|
+
)
|
synth_ai/http.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backward-compatible HTTP client exports.
|
|
3
|
+
|
|
4
|
+
Historically, some modules imported ``synth_ai.http``. The canonical location
|
|
5
|
+
is ``synth_ai.http_client``; this module simply re-exports the same symbols so
|
|
6
|
+
legacy imports keep working.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from synth_ai.http_client import AsyncHttpClient, HTTPError, sleep
|
|
11
|
+
|
|
12
|
+
__all__ = ["AsyncHttpClient", "HTTPError", "sleep"]
|
synth_ai/judge_schemas.py
CHANGED
|
@@ -9,7 +9,7 @@ This is the canonical contract that the backend MUST conform to.
|
|
|
9
9
|
|
|
10
10
|
from __future__ import annotations
|
|
11
11
|
|
|
12
|
-
from typing import Any, Literal
|
|
12
|
+
from typing import Any, Literal, Optional
|
|
13
13
|
|
|
14
14
|
from pydantic import BaseModel, Field
|
|
15
15
|
|
|
@@ -31,7 +31,7 @@ class ReviewPayload(BaseModel):
|
|
|
31
31
|
description="Map of criterion keys to their scores"
|
|
32
32
|
)
|
|
33
33
|
total: float = Field(default=0.0, description="Aggregated total score")
|
|
34
|
-
summary: str
|
|
34
|
+
summary: Optional[str] = Field(None, description="Optional text summary")
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class JudgeScoreResponse(BaseModel):
|
|
@@ -46,7 +46,7 @@ class JudgeScoreResponse(BaseModel):
|
|
|
46
46
|
default_factory=list,
|
|
47
47
|
description="List of per-event rubric reviews (one per step)"
|
|
48
48
|
)
|
|
49
|
-
outcome_review: ReviewPayload
|
|
49
|
+
outcome_review: Optional[ReviewPayload] = Field(
|
|
50
50
|
None,
|
|
51
51
|
description="Optional outcome-level rubric review"
|
|
52
52
|
)
|
|
@@ -63,7 +63,7 @@ class JudgeScoreResponse(BaseModel):
|
|
|
63
63
|
description="Request metadata (provider, options, etc.)"
|
|
64
64
|
)
|
|
65
65
|
|
|
66
|
-
def aggregate_event_reward(self) -> float
|
|
66
|
+
def aggregate_event_reward(self) -> Optional[float]:
|
|
67
67
|
"""
|
|
68
68
|
Aggregate all event totals into a single reward.
|
|
69
69
|
|
|
@@ -74,7 +74,7 @@ class JudgeScoreResponse(BaseModel):
|
|
|
74
74
|
return None
|
|
75
75
|
return sum(self.event_totals)
|
|
76
76
|
|
|
77
|
-
def aggregate_outcome_reward(self) -> float
|
|
77
|
+
def aggregate_outcome_reward(self) -> Optional[float]:
|
|
78
78
|
"""
|
|
79
79
|
Extract outcome reward from outcome_review.
|
|
80
80
|
|
|
@@ -92,15 +92,15 @@ class JudgeTaskApp(BaseModel):
|
|
|
92
92
|
"""Task application metadata."""
|
|
93
93
|
|
|
94
94
|
id: str = Field(..., description="Task app identifier")
|
|
95
|
-
base_url: str
|
|
95
|
+
base_url: Optional[str] = Field(None, description="Optional base URL for task app")
|
|
96
96
|
|
|
97
97
|
|
|
98
98
|
class JudgeOptions(BaseModel):
|
|
99
99
|
"""Judge provider and configuration options."""
|
|
100
100
|
|
|
101
|
-
provider: str
|
|
102
|
-
model: str
|
|
103
|
-
rubric_id: str
|
|
101
|
+
provider: Optional[str] = Field(None, description="Judge provider (e.g., 'openai', 'groq')")
|
|
102
|
+
model: Optional[str] = Field(None, description="Model identifier")
|
|
103
|
+
rubric_id: Optional[str] = Field(None, description="Rubric identifier")
|
|
104
104
|
event: bool = Field(True, description="Enable event-level judging")
|
|
105
105
|
outcome: bool = Field(True, description="Enable outcome-level judging")
|
|
106
106
|
|
|
@@ -123,5 +123,4 @@ class JudgeScoreRequest(BaseModel):
|
|
|
123
123
|
task_app: JudgeTaskApp = Field(..., description="Task application metadata")
|
|
124
124
|
trace: JudgeTracePayload = Field(..., description="Trajectory trace to evaluate")
|
|
125
125
|
options: JudgeOptions = Field(default_factory=lambda: JudgeOptions(), description="Judge options")
|
|
126
|
-
rubric: dict[str, Any]
|
|
127
|
-
|
|
126
|
+
rubric: Optional[dict[str, Any]] = Field(None, description="Optional explicit rubric criteria")
|
synth_ai/learning/rl/client.py
CHANGED
|
@@ -107,7 +107,9 @@ class RlClient:
|
|
|
107
107
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
|
108
108
|
try:
|
|
109
109
|
js = await http.get(
|
|
110
|
-
f"{_api_base(self._base_url)}/learning/jobs/{job_id}/events",
|
|
110
|
+
f"{_api_base(self._base_url)}/learning/jobs/{job_id}/events",
|
|
111
|
+
params=params,
|
|
112
|
+
headers={"accept": "application/json"},
|
|
111
113
|
)
|
|
112
114
|
except HTTPError as he:
|
|
113
115
|
with suppress(Exception):
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .config import StreamConfig
|
|
2
|
+
from .handlers import (
|
|
3
|
+
BufferedHandler,
|
|
4
|
+
CallbackHandler,
|
|
5
|
+
CLIHandler,
|
|
6
|
+
IntegrationTestHandler,
|
|
7
|
+
JSONHandler,
|
|
8
|
+
LossCurveHandler,
|
|
9
|
+
RichHandler,
|
|
10
|
+
StreamHandler,
|
|
11
|
+
)
|
|
12
|
+
from .streamer import JobStreamer, StreamEndpoints
|
|
13
|
+
from .types import StreamMessage, StreamType
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BufferedHandler",
|
|
17
|
+
"CallbackHandler",
|
|
18
|
+
"CLIHandler",
|
|
19
|
+
"IntegrationTestHandler",
|
|
20
|
+
"JSONHandler",
|
|
21
|
+
"LossCurveHandler",
|
|
22
|
+
"JobStreamer",
|
|
23
|
+
"RichHandler",
|
|
24
|
+
"StreamEndpoints",
|
|
25
|
+
"StreamConfig",
|
|
26
|
+
"StreamHandler",
|
|
27
|
+
"StreamMessage",
|
|
28
|
+
"StreamType",
|
|
29
|
+
]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .types import StreamType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(slots=True)
|
|
10
|
+
class StreamConfig:
|
|
11
|
+
"""Configuration describing which streams to consume and how to filter them."""
|
|
12
|
+
|
|
13
|
+
enabled_streams: set[StreamType] = field(default_factory=lambda: set(StreamType))
|
|
14
|
+
event_types: set[str] | None = None # Whitelist: only include these event types
|
|
15
|
+
event_types_exclude: set[str] | None = None # Blacklist: exclude these event types
|
|
16
|
+
event_levels: set[str] | None = None
|
|
17
|
+
metric_names: set[str] | None = None
|
|
18
|
+
metric_phases: set[str] | None = None
|
|
19
|
+
timeline_phases: set[str] | None = None
|
|
20
|
+
sample_rate: float = 1.0
|
|
21
|
+
max_events_per_poll: int | None = None
|
|
22
|
+
deduplicate: bool = True
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def default(cls) -> StreamConfig:
|
|
26
|
+
"""Return a configuration representing the default (all streams) view."""
|
|
27
|
+
return cls(
|
|
28
|
+
event_types_exclude={
|
|
29
|
+
# Filter out noisy events that just announce what metrics already show
|
|
30
|
+
"sft.progress", # Generic "Training progress" with no data
|
|
31
|
+
"sft.loss", # Generic "Loss update" with no data
|
|
32
|
+
"sft.upstream.status", # Very verbose status echo events
|
|
33
|
+
}
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def minimal(cls) -> StreamConfig:
|
|
38
|
+
"""Return a configuration streaming status updates only."""
|
|
39
|
+
return cls(enabled_streams={StreamType.STATUS})
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def verbose(cls) -> StreamConfig:
|
|
43
|
+
"""Return a configuration with all streams and events (no filters)."""
|
|
44
|
+
return cls()
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def progress_only(cls) -> StreamConfig:
|
|
48
|
+
"""Return a configuration tailored to show training progress."""
|
|
49
|
+
return cls(
|
|
50
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
51
|
+
event_types={"sft.progress", "rl.train.step", "sft.validation.summary"},
|
|
52
|
+
metric_names={"train.loss", "eval.reward_mean"},
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def errors_only(cls) -> StreamConfig:
|
|
57
|
+
"""Return a configuration that focuses on heightened severity signals."""
|
|
58
|
+
return cls(
|
|
59
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS},
|
|
60
|
+
event_levels={"error", "warning"},
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def should_include_event(self, event: dict[str, Any]) -> bool:
|
|
64
|
+
"""Determine whether an event message should be included."""
|
|
65
|
+
event_type = event.get("type")
|
|
66
|
+
|
|
67
|
+
# Apply blacklist first (takes precedence)
|
|
68
|
+
if self.event_types_exclude and event_type in self.event_types_exclude:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
# Then apply whitelist
|
|
72
|
+
if self.event_types and event_type not in self.event_types:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
if self.event_levels:
|
|
76
|
+
return event.get("level") in self.event_levels
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
def should_include_metric(self, metric: dict[str, Any]) -> bool:
|
|
80
|
+
"""Determine whether a metric point should be included."""
|
|
81
|
+
if self.metric_names and metric.get("name") not in self.metric_names:
|
|
82
|
+
return False
|
|
83
|
+
if self.metric_phases:
|
|
84
|
+
return metric.get("phase") in self.metric_phases
|
|
85
|
+
return True
|
|
86
|
+
|
|
87
|
+
def should_include_timeline(self, timeline_entry: dict[str, Any]) -> bool:
|
|
88
|
+
"""Determine whether a timeline entry should be included."""
|
|
89
|
+
if self.timeline_phases:
|
|
90
|
+
return timeline_entry.get("phase") in self.timeline_phases
|
|
91
|
+
return True
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
__all__ = ["StreamConfig"]
|