synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/__init__.py +1 -0
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/RECORD +294 -258
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
synth_ai/learning/ft_client.py
CHANGED
|
@@ -1,59 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from typing import Any, Dict, Optional
|
|
5
|
-
|
|
6
|
-
from ..http import AsyncHttpClient, HTTPError
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class FtClient:
|
|
10
|
-
def __init__(self, base_url: str, api_key: str, *, timeout: float = 30.0) -> None:
|
|
11
|
-
self._base_url = base_url.rstrip("/")
|
|
12
|
-
self._api_key = api_key
|
|
13
|
-
self._timeout = timeout
|
|
1
|
+
"""Backward-compatible shim for FtClient (moved to synth_ai.learning.sft.client)."""
|
|
14
2
|
|
|
15
|
-
|
|
16
|
-
p = Path(path)
|
|
17
|
-
content = p.read_bytes()
|
|
18
|
-
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
19
|
-
data = {"purpose": purpose}
|
|
20
|
-
files = {"file": (p.name, content, _infer_content_type(p.name))}
|
|
21
|
-
js = await http.post_multipart("/api/learning/files", data=data, files=files)
|
|
22
|
-
if not isinstance(js, dict) or "id" not in js:
|
|
23
|
-
raise HTTPError(status=500, url="/api/learning/files", message="invalid_upload_response", body_snippet=str(js)[:200])
|
|
24
|
-
return str(js["id"])
|
|
25
|
-
|
|
26
|
-
async def create_sft_job(
|
|
27
|
-
self,
|
|
28
|
-
*,
|
|
29
|
-
model: str,
|
|
30
|
-
training_file_id: str,
|
|
31
|
-
hyperparameters: Dict[str, Any],
|
|
32
|
-
metadata: Optional[Dict[str, Any]] = None,
|
|
33
|
-
) -> Dict[str, Any]:
|
|
34
|
-
body = {
|
|
35
|
-
"training_type": "sft_offline",
|
|
36
|
-
"model": model,
|
|
37
|
-
"training_file_id": training_file_id,
|
|
38
|
-
"hyperparameters": dict(hyperparameters or {}),
|
|
39
|
-
"metadata": dict(metadata or {}),
|
|
40
|
-
}
|
|
41
|
-
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
42
|
-
return await http.post_json("/api/learning/jobs", json=body)
|
|
43
|
-
|
|
44
|
-
async def start_job(self, job_id: str) -> Dict[str, Any]:
|
|
45
|
-
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
46
|
-
return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def _infer_content_type(filename: str) -> str:
|
|
50
|
-
name = filename.lower()
|
|
51
|
-
if name.endswith(".jsonl"):
|
|
52
|
-
return "application/jsonl"
|
|
53
|
-
if name.endswith(".json"):
|
|
54
|
-
return "application/json"
|
|
55
|
-
if name.endswith(".txt"):
|
|
56
|
-
return "text/plain"
|
|
57
|
-
return "application/octet-stream"
|
|
3
|
+
from __future__ import annotations
|
|
58
4
|
|
|
5
|
+
from .sft.client import FtClient
|
|
59
6
|
|
|
7
|
+
__all__ = ["FtClient"]
|
synth_ai/learning/health.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
import aiohttp
|
|
3
|
+
from typing import Any
|
|
5
4
|
|
|
6
5
|
from ..http import AsyncHttpClient
|
|
7
6
|
|
|
@@ -11,20 +10,28 @@ def _api_base(b: str) -> str:
|
|
|
11
10
|
return b if b.endswith("/api") else f"{b}/api"
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
async def backend_health(base_url: str, api_key: str) ->
|
|
13
|
+
async def backend_health(base_url: str, api_key: str) -> dict[str, Any]:
|
|
15
14
|
async with AsyncHttpClient(base_url, api_key, timeout=15.0) as http:
|
|
16
15
|
js = await http.get(f"{_api_base(base_url)}/health")
|
|
17
16
|
return {"ok": True, "raw": js}
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
async def task_app_health(task_app_url: str) ->
|
|
19
|
+
async def task_app_health(task_app_url: str) -> dict[str, Any]:
|
|
21
20
|
# Delegate to central task module for consistency
|
|
22
21
|
from synth_ai.task.health import task_app_health as _th
|
|
23
22
|
|
|
24
23
|
return await _th(task_app_url)
|
|
25
24
|
|
|
26
25
|
|
|
27
|
-
async def pricing_preflight(
|
|
26
|
+
async def pricing_preflight(
|
|
27
|
+
base_url: str,
|
|
28
|
+
api_key: str,
|
|
29
|
+
*,
|
|
30
|
+
job_type: str,
|
|
31
|
+
gpu_type: str,
|
|
32
|
+
estimated_seconds: float,
|
|
33
|
+
container_count: int,
|
|
34
|
+
) -> dict[str, Any]:
|
|
28
35
|
body = {
|
|
29
36
|
"job_type": job_type,
|
|
30
37
|
"gpu_type": gpu_type,
|
|
@@ -36,8 +43,7 @@ async def pricing_preflight(base_url: str, api_key: str, *, job_type: str, gpu_t
|
|
|
36
43
|
return js if isinstance(js, dict) else {"raw": js}
|
|
37
44
|
|
|
38
45
|
|
|
39
|
-
async def balance_autumn_normalized(base_url: str, api_key: str) ->
|
|
46
|
+
async def balance_autumn_normalized(base_url: str, api_key: str) -> dict[str, Any]:
|
|
40
47
|
async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
|
|
41
48
|
js = await http.get(f"{_api_base(base_url)}/v1/balance/autumn-normalized")
|
|
42
49
|
return js if isinstance(js, dict) else {"raw": js}
|
|
43
|
-
|
synth_ai/learning/jobs.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
4
3
|
import time
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from contextlib import suppress
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
|
-
from .constants import TERMINAL_EVENT_FAILURE, TERMINAL_EVENT_SUCCESS, TERMINAL_STATUSES
|
|
7
8
|
from ..http import AsyncHttpClient, sleep
|
|
9
|
+
from .constants import TERMINAL_EVENT_FAILURE, TERMINAL_EVENT_SUCCESS, TERMINAL_STATUSES
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
def _api_base(b: str) -> str:
|
|
@@ -17,7 +19,7 @@ class JobsApiResolver:
|
|
|
17
19
|
self._base = _api_base(base_url)
|
|
18
20
|
self._strict = strict
|
|
19
21
|
|
|
20
|
-
def status_urls(self, job_id: str) ->
|
|
22
|
+
def status_urls(self, job_id: str) -> list[str]:
|
|
21
23
|
if self._strict:
|
|
22
24
|
return [f"{self._base}/learning/jobs/{job_id}"]
|
|
23
25
|
return [
|
|
@@ -26,7 +28,7 @@ class JobsApiResolver:
|
|
|
26
28
|
f"{self._base}/orchestration/jobs/{job_id}",
|
|
27
29
|
]
|
|
28
30
|
|
|
29
|
-
def events_urls(self, job_id: str, since: int) ->
|
|
31
|
+
def events_urls(self, job_id: str, since: int) -> list[str]:
|
|
30
32
|
if self._strict:
|
|
31
33
|
return [f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200"]
|
|
32
34
|
return [
|
|
@@ -40,7 +42,15 @@ class JobsApiResolver:
|
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
class JobHandle:
|
|
43
|
-
def __init__(
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
base_url: str,
|
|
48
|
+
api_key: str,
|
|
49
|
+
job_id: str,
|
|
50
|
+
*,
|
|
51
|
+
strict: bool = True,
|
|
52
|
+
timeout: float = 600.0,
|
|
53
|
+
) -> None:
|
|
44
54
|
self.base_url = base_url.rstrip("/")
|
|
45
55
|
self.api_key = api_key
|
|
46
56
|
self.job_id = job_id
|
|
@@ -54,23 +64,23 @@ class JobHandle:
|
|
|
54
64
|
max_seconds: float | None = None,
|
|
55
65
|
empty_polls_threshold: int = 5,
|
|
56
66
|
startup_deadline_s: int = 45,
|
|
57
|
-
on_event:
|
|
58
|
-
on_metric:
|
|
59
|
-
) ->
|
|
60
|
-
last_seq_by_stream:
|
|
61
|
-
events_job_id:
|
|
62
|
-
last_status:
|
|
63
|
-
last_step_by_name:
|
|
67
|
+
on_event: Callable[[dict[str, Any]], None] | None = None,
|
|
68
|
+
on_metric: Callable[[dict[str, Any]], None] | None = None,
|
|
69
|
+
) -> dict[str, Any]:
|
|
70
|
+
last_seq_by_stream: dict[str, int] = {}
|
|
71
|
+
events_job_id: str | None = None
|
|
72
|
+
last_status: str | None = None
|
|
73
|
+
last_step_by_name: dict[str, int] = {}
|
|
64
74
|
empty_polls = 0
|
|
65
75
|
saw_any_event = False
|
|
66
76
|
start_t = time.time()
|
|
67
77
|
resolver = JobsApiResolver(self.base_url, strict=self.strict)
|
|
68
|
-
detected_fine_tuned_model:
|
|
78
|
+
detected_fine_tuned_model: str | None = None
|
|
69
79
|
|
|
70
80
|
async with AsyncHttpClient(self.base_url, self.api_key, timeout=self.timeout) as http:
|
|
71
81
|
while True:
|
|
72
82
|
# Status
|
|
73
|
-
status_data:
|
|
83
|
+
status_data: dict[str, Any] | None = None
|
|
74
84
|
for su in resolver.status_urls(self.job_id):
|
|
75
85
|
try:
|
|
76
86
|
status_data = await http.get(su)
|
|
@@ -91,10 +101,8 @@ class JobHandle:
|
|
|
91
101
|
if status and status != last_status:
|
|
92
102
|
last_status = status
|
|
93
103
|
if on_event:
|
|
94
|
-
|
|
104
|
+
with suppress(Exception):
|
|
95
105
|
on_event({"type": "job.status", "message": status})
|
|
96
|
-
except Exception:
|
|
97
|
-
pass
|
|
98
106
|
|
|
99
107
|
# Events
|
|
100
108
|
stream_ids = [self.job_id]
|
|
@@ -102,7 +110,7 @@ class JobHandle:
|
|
|
102
110
|
stream_ids.append(events_job_id)
|
|
103
111
|
total_events_this_cycle = 0
|
|
104
112
|
terminal_event_seen = False
|
|
105
|
-
terminal_event_status:
|
|
113
|
+
terminal_event_status: str | None = None
|
|
106
114
|
for ev_id in stream_ids:
|
|
107
115
|
since = last_seq_by_stream.get(ev_id, 0)
|
|
108
116
|
for eu in resolver.events_urls(ev_id, since):
|
|
@@ -110,11 +118,8 @@ class JobHandle:
|
|
|
110
118
|
ev_js = await http.get(eu)
|
|
111
119
|
except Exception:
|
|
112
120
|
continue
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
if not isinstance(events, list):
|
|
116
|
-
events = []
|
|
117
|
-
except Exception:
|
|
121
|
+
events = (ev_js or {}).get("events") or (ev_js or {}).get("data") or []
|
|
122
|
+
if not isinstance(events, list):
|
|
118
123
|
events = []
|
|
119
124
|
total_events_this_cycle += len(events)
|
|
120
125
|
if events:
|
|
@@ -125,20 +130,16 @@ class JobHandle:
|
|
|
125
130
|
continue
|
|
126
131
|
last_seq_by_stream[ev_id] = seq_val
|
|
127
132
|
if on_event:
|
|
128
|
-
|
|
133
|
+
with suppress(Exception):
|
|
129
134
|
on_event(e)
|
|
130
|
-
except Exception:
|
|
131
|
-
pass
|
|
132
135
|
et = str(e.get("type") or e.get("event_type") or "").lower()
|
|
133
136
|
# Capture fine_tuned_model from event data when available
|
|
134
137
|
if not detected_fine_tuned_model:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
ftm = data_obj.get("fine_tuned_model")
|
|
138
|
+
data_obj = e.get("data") or {}
|
|
139
|
+
if isinstance(data_obj, dict):
|
|
140
|
+
ftm = data_obj.get("fine_tuned_model")
|
|
138
141
|
if isinstance(ftm, str) and ftm:
|
|
139
142
|
detected_fine_tuned_model = ftm
|
|
140
|
-
except Exception:
|
|
141
|
-
pass
|
|
142
143
|
if et in TERMINAL_EVENT_SUCCESS:
|
|
143
144
|
terminal_event_seen = True
|
|
144
145
|
terminal_event_status = "succeeded"
|
|
@@ -158,10 +159,8 @@ class JobHandle:
|
|
|
158
159
|
continue
|
|
159
160
|
last_step_by_name[name] = step
|
|
160
161
|
if on_metric:
|
|
161
|
-
|
|
162
|
+
with suppress(Exception):
|
|
162
163
|
on_metric(p)
|
|
163
|
-
except Exception:
|
|
164
|
-
pass
|
|
165
164
|
except Exception:
|
|
166
165
|
pass
|
|
167
166
|
|
|
@@ -169,20 +168,17 @@ class JobHandle:
|
|
|
169
168
|
if terminal_event_seen or (status and status in TERMINAL_STATUSES):
|
|
170
169
|
# Best-effort enrichment of final result with fine_tuned_model
|
|
171
170
|
result_status = terminal_event_status or status or "completed"
|
|
172
|
-
final_res:
|
|
171
|
+
final_res: dict[str, Any] = {"status": result_status, "job_id": self.job_id}
|
|
173
172
|
if not detected_fine_tuned_model:
|
|
174
173
|
# Briefly try to re-fetch status to see if fine_tuned_model is persisted
|
|
175
174
|
try:
|
|
176
175
|
for su in resolver.status_urls(self.job_id):
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
break
|
|
184
|
-
except Exception:
|
|
185
|
-
continue
|
|
176
|
+
final_status = await http.get(su)
|
|
177
|
+
if isinstance(final_status, dict):
|
|
178
|
+
ftm2 = final_status.get("fine_tuned_model")
|
|
179
|
+
if isinstance(ftm2, str) and ftm2:
|
|
180
|
+
detected_fine_tuned_model = ftm2
|
|
181
|
+
break
|
|
186
182
|
except Exception:
|
|
187
183
|
pass
|
|
188
184
|
if detected_fine_tuned_model:
|
|
@@ -200,6 +196,6 @@ class JobHandle:
|
|
|
200
196
|
)
|
|
201
197
|
await sleep(interval_seconds)
|
|
202
198
|
if max_seconds is not None and (time.time() - start_t) >= max_seconds:
|
|
203
|
-
raise TimeoutError(
|
|
204
|
-
|
|
205
|
-
|
|
199
|
+
raise TimeoutError(
|
|
200
|
+
f"Polling timed out after {max_seconds}s for job {self.job_id}"
|
|
201
|
+
)
|
|
@@ -1,18 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .client import RlClient
|
|
4
|
+
from .config import RLJobConfig
|
|
1
5
|
from .contracts import (
|
|
2
6
|
RolloutEnvSpec,
|
|
7
|
+
RolloutMetrics,
|
|
3
8
|
RolloutPolicySpec,
|
|
4
9
|
RolloutRecordConfig,
|
|
5
|
-
RolloutSafetyConfig,
|
|
6
10
|
RolloutRequest,
|
|
11
|
+
RolloutResponse,
|
|
12
|
+
RolloutSafetyConfig,
|
|
7
13
|
RolloutStep,
|
|
8
14
|
RolloutTrajectory,
|
|
9
|
-
RolloutMetrics,
|
|
10
|
-
RolloutResponse,
|
|
11
15
|
)
|
|
12
|
-
from .env_keys import
|
|
16
|
+
from .env_keys import (
|
|
17
|
+
MAX_ENVIRONMENT_API_KEY_BYTES,
|
|
18
|
+
encrypt_for_backend,
|
|
19
|
+
setup_environment_api_key,
|
|
20
|
+
)
|
|
13
21
|
from .secrets import mint_environment_api_key
|
|
14
22
|
|
|
15
23
|
__all__ = [
|
|
24
|
+
"RlClient",
|
|
25
|
+
"RLJobConfig",
|
|
16
26
|
"RolloutEnvSpec",
|
|
17
27
|
"RolloutPolicySpec",
|
|
18
28
|
"RolloutRecordConfig",
|
|
@@ -27,4 +37,3 @@ __all__ = [
|
|
|
27
37
|
"mint_environment_api_key",
|
|
28
38
|
"MAX_ENVIRONMENT_API_KEY_BYTES",
|
|
29
39
|
]
|
|
30
|
-
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from contextlib import suppress
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from synth_ai.api.models.supported import (
|
|
9
|
+
UnsupportedModelError,
|
|
10
|
+
normalize_model_identifier,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from ...http import AsyncHttpClient, HTTPError, sleep
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _api_base(b: str) -> str:
|
|
17
|
+
b = (b or "").rstrip("/")
|
|
18
|
+
return b if b.endswith("/api") else f"{b}/api"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RlClient:
|
|
22
|
+
"""Lightweight RL client for provider-agnostic job control."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, base_url: str, api_key: str, *, timeout: float = 600.0) -> None:
|
|
25
|
+
self._base_url = base_url.rstrip("/")
|
|
26
|
+
self._api_key = api_key
|
|
27
|
+
self._timeout = timeout
|
|
28
|
+
|
|
29
|
+
async def resolve_trainer_start_url(self, trainer_id: str) -> str:
|
|
30
|
+
path = f"/api/rl/services/{trainer_id}"
|
|
31
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
|
32
|
+
js = await http.get(path)
|
|
33
|
+
if not isinstance(js, dict):
|
|
34
|
+
raise HTTPError(
|
|
35
|
+
status=500,
|
|
36
|
+
url=path,
|
|
37
|
+
message="invalid_service_response",
|
|
38
|
+
body_snippet=str(js)[:200],
|
|
39
|
+
)
|
|
40
|
+
start_url = js.get("training_start_url")
|
|
41
|
+
if not isinstance(start_url, str) or not start_url:
|
|
42
|
+
raise HTTPError(
|
|
43
|
+
status=500,
|
|
44
|
+
url=path,
|
|
45
|
+
message="missing_training_start_url",
|
|
46
|
+
body_snippet=str(js)[:200],
|
|
47
|
+
)
|
|
48
|
+
return start_url
|
|
49
|
+
|
|
50
|
+
async def create_job(
|
|
51
|
+
self,
|
|
52
|
+
*,
|
|
53
|
+
model: str,
|
|
54
|
+
task_app_url: str,
|
|
55
|
+
trainer: dict[str, Any],
|
|
56
|
+
trainer_id: str | None = None,
|
|
57
|
+
job_config_id: str | None = None,
|
|
58
|
+
inline_config: dict[str, Any] | None = None,
|
|
59
|
+
) -> dict[str, Any]:
|
|
60
|
+
try:
|
|
61
|
+
normalized_model = normalize_model_identifier(model)
|
|
62
|
+
except UnsupportedModelError as exc:
|
|
63
|
+
raise ValueError(str(exc)) from exc
|
|
64
|
+
|
|
65
|
+
body = {
|
|
66
|
+
"job_type": "rl",
|
|
67
|
+
"data": {
|
|
68
|
+
"model": normalized_model,
|
|
69
|
+
"endpoint_base_url": task_app_url,
|
|
70
|
+
**({"job_config_id": job_config_id} if job_config_id else {}),
|
|
71
|
+
**({"config": inline_config} if inline_config else {}),
|
|
72
|
+
"trainer": {
|
|
73
|
+
"batch_size": int(trainer.get("batch_size", 1)),
|
|
74
|
+
"group_size": max(2, int(trainer.get("group_size", 2))),
|
|
75
|
+
},
|
|
76
|
+
},
|
|
77
|
+
}
|
|
78
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
79
|
+
js = await http.post_json(f"{_api_base(self._base_url)}/rl/jobs", json=body)
|
|
80
|
+
if not isinstance(js, dict):
|
|
81
|
+
raise HTTPError(
|
|
82
|
+
status=500,
|
|
83
|
+
url="/api/rl/jobs",
|
|
84
|
+
message="invalid_create_response",
|
|
85
|
+
body_snippet=str(js)[:200],
|
|
86
|
+
)
|
|
87
|
+
return js
|
|
88
|
+
|
|
89
|
+
async def start_job_if_supported(self, job_id: str) -> dict[str, Any] | None:
|
|
90
|
+
path = f"{_api_base(self._base_url)}/rl/jobs/{job_id}/start"
|
|
91
|
+
try:
|
|
92
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
|
93
|
+
return await http.post_json(path, json={})
|
|
94
|
+
except HTTPError as he: # noqa: PERF203
|
|
95
|
+
if he.status == 404:
|
|
96
|
+
return None
|
|
97
|
+
raise
|
|
98
|
+
|
|
99
|
+
async def get_job(self, job_id: str) -> dict[str, Any]:
|
|
100
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
|
101
|
+
return await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}")
|
|
102
|
+
|
|
103
|
+
async def get_events(
|
|
104
|
+
self, job_id: str, *, since_seq: int = 0, limit: int = 200
|
|
105
|
+
) -> list[dict[str, Any]]:
|
|
106
|
+
params = {"since_seq": since_seq, "limit": limit}
|
|
107
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
|
108
|
+
try:
|
|
109
|
+
js = await http.get(
|
|
110
|
+
f"{_api_base(self._base_url)}/learning/jobs/{job_id}/events", params=params
|
|
111
|
+
)
|
|
112
|
+
except HTTPError as he:
|
|
113
|
+
with suppress(Exception):
|
|
114
|
+
print(
|
|
115
|
+
f"[poll] events HTTPError status={he.status} url={he.url} since_seq={since_seq} body={(he.body_snippet or '')[:200]}"
|
|
116
|
+
)
|
|
117
|
+
raise
|
|
118
|
+
if isinstance(js, dict):
|
|
119
|
+
evs = js.get("events") or js.get("data")
|
|
120
|
+
if isinstance(evs, list):
|
|
121
|
+
return evs
|
|
122
|
+
return []
|
|
123
|
+
|
|
124
|
+
async def get_metrics(
|
|
125
|
+
self, job_id: str, *, after_step: int = -1, limit: int = 200
|
|
126
|
+
) -> list[dict[str, Any]]:
|
|
127
|
+
params = {"after_step": after_step, "limit": limit}
|
|
128
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
|
|
129
|
+
js = await http.get(
|
|
130
|
+
f"{_api_base(self._base_url)}/learning/jobs/{job_id}/metrics", params=params
|
|
131
|
+
)
|
|
132
|
+
if isinstance(js, dict) and isinstance(js.get("points"), list):
|
|
133
|
+
return js["points"]
|
|
134
|
+
return []
|
|
135
|
+
|
|
136
|
+
async def poll_until_terminal(
|
|
137
|
+
self,
|
|
138
|
+
job_id: str,
|
|
139
|
+
*,
|
|
140
|
+
interval_seconds: float = 2.0,
|
|
141
|
+
max_seconds: float | None = None,
|
|
142
|
+
empty_polls_threshold: int = 5,
|
|
143
|
+
startup_deadline_s: int = 45,
|
|
144
|
+
on_event: Callable[[dict[str, Any]], None] | None = None,
|
|
145
|
+
on_metric: Callable[[dict[str, Any]], None] | None = None,
|
|
146
|
+
) -> dict[str, Any]:
|
|
147
|
+
last_seq_by_stream: dict[str, int] = {}
|
|
148
|
+
events_job_id: str | None = None
|
|
149
|
+
last_status: str | None = None
|
|
150
|
+
last_step_by_name: dict[str, int] = {}
|
|
151
|
+
empty_polls = 0
|
|
152
|
+
saw_any_event = False
|
|
153
|
+
start_t = time.time()
|
|
154
|
+
terminal = {"succeeded", "failed", "cancelled", "canceled", "error", "completed"}
|
|
155
|
+
|
|
156
|
+
while True:
|
|
157
|
+
status_data: dict[str, Any] | None = None
|
|
158
|
+
try:
|
|
159
|
+
status_data = await self.get_job(job_id)
|
|
160
|
+
except Exception:
|
|
161
|
+
status_data = None
|
|
162
|
+
if status_data is None:
|
|
163
|
+
with suppress(Exception):
|
|
164
|
+
print(f"[poll] get_job returned None base={self._base_url} job_id={job_id}")
|
|
165
|
+
status = str((status_data or {}).get("status") or "").lower()
|
|
166
|
+
if status_data:
|
|
167
|
+
linked = status_data.get("linked_job_id")
|
|
168
|
+
if isinstance(linked, str) and linked and linked != events_job_id:
|
|
169
|
+
events_job_id = linked
|
|
170
|
+
with suppress(Exception):
|
|
171
|
+
print(f"[poll] discovered linked_job_id stream={events_job_id}")
|
|
172
|
+
if status and status != last_status:
|
|
173
|
+
last_status = status
|
|
174
|
+
if on_event:
|
|
175
|
+
with suppress(Exception):
|
|
176
|
+
on_event({"type": "rl.status", "message": status})
|
|
177
|
+
|
|
178
|
+
stream_ids = [job_id]
|
|
179
|
+
if events_job_id and events_job_id not in stream_ids:
|
|
180
|
+
stream_ids.append(events_job_id)
|
|
181
|
+
with suppress(Exception):
|
|
182
|
+
print(
|
|
183
|
+
f"[poll] streams={stream_ids} intervals={interval_seconds}s since_map={last_seq_by_stream} empty_polls={empty_polls}"
|
|
184
|
+
)
|
|
185
|
+
total_events_this_cycle = 0
|
|
186
|
+
terminal_event_seen = False
|
|
187
|
+
terminal_event_status: str | None = None
|
|
188
|
+
for ev_id in stream_ids:
|
|
189
|
+
since = last_seq_by_stream.get(ev_id, 0)
|
|
190
|
+
try:
|
|
191
|
+
events = await self.get_events(ev_id, since_seq=since, limit=200)
|
|
192
|
+
except HTTPError as he:
|
|
193
|
+
with suppress(Exception):
|
|
194
|
+
print(
|
|
195
|
+
f"[poll] get_events error status={he.status} url={he.url} since={since} body={(he.body_snippet or '')[:200]}"
|
|
196
|
+
)
|
|
197
|
+
events = []
|
|
198
|
+
except Exception as e:
|
|
199
|
+
with suppress(Exception):
|
|
200
|
+
print(
|
|
201
|
+
f"[poll] get_events unexpected error ev_id={ev_id} since={since} err={type(e).__name__}: {e}"
|
|
202
|
+
)
|
|
203
|
+
events = []
|
|
204
|
+
total_events_this_cycle += len(events)
|
|
205
|
+
if events:
|
|
206
|
+
saw_any_event = True
|
|
207
|
+
for e in events:
|
|
208
|
+
seq_val = int(e.get("seq") or 0)
|
|
209
|
+
if seq_val <= last_seq_by_stream.get(ev_id, 0):
|
|
210
|
+
continue
|
|
211
|
+
last_seq_by_stream[ev_id] = seq_val
|
|
212
|
+
if on_event:
|
|
213
|
+
with suppress(Exception):
|
|
214
|
+
on_event(e)
|
|
215
|
+
et = str(e.get("type") or e.get("event_type") or "").lower()
|
|
216
|
+
if et in ("rl.job.completed", "workflow.completed", "rl.train.completed"):
|
|
217
|
+
terminal_event_seen = True
|
|
218
|
+
terminal_event_status = "succeeded"
|
|
219
|
+
elif et in ("rl.job.failed", "workflow.failed"):
|
|
220
|
+
terminal_event_seen = True
|
|
221
|
+
terminal_event_status = "failed"
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
after = max(last_step_by_name.values()) if last_step_by_name else -1
|
|
225
|
+
points = await self.get_metrics(job_id, after_step=after, limit=200)
|
|
226
|
+
for p in points:
|
|
227
|
+
name = str(p.get("name") or "")
|
|
228
|
+
step = int(p.get("step") or -1)
|
|
229
|
+
if step <= last_step_by_name.get(name, -1):
|
|
230
|
+
continue
|
|
231
|
+
last_step_by_name[name] = step
|
|
232
|
+
if on_metric:
|
|
233
|
+
with suppress(Exception):
|
|
234
|
+
on_metric(p)
|
|
235
|
+
except Exception:
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
if terminal_event_seen:
|
|
239
|
+
return {"status": terminal_event_status or status or "completed", "job_id": job_id}
|
|
240
|
+
if status and status in terminal:
|
|
241
|
+
return {"status": status, "job_id": job_id}
|
|
242
|
+
|
|
243
|
+
if total_events_this_cycle == 0:
|
|
244
|
+
empty_polls += 1
|
|
245
|
+
else:
|
|
246
|
+
empty_polls = 0
|
|
247
|
+
if empty_polls >= max(1, int(empty_polls_threshold)):
|
|
248
|
+
with suppress(Exception):
|
|
249
|
+
print(
|
|
250
|
+
f"[poll] threshold hit: empty_polls={empty_polls} >= {empty_polls_threshold} streams={stream_ids} last_seq_map={last_seq_by_stream}"
|
|
251
|
+
)
|
|
252
|
+
raise AssertionError(
|
|
253
|
+
f"No new events detected for {empty_polls_threshold} consecutive polls. Check event ingestion."
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
|
|
257
|
+
with suppress(Exception):
|
|
258
|
+
print(
|
|
259
|
+
f"[poll] startup window exceeded: {startup_deadline_s}s base={self._base_url} job={job_id} streams={stream_ids} last_seq_map={last_seq_by_stream}"
|
|
260
|
+
)
|
|
261
|
+
raise AssertionError(
|
|
262
|
+
f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming."
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
await sleep(interval_seconds)
|
|
266
|
+
if max_seconds is not None and (time.time() - start_t) >= max_seconds:
|
|
267
|
+
raise TimeoutError(f"Polling timed out after {max_seconds}s for job {job_id}")
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _ensure_positive(value: Any, *, name: str) -> int:
|
|
8
|
+
try:
|
|
9
|
+
ivalue = int(value)
|
|
10
|
+
except (TypeError, ValueError) as exc:
|
|
11
|
+
raise ValueError(f"{name} must be an integer") from exc
|
|
12
|
+
if ivalue < 1:
|
|
13
|
+
raise ValueError(f"{name} must be >= 1")
|
|
14
|
+
return ivalue
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(slots=True)
|
|
18
|
+
class RLJobConfig:
|
|
19
|
+
model: str
|
|
20
|
+
task_app_url: str
|
|
21
|
+
trainer_id: str
|
|
22
|
+
batch_size: int = 1
|
|
23
|
+
group_size: int = 2
|
|
24
|
+
job_config_id: str | None = None
|
|
25
|
+
inline_config: dict[str, Any] | None = None
|
|
26
|
+
|
|
27
|
+
def trainer_dict(self) -> dict[str, Any]:
|
|
28
|
+
return {
|
|
29
|
+
"batch_size": _ensure_positive(self.batch_size, name="trainer.batch_size"),
|
|
30
|
+
"group_size": _ensure_positive(self.group_size, name="trainer.group_size"),
|
|
31
|
+
}
|
|
@@ -1,20 +1,17 @@
|
|
|
1
|
-
|
|
1
|
+
"""Compatibility re-export for rollout contracts used by RL tooling."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
Compatibility layer: re-export Task App rollout contracts from synth_ai.task.contracts
|
|
5
|
-
so existing imports continue to work while consolidating under synth_ai.task.
|
|
6
|
-
"""
|
|
3
|
+
from __future__ import annotations
|
|
7
4
|
|
|
8
5
|
from synth_ai.task.contracts import (
|
|
9
6
|
RolloutEnvSpec,
|
|
7
|
+
RolloutMetrics,
|
|
10
8
|
RolloutPolicySpec,
|
|
11
9
|
RolloutRecordConfig,
|
|
12
|
-
RolloutSafetyConfig,
|
|
13
10
|
RolloutRequest,
|
|
11
|
+
RolloutResponse,
|
|
12
|
+
RolloutSafetyConfig,
|
|
14
13
|
RolloutStep,
|
|
15
14
|
RolloutTrajectory,
|
|
16
|
-
RolloutMetrics,
|
|
17
|
-
RolloutResponse,
|
|
18
15
|
)
|
|
19
16
|
|
|
20
17
|
__all__ = [
|
|
@@ -28,5 +25,3 @@ __all__ = [
|
|
|
28
25
|
"RolloutMetrics",
|
|
29
26
|
"RolloutResponse",
|
|
30
27
|
]
|
|
31
|
-
|
|
32
|
-
|