synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/RECORD +269 -233
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -1,42 +1,123 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
"""Task App configuration for the GRPO Crafter example."""
|
|
4
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
5
6
|
import os
|
|
6
7
|
import sys
|
|
8
|
+
from collections.abc import Iterable, Sequence
|
|
7
9
|
from dataclasses import dataclass
|
|
8
10
|
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
10
12
|
|
|
11
|
-
from synth_ai.task.
|
|
13
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
14
|
+
from synth_ai.task.contracts import RolloutMetrics, RolloutRequest, RolloutResponse, TaskInfo
|
|
12
15
|
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
16
|
+
from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
13
17
|
from synth_ai.task.rubrics import load_rubric
|
|
14
18
|
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
15
|
-
from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
16
|
-
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
17
19
|
from synth_ai.task.tracing_utils import (
|
|
18
20
|
build_tracer_factory,
|
|
19
21
|
resolve_sft_output_dir,
|
|
20
22
|
resolve_tracing_db_url,
|
|
21
23
|
tracing_env_enabled,
|
|
22
24
|
)
|
|
23
|
-
|
|
24
25
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
25
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
DEFAULT_ALIAS_OPS: list[str] = ["agent", "env"] * 10
|
|
30
|
+
DEFAULT_ALIAS_STEP_REWARDS: dict[str, Any] = {
|
|
31
|
+
"enabled": True,
|
|
32
|
+
"mode": "decision_stepwise",
|
|
33
|
+
"indicator_lambda": 1.0,
|
|
34
|
+
"step_beta": 0.0,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
_HERE = Path(__file__).resolve()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _resolve_repo_root() -> Path:
|
|
41
|
+
"""Best-effort detection of the Synth AI repo root across local and Modal mounts."""
|
|
42
|
+
|
|
43
|
+
candidates: list[Path] = []
|
|
44
|
+
env_root = os.getenv("SYNTH_AI_REPO_ROOT")
|
|
45
|
+
if env_root:
|
|
46
|
+
candidates.append(Path(env_root).expanduser())
|
|
47
|
+
candidates.append(Path("/opt/synth_ai_repo"))
|
|
48
|
+
candidates.extend(parent for parent in [_HERE.parent, *_HERE.parents])
|
|
49
|
+
|
|
50
|
+
for candidate in candidates:
|
|
51
|
+
try:
|
|
52
|
+
resolved = candidate.resolve()
|
|
53
|
+
except Exception:
|
|
54
|
+
continue
|
|
55
|
+
if not resolved.exists():
|
|
56
|
+
continue
|
|
57
|
+
if (resolved / "pyproject.toml").exists() or (resolved / "uv.lock").exists():
|
|
58
|
+
return resolved
|
|
59
|
+
if (resolved / "synth_ai").is_dir():
|
|
60
|
+
return resolved
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
return _HERE.parents[3]
|
|
64
|
+
except IndexError:
|
|
65
|
+
return _HERE.parent
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _resolve_task_app_root(repo_root: Path) -> Path:
|
|
69
|
+
"""Locate the task_app directory even when the module is copied to a temp mount."""
|
|
70
|
+
|
|
71
|
+
preferred = (repo_root / "examples" / "warming_up_to_rl" / "task_app").resolve()
|
|
72
|
+
if preferred.is_dir():
|
|
73
|
+
return preferred
|
|
74
|
+
|
|
75
|
+
local_parent = _HERE.parent.resolve()
|
|
76
|
+
if (local_parent / "synth_envs_hosted").is_dir():
|
|
77
|
+
return local_parent
|
|
26
78
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
79
|
+
for parent in _HERE.parents:
|
|
80
|
+
candidate = parent.resolve()
|
|
81
|
+
if (candidate / "synth_envs_hosted").is_dir():
|
|
82
|
+
return candidate
|
|
30
83
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
84
|
+
fallback = Path("/opt/synth_ai_repo/examples/warming_up_to_rl/task_app")
|
|
85
|
+
if fallback.is_dir():
|
|
86
|
+
return fallback.resolve()
|
|
87
|
+
|
|
88
|
+
return local_parent
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
REPO_ROOT = _resolve_repo_root()
|
|
92
|
+
TASK_APP_ROOT = _resolve_task_app_root(REPO_ROOT)
|
|
93
|
+
SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
|
|
94
|
+
|
|
95
|
+
EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
|
|
96
|
+
|
|
97
|
+
for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
|
|
98
|
+
try:
|
|
99
|
+
resolved = path.resolve()
|
|
100
|
+
except Exception:
|
|
101
|
+
resolved = path
|
|
102
|
+
if resolved.exists():
|
|
103
|
+
path_str = str(resolved)
|
|
104
|
+
if path_str not in sys.path:
|
|
105
|
+
sys.path.insert(0, path_str)
|
|
106
|
+
|
|
107
|
+
# Fallback: explicitly add Modal mount path for 'examples' if REPO_ROOT detection fails
|
|
108
|
+
try:
|
|
109
|
+
_hard_examples = Path("/opt/synth_ai_repo/examples")
|
|
110
|
+
if _hard_examples.exists():
|
|
111
|
+
_hard_examples_str = str(_hard_examples.resolve())
|
|
112
|
+
if _hard_examples_str not in sys.path:
|
|
113
|
+
sys.path.insert(0, _hard_examples_str)
|
|
114
|
+
except Exception:
|
|
115
|
+
pass
|
|
35
116
|
|
|
36
117
|
HAS_HOSTED = True
|
|
37
118
|
try:
|
|
38
119
|
import crafter # type: ignore
|
|
39
|
-
import crafter.constants as
|
|
120
|
+
import crafter.constants as crafter_constants # type: ignore
|
|
40
121
|
from synth_ai.environments.examples.crafter_classic.taskset import TRAIT_BOUNDS
|
|
41
122
|
from synth_envs_hosted.branching import router as branching_router # type: ignore
|
|
42
123
|
from synth_envs_hosted.environment_routes import router as environment_router # type: ignore
|
|
@@ -44,11 +125,23 @@ try:
|
|
|
44
125
|
from synth_envs_hosted.policy_routes import router as policy_router # type: ignore
|
|
45
126
|
from synth_envs_hosted.rollout import ( # type: ignore
|
|
46
127
|
RolloutEnvSpec as LegacyRolloutEnvSpec,
|
|
128
|
+
)
|
|
129
|
+
from synth_envs_hosted.rollout import (
|
|
47
130
|
RolloutPolicySpec as LegacyRolloutPolicySpec,
|
|
131
|
+
)
|
|
132
|
+
from synth_envs_hosted.rollout import (
|
|
48
133
|
RolloutRecordConfig as LegacyRolloutRecordConfig,
|
|
134
|
+
)
|
|
135
|
+
from synth_envs_hosted.rollout import (
|
|
49
136
|
RolloutRequest as LegacyRolloutRequest,
|
|
137
|
+
)
|
|
138
|
+
from synth_envs_hosted.rollout import (
|
|
50
139
|
RolloutResponse as LegacyRolloutResponse,
|
|
140
|
+
)
|
|
141
|
+
from synth_envs_hosted.rollout import (
|
|
51
142
|
RolloutSafetyConfig as LegacyRolloutSafetyConfig,
|
|
143
|
+
)
|
|
144
|
+
from synth_envs_hosted.rollout import (
|
|
52
145
|
execute_rollout as legacy_execute_rollout,
|
|
53
146
|
)
|
|
54
147
|
except Exception as exc: # pragma: no cover - import-time validation
|
|
@@ -121,16 +214,16 @@ class CrafterDataset:
|
|
|
121
214
|
area_env = env_value("CRAFTER_AREA", "64,64")
|
|
122
215
|
self.area = tuple(int(x) for x in str(area_env).split(","))
|
|
123
216
|
self.length = int(env_value("CRAFTER_EPISODE_LENGTH", 10000))
|
|
124
|
-
self._cache:
|
|
217
|
+
self._cache: dict[int, dict[str, Any]] = {}
|
|
125
218
|
|
|
126
|
-
def config_for_seed(self, seed: int) ->
|
|
219
|
+
def config_for_seed(self, seed: int) -> dict[str, Any]:
|
|
127
220
|
return {
|
|
128
221
|
"seed": int(seed),
|
|
129
222
|
"area": list(self.area),
|
|
130
223
|
"length": self.length,
|
|
131
224
|
}
|
|
132
225
|
|
|
133
|
-
def describe_seed(self, seed: int) ->
|
|
226
|
+
def describe_seed(self, seed: int) -> dict[str, Any]:
|
|
134
227
|
seed = int(seed)
|
|
135
228
|
if seed in self._cache:
|
|
136
229
|
return self._cache[seed]
|
|
@@ -156,7 +249,7 @@ class CrafterDataset:
|
|
|
156
249
|
self._cache[seed] = summary
|
|
157
250
|
return summary
|
|
158
251
|
|
|
159
|
-
def _difficulty(self, traits:
|
|
252
|
+
def _difficulty(self, traits: dict[str, int]) -> str:
|
|
160
253
|
for difficulty, bounds in TRAIT_BOUNDS.items():
|
|
161
254
|
if traits.get("trees", 0) >= bounds.get("min_trees", 0) and traits.get(
|
|
162
255
|
"hostiles", 0
|
|
@@ -165,14 +258,14 @@ class CrafterDataset:
|
|
|
165
258
|
return "custom"
|
|
166
259
|
|
|
167
260
|
@property
|
|
168
|
-
def seed_range(self) ->
|
|
261
|
+
def seed_range(self) -> list[int]:
|
|
169
262
|
return [self.seed_min, self.seed_max]
|
|
170
263
|
|
|
171
264
|
|
|
172
|
-
def _compute_world_traits(env:
|
|
265
|
+
def _compute_world_traits(env: crafter.Env, radius: int = 10) -> dict[str, int]:
|
|
173
266
|
# Local copy to avoid import-time issues; mirrors synth_ai.environments.examples.crafter_classic.taskset.world_traits
|
|
174
|
-
from crafter import objects as _objects # type: ignore
|
|
175
267
|
import numpy as _np # type: ignore
|
|
268
|
+
from crafter import objects as _objects # type: ignore
|
|
176
269
|
|
|
177
270
|
player = getattr(env, "_player", None)
|
|
178
271
|
if player is None:
|
|
@@ -185,7 +278,7 @@ def _compute_world_traits(env: "crafter.Env", radius: int = 10) -> Dict[str, int
|
|
|
185
278
|
if obj is None or obj is player:
|
|
186
279
|
continue
|
|
187
280
|
try:
|
|
188
|
-
if _np.abs(
|
|
281
|
+
if _np.abs(obj.pos - pos).sum() > radius:
|
|
189
282
|
continue
|
|
190
283
|
except Exception:
|
|
191
284
|
continue
|
|
@@ -193,14 +286,12 @@ def _compute_world_traits(env: "crafter.Env", radius: int = 10) -> Dict[str, int
|
|
|
193
286
|
counts["trees"] += 1
|
|
194
287
|
elif isinstance(obj, _objects.Cow):
|
|
195
288
|
counts["cows"] += 1
|
|
196
|
-
elif isinstance(obj,
|
|
289
|
+
elif isinstance(obj, _objects.Zombie | _objects.Skeleton):
|
|
197
290
|
counts["hostiles"] += 1
|
|
198
291
|
return counts
|
|
199
292
|
|
|
200
293
|
|
|
201
294
|
def env_value(key: str, default: Any) -> Any:
|
|
202
|
-
import os
|
|
203
|
-
|
|
204
295
|
return os.getenv(key, default)
|
|
205
296
|
|
|
206
297
|
|
|
@@ -217,8 +308,8 @@ def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
|
217
308
|
environments=["crafter"],
|
|
218
309
|
action_space={
|
|
219
310
|
"type": "discrete",
|
|
220
|
-
"size": len(
|
|
221
|
-
"actions": list(
|
|
311
|
+
"size": len(crafter_constants.actions),
|
|
312
|
+
"actions": list(crafter_constants.actions),
|
|
222
313
|
},
|
|
223
314
|
observation={
|
|
224
315
|
"summary": "RGB frame plus inventory, achievements, and semantic map patches.",
|
|
@@ -289,7 +380,7 @@ EVENTS_RUBRIC = load_rubric(
|
|
|
289
380
|
)
|
|
290
381
|
|
|
291
382
|
|
|
292
|
-
def describe_taskset(dataset: CrafterDataset) ->
|
|
383
|
+
def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]:
|
|
293
384
|
return {
|
|
294
385
|
**DATASET_SPEC.model_dump(),
|
|
295
386
|
"seed_range": dataset.seed_range,
|
|
@@ -351,6 +442,82 @@ def _normalise_op(op_value: Any, index: int) -> str:
|
|
|
351
442
|
raise ValueError(f"Unsupported op type '{candidate}' at index {index}")
|
|
352
443
|
|
|
353
444
|
|
|
445
|
+
def _coerce_math_to_crafter(request: RolloutRequest) -> RolloutRequest:
|
|
446
|
+
"""Map legacy math env/policy names to crafter and enrich rollout defaults."""
|
|
447
|
+
|
|
448
|
+
def _needs_crafter(name: str | None) -> bool:
|
|
449
|
+
if not name:
|
|
450
|
+
return False
|
|
451
|
+
lowered = str(name).strip().lower()
|
|
452
|
+
return lowered.startswith("math")
|
|
453
|
+
|
|
454
|
+
env_updates: dict[str, Any] = {}
|
|
455
|
+
policy_updates: dict[str, Any] = {}
|
|
456
|
+
alias_applied = False
|
|
457
|
+
|
|
458
|
+
if _needs_crafter(request.env.env_name):
|
|
459
|
+
env_updates["env_name"] = "crafter"
|
|
460
|
+
alias_applied = True
|
|
461
|
+
if request.env.env_id and _needs_crafter(request.env.env_id):
|
|
462
|
+
env_updates["env_id"] = None
|
|
463
|
+
alias_applied = True
|
|
464
|
+
if _needs_crafter(request.policy.policy_name):
|
|
465
|
+
policy_updates["policy_name"] = "crafter-react"
|
|
466
|
+
alias_applied = True
|
|
467
|
+
if request.policy.policy_id and _needs_crafter(request.policy.policy_id):
|
|
468
|
+
policy_updates["policy_id"] = None
|
|
469
|
+
alias_applied = True
|
|
470
|
+
|
|
471
|
+
if not alias_applied:
|
|
472
|
+
return request
|
|
473
|
+
|
|
474
|
+
updated_env = request.env.model_copy(update=env_updates) if env_updates else request.env
|
|
475
|
+
updated_policy = (
|
|
476
|
+
request.policy.model_copy(update=policy_updates) if policy_updates else request.policy
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
env_cfg = dict(updated_env.config or {})
|
|
480
|
+
env_cfg.setdefault("difficulty", "normal")
|
|
481
|
+
env_cfg.setdefault("step_rewards", dict(DEFAULT_ALIAS_STEP_REWARDS))
|
|
482
|
+
env_cfg.setdefault("env_params", {"max_steps_per_episode": 200})
|
|
483
|
+
updated_env = updated_env.model_copy(update={"config": env_cfg})
|
|
484
|
+
|
|
485
|
+
policy_cfg = dict(updated_policy.config or {})
|
|
486
|
+
policy_cfg.setdefault("max_llm_calls", 10)
|
|
487
|
+
policy_cfg.setdefault("max_completion_tokens", 1024)
|
|
488
|
+
policy_cfg.setdefault("temperature", 0.2)
|
|
489
|
+
policy_cfg.setdefault("step_rewards", dict(DEFAULT_ALIAS_STEP_REWARDS))
|
|
490
|
+
updated_policy = updated_policy.model_copy(update={"config": policy_cfg})
|
|
491
|
+
|
|
492
|
+
ops_override = request.ops
|
|
493
|
+
if not ops_override or len(ops_override) < len(DEFAULT_ALIAS_OPS):
|
|
494
|
+
ops_override = list(DEFAULT_ALIAS_OPS)
|
|
495
|
+
|
|
496
|
+
coerced = request.model_copy(update={"env": updated_env, "policy": updated_policy, "ops": ops_override})
|
|
497
|
+
|
|
498
|
+
try:
|
|
499
|
+
print(
|
|
500
|
+
"[rollout] remapped math request -> crafter "
|
|
501
|
+
f"(env={request.env.env_name!r}→{coerced.env.env_name!r}, "
|
|
502
|
+
f"policy={request.policy.policy_name!r}→{coerced.policy.policy_name!r})",
|
|
503
|
+
flush=True,
|
|
504
|
+
)
|
|
505
|
+
except Exception:
|
|
506
|
+
pass
|
|
507
|
+
try:
|
|
508
|
+
logger.info(
|
|
509
|
+
"ROLLOUT_ALIAS: remapped math env/policy to crafter (env=%s→%s, policy=%s→%s)",
|
|
510
|
+
request.env.env_name,
|
|
511
|
+
coerced.env.env_name,
|
|
512
|
+
request.policy.policy_name,
|
|
513
|
+
coerced.policy.policy_name,
|
|
514
|
+
)
|
|
515
|
+
except Exception:
|
|
516
|
+
pass
|
|
517
|
+
|
|
518
|
+
return coerced
|
|
519
|
+
|
|
520
|
+
|
|
354
521
|
async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
|
|
355
522
|
# If hosted env service code is not bundled, return a no-op rollout response compatible with contracts
|
|
356
523
|
if not HAS_HOSTED:
|
|
@@ -370,19 +537,49 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
370
537
|
trace=None,
|
|
371
538
|
)
|
|
372
539
|
|
|
373
|
-
|
|
540
|
+
request = _coerce_math_to_crafter(request)
|
|
541
|
+
|
|
542
|
+
policy_cfg = dict(request.policy.config or {})
|
|
543
|
+
try:
|
|
544
|
+
max_llm_calls = int(policy_cfg.get("max_llm_calls") or 10)
|
|
545
|
+
except Exception:
|
|
546
|
+
max_llm_calls = 10
|
|
547
|
+
policy_cfg.setdefault("max_llm_calls", max_llm_calls)
|
|
548
|
+
policy_cfg.setdefault("max_tokens", 512)
|
|
549
|
+
policy_cfg.setdefault("max_completion_tokens", 512)
|
|
550
|
+
policy_cfg.setdefault("temperature", 0.2)
|
|
551
|
+
policy_cfg.setdefault("top_p", 0.95)
|
|
552
|
+
|
|
553
|
+
env_cfg = dict(request.env.config or {})
|
|
554
|
+
env_params = dict(env_cfg.get("env_params") or {})
|
|
555
|
+
try:
|
|
556
|
+
max_steps_episode = int(env_params.get("max_steps_per_episode") or max_llm_calls)
|
|
557
|
+
except Exception:
|
|
558
|
+
max_steps_episode = max_llm_calls
|
|
559
|
+
desired_steps = max(max_llm_calls, max_steps_episode)
|
|
560
|
+
env_params["max_steps_per_episode"] = int(desired_steps)
|
|
561
|
+
env_cfg["env_params"] = env_params
|
|
562
|
+
|
|
563
|
+
updated_policy = request.policy.model_copy(update={"config": policy_cfg})
|
|
564
|
+
updated_env = request.env.model_copy(update={"config": env_cfg})
|
|
565
|
+
request = request.model_copy(update={"policy": updated_policy, "env": updated_env})
|
|
566
|
+
|
|
567
|
+
converted_ops: list[str] = [_normalise_op(op, idx) for idx, op in enumerate(request.ops)]
|
|
568
|
+
max_ops_allowed = max_llm_calls * 2 if max_llm_calls > 0 else len(converted_ops)
|
|
569
|
+
if max_ops_allowed and len(converted_ops) > max_ops_allowed:
|
|
570
|
+
converted_ops = converted_ops[:max_ops_allowed]
|
|
374
571
|
legacy_request = LegacyRolloutRequest(
|
|
375
572
|
run_id=request.run_id,
|
|
376
573
|
env=LegacyRolloutEnvSpec(
|
|
377
574
|
env_id=request.env.env_id,
|
|
378
575
|
env_name=request.env.env_name,
|
|
379
|
-
config=
|
|
576
|
+
config=env_cfg,
|
|
380
577
|
seed=request.env.seed,
|
|
381
578
|
),
|
|
382
579
|
policy=LegacyRolloutPolicySpec(
|
|
383
580
|
policy_id=request.policy.policy_id,
|
|
384
581
|
policy_name=request.policy.policy_name,
|
|
385
|
-
config=
|
|
582
|
+
config=policy_cfg,
|
|
386
583
|
),
|
|
387
584
|
ops=converted_ops,
|
|
388
585
|
record=LegacyRolloutRecordConfig(**request.record.model_dump()),
|
|
@@ -418,7 +615,7 @@ def build_config() -> TaskAppConfig:
|
|
|
418
615
|
)
|
|
419
616
|
sft_output_dir = resolve_sft_output_dir()
|
|
420
617
|
|
|
421
|
-
app_state:
|
|
618
|
+
app_state: dict[str, Any] = {
|
|
422
619
|
"task_app": hosted_task_app,
|
|
423
620
|
"allowed_environments": ["crafter"],
|
|
424
621
|
"tracing_enabled": tracing_enabled,
|
|
@@ -436,7 +633,7 @@ def build_config() -> TaskAppConfig:
|
|
|
436
633
|
if sft_output_dir:
|
|
437
634
|
print(f"[task:sft] writing JSONL to {sft_output_dir}", flush=True)
|
|
438
635
|
|
|
439
|
-
def _describe_taskset() ->
|
|
636
|
+
def _describe_taskset() -> dict[str, Any]:
|
|
440
637
|
return describe_taskset(dataset)
|
|
441
638
|
|
|
442
639
|
def _provide_instances(seeds: Sequence[int]):
|
|
@@ -489,10 +686,12 @@ register_task_app(
|
|
|
489
686
|
"crafter",
|
|
490
687
|
),
|
|
491
688
|
extra_local_dirs=(
|
|
689
|
+
# Mount repo root so local modules resolve when deployed on Modal
|
|
690
|
+
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
492
691
|
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
493
692
|
(str(TASK_APP_ROOT), "/opt/synth_ai_repo/examples/warming_up_to_rl/task_app"),
|
|
494
693
|
),
|
|
495
|
-
secret_names=("
|
|
694
|
+
secret_names=("groq-api-key", "openai-api-key"),
|
|
496
695
|
memory=16384,
|
|
497
696
|
cpu=4.0,
|
|
498
697
|
max_containers=10,
|
|
@@ -14,12 +14,11 @@ from pathlib import Path
|
|
|
14
14
|
from fastapi.exceptions import RequestValidationError
|
|
15
15
|
from fastapi.responses import JSONResponse
|
|
16
16
|
from starlette.requests import Request
|
|
17
|
-
|
|
18
17
|
from synth_ai.task.apps import ModalDeploymentConfig, registry
|
|
19
|
-
from .grpo_crafter import build_config
|
|
20
18
|
from synth_ai.task.auth import is_api_key_header_authorized, normalize_environment_api_key
|
|
21
19
|
from synth_ai.task.server import TaskAppConfig, create_task_app, run_task_app
|
|
22
20
|
|
|
21
|
+
from .grpo_crafter import build_config
|
|
23
22
|
|
|
24
23
|
APP_ID = "grpo-crafter"
|
|
25
24
|
|
|
@@ -104,7 +103,7 @@ def fastapi_app():
|
|
|
104
103
|
try:
|
|
105
104
|
hdr = request.headers
|
|
106
105
|
snapshot = {
|
|
107
|
-
"path": str(
|
|
106
|
+
"path": str(request.url.path),
|
|
108
107
|
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
109
108
|
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
110
109
|
"have_authorization": bool(hdr.get("authorization")),
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from typing import Dict, List, Optional
|
|
5
4
|
|
|
6
5
|
from fastapi import APIRouter, HTTPException
|
|
7
6
|
from pydantic import BaseModel
|
|
8
7
|
|
|
9
8
|
from .registry import registry
|
|
10
|
-
from .storage.volume import storage
|
|
11
9
|
|
|
12
10
|
logger = logging.getLogger(__name__)
|
|
13
11
|
|
|
@@ -15,15 +13,15 @@ router = APIRouter()
|
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
class BranchRequest(BaseModel):
|
|
18
|
-
env_ids:
|
|
19
|
-
policy_ids:
|
|
16
|
+
env_ids: list[str] | None = None
|
|
17
|
+
policy_ids: list[str] | None = None
|
|
20
18
|
num_children: int = 1
|
|
21
19
|
max_branches: int = 10
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
class BranchResponse(BaseModel):
|
|
25
|
-
env_branches:
|
|
26
|
-
policy_branches:
|
|
23
|
+
env_branches: dict[str, list[str]]
|
|
24
|
+
policy_branches: dict[str, list[str]]
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
@router.post("/branch", response_model=BranchResponse)
|
|
@@ -53,8 +51,8 @@ async def create_branches(request: BranchRequest) -> BranchResponse:
|
|
|
53
51
|
for child_idx in range(request.num_children):
|
|
54
52
|
# Create snapshot of parent
|
|
55
53
|
from .environment_routes import (
|
|
56
|
-
snapshot_environment,
|
|
57
54
|
EnvSnapshotRequest,
|
|
55
|
+
snapshot_environment,
|
|
58
56
|
)
|
|
59
57
|
|
|
60
58
|
snapshot_response = await snapshot_environment(
|
|
@@ -63,8 +61,8 @@ async def create_branches(request: BranchRequest) -> BranchResponse:
|
|
|
63
61
|
|
|
64
62
|
# Restore to new environment with modified seed
|
|
65
63
|
from .environment_routes import (
|
|
66
|
-
restore_environment,
|
|
67
64
|
EnvRestoreRequest,
|
|
65
|
+
restore_environment,
|
|
68
66
|
)
|
|
69
67
|
|
|
70
68
|
restore_response = await restore_environment(
|
|
@@ -100,14 +98,14 @@ async def create_branches(request: BranchRequest) -> BranchResponse:
|
|
|
100
98
|
|
|
101
99
|
for child_idx in range(request.num_children):
|
|
102
100
|
# Create snapshot of parent
|
|
103
|
-
from .policy_routes import
|
|
101
|
+
from .policy_routes import PolicySnapshotRequest, snapshot_policy
|
|
104
102
|
|
|
105
103
|
snapshot_response = await snapshot_policy(
|
|
106
104
|
PolicySnapshotRequest(policy_id=policy_id)
|
|
107
105
|
)
|
|
108
106
|
|
|
109
107
|
# Restore to new policy
|
|
110
|
-
from .policy_routes import
|
|
108
|
+
from .policy_routes import PolicyRestoreRequest, restore_policy
|
|
111
109
|
|
|
112
110
|
restore_response = await restore_policy(
|
|
113
111
|
PolicyRestoreRequest(snapshot_id=snapshot_response.snapshot_id)
|
|
@@ -142,4 +140,4 @@ async def create_branches(request: BranchRequest) -> BranchResponse:
|
|
|
142
140
|
|
|
143
141
|
except Exception as e:
|
|
144
142
|
logger.error(f"Failed to create branches: {e}")
|
|
145
|
-
raise HTTPException(status_code=500, detail=str(e))
|
|
143
|
+
raise HTTPException(status_code=500, detail=str(e)) from e
|