synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +20 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev6.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/RECORD +291 -262
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev5.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,14 +1,13 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
"""Modal task app for Hendrycks MATH single-step RL environment."""
|
|
4
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
5
|
import os
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from functools import lru_cache
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
|
|
8
10
|
from modal import App, Image, Secret, asgi_app
|
|
9
|
-
from functools import lru_cache
|
|
10
|
-
from typing import Iterable
|
|
11
|
-
|
|
12
11
|
from starlette.requests import Request
|
|
13
12
|
|
|
14
13
|
try: # Backward compatibility with older installed SDKs
|
|
@@ -25,7 +24,9 @@ _SYNTH_HOSTED = None
|
|
|
25
24
|
try:
|
|
26
25
|
probe = _HERE
|
|
27
26
|
for _ in range(8):
|
|
28
|
-
candidate = (
|
|
27
|
+
candidate = (
|
|
28
|
+
probe / "backend/app/routes/clustered_training/dev/synth_envs_hosted"
|
|
29
|
+
).resolve()
|
|
29
30
|
if candidate.exists():
|
|
30
31
|
_SYNTH_HOSTED = candidate
|
|
31
32
|
break
|
|
@@ -97,16 +98,17 @@ app = App("hendrycks-math-task-app")
|
|
|
97
98
|
@asgi_app()
|
|
98
99
|
def fastapi_app():
|
|
99
100
|
import httpx
|
|
100
|
-
from fastapi import Body, HTTPException, status
|
|
101
|
-
from fastapi import FastAPI
|
|
101
|
+
from fastapi import Body, FastAPI, HTTPException, status
|
|
102
102
|
from fastapi.middleware.cors import CORSMiddleware
|
|
103
103
|
from fastapi.responses import JSONResponse
|
|
104
|
+
|
|
104
105
|
try:
|
|
105
106
|
from synth_ai.task.auth import (
|
|
106
107
|
is_api_key_header_authorized,
|
|
107
108
|
normalize_environment_api_key,
|
|
108
109
|
)
|
|
109
110
|
except Exception: # pragma: no cover - fallback for older synth-ai builds
|
|
111
|
+
|
|
110
112
|
def _normalize_env_key_fallback() -> str | None:
|
|
111
113
|
key = os.getenv("ENVIRONMENT_API_KEY")
|
|
112
114
|
if key:
|
|
@@ -130,7 +132,7 @@ def fastapi_app():
|
|
|
130
132
|
for value in values:
|
|
131
133
|
if not isinstance(value, str):
|
|
132
134
|
continue
|
|
133
|
-
for chunk in value.split(
|
|
135
|
+
for chunk in value.split(","):
|
|
134
136
|
chunk = chunk.strip()
|
|
135
137
|
if chunk:
|
|
136
138
|
parts.append(chunk)
|
|
@@ -172,19 +174,27 @@ def fastapi_app():
|
|
|
172
174
|
|
|
173
175
|
def _normalize_answer_text(s: str) -> str:
|
|
174
176
|
import re as _re
|
|
177
|
+
|
|
175
178
|
return _re.sub(r"[^0-9A-Za-z.+\-/*=]", "", (s or "").strip()).lower()
|
|
176
179
|
|
|
177
180
|
def _extract_boxed(s: str) -> str:
|
|
178
181
|
import re as _re
|
|
182
|
+
|
|
179
183
|
m = list(_re.finditer(r"\\boxed\{([^}]+)\}", s or ""))
|
|
180
184
|
return m[-1].group(1) if m else ""
|
|
181
185
|
|
|
182
186
|
def _load_hendrycks_problem(seed: int, subject: str | None = None) -> tuple[str, str]:
|
|
183
187
|
subj = subject or os.getenv("HENDRYCKS_MATH_CONFIG", "default")
|
|
184
|
-
ds = _hf_split(
|
|
188
|
+
ds = _hf_split(
|
|
189
|
+
subj, os.getenv("HENDRYCKS_MATH_SPLIT", "test"), os.getenv("HENDRYCKS_MATH_SLICE")
|
|
190
|
+
)
|
|
185
191
|
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
186
192
|
if n == 0 and subject not in {"", "default"}:
|
|
187
|
-
ds = _hf_split(
|
|
193
|
+
ds = _hf_split(
|
|
194
|
+
"default",
|
|
195
|
+
os.getenv("HENDRYCKS_MATH_SPLIT", "test"),
|
|
196
|
+
os.getenv("HENDRYCKS_MATH_SLICE"),
|
|
197
|
+
)
|
|
188
198
|
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
189
199
|
if n == 0:
|
|
190
200
|
raise RuntimeError("Hendrycks MATH dataset loaded empty")
|
|
@@ -225,7 +235,11 @@ def fastapi_app():
|
|
|
225
235
|
|
|
226
236
|
def _resolve_env_keys() -> set[str]:
|
|
227
237
|
keys: set[str] = set()
|
|
228
|
-
for alias in (
|
|
238
|
+
for alias in (
|
|
239
|
+
"ENVIRONMENT_API_KEY",
|
|
240
|
+
"dev_environment_api_key",
|
|
241
|
+
"DEV_ENVIRONMENT_API_KEY",
|
|
242
|
+
):
|
|
229
243
|
value = os.environ.get(alias)
|
|
230
244
|
if value:
|
|
231
245
|
os.environ.setdefault("ENVIRONMENT_API_KEY", value)
|
|
@@ -250,8 +264,12 @@ def fastapi_app():
|
|
|
250
264
|
candidates.append(primary.strip())
|
|
251
265
|
secondary = x_api_keys or headers.get("x-api-keys")
|
|
252
266
|
if secondary:
|
|
253
|
-
candidates.extend(
|
|
254
|
-
|
|
267
|
+
candidates.extend(
|
|
268
|
+
[value.strip() for value in secondary.split(",") if value.strip()]
|
|
269
|
+
)
|
|
270
|
+
auth_header = (
|
|
271
|
+
authorization or headers.get("authorization") or headers.get("Authorization")
|
|
272
|
+
)
|
|
255
273
|
if auth_header and auth_header.lower().startswith("bearer "):
|
|
256
274
|
token = auth_header.split(" ", 1)[1].strip()
|
|
257
275
|
if token:
|
|
@@ -274,7 +292,10 @@ def fastapi_app():
|
|
|
274
292
|
async def info():
|
|
275
293
|
return {
|
|
276
294
|
"service": {"base_url": os.getenv("SERVICE_BASE_URL", "")},
|
|
277
|
-
"inference": {
|
|
295
|
+
"inference": {
|
|
296
|
+
"base_url": "",
|
|
297
|
+
"endpoints": {"chat_completions": "/v1/chat/completions"},
|
|
298
|
+
},
|
|
278
299
|
}
|
|
279
300
|
|
|
280
301
|
@app.get("/health")
|
|
@@ -282,7 +303,10 @@ def fastapi_app():
|
|
|
282
303
|
env_keys = _resolve_env_keys()
|
|
283
304
|
env_key = next(iter(env_keys), None)
|
|
284
305
|
if not env_key:
|
|
285
|
-
return JSONResponse(
|
|
306
|
+
return JSONResponse(
|
|
307
|
+
status_code=503,
|
|
308
|
+
content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
|
|
309
|
+
)
|
|
286
310
|
# Authorize using all header variants; avoid typed Header params to prevent 422s
|
|
287
311
|
authorized = is_api_key_header_authorized(request)
|
|
288
312
|
if not authorized:
|
|
@@ -302,7 +326,10 @@ def fastapi_app():
|
|
|
302
326
|
env_keys = _resolve_env_keys()
|
|
303
327
|
env_key = next(iter(env_keys), None)
|
|
304
328
|
if not env_key:
|
|
305
|
-
return JSONResponse(
|
|
329
|
+
return JSONResponse(
|
|
330
|
+
status_code=503,
|
|
331
|
+
content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
|
|
332
|
+
)
|
|
306
333
|
authorized = is_api_key_header_authorized(request)
|
|
307
334
|
if not authorized:
|
|
308
335
|
prefix = _log_env_key_prefix("health/rollout", env_key)
|
|
@@ -321,17 +348,22 @@ def fastapi_app():
|
|
|
321
348
|
async def task_info(seed: int = 0, subject: str = "default"):
|
|
322
349
|
"""Return Hendrycks MATH problem/answer and tool schema for a seed."""
|
|
323
350
|
q, a = _load_hendrycks_problem(int(seed), subject=subject)
|
|
324
|
-
tools = [
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
"
|
|
329
|
-
|
|
330
|
-
"
|
|
351
|
+
tools = [
|
|
352
|
+
{
|
|
353
|
+
"name": "submit_answer",
|
|
354
|
+
"description": "Provide the final numerical or algebraic answer for the current math problem.",
|
|
355
|
+
"parameters": {
|
|
356
|
+
"type": "object",
|
|
357
|
+
"properties": {
|
|
358
|
+
"answer": {
|
|
359
|
+
"type": "string",
|
|
360
|
+
"description": "The proposed final answer",
|
|
361
|
+
},
|
|
362
|
+
},
|
|
363
|
+
"required": ["answer"],
|
|
331
364
|
},
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
}]
|
|
365
|
+
}
|
|
366
|
+
]
|
|
335
367
|
return {
|
|
336
368
|
"seed": int(seed),
|
|
337
369
|
"subject": subject,
|
|
@@ -354,7 +386,7 @@ def fastapi_app():
|
|
|
354
386
|
try:
|
|
355
387
|
hdr = request.headers
|
|
356
388
|
snapshot = {
|
|
357
|
-
"path": str(
|
|
389
|
+
"path": str(request.url.path),
|
|
358
390
|
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
359
391
|
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
360
392
|
"have_authorization": bool(hdr.get("authorization")),
|
|
@@ -363,7 +395,9 @@ def fastapi_app():
|
|
|
363
395
|
print("[422] validation", snapshot, flush=True)
|
|
364
396
|
except Exception:
|
|
365
397
|
pass
|
|
366
|
-
return JSONResponse(
|
|
398
|
+
return JSONResponse(
|
|
399
|
+
status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]}
|
|
400
|
+
)
|
|
367
401
|
|
|
368
402
|
@api.get("/")
|
|
369
403
|
async def root_probe():
|
|
@@ -376,27 +410,32 @@ def fastapi_app():
|
|
|
376
410
|
env_key = (
|
|
377
411
|
os.environ.get("ENVIRONMENT_API_KEY")
|
|
378
412
|
or os.environ.get("DEV_ENVIRONMENT_API_KEY")
|
|
379
|
-
or os.environ.get("
|
|
413
|
+
or os.environ.get("DEV_ENVIRONMENT_API_KEY")
|
|
380
414
|
)
|
|
381
415
|
if not env_key:
|
|
382
416
|
raise RuntimeError("ENVIRONMENT_API_KEY missing in task app environment")
|
|
383
417
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
418
|
+
openai_remove_fields = (
|
|
419
|
+
"stop_after_tool_calls",
|
|
420
|
+
"thinking_mode",
|
|
421
|
+
"thinking_budget",
|
|
422
|
+
"reasoning",
|
|
423
|
+
)
|
|
424
|
+
openai_remove_sampling_fields = ("temperature", "top_p")
|
|
425
|
+
tool_choice_force = {"type": "function", "function": {"name": "submit_answer"}}
|
|
387
426
|
|
|
388
427
|
def _prepare_openai_payload(model: str | None, payload: dict[str, object]) -> dict[str, object]:
|
|
389
428
|
sanitized = dict(payload)
|
|
390
|
-
for key in
|
|
429
|
+
for key in openai_remove_fields:
|
|
391
430
|
sanitized.pop(key, None)
|
|
392
431
|
if model and "gpt-5" in model:
|
|
393
432
|
if "max_tokens" in sanitized and "max_completion_tokens" not in sanitized:
|
|
394
433
|
sanitized["max_completion_tokens"] = sanitized.pop("max_tokens")
|
|
395
434
|
else:
|
|
396
435
|
sanitized.pop("max_tokens", None)
|
|
397
|
-
for field in
|
|
436
|
+
for field in openai_remove_sampling_fields:
|
|
398
437
|
sanitized.pop(field, None)
|
|
399
|
-
sanitized["tool_choice"] =
|
|
438
|
+
sanitized["tool_choice"] = tool_choice_force
|
|
400
439
|
sanitized["parallel_tool_calls"] = False
|
|
401
440
|
return sanitized
|
|
402
441
|
|
|
@@ -404,12 +443,18 @@ def fastapi_app():
|
|
|
404
443
|
def proxy_chat_completions(request: dict[str, object] = Body(...)):
|
|
405
444
|
key = os.environ.get("OPENAI_API_KEY")
|
|
406
445
|
if not key:
|
|
407
|
-
raise HTTPException(
|
|
446
|
+
raise HTTPException(
|
|
447
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="OPENAI_API_KEY missing"
|
|
448
|
+
)
|
|
408
449
|
model = request.get("model") if isinstance(request, dict) else None
|
|
409
|
-
payload = _prepare_openai_payload(
|
|
450
|
+
payload = _prepare_openai_payload(
|
|
451
|
+
model if isinstance(model, str) else None, request if isinstance(request, dict) else {}
|
|
452
|
+
)
|
|
410
453
|
headers = {"Authorization": f"Bearer {key}"}
|
|
411
454
|
with httpx.Client(timeout=httpx.Timeout(180.0), follow_redirects=True) as client:
|
|
412
|
-
resp = client.post(
|
|
455
|
+
resp = client.post(
|
|
456
|
+
"https://api.openai.com/v1/chat/completions", json=payload, headers=headers
|
|
457
|
+
)
|
|
413
458
|
try:
|
|
414
459
|
data = resp.json()
|
|
415
460
|
except Exception:
|
|
@@ -423,8 +468,8 @@ def fastapi_app():
|
|
|
423
468
|
# Minimal math rollout endpoint: alternates agent/env; calls inference_url chat/completions
|
|
424
469
|
@api.post("/rollout")
|
|
425
470
|
def rollout(request: dict[str, object] = Body(...)):
|
|
426
|
-
from typing import Any
|
|
427
471
|
import json as _json
|
|
472
|
+
from typing import Any
|
|
428
473
|
|
|
429
474
|
run_id = str(request.get("run_id"))
|
|
430
475
|
data = request if isinstance(request, dict) else {}
|
|
@@ -442,15 +487,25 @@ def fastapi_app():
|
|
|
442
487
|
env_cfg = (env or {}).get("config") or {}
|
|
443
488
|
# Prefer env.seed; fall back to env.config.seed -> default 0
|
|
444
489
|
try:
|
|
445
|
-
seed_val =
|
|
490
|
+
seed_val = (
|
|
491
|
+
int((env or {}).get("seed"))
|
|
492
|
+
if isinstance(env, dict) and (env or {}).get("seed") is not None
|
|
493
|
+
else 0
|
|
494
|
+
)
|
|
446
495
|
except Exception:
|
|
447
496
|
seed_val = 0
|
|
448
497
|
if seed_val == 0:
|
|
449
498
|
try:
|
|
450
|
-
seed_val =
|
|
499
|
+
seed_val = (
|
|
500
|
+
int(env_cfg.get("seed"))
|
|
501
|
+
if isinstance(env_cfg, dict) and env_cfg.get("seed") is not None
|
|
502
|
+
else 0
|
|
503
|
+
)
|
|
451
504
|
except Exception:
|
|
452
505
|
seed_val = 0
|
|
453
|
-
subject = (env_cfg.get("subject") if isinstance(env_cfg, dict) else None) or os.getenv(
|
|
506
|
+
subject = (env_cfg.get("subject") if isinstance(env_cfg, dict) else None) or os.getenv(
|
|
507
|
+
"HENDRYCKS_MATH_CONFIG", "default"
|
|
508
|
+
)
|
|
454
509
|
# Load real Hendrycks problem text/solution (download if necessary). Crash on failure.
|
|
455
510
|
qh, ah = _load_hendrycks_problem(seed_val, subject=subject)
|
|
456
511
|
question = qh
|
|
@@ -468,7 +523,10 @@ def fastapi_app():
|
|
|
468
523
|
sanitized.pop("max_tokens", None)
|
|
469
524
|
for field in ("temperature", "top_p"):
|
|
470
525
|
sanitized.pop(field, None)
|
|
471
|
-
sanitized["tool_choice"] = {
|
|
526
|
+
sanitized["tool_choice"] = {
|
|
527
|
+
"type": "function",
|
|
528
|
+
"function": {"name": "submit_answer"},
|
|
529
|
+
}
|
|
472
530
|
sanitized["parallel_tool_calls"] = False
|
|
473
531
|
return sanitized
|
|
474
532
|
|
|
@@ -503,25 +561,27 @@ def fastapi_app():
|
|
|
503
561
|
|
|
504
562
|
user_prompt = (
|
|
505
563
|
str(question)
|
|
506
|
-
if isinstance(question,
|
|
564
|
+
if isinstance(question, str | int | float) and str(question).strip()
|
|
507
565
|
else "Solve the problem. Provide answer steps succinctly."
|
|
508
566
|
)
|
|
509
567
|
payload = {
|
|
510
568
|
"model": model,
|
|
511
569
|
"messages": [{"role": "user", "content": user_prompt}],
|
|
512
|
-
"tools": [
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
"
|
|
516
|
-
|
|
517
|
-
"
|
|
518
|
-
|
|
519
|
-
"
|
|
570
|
+
"tools": [
|
|
571
|
+
{
|
|
572
|
+
"type": "function",
|
|
573
|
+
"function": {
|
|
574
|
+
"name": "submit_answer",
|
|
575
|
+
"parameters": {
|
|
576
|
+
"type": "object",
|
|
577
|
+
"properties": {
|
|
578
|
+
"answer": {"type": "string"},
|
|
579
|
+
},
|
|
580
|
+
"required": ["answer"],
|
|
520
581
|
},
|
|
521
|
-
"required": ["answer"],
|
|
522
582
|
},
|
|
523
|
-
}
|
|
524
|
-
|
|
583
|
+
}
|
|
584
|
+
],
|
|
525
585
|
"max_tokens": 256,
|
|
526
586
|
"temperature": 0.2,
|
|
527
587
|
}
|
|
@@ -529,13 +589,13 @@ def fastapi_app():
|
|
|
529
589
|
|
|
530
590
|
try:
|
|
531
591
|
tool_names = []
|
|
532
|
-
for t in
|
|
592
|
+
for t in payload.get("tools") or []:
|
|
533
593
|
if isinstance(t, dict):
|
|
534
594
|
fn = (t.get("function") or {}) if isinstance(t.get("function"), dict) else {}
|
|
535
595
|
name = fn.get("name")
|
|
536
596
|
if isinstance(name, str):
|
|
537
597
|
tool_names.append(name)
|
|
538
|
-
print(
|
|
598
|
+
print("[math] system: <none>", flush=True)
|
|
539
599
|
print(f"[math] user: {user_prompt}", flush=True)
|
|
540
600
|
print(f"[math] tools: {tool_names}", flush=True)
|
|
541
601
|
except Exception:
|
|
@@ -547,7 +607,9 @@ def fastapi_app():
|
|
|
547
607
|
if sk:
|
|
548
608
|
headers["Authorization"] = f"Bearer {sk}"
|
|
549
609
|
with httpx.Client(timeout=httpx.Timeout(180.0), follow_redirects=True) as client:
|
|
550
|
-
resp = client.post(
|
|
610
|
+
resp = client.post(
|
|
611
|
+
f"{inference_url}/v1/chat/completions", json=to_send, headers=headers
|
|
612
|
+
)
|
|
551
613
|
try:
|
|
552
614
|
data = resp.json()
|
|
553
615
|
except Exception:
|
|
@@ -580,14 +642,21 @@ def fastapi_app():
|
|
|
580
642
|
|
|
581
643
|
tool_answer = _parse_tool_answer(data)
|
|
582
644
|
history.append({"answer": tool_answer})
|
|
583
|
-
steps.append(
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
645
|
+
steps.append(
|
|
646
|
+
{
|
|
647
|
+
"obs": {},
|
|
648
|
+
"tool_calls": [
|
|
649
|
+
{
|
|
650
|
+
"tool_name": "submit_answer",
|
|
651
|
+
"arguments": _json.dumps({"answer": tool_answer}),
|
|
652
|
+
}
|
|
653
|
+
],
|
|
654
|
+
"reward": None,
|
|
655
|
+
"done": False,
|
|
656
|
+
"truncated": False,
|
|
657
|
+
"info": None,
|
|
658
|
+
}
|
|
659
|
+
)
|
|
591
660
|
|
|
592
661
|
# Evaluate answer correctness using tool output (or fall back to assistant text)
|
|
593
662
|
reward_val = 0.0
|
|
@@ -605,25 +674,57 @@ def fastapi_app():
|
|
|
605
674
|
except Exception:
|
|
606
675
|
reward_val = 0.0
|
|
607
676
|
|
|
677
|
+
# Immediate, concise rollout logging mirroring RL format
|
|
678
|
+
try:
|
|
679
|
+
preview = tool_answer[:120] + (
|
|
680
|
+
"…" if isinstance(tool_answer, str) and len(tool_answer) > 120 else ""
|
|
681
|
+
)
|
|
682
|
+
components = {
|
|
683
|
+
"env": float(reward_val),
|
|
684
|
+
"rubric_event": 1.0 if bool(tool_answer.strip()) else 0.0,
|
|
685
|
+
"rubric_outcome": 1.0 if float(reward_val) > 0.0 else 0.0,
|
|
686
|
+
}
|
|
687
|
+
print(
|
|
688
|
+
"[MATH_ROLLOUT] run=",
|
|
689
|
+
run_id,
|
|
690
|
+
" seed=",
|
|
691
|
+
seed_val,
|
|
692
|
+
" subject=",
|
|
693
|
+
subject,
|
|
694
|
+
" tool=submit_answer answer=",
|
|
695
|
+
preview,
|
|
696
|
+
" reward=",
|
|
697
|
+
float(reward_val),
|
|
698
|
+
" components=",
|
|
699
|
+
components,
|
|
700
|
+
flush=True,
|
|
701
|
+
)
|
|
702
|
+
except Exception:
|
|
703
|
+
pass
|
|
704
|
+
|
|
608
705
|
total_reward += float(reward_val)
|
|
609
|
-
steps.append(
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
706
|
+
steps.append(
|
|
707
|
+
{
|
|
708
|
+
"obs": {},
|
|
709
|
+
"tool_calls": [],
|
|
710
|
+
"reward": reward_val,
|
|
711
|
+
"done": True,
|
|
712
|
+
"truncated": False,
|
|
713
|
+
"info": None,
|
|
714
|
+
}
|
|
715
|
+
)
|
|
617
716
|
|
|
618
717
|
return {
|
|
619
718
|
"run_id": run_id,
|
|
620
|
-
"trajectories": [
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
719
|
+
"trajectories": [
|
|
720
|
+
{
|
|
721
|
+
"env_id": env_name,
|
|
722
|
+
"policy_id": (policy or {}).get("policy_name") or "math-react",
|
|
723
|
+
"steps": steps,
|
|
724
|
+
"final": {"observation": {}},
|
|
725
|
+
"length": len(steps),
|
|
726
|
+
}
|
|
727
|
+
],
|
|
627
728
|
"branches": {},
|
|
628
729
|
"metrics": {
|
|
629
730
|
"episode_returns": [total_reward],
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
6
6
|
from synth_ai.task.apps.math_single_step import build_config as base_build_config
|
|
7
7
|
|
|
8
|
-
|
|
9
8
|
DEMO_MODAL_CONFIG = ModalDeploymentConfig(
|
|
10
9
|
app_name="hendrycks-math-task-app",
|
|
11
10
|
pip_packages=(
|
|
@@ -36,4 +35,3 @@ register_task_app(
|
|
|
36
35
|
modal=DEMO_MODAL_CONFIG,
|
|
37
36
|
)
|
|
38
37
|
)
|
|
39
|
-
|
|
@@ -191,7 +191,9 @@ class BanditEngine(StatefulEngine, IReproducibleEngine):
|
|
|
191
191
|
step_count=self.step_count,
|
|
192
192
|
max_steps=self.max_steps,
|
|
193
193
|
last_arm=self.last_arm,
|
|
194
|
-
last_reward=float(reward)
|
|
194
|
+
last_reward=float(reward)
|
|
195
|
+
if reward is not None
|
|
196
|
+
else (self.last_reward if self.step_count else None),
|
|
195
197
|
cumulative_reward=float(self.total_reward),
|
|
196
198
|
reward_history=self.reward_history.copy(),
|
|
197
199
|
arm_pull_counts=self.arm_pull_counts.copy(),
|
|
@@ -238,7 +240,9 @@ class BanditEngine(StatefulEngine, IReproducibleEngine):
|
|
|
238
240
|
engine.arm_probabilities = data.get("arm_probabilities", engine.arm_probabilities)
|
|
239
241
|
engine.arm_means = data.get("arm_means", engine.arm_means)
|
|
240
242
|
engine.arm_stds = data.get("arm_stds", engine.arm_stds)
|
|
241
|
-
engine.true_expected_rewards = list(
|
|
243
|
+
engine.true_expected_rewards = list(
|
|
244
|
+
data.get("true_expected_rewards", engine.true_expected_rewards)
|
|
245
|
+
)
|
|
242
246
|
engine.arm_count = len(engine.true_expected_rewards)
|
|
243
247
|
|
|
244
248
|
engine.step_count = int(data.get("step_count", 0))
|
|
@@ -247,7 +251,9 @@ class BanditEngine(StatefulEngine, IReproducibleEngine):
|
|
|
247
251
|
engine.last_arm = data.get("last_arm")
|
|
248
252
|
engine.reward_history = list(data.get("reward_history", []))
|
|
249
253
|
engine.arm_history = list(data.get("arm_history", []))
|
|
250
|
-
engine.arm_pull_counts = list(
|
|
254
|
+
engine.arm_pull_counts = list(
|
|
255
|
+
data.get("arm_pull_counts", [0 for _ in range(engine.arm_count)])
|
|
256
|
+
)
|
|
251
257
|
engine.terminated = bool(data.get("terminated", False))
|
|
252
258
|
engine.status = data.get("status", "in_progress")
|
|
253
259
|
|
|
@@ -287,7 +293,9 @@ class SynthBanditCheckpointObservationCallable(GetObservationCallable):
|
|
|
287
293
|
"arm_count": pub.arm_count,
|
|
288
294
|
"total_reward": priv.total_reward,
|
|
289
295
|
"steps_taken": pub.step_count,
|
|
290
|
-
"best_expected_reward": max(priv.true_expected_rewards)
|
|
296
|
+
"best_expected_reward": max(priv.true_expected_rewards)
|
|
297
|
+
if priv.true_expected_rewards
|
|
298
|
+
else None,
|
|
291
299
|
"terminated": pub.terminated,
|
|
292
300
|
"status": pub.status,
|
|
293
301
|
}
|
|
@@ -156,10 +156,10 @@ async def create_bandit_taskset(
|
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
expected = _expected_rewards(metadata)
|
|
159
|
-
arm_count =
|
|
160
|
-
len(
|
|
161
|
-
|
|
162
|
-
or 0
|
|
159
|
+
arm_count = (
|
|
160
|
+
len(expected)
|
|
161
|
+
if expected
|
|
162
|
+
else (len(metadata.arm_probabilities or []) or len(metadata.arm_means or []) or 0)
|
|
163
163
|
)
|
|
164
164
|
if arm_count == 0:
|
|
165
165
|
arm_count = 1
|
|
@@ -2,11 +2,16 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import base64
|
|
5
6
|
import dataclasses
|
|
6
7
|
import logging
|
|
7
8
|
import time
|
|
9
|
+
from io import BytesIO
|
|
8
10
|
from typing import Any, Dict, List, Optional, Union
|
|
9
11
|
|
|
12
|
+
import numpy as np
|
|
13
|
+
from PIL import Image
|
|
14
|
+
|
|
10
15
|
# Import tracing abstractions
|
|
11
16
|
from synth_ai.tracing_v3.abstractions import (
|
|
12
17
|
RuntimeEvent,
|
|
@@ -43,6 +48,51 @@ from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
|
|
43
48
|
from synth_ai.environments.stateful.core import StatefulEnvironment
|
|
44
49
|
|
|
45
50
|
|
|
51
|
+
def _convert_numpy_to_python(obj: Any) -> Any:
|
|
52
|
+
if isinstance(obj, np.integer):
|
|
53
|
+
return int(obj)
|
|
54
|
+
if isinstance(obj, np.floating):
|
|
55
|
+
return float(obj)
|
|
56
|
+
if isinstance(obj, np.ndarray):
|
|
57
|
+
return obj.tolist()
|
|
58
|
+
if isinstance(obj, dict):
|
|
59
|
+
return {k: _convert_numpy_to_python(v) for k, v in obj.items()}
|
|
60
|
+
if isinstance(obj, (list, tuple)):
|
|
61
|
+
return [_convert_numpy_to_python(item) for item in obj]
|
|
62
|
+
return obj
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _encode_image_to_base64(image_array: Any) -> dict[str, Any] | None:
|
|
66
|
+
if not isinstance(image_array, np.ndarray):
|
|
67
|
+
return None
|
|
68
|
+
if image_array.ndim != 3 or image_array.shape[-1] not in (1, 3, 4):
|
|
69
|
+
return None
|
|
70
|
+
try:
|
|
71
|
+
array_uint8 = (
|
|
72
|
+
image_array.astype("uint8")
|
|
73
|
+
if image_array.dtype != np.uint8
|
|
74
|
+
else image_array # pragma: no cover - fast path
|
|
75
|
+
)
|
|
76
|
+
mode = "L" if array_uint8.shape[-1] == 1 else "RGB"
|
|
77
|
+
if array_uint8.shape[-1] == 4:
|
|
78
|
+
mode = "RGBA"
|
|
79
|
+
img = Image.fromarray(array_uint8, mode=mode)
|
|
80
|
+
buffer = BytesIO()
|
|
81
|
+
img.save(buffer, format="PNG")
|
|
82
|
+
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
|
|
83
|
+
width = int(array_uint8.shape[1])
|
|
84
|
+
height = int(array_uint8.shape[0])
|
|
85
|
+
return {
|
|
86
|
+
"format": "png",
|
|
87
|
+
"width": width,
|
|
88
|
+
"height": height,
|
|
89
|
+
"data": encoded,
|
|
90
|
+
"data_url": f"data:image/png;base64,{encoded}",
|
|
91
|
+
}
|
|
92
|
+
except Exception:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
46
96
|
# --- Tool Definition ---
|
|
47
97
|
class CrafterActionInput(BaseModel):
|
|
48
98
|
action: int = Field(..., description="Integer action for the Crafter environment.")
|
|
@@ -362,7 +412,8 @@ class CrafterClassicEnvironment(StatefulEnvironment, ReproducibleEnvironment[Cra
|
|
|
362
412
|
state_before = {"private_state": priv, "public_state": pub}
|
|
363
413
|
|
|
364
414
|
active_obs_cb = obs_cb or SynthCrafterObservationCallable()
|
|
365
|
-
|
|
415
|
+
raw_observation = await active_obs_cb.get_observation(pub, priv)
|
|
416
|
+
observation = self._prepare_observation(raw_observation)
|
|
366
417
|
if extra_obs and isinstance(observation, dict):
|
|
367
418
|
observation.update(extra_obs)
|
|
368
419
|
|
|
@@ -385,6 +436,30 @@ class CrafterClassicEnvironment(StatefulEnvironment, ReproducibleEnvironment[Cra
|
|
|
385
436
|
|
|
386
437
|
return observation
|
|
387
438
|
|
|
439
|
+
def _prepare_observation(self, observation: Any) -> dict[str, Any]:
|
|
440
|
+
obs_dict: dict[str, Any]
|
|
441
|
+
image_payload: dict[str, Any] | None = None
|
|
442
|
+
|
|
443
|
+
if isinstance(observation, dict):
|
|
444
|
+
image_payload = _encode_image_to_base64(observation.get("observation_image"))
|
|
445
|
+
sanitized = dict(observation)
|
|
446
|
+
sanitized.pop("observation_image", None)
|
|
447
|
+
obs_dict = _convert_numpy_to_python(sanitized) or {}
|
|
448
|
+
else:
|
|
449
|
+
obs_dict = _convert_numpy_to_python(observation) or {}
|
|
450
|
+
|
|
451
|
+
if not isinstance(obs_dict, dict):
|
|
452
|
+
obs_dict = {"value": obs_dict}
|
|
453
|
+
|
|
454
|
+
if image_payload:
|
|
455
|
+
obs_dict["observation_image_base64"] = image_payload["data"]
|
|
456
|
+
obs_dict["observation_image_format"] = image_payload["format"]
|
|
457
|
+
obs_dict["observation_image_width"] = image_payload["width"]
|
|
458
|
+
obs_dict["observation_image_height"] = image_payload["height"]
|
|
459
|
+
obs_dict["observation_image_data_url"] = image_payload["data_url"]
|
|
460
|
+
|
|
461
|
+
return obs_dict
|
|
462
|
+
|
|
388
463
|
# ────────────────────────────────────────────────────────────────────
|
|
389
464
|
# ReproducibleEnvironment plumbing
|
|
390
465
|
# ────────────────────────────────────────────────────────────────────
|