synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +6 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -38
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +288 -39
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
- synth_ai/api/train/builders.py +99 -4
- synth_ai/api/train/cli.py +516 -26
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +23 -2
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +61 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/auth/credentials.py +119 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +94 -18
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1112 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +200 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/validation.py +386 -0
- synth_ai/cli/demo.py +30 -158
- synth_ai/cli/deploy/__init__.py +43 -0
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +51 -1480
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -10
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +518 -0
- synth_ai/streaming/streamer.py +320 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +45 -9
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +40 -33
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +285 -3
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Pokémon VL: Vision-Language RL Pipeline
|
|
2
|
+
|
|
3
|
+
This playbook demonstrates end-to-end vision-language reinforcement learning on Pokémon Red using Synth AI's CLI tools. We follow the eval → collect data → SFT → RL → eval pipeline, but with vision models throughout.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
**Model**: Qwen3-VL-4B-Instruct (4B parameter vision-language model via Synth API)
|
|
8
|
+
**Environment**: Pokémon Red (Game Boy emulation with vision support)
|
|
9
|
+
**Benchmark**: Pallet Town progression task (leave bedroom → get starter → win first battle)
|
|
10
|
+
|
|
11
|
+
## Pipeline Steps
|
|
12
|
+
|
|
13
|
+
1. **Deploy Task App** - Host the Pokémon Red environment
|
|
14
|
+
2. **Collect Vision Rollouts** - Generate high-quality demonstrations using Qwen3-VL
|
|
15
|
+
3. **Filter Dataset** - Extract successful trajectories for supervised fine-tuning
|
|
16
|
+
4. **Fine-Tune Qwen3-4B VL** - Train vision-language model on filtered data
|
|
17
|
+
5. **Vision-Language RL** - Bootstrap RL training from SFT checkpoint
|
|
18
|
+
6. **Final Evaluation** - Compare SFT and RL performance
|
|
19
|
+
|
|
20
|
+
## Prerequisites
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
# Install dependencies
|
|
24
|
+
uv pip install -e .
|
|
25
|
+
|
|
26
|
+
# Setup authentication
|
|
27
|
+
uvx synth-ai setup
|
|
28
|
+
|
|
29
|
+
# Copy environment template
|
|
30
|
+
cp examples/blog_posts/pokemon_vl/.env.example .env
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Quick Start
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
# Export trace database path
|
|
37
|
+
export POKEMON_VL_TRACE_DB=traces/v3/pokemon_vl_blog.db
|
|
38
|
+
|
|
39
|
+
# 1. Deploy task app
|
|
40
|
+
uvx synth-ai deploy pokemon_red --runtime modal --name pokemon-vl-blog --env-file .env
|
|
41
|
+
|
|
42
|
+
# 2. Collect vision rollouts with Qwen3-VL
|
|
43
|
+
uvx synth-ai eval pokemon_red --config examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml --trace-db "${POKEMON_VL_TRACE_DB}"
|
|
44
|
+
|
|
45
|
+
# 3. Filter high-reward trajectories
|
|
46
|
+
uvx synth-ai filter --config examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml
|
|
47
|
+
|
|
48
|
+
# 4. Fine-tune Qwen3-4B VL
|
|
49
|
+
uvx synth-ai train --type sft --config examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml --env-file .env --poll
|
|
50
|
+
|
|
51
|
+
# 5. RL from SFT checkpoint (replace JOB_ID)
|
|
52
|
+
uvx synth-ai train --type rl --config examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml --env-file .env --poll
|
|
53
|
+
|
|
54
|
+
# 6. Evaluate final RL model
|
|
55
|
+
uvx synth-ai eval pokemon_red --config examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml --trace-db "${POKEMON_VL_TRACE_DB}"
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Vision Features
|
|
59
|
+
|
|
60
|
+
- **Full Game Boy Frames**: Base64-encoded PNG screenshots (160x144 resolution)
|
|
61
|
+
- **Vision-Only Mode**: Pure image understanding without text state
|
|
62
|
+
- **Vision + Text Mode**: Combined visual and structured state information
|
|
63
|
+
- **Efficient Action Batching**: `execute_sequence` tool for 5-10 actions per inference call
|
|
64
|
+
|
|
65
|
+
## Expected Results
|
|
66
|
+
|
|
67
|
+
| Stage | Model | Mean Reward | Success Rate | Best Achievement |
|
|
68
|
+
|-------|-------|-------------|--------------|------------------|
|
|
69
|
+
| Initial | Qwen3-VL (vision) | ~150 | 60% | Win first battle |
|
|
70
|
+
| SFT | Qwen3-4B VL | ~200 | 75% | Win first battle + explore |
|
|
71
|
+
| RL | Qwen3-4B VL + RL | ~350 | 85% | Complete Pallet Town |
|
|
72
|
+
|
|
73
|
+
## Files
|
|
74
|
+
|
|
75
|
+
- `configs/` - All TOML configuration files
|
|
76
|
+
- `ft_data/` - Filtered datasets for fine-tuning
|
|
77
|
+
- `.env.example` - Environment variables template
|
|
78
|
+
|
|
79
|
+
## Vision Model Configuration
|
|
80
|
+
|
|
81
|
+
The vision models receive:
|
|
82
|
+
- **Input**: Game Boy screenshot + optional structured state (position, HP, party, etc.)
|
|
83
|
+
- **Output**: Sequence of button presses via `execute_sequence` tool
|
|
84
|
+
- **Action Space**: UP, DOWN, LEFT, RIGHT, A, B, START, SELECT with frame counts
|
|
85
|
+
|
|
86
|
+
## Reward Function
|
|
87
|
+
|
|
88
|
+
Dense rewards for Pallet Town progression:
|
|
89
|
+
- Leave bedroom (+20)
|
|
90
|
+
- Exit house (+30)
|
|
91
|
+
- Find Oak's lab (+40)
|
|
92
|
+
- Talk to Oak (+50)
|
|
93
|
+
- Get starter Pokémon (+100)
|
|
94
|
+
- Enter battle (+75)
|
|
95
|
+
- Deal damage (+50 per 10HP)
|
|
96
|
+
- Win battle (+150)
|
|
97
|
+
|
|
98
|
+
Total possible: ~700 points
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[eval]
|
|
2
|
+
app_id = "pokemon_red"
|
|
3
|
+
task_app_url = "http://127.0.0.1:8914"
|
|
4
|
+
model = "gpt-5-nano"
|
|
5
|
+
seeds = [0] # Single seed for testing
|
|
6
|
+
max_turns = 10 # 10 LLM calls per episode to allow more progress
|
|
7
|
+
concurrency = 1 # Run 1 rollout
|
|
8
|
+
env_name = "pokemon_red"
|
|
9
|
+
policy_name = "pokemon_vl_qwen3_vl" # Reuse policy config, will override model
|
|
10
|
+
trace_format = "full"
|
|
11
|
+
return_trace = true
|
|
12
|
+
|
|
13
|
+
[eval.policy_config]
|
|
14
|
+
provider = "openai" # Use OpenAI API for gpt-5-nano
|
|
15
|
+
model = "gpt-5-nano"
|
|
16
|
+
inference_url = "https://api.openai.com/v1"
|
|
17
|
+
temperature = 0.7
|
|
18
|
+
top_p = 0.95
|
|
19
|
+
max_tokens = 512
|
|
20
|
+
use_vision = true
|
|
21
|
+
image_only_mode = false
|
|
22
|
+
max_llm_calls = 10
|
|
23
|
+
|
|
24
|
+
[eval.env_config.env_params]
|
|
25
|
+
max_steps_per_episode = 100 # Allow time to achieve milestones
|
|
26
|
+
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[eval]
|
|
2
|
+
app_id = "pokemon_red"
|
|
3
|
+
task_app_url = "http://127.0.0.1:8914"
|
|
4
|
+
model = "Qwen/Qwen3-VL-30B-A3B-Thinking" # Larger thinking variant - needs more time to load
|
|
5
|
+
seeds = [10, 11] # 2 seeds for quick testing
|
|
6
|
+
max_turns = 10 # 10 LLM calls per episode to allow more progress
|
|
7
|
+
concurrency = 2 # Run 2 rollouts in parallel
|
|
8
|
+
env_name = "pokemon_red"
|
|
9
|
+
policy_name = "pokemon_vl_qwen3_vl"
|
|
10
|
+
trace_format = "full"
|
|
11
|
+
return_trace = true
|
|
12
|
+
|
|
13
|
+
[eval.policy_config]
|
|
14
|
+
provider = "synth" # Use Synth internal API for vision models
|
|
15
|
+
model = "Qwen/Qwen3-VL-30B-A3B-Thinking" # Larger thinking variant - needs more time to load
|
|
16
|
+
inference_url = "https://synth-laboratories-dev--learning-v2-service-fastapi-app.modal.run/chat/completions"
|
|
17
|
+
temperature = 1.0 # Higher temperature to encourage exploration
|
|
18
|
+
top_p = 0.95
|
|
19
|
+
max_tokens = 2048 # Reduced to avoid token budget issues
|
|
20
|
+
use_vision = true
|
|
21
|
+
image_only_mode = false
|
|
22
|
+
max_llm_calls = 10
|
|
23
|
+
thinking_mode = "think" # Enable thinking/reasoning mode
|
|
24
|
+
thinking_budget = 3072 # Increased token budget for reasoning
|
|
25
|
+
|
|
26
|
+
[eval.env_config.env_params]
|
|
27
|
+
max_steps_per_episode = 100 # Increased from 3 to allow time to achieve milestones
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[eval]
|
|
2
|
+
app_id = "pokemon_red"
|
|
3
|
+
task_app_url = "http://127.0.0.1:8914"
|
|
4
|
+
model = "fft:REPLACE-WITH-RL-JOB-ID" # Update with final RL job ID
|
|
5
|
+
seeds = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
|
6
|
+
max_turns = 15 # Allow more steps for trained model
|
|
7
|
+
concurrency = 3
|
|
8
|
+
env_name = "pokemon_red"
|
|
9
|
+
policy_name = "pokemon_vl_rl_final"
|
|
10
|
+
trace_format = "full"
|
|
11
|
+
return_trace = true
|
|
12
|
+
|
|
13
|
+
[eval.policy_config]
|
|
14
|
+
provider = "synth"
|
|
15
|
+
model = "fft:REPLACE-WITH-RL-JOB-ID" # Update with final RL job ID
|
|
16
|
+
temperature = 0.1 # Lower temperature for evaluation
|
|
17
|
+
top_p = 0.9
|
|
18
|
+
max_tokens = 4096
|
|
19
|
+
use_vision = true
|
|
20
|
+
image_only_mode = false
|
|
21
|
+
max_llm_calls = 15
|
|
22
|
+
|
|
23
|
+
[eval.env_config.env_params]
|
|
24
|
+
max_steps_per_episode = 15
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Filter high-quality vision-language rollouts for SFT training
|
|
2
|
+
# Assumes traces stored in pokemon_vl_blog.db via eval commands
|
|
3
|
+
|
|
4
|
+
[filter]
|
|
5
|
+
db = "traces/v3/pokemon_vl_blog.db"
|
|
6
|
+
output = "examples/blog_posts/pokemon_vl/ft_data/pokemon_vl_high_reward.jsonl"
|
|
7
|
+
min_official_score = 0.3 # Require at least 30% completion (Pallet Town progression)
|
|
8
|
+
models = ["Qwen/Qwen3-VL-4B-Instruct"] # Vision models used for rollouts
|
|
9
|
+
shuffle = true
|
|
10
|
+
shuffle_seed = 42
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Vision-Language RL: Continue training Qwen3-4B VL from SFT checkpoint
|
|
2
|
+
# Update task_url with deployed Modal task app URL
|
|
3
|
+
# Set model.source to the SFT job id from `uvx synth-ai train --type sft`
|
|
4
|
+
|
|
5
|
+
type = "rl"
|
|
6
|
+
|
|
7
|
+
[services]
|
|
8
|
+
task_url = "http://127.0.0.1:8914"
|
|
9
|
+
|
|
10
|
+
[compute]
|
|
11
|
+
gpu_type = "H100"
|
|
12
|
+
gpu_count = 8
|
|
13
|
+
|
|
14
|
+
[topology]
|
|
15
|
+
gpus_for_vllm = 4
|
|
16
|
+
gpus_for_training = 3
|
|
17
|
+
gpus_for_ref = 1
|
|
18
|
+
|
|
19
|
+
[vllm]
|
|
20
|
+
tensor_parallel_size = 4
|
|
21
|
+
|
|
22
|
+
[model]
|
|
23
|
+
source = "fft:REPLACE-WITH-SFT-JOB-ID" # Update with actual SFT job ID
|
|
24
|
+
label = "pokemon_vl_rl_blog"
|
|
25
|
+
supports_vision = true
|
|
26
|
+
|
|
27
|
+
[rollout]
|
|
28
|
+
max_turns = 10
|
|
29
|
+
episodes_per_batch = 64
|
|
30
|
+
task_app_origin_rewards_only = true
|
|
31
|
+
|
|
32
|
+
[evaluation]
|
|
33
|
+
instances = 100
|
|
34
|
+
every_n_iters = 20
|
|
35
|
+
seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
|
|
36
|
+
|
|
37
|
+
[training]
|
|
38
|
+
log_interval = 1
|
|
39
|
+
|
|
40
|
+
[training.weight_sync]
|
|
41
|
+
enable = true
|
|
42
|
+
targets = ["policy"]
|
|
43
|
+
weight_sync_interval = 1
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Vision-Language Supervised Fine-Tuning: Qwen3-4B VL on filtered Pokémon rollouts
|
|
2
|
+
# Update the `data` path once `uvx synth-ai filter` produces your JSONL
|
|
3
|
+
|
|
4
|
+
[algorithm]
|
|
5
|
+
type = "offline"
|
|
6
|
+
method = "sft"
|
|
7
|
+
variety = "fft"
|
|
8
|
+
|
|
9
|
+
[job]
|
|
10
|
+
model = "Qwen/Qwen3-VL-4B-Instruct" # Vision-enabled Qwen3-VL model
|
|
11
|
+
data = "../ft_data/pokemon_vl_high_reward.jsonl"
|
|
12
|
+
poll_seconds = 1800
|
|
13
|
+
|
|
14
|
+
[compute]
|
|
15
|
+
gpu_type = "H100"
|
|
16
|
+
gpu_count = 4
|
|
17
|
+
nodes = 1
|
|
18
|
+
|
|
19
|
+
[data.topology]
|
|
20
|
+
container_count = 4
|
|
21
|
+
|
|
22
|
+
[training]
|
|
23
|
+
mode = "full_finetune"
|
|
24
|
+
use_qlora = false
|
|
25
|
+
|
|
26
|
+
[hyperparameters]
|
|
27
|
+
n_epochs = 2
|
|
28
|
+
world_size = 4
|
|
29
|
+
sequence_length = 4096 # Longer for vision tokens + text
|
|
30
|
+
per_device_batch = 2
|
|
31
|
+
gradient_accumulation_steps = 64
|
|
32
|
+
learning_rate = 8e-6
|
|
33
|
+
warmup_ratio = 0.03
|
|
34
|
+
|
|
35
|
+
[hyperparameters.parallelism]
|
|
36
|
+
use_deepspeed = true
|
|
37
|
+
deepspeed_stage = 3
|
|
38
|
+
fsdp = false
|
|
39
|
+
bf16 = true
|
|
40
|
+
fp16 = false
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Extract images from pokemon_vl trace database or trace JSON file and save to images_gpt5 directory.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
# From trace database:
|
|
6
|
+
python extract_images.py --trace-db traces/v3/pokemon_vl_gpt5nano.db
|
|
7
|
+
|
|
8
|
+
# From trace JSON file:
|
|
9
|
+
python extract_images.py --trace-json trace.json
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import base64
|
|
14
|
+
import json
|
|
15
|
+
import sqlite3
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from synth_ai.tracing_v3.trace_utils import load_session_trace
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def extract_image_urls_from_content(content: Any) -> list[str]:
|
|
23
|
+
"""Extract image URLs from message content."""
|
|
24
|
+
urls = []
|
|
25
|
+
|
|
26
|
+
if isinstance(content, list):
|
|
27
|
+
for part in content:
|
|
28
|
+
if isinstance(part, dict):
|
|
29
|
+
if part.get("type") == "image_url" and "image_url" in part:
|
|
30
|
+
url = part["image_url"].get("url")
|
|
31
|
+
if isinstance(url, str) and url.startswith("data:image"):
|
|
32
|
+
urls.append(url)
|
|
33
|
+
elif part.get("type") == "image":
|
|
34
|
+
img = part.get("image")
|
|
35
|
+
if isinstance(img, str) and img.startswith("data:image"):
|
|
36
|
+
urls.append(img)
|
|
37
|
+
elif isinstance(content, str):
|
|
38
|
+
# Check if it's a JSON string
|
|
39
|
+
try:
|
|
40
|
+
parsed = json.loads(content)
|
|
41
|
+
return extract_image_urls_from_content(parsed)
|
|
42
|
+
except:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
return urls
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def extract_state_info_from_message(message: dict[str, Any]) -> dict[str, Any]:
|
|
49
|
+
"""Extract state info from message metadata or content."""
|
|
50
|
+
metadata = message.get("metadata", {})
|
|
51
|
+
state = {}
|
|
52
|
+
|
|
53
|
+
# Try to get state from metadata
|
|
54
|
+
if "system_state_before" in metadata:
|
|
55
|
+
state_before = metadata["system_state_before"]
|
|
56
|
+
if isinstance(state_before, dict):
|
|
57
|
+
obs = state_before.get("obs", {})
|
|
58
|
+
state.update({
|
|
59
|
+
"position": obs.get("position", "?"),
|
|
60
|
+
"map_id": obs.get("map_id", "?"),
|
|
61
|
+
"player_x": obs.get("player_x", "?"),
|
|
62
|
+
"player_y": obs.get("player_y", "?"),
|
|
63
|
+
"text_box_active": obs.get("text_box_active", False),
|
|
64
|
+
})
|
|
65
|
+
|
|
66
|
+
# Try to extract from content text
|
|
67
|
+
content = message.get("content", "")
|
|
68
|
+
if isinstance(content, str) and "position" in content:
|
|
69
|
+
try:
|
|
70
|
+
# Look for state summary in content
|
|
71
|
+
if "State summary:" in content:
|
|
72
|
+
parts = content.split("State summary:")
|
|
73
|
+
if len(parts) > 1:
|
|
74
|
+
import ast
|
|
75
|
+
state_str = parts[1].split("'")[0] if "'" not in parts[1] else parts[1]
|
|
76
|
+
try:
|
|
77
|
+
state_dict = ast.literal_eval(state_str.split("'")[0] if "'" in state_str else state_str)
|
|
78
|
+
if isinstance(state_dict, dict):
|
|
79
|
+
state.update({
|
|
80
|
+
"position": state_dict.get("position", "?"),
|
|
81
|
+
"map_id": state_dict.get("map_id", "?"),
|
|
82
|
+
"player_x": state_dict.get("player_x", "?"),
|
|
83
|
+
"player_y": state_dict.get("player_y", "?"),
|
|
84
|
+
"text_box_active": state_dict.get("text_box_active", False),
|
|
85
|
+
})
|
|
86
|
+
except:
|
|
87
|
+
pass
|
|
88
|
+
except:
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
return state
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def extract_images_from_trace_dict(trace: dict[str, Any], output_dir: Path):
|
|
95
|
+
"""Extract images from a trace dictionary."""
|
|
96
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
# Get messages from trace
|
|
99
|
+
messages = trace.get("markov_blanket_message_history", []) or trace.get("messages", [])
|
|
100
|
+
|
|
101
|
+
if not messages:
|
|
102
|
+
print(f" No messages found in trace")
|
|
103
|
+
return 0
|
|
104
|
+
|
|
105
|
+
print(f" Found {len(messages)} messages")
|
|
106
|
+
|
|
107
|
+
image_count = 0
|
|
108
|
+
step_idx = 0
|
|
109
|
+
for msg_idx, msg in enumerate(messages):
|
|
110
|
+
# Extract images from message content
|
|
111
|
+
content = msg.get("content", "")
|
|
112
|
+
image_urls = extract_image_urls_from_content(content)
|
|
113
|
+
|
|
114
|
+
if not image_urls:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
# Extract state info for filename
|
|
118
|
+
state = extract_state_info_from_message(msg)
|
|
119
|
+
|
|
120
|
+
for img_idx, img_url in enumerate(image_urls):
|
|
121
|
+
# Extract base64 data
|
|
122
|
+
if img_url.startswith("data:image"):
|
|
123
|
+
# Format: data:image/png;base64,<data>
|
|
124
|
+
parts = img_url.split(",", 1)
|
|
125
|
+
if len(parts) != 2:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
b64_data = parts[1]
|
|
129
|
+
try:
|
|
130
|
+
img_data = base64.b64decode(b64_data)
|
|
131
|
+
|
|
132
|
+
# Create filename
|
|
133
|
+
pos_str = f"{state.get('map_id', '?')}_{state.get('player_x', '?')},{state.get('player_y', '?')}"
|
|
134
|
+
textbox_str = "True" if state.get("text_box_active") else "False"
|
|
135
|
+
filename = f"step_{step_idx:03d}_pos_{pos_str}_textbox_{textbox_str}.png"
|
|
136
|
+
|
|
137
|
+
filepath = output_dir / filename
|
|
138
|
+
filepath.write_bytes(img_data)
|
|
139
|
+
|
|
140
|
+
print(f" Saved: {filename}")
|
|
141
|
+
image_count += 1
|
|
142
|
+
step_idx += 1
|
|
143
|
+
except Exception as e:
|
|
144
|
+
print(f" Error decoding image: {e}")
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
return image_count
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def extract_images_from_trace_db(trace_db: str, output_dir: Path, model_filter: str | None = None):
|
|
151
|
+
"""Extract images from trace database and save to output directory."""
|
|
152
|
+
conn = sqlite3.connect(trace_db)
|
|
153
|
+
conn.row_factory = sqlite3.Row
|
|
154
|
+
|
|
155
|
+
# Get all session IDs
|
|
156
|
+
query = "SELECT session_id, metadata FROM session_traces"
|
|
157
|
+
if model_filter:
|
|
158
|
+
query += " WHERE metadata LIKE ?"
|
|
159
|
+
params = (f'%{model_filter}%',)
|
|
160
|
+
else:
|
|
161
|
+
params = ()
|
|
162
|
+
|
|
163
|
+
rows = conn.execute(query, params).fetchall()
|
|
164
|
+
|
|
165
|
+
if not rows:
|
|
166
|
+
print(f"No traces found in {trace_db}")
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
print(f"Found {len(rows)} trace(s)")
|
|
170
|
+
|
|
171
|
+
total_images = 0
|
|
172
|
+
for row in rows:
|
|
173
|
+
session_id = row["session_id"]
|
|
174
|
+
print(f"\nProcessing session: {session_id}")
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
trace = load_session_trace(conn, session_id)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
print(f" Error loading trace: {e}")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
count = extract_images_from_trace_dict(trace, output_dir)
|
|
183
|
+
total_images += count
|
|
184
|
+
|
|
185
|
+
conn.close()
|
|
186
|
+
print(f"\n✓ Extracted {total_images} images to {output_dir}/")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def extract_images_from_trace_json(trace_json: Path, output_dir: Path):
|
|
190
|
+
"""Extract images from trace JSON file."""
|
|
191
|
+
print(f"Loading trace from {trace_json}")
|
|
192
|
+
|
|
193
|
+
with open(trace_json) as f:
|
|
194
|
+
trace = json.load(f)
|
|
195
|
+
|
|
196
|
+
# Handle trace wrapped in "session_trace" key
|
|
197
|
+
if "session_trace" in trace:
|
|
198
|
+
trace = trace["session_trace"]
|
|
199
|
+
|
|
200
|
+
count = extract_images_from_trace_dict(trace, output_dir)
|
|
201
|
+
print(f"\n✓ Extracted {count} images to {output_dir}/")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def main():
|
|
205
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
206
|
+
parser.add_argument(
|
|
207
|
+
"--trace-db",
|
|
208
|
+
help="Path to trace database",
|
|
209
|
+
)
|
|
210
|
+
parser.add_argument(
|
|
211
|
+
"--trace-json",
|
|
212
|
+
type=Path,
|
|
213
|
+
help="Path to trace JSON file",
|
|
214
|
+
)
|
|
215
|
+
parser.add_argument(
|
|
216
|
+
"--output-dir",
|
|
217
|
+
default="examples/blog_posts/pokemon_vl/images_gpt5",
|
|
218
|
+
help="Output directory for images",
|
|
219
|
+
)
|
|
220
|
+
parser.add_argument(
|
|
221
|
+
"--model-filter",
|
|
222
|
+
help="Filter traces by model name (optional)",
|
|
223
|
+
)
|
|
224
|
+
args = parser.parse_args()
|
|
225
|
+
|
|
226
|
+
output_dir = Path(args.output_dir)
|
|
227
|
+
|
|
228
|
+
if args.trace_json:
|
|
229
|
+
extract_images_from_trace_json(args.trace_json, output_dir)
|
|
230
|
+
elif args.trace_db:
|
|
231
|
+
extract_images_from_trace_db(args.trace_db, output_dir, args.model_filter)
|
|
232
|
+
else:
|
|
233
|
+
parser.error("Must provide either --trace-db or --trace-json")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
if __name__ == "__main__":
|
|
237
|
+
main()
|
|
238
|
+
|
|
239
|
+
|