synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +6 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev9.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/licenses/LICENSE +0 -0
synth_ai/jobs/client.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
|
+
from synth_ai.api.models.supported import normalize_model_identifier
|
|
5
6
|
from synth_ai.http import AsyncHttpClient
|
|
7
|
+
from synth_ai.learning.sft.config import prepare_sft_job_payload
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class FilesApi:
|
|
@@ -15,9 +17,9 @@ class FilesApi:
|
|
|
15
17
|
filename: str,
|
|
16
18
|
content: bytes,
|
|
17
19
|
purpose: str,
|
|
18
|
-
content_type:
|
|
19
|
-
idempotency_key:
|
|
20
|
-
) ->
|
|
20
|
+
content_type: str | None = None,
|
|
21
|
+
idempotency_key: str | None = None,
|
|
22
|
+
) -> dict[str, Any]:
|
|
21
23
|
data = {"purpose": purpose}
|
|
22
24
|
files = {"file": (filename, content, content_type)}
|
|
23
25
|
headers = {"Idempotency-Key": idempotency_key} if idempotency_key else None
|
|
@@ -26,9 +28,9 @@ class FilesApi:
|
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
async def list(
|
|
29
|
-
self, *, purpose:
|
|
30
|
-
) ->
|
|
31
|
-
params:
|
|
31
|
+
self, *, purpose: str | None = None, after: str | None = None, limit: int = 20
|
|
32
|
+
) -> dict[str, Any]:
|
|
33
|
+
params: dict[str, Any] = {}
|
|
32
34
|
if purpose is not None:
|
|
33
35
|
params["purpose"] = purpose
|
|
34
36
|
if after is not None:
|
|
@@ -36,16 +38,16 @@ class FilesApi:
|
|
|
36
38
|
params["limit"] = limit
|
|
37
39
|
return await self._http.get("/api/files", params=params)
|
|
38
40
|
|
|
39
|
-
async def retrieve(self, file_id: str) ->
|
|
41
|
+
async def retrieve(self, file_id: str) -> dict[str, Any]:
|
|
40
42
|
return await self._http.get(f"/api/files/{file_id}")
|
|
41
43
|
|
|
42
44
|
async def delete(self, file_id: str) -> Any:
|
|
43
45
|
return await self._http.delete(f"/api/files/{file_id}")
|
|
44
46
|
|
|
45
47
|
async def list_jobs(
|
|
46
|
-
self, file_id: str, *, after:
|
|
47
|
-
) ->
|
|
48
|
-
params:
|
|
48
|
+
self, file_id: str, *, after: str | None = None, limit: int = 20
|
|
49
|
+
) -> dict[str, Any]:
|
|
50
|
+
params: dict[str, Any] = {"limit": limit}
|
|
49
51
|
if after is not None:
|
|
50
52
|
params["after"] = after
|
|
51
53
|
return await self._http.get(f"/api/files/{file_id}/jobs", params=params)
|
|
@@ -60,42 +62,40 @@ class SftJobsApi:
|
|
|
60
62
|
*,
|
|
61
63
|
training_file: str,
|
|
62
64
|
model: str,
|
|
63
|
-
validation_file:
|
|
64
|
-
hyperparameters:
|
|
65
|
-
suffix:
|
|
66
|
-
integrations:
|
|
67
|
-
metadata:
|
|
68
|
-
idempotency_key:
|
|
69
|
-
) ->
|
|
70
|
-
payload
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if metadata is not None:
|
|
83
|
-
payload["metadata"] = metadata
|
|
65
|
+
validation_file: str | None = None,
|
|
66
|
+
hyperparameters: dict[str, Any] | None = None,
|
|
67
|
+
suffix: str | None = None,
|
|
68
|
+
integrations: dict[str, Any] | None = None,
|
|
69
|
+
metadata: dict[str, Any] | None = None,
|
|
70
|
+
idempotency_key: str | None = None,
|
|
71
|
+
) -> dict[str, Any]:
|
|
72
|
+
payload = prepare_sft_job_payload(
|
|
73
|
+
model=model,
|
|
74
|
+
training_file=training_file,
|
|
75
|
+
hyperparameters=hyperparameters,
|
|
76
|
+
metadata=metadata,
|
|
77
|
+
training_type=None,
|
|
78
|
+
validation_file=validation_file,
|
|
79
|
+
suffix=suffix,
|
|
80
|
+
integrations=integrations,
|
|
81
|
+
training_file_field="training_file",
|
|
82
|
+
require_training_file=True,
|
|
83
|
+
)
|
|
84
84
|
headers = {"Idempotency-Key": idempotency_key} if idempotency_key else None
|
|
85
85
|
return await self._http.post_json("/api/sft/jobs", json=payload, headers=headers)
|
|
86
86
|
|
|
87
87
|
async def list(
|
|
88
88
|
self,
|
|
89
89
|
*,
|
|
90
|
-
status:
|
|
91
|
-
model:
|
|
92
|
-
file_id:
|
|
93
|
-
created_after:
|
|
94
|
-
created_before:
|
|
95
|
-
after:
|
|
90
|
+
status: str | None = None,
|
|
91
|
+
model: str | None = None,
|
|
92
|
+
file_id: str | None = None,
|
|
93
|
+
created_after: int | None = None,
|
|
94
|
+
created_before: int | None = None,
|
|
95
|
+
after: str | None = None,
|
|
96
96
|
limit: int = 20,
|
|
97
|
-
) ->
|
|
98
|
-
params:
|
|
97
|
+
) -> dict[str, Any]:
|
|
98
|
+
params: dict[str, Any] = {"limit": limit}
|
|
99
99
|
if status is not None:
|
|
100
100
|
params["status"] = status
|
|
101
101
|
if model is not None:
|
|
@@ -110,22 +110,22 @@ class SftJobsApi:
|
|
|
110
110
|
params["after"] = after
|
|
111
111
|
return await self._http.get("/api/sft/jobs", params=params)
|
|
112
112
|
|
|
113
|
-
async def retrieve(self, job_id: str) ->
|
|
113
|
+
async def retrieve(self, job_id: str) -> dict[str, Any]:
|
|
114
114
|
return await self._http.get(f"/api/sft/jobs/{job_id}")
|
|
115
115
|
|
|
116
|
-
async def cancel(self, job_id: str) ->
|
|
116
|
+
async def cancel(self, job_id: str) -> dict[str, Any]:
|
|
117
117
|
return await self._http.post_json(f"/api/sft/jobs/{job_id}/cancel", json={})
|
|
118
118
|
|
|
119
119
|
async def list_events(
|
|
120
120
|
self, job_id: str, *, since_seq: int = 0, limit: int = 200
|
|
121
|
-
) ->
|
|
121
|
+
) -> dict[str, Any]:
|
|
122
122
|
params = {"since_seq": since_seq, "limit": limit}
|
|
123
123
|
return await self._http.get(f"/api/sft/jobs/{job_id}/events", params=params)
|
|
124
124
|
|
|
125
125
|
async def checkpoints(
|
|
126
|
-
self, job_id: str, *, after:
|
|
127
|
-
) ->
|
|
128
|
-
params:
|
|
126
|
+
self, job_id: str, *, after: str | None = None, limit: int = 10
|
|
127
|
+
) -> dict[str, Any]:
|
|
128
|
+
params: dict[str, Any] = {"limit": limit}
|
|
129
129
|
if after is not None:
|
|
130
130
|
params["after"] = after
|
|
131
131
|
return await self._http.get(f"/api/sft/jobs/{job_id}/checkpoints", params=params)
|
|
@@ -141,14 +141,14 @@ class RlJobsApi:
|
|
|
141
141
|
model: str,
|
|
142
142
|
endpoint_base_url: str,
|
|
143
143
|
trainer_id: str,
|
|
144
|
-
trainer:
|
|
145
|
-
job_config_id:
|
|
146
|
-
config:
|
|
147
|
-
metadata:
|
|
148
|
-
idempotency_key:
|
|
149
|
-
) ->
|
|
150
|
-
payload:
|
|
151
|
-
"model": model,
|
|
144
|
+
trainer: dict[str, Any] | None = None,
|
|
145
|
+
job_config_id: str | None = None,
|
|
146
|
+
config: dict[str, Any] | None = None,
|
|
147
|
+
metadata: dict[str, Any] | None = None,
|
|
148
|
+
idempotency_key: str | None = None,
|
|
149
|
+
) -> dict[str, Any]:
|
|
150
|
+
payload: dict[str, Any] = {
|
|
151
|
+
"model": normalize_model_identifier(model),
|
|
152
152
|
"endpoint_base_url": endpoint_base_url,
|
|
153
153
|
"trainer_id": trainer_id,
|
|
154
154
|
}
|
|
@@ -166,14 +166,14 @@ class RlJobsApi:
|
|
|
166
166
|
async def list(
|
|
167
167
|
self,
|
|
168
168
|
*,
|
|
169
|
-
status:
|
|
170
|
-
model:
|
|
171
|
-
created_after:
|
|
172
|
-
created_before:
|
|
173
|
-
after:
|
|
169
|
+
status: str | None = None,
|
|
170
|
+
model: str | None = None,
|
|
171
|
+
created_after: int | None = None,
|
|
172
|
+
created_before: int | None = None,
|
|
173
|
+
after: str | None = None,
|
|
174
174
|
limit: int = 20,
|
|
175
|
-
) ->
|
|
176
|
-
params:
|
|
175
|
+
) -> dict[str, Any]:
|
|
176
|
+
params: dict[str, Any] = {"limit": limit}
|
|
177
177
|
if status is not None:
|
|
178
178
|
params["status"] = status
|
|
179
179
|
if model is not None:
|
|
@@ -186,21 +186,21 @@ class RlJobsApi:
|
|
|
186
186
|
params["after"] = after
|
|
187
187
|
return await self._http.get("/api/rl/jobs", params=params)
|
|
188
188
|
|
|
189
|
-
async def retrieve(self, job_id: str) ->
|
|
189
|
+
async def retrieve(self, job_id: str) -> dict[str, Any]:
|
|
190
190
|
return await self._http.get(f"/api/rl/jobs/{job_id}")
|
|
191
191
|
|
|
192
|
-
async def cancel(self, job_id: str) ->
|
|
192
|
+
async def cancel(self, job_id: str) -> dict[str, Any]:
|
|
193
193
|
return await self._http.post_json(f"/api/rl/jobs/{job_id}/cancel", json={})
|
|
194
194
|
|
|
195
195
|
async def list_events(
|
|
196
196
|
self, job_id: str, *, since_seq: int = 0, limit: int = 200
|
|
197
|
-
) ->
|
|
197
|
+
) -> dict[str, Any]:
|
|
198
198
|
params = {"since_seq": since_seq, "limit": limit}
|
|
199
199
|
return await self._http.get(f"/api/rl/jobs/{job_id}/events", params=params)
|
|
200
200
|
|
|
201
201
|
async def metrics(
|
|
202
202
|
self, job_id: str, *, after_step: int = -1, limit: int = 200
|
|
203
|
-
) ->
|
|
203
|
+
) -> dict[str, Any]:
|
|
204
204
|
params = {"after_step": after_step, "limit": limit}
|
|
205
205
|
return await self._http.get(f"/api/rl/jobs/{job_id}/metrics", params=params)
|
|
206
206
|
|
|
@@ -212,13 +212,13 @@ class ModelsApi:
|
|
|
212
212
|
async def list(
|
|
213
213
|
self,
|
|
214
214
|
*,
|
|
215
|
-
source:
|
|
216
|
-
base_model:
|
|
217
|
-
status:
|
|
218
|
-
after:
|
|
215
|
+
source: str | None = None,
|
|
216
|
+
base_model: str | None = None,
|
|
217
|
+
status: str | None = None,
|
|
218
|
+
after: str | None = None,
|
|
219
219
|
limit: int = 20,
|
|
220
|
-
) ->
|
|
221
|
-
params:
|
|
220
|
+
) -> dict[str, Any]:
|
|
221
|
+
params: dict[str, Any] = {"limit": limit}
|
|
222
222
|
if source is not None:
|
|
223
223
|
params["source"] = source
|
|
224
224
|
if base_model is not None:
|
|
@@ -229,35 +229,30 @@ class ModelsApi:
|
|
|
229
229
|
params["after"] = after
|
|
230
230
|
return await self._http.get("/api/models", params=params)
|
|
231
231
|
|
|
232
|
-
async def retrieve(self, model_id: str) ->
|
|
232
|
+
async def retrieve(self, model_id: str) -> dict[str, Any]:
|
|
233
233
|
return await self._http.get(f"/api/models/{model_id}")
|
|
234
234
|
|
|
235
235
|
async def delete(self, model_id: str) -> Any:
|
|
236
236
|
return await self._http.delete(f"/api/models/{model_id}")
|
|
237
237
|
|
|
238
238
|
async def list_jobs(
|
|
239
|
-
self, model_id: str, *, after:
|
|
240
|
-
) ->
|
|
241
|
-
params:
|
|
239
|
+
self, model_id: str, *, after: str | None = None, limit: int = 20
|
|
240
|
+
) -> dict[str, Any]:
|
|
241
|
+
params: dict[str, Any] = {"limit": limit}
|
|
242
242
|
if after is not None:
|
|
243
243
|
params["after"] = after
|
|
244
244
|
return await self._http.get(f"/api/models/{model_id}/jobs", params=params)
|
|
245
245
|
|
|
246
246
|
|
|
247
247
|
class JobsClient:
|
|
248
|
-
"""High-level client aggregating job APIs.
|
|
249
|
-
|
|
250
|
-
Usage:
|
|
251
|
-
async with JobsClient(base_url, api_key) as c:
|
|
252
|
-
await c.files.list()
|
|
253
|
-
"""
|
|
248
|
+
"""High-level client aggregating job APIs."""
|
|
254
249
|
|
|
255
250
|
def __init__(
|
|
256
251
|
self,
|
|
257
252
|
base_url: str,
|
|
258
253
|
api_key: str,
|
|
259
254
|
timeout: float = 30.0,
|
|
260
|
-
http:
|
|
255
|
+
http: AsyncHttpClient | None = None,
|
|
261
256
|
) -> None:
|
|
262
257
|
self._base_url = base_url
|
|
263
258
|
self._api_key = api_key
|
|
@@ -268,7 +263,7 @@ class JobsClient:
|
|
|
268
263
|
self.rl = RlJobsApi(self._http)
|
|
269
264
|
self.models = ModelsApi(self._http)
|
|
270
265
|
|
|
271
|
-
async def __aenter__(self) ->
|
|
266
|
+
async def __aenter__(self) -> JobsClient:
|
|
272
267
|
await self._http.__aenter__()
|
|
273
268
|
return self
|
|
274
269
|
|
synth_ai/learning/__init__.py
CHANGED
|
@@ -1,16 +1,51 @@
|
|
|
1
|
+
from synth_ai.task import task_app_health, validate_task_app_url
|
|
2
|
+
|
|
1
3
|
from .client import LearningClient
|
|
2
|
-
from .
|
|
3
|
-
from .ft_client import FtClient
|
|
4
|
-
from .validators import validate_training_jsonl, validate_trainer_cfg_rl
|
|
5
|
-
from synth_ai.task import validate_task_app_url, task_app_health
|
|
6
|
-
from .health import backend_health, pricing_preflight, balance_autumn_normalized
|
|
7
|
-
from .sse import stream_events as stream_job_events
|
|
4
|
+
from .health import backend_health, balance_autumn_normalized, pricing_preflight
|
|
8
5
|
from .jobs import JobHandle, JobsApiResolver
|
|
6
|
+
from .rl import (
|
|
7
|
+
MAX_ENVIRONMENT_API_KEY_BYTES,
|
|
8
|
+
RlClient,
|
|
9
|
+
RLJobConfig,
|
|
10
|
+
RolloutEnvSpec,
|
|
11
|
+
RolloutMetrics,
|
|
12
|
+
RolloutPolicySpec,
|
|
13
|
+
RolloutRecordConfig,
|
|
14
|
+
RolloutRequest,
|
|
15
|
+
RolloutResponse,
|
|
16
|
+
RolloutSafetyConfig,
|
|
17
|
+
RolloutStep,
|
|
18
|
+
RolloutTrajectory,
|
|
19
|
+
encrypt_for_backend,
|
|
20
|
+
mint_environment_api_key,
|
|
21
|
+
setup_environment_api_key,
|
|
22
|
+
)
|
|
23
|
+
from .sft import FtClient
|
|
24
|
+
from .sft.config import SFTJobConfig, prepare_sft_job_payload
|
|
25
|
+
from .sse import stream_events as stream_job_events
|
|
26
|
+
from .validators import validate_trainer_cfg_rl, validate_training_jsonl
|
|
9
27
|
|
|
10
28
|
__all__ = [
|
|
11
29
|
"LearningClient",
|
|
12
30
|
"RlClient",
|
|
31
|
+
"RLJobConfig",
|
|
13
32
|
"FtClient",
|
|
33
|
+
"SFTJobConfig",
|
|
34
|
+
"prepare_sft_job_payload",
|
|
35
|
+
"RolloutEnvSpec",
|
|
36
|
+
"RolloutPolicySpec",
|
|
37
|
+
"RolloutRecordConfig",
|
|
38
|
+
"RolloutSafetyConfig",
|
|
39
|
+
"RolloutRequest",
|
|
40
|
+
"RolloutStep",
|
|
41
|
+
"RolloutTrajectory",
|
|
42
|
+
"RolloutMetrics",
|
|
43
|
+
"RolloutResponse",
|
|
44
|
+
"mint_environment_api_key",
|
|
45
|
+
"encrypt_for_backend",
|
|
46
|
+
"setup_environment_api_key",
|
|
47
|
+
"MAX_ENVIRONMENT_API_KEY_BYTES",
|
|
48
|
+
# convenience re-export for typing
|
|
14
49
|
"validate_training_jsonl",
|
|
15
50
|
"validate_trainer_cfg_rl",
|
|
16
51
|
"validate_task_app_url",
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# class LearningModality(str, enum.Enum):
|
|
2
|
+
# """Modality of learning."""
|
|
3
|
+
|
|
4
|
+
# online_on_policy = "online_on_policy"
|
|
5
|
+
# online_off_policy = "online_off_policy"
|
|
6
|
+
# offline = "offline"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# class LearningAlgorithm(str, enum.Enum):
|
|
10
|
+
# """Algorithm of learning."""
|
|
11
|
+
|
|
12
|
+
# gspo = "gspo"
|
|
13
|
+
# reinforce = "reinforce"
|
|
14
|
+
# sft = "sft"
|
synth_ai/learning/client.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from contextlib import suppress
|
|
3
5
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, TypedDict
|
|
7
|
+
|
|
8
|
+
from synth_ai.api.models.supported import (
|
|
9
|
+
UnsupportedModelError,
|
|
10
|
+
normalize_model_identifier,
|
|
11
|
+
)
|
|
12
|
+
from synth_ai.learning.sft.config import prepare_sft_job_payload
|
|
5
13
|
|
|
6
14
|
from ..http import AsyncHttpClient, HTTPError, sleep
|
|
7
15
|
|
|
@@ -34,30 +42,56 @@ class LearningClient:
|
|
|
34
42
|
training_type: str,
|
|
35
43
|
model: str,
|
|
36
44
|
training_file_id: str,
|
|
37
|
-
hyperparameters:
|
|
38
|
-
metadata:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
"
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
|
|
45
|
+
hyperparameters: dict[str, Any] | None = None,
|
|
46
|
+
metadata: dict[str, Any] | None = None,
|
|
47
|
+
validation_file: str | None = None,
|
|
48
|
+
) -> dict[str, Any]:
|
|
49
|
+
lower_type = (training_type or "").strip().lower()
|
|
50
|
+
require_base = (
|
|
51
|
+
lower_type.startswith("sft")
|
|
52
|
+
or lower_type.startswith("fft")
|
|
53
|
+
or lower_type.startswith("qft")
|
|
54
|
+
)
|
|
55
|
+
try:
|
|
56
|
+
normalized_model = normalize_model_identifier(
|
|
57
|
+
model, allow_finetuned_prefixes=not require_base
|
|
58
|
+
)
|
|
59
|
+
except UnsupportedModelError as exc:
|
|
60
|
+
raise ValueError(str(exc)) from exc
|
|
61
|
+
|
|
62
|
+
if lower_type.startswith("sft") or lower_type in {"fft", "qft"}:
|
|
63
|
+
body = prepare_sft_job_payload(
|
|
64
|
+
model=model,
|
|
65
|
+
training_file=training_file_id,
|
|
66
|
+
hyperparameters=hyperparameters,
|
|
67
|
+
metadata=metadata,
|
|
68
|
+
training_type=training_type or "sft_offline",
|
|
69
|
+
validation_file=validation_file,
|
|
70
|
+
training_file_field="training_file_id",
|
|
71
|
+
require_training_file=True,
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
body = {
|
|
75
|
+
"training_type": training_type,
|
|
76
|
+
"model": normalized_model,
|
|
77
|
+
"training_file_id": training_file_id,
|
|
78
|
+
"hyperparameters": hyperparameters or {},
|
|
79
|
+
"metadata": metadata or {},
|
|
80
|
+
}
|
|
47
81
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
48
82
|
return await http.post_json("/api/learning/jobs", json=body)
|
|
49
83
|
|
|
50
|
-
async def start_job(self, job_id: str) ->
|
|
84
|
+
async def start_job(self, job_id: str) -> dict[str, Any]:
|
|
51
85
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
52
86
|
return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
|
|
53
87
|
|
|
54
|
-
async def get_job(self, job_id: str) ->
|
|
88
|
+
async def get_job(self, job_id: str) -> dict[str, Any]:
|
|
55
89
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
56
90
|
return await http.get(f"/api/learning/jobs/{job_id}")
|
|
57
91
|
|
|
58
92
|
async def get_events(
|
|
59
93
|
self, job_id: str, *, since_seq: int = 0, limit: int = 200
|
|
60
|
-
) ->
|
|
94
|
+
) -> list[dict[str, Any]]:
|
|
61
95
|
params = {"since_seq": since_seq, "limit": limit}
|
|
62
96
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
63
97
|
js = await http.get(f"/api/learning/jobs/{job_id}/events", params=params)
|
|
@@ -73,8 +107,8 @@ class LearningClient:
|
|
|
73
107
|
after_step: int | None = None,
|
|
74
108
|
limit: int = 500,
|
|
75
109
|
run_id: str | None = None,
|
|
76
|
-
) ->
|
|
77
|
-
params:
|
|
110
|
+
) -> list[dict[str, Any]]:
|
|
111
|
+
params: dict[str, Any] = {"limit": limit}
|
|
78
112
|
if name is not None:
|
|
79
113
|
params["name"] = name
|
|
80
114
|
if after_step is not None:
|
|
@@ -87,7 +121,7 @@ class LearningClient:
|
|
|
87
121
|
return js["points"]
|
|
88
122
|
return []
|
|
89
123
|
|
|
90
|
-
async def get_timeline(self, job_id: str, *, limit: int = 200) ->
|
|
124
|
+
async def get_timeline(self, job_id: str, *, limit: int = 200) -> list[dict[str, Any]]:
|
|
91
125
|
params = {"limit": limit}
|
|
92
126
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
93
127
|
js = await http.get(f"/api/learning/jobs/{job_id}/timeline", params=params)
|
|
@@ -101,8 +135,8 @@ class LearningClient:
|
|
|
101
135
|
*,
|
|
102
136
|
interval_seconds: float = 2.0,
|
|
103
137
|
max_seconds: float | None = 3600,
|
|
104
|
-
on_event: Callable[[
|
|
105
|
-
) ->
|
|
138
|
+
on_event: Callable[[dict[str, Any]], None] | None = None,
|
|
139
|
+
) -> dict[str, Any]:
|
|
106
140
|
last_seq = 0
|
|
107
141
|
elapsed = 0.0
|
|
108
142
|
while True:
|
|
@@ -112,10 +146,8 @@ class LearningClient:
|
|
|
112
146
|
if isinstance(e, dict) and isinstance(e.get("seq"), int):
|
|
113
147
|
last_seq = max(last_seq, int(e["seq"]))
|
|
114
148
|
if on_event:
|
|
115
|
-
|
|
149
|
+
with suppress(Exception):
|
|
116
150
|
on_event(e)
|
|
117
|
-
except Exception:
|
|
118
|
-
pass
|
|
119
151
|
|
|
120
152
|
# Status
|
|
121
153
|
job = await self.get_job(job_id)
|
|
@@ -132,7 +164,7 @@ class LearningClient:
|
|
|
132
164
|
# --- Optional diagnostics ---
|
|
133
165
|
async def pricing_preflight(
|
|
134
166
|
self, *, job_type: str, gpu_type: str, estimated_seconds: float, container_count: int
|
|
135
|
-
) ->
|
|
167
|
+
) -> dict[str, Any]:
|
|
136
168
|
body = {
|
|
137
169
|
"job_type": job_type,
|
|
138
170
|
"gpu_type": gpu_type,
|
|
@@ -150,7 +182,7 @@ class LearningClient:
|
|
|
150
182
|
)
|
|
151
183
|
return js
|
|
152
184
|
|
|
153
|
-
async def balance_autumn_normalized(self) ->
|
|
185
|
+
async def balance_autumn_normalized(self) -> dict[str, Any]:
|
|
154
186
|
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
155
187
|
js = await http.get("/api/v1/balance/autumn-normalized")
|
|
156
188
|
if not isinstance(js, dict):
|
|
@@ -163,6 +195,41 @@ class LearningClient:
|
|
|
163
195
|
return js
|
|
164
196
|
|
|
165
197
|
|
|
198
|
+
class FineTunedModelInfo(TypedDict, total=False):
|
|
199
|
+
id: str
|
|
200
|
+
base_model: str | None
|
|
201
|
+
created_at: int | None
|
|
202
|
+
job_id: str | None
|
|
203
|
+
status: str | None
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LearningClient(LearningClient): # type: ignore[misc]
|
|
207
|
+
async def list_fine_tuned_models(self) -> list[FineTunedModelInfo]:
|
|
208
|
+
"""Return completed fine‑tuned models for the caller's organization.
|
|
209
|
+
|
|
210
|
+
Calls backend route `/api/learning/models` and returns a compact list.
|
|
211
|
+
"""
|
|
212
|
+
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
213
|
+
js = await http.get("/api/learning/models")
|
|
214
|
+
if isinstance(js, dict) and isinstance(js.get("data"), list):
|
|
215
|
+
out: list[FineTunedModelInfo] = []
|
|
216
|
+
for item in js["data"]:
|
|
217
|
+
if not isinstance(item, dict):
|
|
218
|
+
continue
|
|
219
|
+
rec: FineTunedModelInfo = {
|
|
220
|
+
"id": str(item.get("id")),
|
|
221
|
+
"base_model": item.get("base_model"),
|
|
222
|
+
"created_at": item.get("created_at"),
|
|
223
|
+
"job_id": item.get("job_id"),
|
|
224
|
+
"status": item.get("status"),
|
|
225
|
+
}
|
|
226
|
+
if rec.get("id"):
|
|
227
|
+
out.append(rec)
|
|
228
|
+
return out
|
|
229
|
+
# Fallback: empty list on unexpected shape
|
|
230
|
+
return []
|
|
231
|
+
|
|
232
|
+
|
|
166
233
|
def _infer_content_type(filename: str) -> str:
|
|
167
234
|
name = filename.lower()
|
|
168
235
|
if name.endswith(".jsonl"):
|
synth_ai/learning/config.py
CHANGED
|
@@ -1,41 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from typing import Any, Dict, Optional
|
|
3
|
+
from .rl.config import RLJobConfig
|
|
5
4
|
|
|
6
|
-
|
|
7
|
-
@dataclass
|
|
8
|
-
class FTJobConfig:
|
|
9
|
-
model: str
|
|
10
|
-
training_file_id: str
|
|
11
|
-
n_epochs: int = 1
|
|
12
|
-
batch_size: int = 1
|
|
13
|
-
upload_to_wasabi: bool = True
|
|
14
|
-
|
|
15
|
-
def hyperparameters(self) -> Dict[str, Any]:
|
|
16
|
-
if self.n_epochs < 1:
|
|
17
|
-
raise ValueError("n_epochs must be >= 1")
|
|
18
|
-
if self.batch_size < 1:
|
|
19
|
-
raise ValueError("batch_size must be >= 1")
|
|
20
|
-
return {"n_epochs": int(self.n_epochs), "batch_size": int(self.batch_size)}
|
|
21
|
-
|
|
22
|
-
def metadata(self) -> Dict[str, Any]: # type: ignore[override]
|
|
23
|
-
return {"upload_to_wasabi": bool(self.upload_to_wasabi)}
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@dataclass
|
|
27
|
-
class RLJobConfig:
|
|
28
|
-
model: str
|
|
29
|
-
task_app_url: str
|
|
30
|
-
trainer_id: str
|
|
31
|
-
batch_size: int = 1
|
|
32
|
-
group_size: int = 2
|
|
33
|
-
job_config_id: Optional[str] = None
|
|
34
|
-
inline_config: Optional[Dict[str, Any]] = None
|
|
35
|
-
|
|
36
|
-
def trainer_dict(self) -> Dict[str, Any]:
|
|
37
|
-
if self.batch_size < 1:
|
|
38
|
-
raise ValueError("batch_size must be >= 1")
|
|
39
|
-
if self.group_size < 2:
|
|
40
|
-
raise ValueError("group_size must be >= 2")
|
|
41
|
-
return {"batch_size": int(self.batch_size), "group_size": int(self.group_size)}
|
|
5
|
+
__all__ = ["RLJobConfig"]
|
synth_ai/learning/ft_client.py
CHANGED
|
@@ -1,62 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from typing import Any, Dict, Optional
|
|
5
|
-
|
|
6
|
-
from ..http import AsyncHttpClient, HTTPError
|
|
7
|
-
|
|
1
|
+
"""Backward-compatible shim for FtClient (moved to synth_ai.learning.sft.client)."""
|
|
8
2
|
|
|
9
|
-
|
|
10
|
-
def __init__(self, base_url: str, api_key: str, *, timeout: float = 30.0) -> None:
|
|
11
|
-
self._base_url = base_url.rstrip("/")
|
|
12
|
-
self._api_key = api_key
|
|
13
|
-
self._timeout = timeout
|
|
14
|
-
|
|
15
|
-
async def upload_training_file(self, path: str | Path, *, purpose: str = "fine-tune") -> str:
|
|
16
|
-
p = Path(path)
|
|
17
|
-
content = p.read_bytes()
|
|
18
|
-
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
19
|
-
data = {"purpose": purpose}
|
|
20
|
-
files = {"file": (p.name, content, _infer_content_type(p.name))}
|
|
21
|
-
js = await http.post_multipart("/api/learning/files", data=data, files=files)
|
|
22
|
-
if not isinstance(js, dict) or "id" not in js:
|
|
23
|
-
raise HTTPError(
|
|
24
|
-
status=500,
|
|
25
|
-
url="/api/learning/files",
|
|
26
|
-
message="invalid_upload_response",
|
|
27
|
-
body_snippet=str(js)[:200],
|
|
28
|
-
)
|
|
29
|
-
return str(js["id"])
|
|
30
|
-
|
|
31
|
-
async def create_sft_job(
|
|
32
|
-
self,
|
|
33
|
-
*,
|
|
34
|
-
model: str,
|
|
35
|
-
training_file_id: str,
|
|
36
|
-
hyperparameters: Dict[str, Any],
|
|
37
|
-
metadata: Optional[Dict[str, Any]] = None,
|
|
38
|
-
) -> Dict[str, Any]:
|
|
39
|
-
body = {
|
|
40
|
-
"training_type": "sft_offline",
|
|
41
|
-
"model": model,
|
|
42
|
-
"training_file_id": training_file_id,
|
|
43
|
-
"hyperparameters": dict(hyperparameters or {}),
|
|
44
|
-
"metadata": dict(metadata or {}),
|
|
45
|
-
}
|
|
46
|
-
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
47
|
-
return await http.post_json("/api/learning/jobs", json=body)
|
|
48
|
-
|
|
49
|
-
async def start_job(self, job_id: str) -> Dict[str, Any]:
|
|
50
|
-
async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
|
|
51
|
-
return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
|
|
3
|
+
from __future__ import annotations
|
|
52
4
|
|
|
5
|
+
from .sft.client import FtClient
|
|
53
6
|
|
|
54
|
-
|
|
55
|
-
name = filename.lower()
|
|
56
|
-
if name.endswith(".jsonl"):
|
|
57
|
-
return "application/jsonl"
|
|
58
|
-
if name.endswith(".json"):
|
|
59
|
-
return "application/json"
|
|
60
|
-
if name.endswith(".txt"):
|
|
61
|
-
return "text/plain"
|
|
62
|
-
return "application/octet-stream"
|
|
7
|
+
__all__ = ["FtClient"]
|