synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +8 -11
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/run_eval.py +36 -37
- examples/rl/run_rl_and_save.py +5 -5
- examples/rl/task_app/math_single_step.py +65 -43
- examples/rl/task_app/math_task_app.py +3 -3
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +5 -5
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +78 -21
- examples/warming_up_to_rl/groq_test.py +4 -4
- examples/warming_up_to_rl/manage_secrets.py +13 -18
- examples/warming_up_to_rl/run_eval.py +42 -44
- examples/warming_up_to_rl/run_fft_and_save.py +11 -16
- examples/warming_up_to_rl/run_local_rollout.py +1 -3
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
- examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
- examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
- examples/warming_up_to_rl/run_rl_and_save.py +5 -6
- examples/warming_up_to_rl/run_rollout_remote.py +8 -10
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +26 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +128 -21
- synth_ai/api/train/cli.py +80 -64
- synth_ai/api/train/config_finder.py +7 -2
- synth_ai/api/train/env_resolver.py +1 -1
- synth_ai/api/train/pollers.py +2 -1
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +1 -2
- synth_ai/api/train/utils.py +13 -44
- synth_ai/cli/__init__.py +8 -0
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +1 -2
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +2 -1
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +2 -1
- synth_ai/cli/root.py +11 -13
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +529 -179
- synth_ai/cli/traces.py +6 -4
- synth_ai/cli/watch.py +12 -18
- synth_ai/demo_registry.py +1 -1
- synth_ai/demos/core/cli.py +36 -43
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +17 -25
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +2 -5
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +4 -7
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/handshake.py +9 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +18 -10
- synth_ai/inference/client.py +15 -5
- synth_ai/jobs/client.py +78 -83
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +91 -24
- synth_ai/learning/config.py +2 -38
- synth_ai/learning/ft_client.py +4 -59
- synth_ai/learning/health.py +5 -6
- synth_ai/learning/jobs.py +31 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -4
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -8
- synth_ai/{rl → learning/rl}/env_keys.py +39 -15
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -281
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -24
- synth_ai/learning/validators.py +25 -28
- synth_ai/lm/__init__.py +21 -47
- synth_ai/main.py +4 -0
- synth_ai/task/__init__.py +25 -27
- synth_ai/task/apps/__init__.py +7 -8
- synth_ai/task/auth.py +8 -8
- synth_ai/task/client.py +14 -14
- synth_ai/task/contracts.py +36 -35
- synth_ai/task/datasets.py +6 -5
- synth_ai/task/errors.py +10 -10
- synth_ai/task/health.py +17 -9
- synth_ai/task/json.py +58 -23
- synth_ai/task/proxy.py +13 -9
- synth_ai/task/rubrics.py +16 -15
- synth_ai/task/server.py +12 -12
- synth_ai/task/tracing_utils.py +4 -4
- synth_ai/task/vendors.py +5 -6
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/decorators.py +18 -16
- synth_ai/tracing_v3/hooks.py +5 -5
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/session_tracer.py +40 -14
- synth_ai/tracing_v3/storage/base.py +85 -0
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -7
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -4
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/main.py +6 -6
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
- synth_ai/{lm → v0/lm}/overrides.py +2 -2
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev8.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/RECORD +268 -238
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -20
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1038
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -152
- examples/rl_old/task_app.py +0 -1131
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/experimental/synth_oss.py +0 -445
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -249
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -838
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""Utilities for validating and constructing SFT job payloads."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from dataclasses import dataclass, field, fields
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from synth_ai.api.models.supported import (
|
|
10
|
+
UnsupportedModelError,
|
|
11
|
+
normalize_model_identifier,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
_STEP_KEYS = ("n_epochs", "total_steps", "train_steps", "steps")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _ensure_positive_int(value: Any, *, key: str) -> int:
|
|
18
|
+
if isinstance(value, bool):
|
|
19
|
+
raise ValueError(f"hyperparameters.{key} must be an integer greater than zero")
|
|
20
|
+
try:
|
|
21
|
+
ivalue = int(value)
|
|
22
|
+
except (TypeError, ValueError) as exc: # pragma: no cover - defensive
|
|
23
|
+
raise ValueError(f"hyperparameters.{key} must be an integer greater than zero") from exc
|
|
24
|
+
if ivalue <= 0:
|
|
25
|
+
raise ValueError(f"hyperparameters.{key} must be an integer greater than zero")
|
|
26
|
+
return ivalue
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _ensure_non_negative_float(value: Any, *, key: str) -> float:
|
|
30
|
+
if isinstance(value, bool):
|
|
31
|
+
raise ValueError(f"hyperparameters.{key} must be a float greater than or equal to zero")
|
|
32
|
+
try:
|
|
33
|
+
fvalue = float(value)
|
|
34
|
+
except (TypeError, ValueError) as exc: # pragma: no cover - defensive
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"hyperparameters.{key} must be a float greater than or equal to zero"
|
|
37
|
+
) from exc
|
|
38
|
+
if fvalue < 0:
|
|
39
|
+
raise ValueError(f"hyperparameters.{key} must be a float greater than or equal to zero")
|
|
40
|
+
return fvalue
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _ensure_positive_float(value: Any, *, key: str) -> float:
|
|
44
|
+
fvalue = _ensure_non_negative_float(value, key=key)
|
|
45
|
+
if fvalue == 0.0:
|
|
46
|
+
raise ValueError(f"hyperparameters.{key} must be greater than zero")
|
|
47
|
+
return fvalue
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(slots=True)
|
|
51
|
+
class SFTTrainingHyperparameters:
|
|
52
|
+
"""Typed representation of SFT training hyperparameters."""
|
|
53
|
+
|
|
54
|
+
n_epochs: int | None = None
|
|
55
|
+
total_steps: int | None = None
|
|
56
|
+
train_steps: int | None = None
|
|
57
|
+
steps: int | None = None
|
|
58
|
+
batch_size: int | None = None
|
|
59
|
+
global_batch: int | None = None
|
|
60
|
+
per_device_batch: int | None = None
|
|
61
|
+
gradient_accumulation_steps: int | None = None
|
|
62
|
+
sequence_length: int | None = None
|
|
63
|
+
learning_rate: float | None = None
|
|
64
|
+
warmup_ratio: float | None = None
|
|
65
|
+
train_kind: str | None = None
|
|
66
|
+
extras: dict[str, Any] = field(default_factory=dict)
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def from_mapping(cls, data: Mapping[str, Any] | None) -> SFTTrainingHyperparameters:
|
|
70
|
+
if data is None:
|
|
71
|
+
raise ValueError("hyperparameters must not be empty")
|
|
72
|
+
normalized: dict[str, Any] = dict(data)
|
|
73
|
+
if not normalized:
|
|
74
|
+
raise ValueError("hyperparameters must not be empty")
|
|
75
|
+
|
|
76
|
+
kwargs: dict[str, Any] = {}
|
|
77
|
+
|
|
78
|
+
def pop_int(name: str) -> int | None:
|
|
79
|
+
if name not in normalized:
|
|
80
|
+
return None
|
|
81
|
+
value = _ensure_positive_int(normalized.pop(name), key=name)
|
|
82
|
+
return value
|
|
83
|
+
|
|
84
|
+
def pop_optional_int(name: str) -> int | None:
|
|
85
|
+
if name not in normalized:
|
|
86
|
+
return None
|
|
87
|
+
value = _ensure_positive_int(normalized.pop(name), key=name)
|
|
88
|
+
return value
|
|
89
|
+
|
|
90
|
+
def pop_positive_float(name: str) -> float | None:
|
|
91
|
+
if name not in normalized:
|
|
92
|
+
return None
|
|
93
|
+
return _ensure_positive_float(normalized.pop(name), key=name)
|
|
94
|
+
|
|
95
|
+
def pop_non_negative_float(name: str) -> float | None:
|
|
96
|
+
if name not in normalized:
|
|
97
|
+
return None
|
|
98
|
+
value = _ensure_non_negative_float(normalized.pop(name), key=name)
|
|
99
|
+
return value
|
|
100
|
+
|
|
101
|
+
# Step-derived keys
|
|
102
|
+
step_values = {
|
|
103
|
+
"n_epochs": pop_int("n_epochs"),
|
|
104
|
+
"total_steps": pop_int("total_steps"),
|
|
105
|
+
"train_steps": pop_int("train_steps"),
|
|
106
|
+
"steps": pop_int("steps"),
|
|
107
|
+
}
|
|
108
|
+
if not any(step_values.values()):
|
|
109
|
+
keys = ", ".join(_STEP_KEYS)
|
|
110
|
+
raise ValueError(f"hyperparameters must include at least one of: {keys}")
|
|
111
|
+
kwargs.update(step_values)
|
|
112
|
+
|
|
113
|
+
kwargs["batch_size"] = pop_optional_int("batch_size")
|
|
114
|
+
kwargs["global_batch"] = pop_optional_int("global_batch")
|
|
115
|
+
kwargs["per_device_batch"] = pop_optional_int("per_device_batch")
|
|
116
|
+
kwargs["gradient_accumulation_steps"] = pop_optional_int("gradient_accumulation_steps")
|
|
117
|
+
kwargs["sequence_length"] = pop_optional_int("sequence_length")
|
|
118
|
+
kwargs["learning_rate"] = pop_positive_float("learning_rate")
|
|
119
|
+
kwargs["warmup_ratio"] = pop_non_negative_float("warmup_ratio")
|
|
120
|
+
|
|
121
|
+
if "warmup_ratio" in kwargs and kwargs["warmup_ratio"] is not None:
|
|
122
|
+
ratio = kwargs["warmup_ratio"]
|
|
123
|
+
if ratio > 1:
|
|
124
|
+
raise ValueError("hyperparameters.warmup_ratio must be between 0 and 1 inclusive")
|
|
125
|
+
|
|
126
|
+
if "train_kind" in normalized:
|
|
127
|
+
value = normalized.pop("train_kind")
|
|
128
|
+
if not isinstance(value, str):
|
|
129
|
+
raise ValueError("hyperparameters.train_kind must be a string")
|
|
130
|
+
kwargs["train_kind"] = value
|
|
131
|
+
|
|
132
|
+
extras = normalized
|
|
133
|
+
|
|
134
|
+
return cls(extras=extras, **kwargs)
|
|
135
|
+
|
|
136
|
+
def to_dict(self) -> dict[str, Any]:
|
|
137
|
+
result: dict[str, Any] = {}
|
|
138
|
+
for field_info in fields(self):
|
|
139
|
+
if field_info.name == "extras":
|
|
140
|
+
continue
|
|
141
|
+
value = getattr(self, field_info.name)
|
|
142
|
+
if value is not None:
|
|
143
|
+
result[field_info.name] = value
|
|
144
|
+
result.update(self.extras)
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _coerce_mapping(value: Mapping[str, Any] | None, *, name: str) -> dict[str, Any]:
|
|
149
|
+
if value is None:
|
|
150
|
+
return {}
|
|
151
|
+
if not isinstance(value, Mapping):
|
|
152
|
+
raise ValueError(f"{name} must be a mapping")
|
|
153
|
+
return dict(value)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@dataclass(slots=True)
|
|
157
|
+
class SFTJobConfig:
|
|
158
|
+
"""Structured representation of an SFT training job request."""
|
|
159
|
+
|
|
160
|
+
model: str
|
|
161
|
+
hyperparameters: Mapping[str, Any] | SFTTrainingHyperparameters
|
|
162
|
+
training_file: str | None = None
|
|
163
|
+
metadata: Mapping[str, Any] | None = None
|
|
164
|
+
training_type: str | None = "sft_offline"
|
|
165
|
+
validation_file: str | None = None
|
|
166
|
+
suffix: str | None = None
|
|
167
|
+
integrations: Mapping[str, Any] | None = None
|
|
168
|
+
|
|
169
|
+
def to_payload(
|
|
170
|
+
self,
|
|
171
|
+
*,
|
|
172
|
+
training_file_field: str = "training_file_id",
|
|
173
|
+
require_training_file: bool = True,
|
|
174
|
+
include_training_file_when_none: bool = False,
|
|
175
|
+
allow_finetuned_prefixes: bool = False,
|
|
176
|
+
) -> dict[str, Any]:
|
|
177
|
+
model = normalize_model_identifier(
|
|
178
|
+
self.model, allow_finetuned_prefixes=allow_finetuned_prefixes
|
|
179
|
+
)
|
|
180
|
+
if isinstance(self.hyperparameters, SFTTrainingHyperparameters):
|
|
181
|
+
hyper_config = self.hyperparameters
|
|
182
|
+
else:
|
|
183
|
+
hyper_config = SFTTrainingHyperparameters.from_mapping(
|
|
184
|
+
_coerce_mapping(self.hyperparameters, name="hyperparameters")
|
|
185
|
+
)
|
|
186
|
+
hyperparameters = hyper_config.to_dict()
|
|
187
|
+
|
|
188
|
+
payload: dict[str, Any] = {
|
|
189
|
+
"model": model,
|
|
190
|
+
"hyperparameters": hyperparameters,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
training_type = (self.training_type or "").strip() if self.training_type else ""
|
|
194
|
+
if training_type:
|
|
195
|
+
payload["training_type"] = training_type
|
|
196
|
+
|
|
197
|
+
metadata = _coerce_mapping(self.metadata, name="metadata")
|
|
198
|
+
if metadata:
|
|
199
|
+
payload["metadata"] = metadata
|
|
200
|
+
|
|
201
|
+
integrations = _coerce_mapping(self.integrations, name="integrations")
|
|
202
|
+
if integrations:
|
|
203
|
+
payload["integrations"] = integrations
|
|
204
|
+
|
|
205
|
+
suffix = (self.suffix or "").strip()
|
|
206
|
+
if suffix:
|
|
207
|
+
payload["suffix"] = suffix
|
|
208
|
+
|
|
209
|
+
validation_file = (self.validation_file or "").strip()
|
|
210
|
+
if validation_file:
|
|
211
|
+
payload["validation_file"] = validation_file
|
|
212
|
+
|
|
213
|
+
if training_file_field:
|
|
214
|
+
training_file = (self.training_file or "").strip() if self.training_file else ""
|
|
215
|
+
if training_file:
|
|
216
|
+
payload[training_file_field] = training_file
|
|
217
|
+
elif require_training_file:
|
|
218
|
+
raise ValueError("training file identifier is required for SFT jobs")
|
|
219
|
+
elif include_training_file_when_none:
|
|
220
|
+
payload[training_file_field] = None
|
|
221
|
+
|
|
222
|
+
return payload
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def prepare_sft_job_payload(
|
|
226
|
+
*,
|
|
227
|
+
model: str,
|
|
228
|
+
hyperparameters: Mapping[str, Any] | SFTTrainingHyperparameters | None,
|
|
229
|
+
training_file: str | None = None,
|
|
230
|
+
metadata: Mapping[str, Any] | None = None,
|
|
231
|
+
training_type: str | None = "sft_offline",
|
|
232
|
+
validation_file: str | None = None,
|
|
233
|
+
suffix: str | None = None,
|
|
234
|
+
integrations: Mapping[str, Any] | None = None,
|
|
235
|
+
training_file_field: str = "training_file_id",
|
|
236
|
+
require_training_file: bool = True,
|
|
237
|
+
include_training_file_when_none: bool = False,
|
|
238
|
+
allow_finetuned_prefixes: bool = False,
|
|
239
|
+
) -> dict[str, Any]:
|
|
240
|
+
"""Validate inputs and return an SFT job payload suitable for API calls."""
|
|
241
|
+
|
|
242
|
+
if isinstance(hyperparameters, SFTTrainingHyperparameters):
|
|
243
|
+
hyper_config = hyperparameters
|
|
244
|
+
else:
|
|
245
|
+
hyper_config = SFTTrainingHyperparameters.from_mapping(hyperparameters or {})
|
|
246
|
+
|
|
247
|
+
config = SFTJobConfig(
|
|
248
|
+
model=model,
|
|
249
|
+
training_file=training_file,
|
|
250
|
+
hyperparameters=hyper_config,
|
|
251
|
+
metadata=metadata,
|
|
252
|
+
training_type=training_type,
|
|
253
|
+
validation_file=validation_file,
|
|
254
|
+
suffix=suffix,
|
|
255
|
+
integrations=integrations,
|
|
256
|
+
)
|
|
257
|
+
return config.to_payload(
|
|
258
|
+
training_file_field=training_file_field,
|
|
259
|
+
require_training_file=require_training_file,
|
|
260
|
+
include_training_file_when_none=include_training_file_when_none,
|
|
261
|
+
allow_finetuned_prefixes=allow_finetuned_prefixes,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
__all__ = [
|
|
266
|
+
"SFTTrainingHyperparameters",
|
|
267
|
+
"SFTJobConfig",
|
|
268
|
+
"prepare_sft_job_payload",
|
|
269
|
+
"UnsupportedModelError",
|
|
270
|
+
]
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
SFTMessageContent = str | dict[str, Any] | list[Any] | None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SFTDataError(ValueError):
|
|
13
|
+
"""Raised when a JSONL record cannot be coerced into an SFTExample."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(slots=True)
|
|
17
|
+
class SFTToolDefinition:
|
|
18
|
+
name: str
|
|
19
|
+
description: str | None
|
|
20
|
+
parameters: dict[str, Any] | None
|
|
21
|
+
raw: dict[str, Any] = field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(slots=True)
|
|
25
|
+
class SFTToolCall:
|
|
26
|
+
name: str
|
|
27
|
+
arguments: Any
|
|
28
|
+
call_id: str | None = None
|
|
29
|
+
type: str | None = None
|
|
30
|
+
raw: dict[str, Any] = field(default_factory=dict)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(slots=True)
|
|
34
|
+
class SFTMessage:
|
|
35
|
+
role: str
|
|
36
|
+
content: SFTMessageContent
|
|
37
|
+
tool_calls: list[SFTToolCall] = field(default_factory=list)
|
|
38
|
+
tool_call_id: str | None = None
|
|
39
|
+
name: str | None = None
|
|
40
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(slots=True)
|
|
44
|
+
class SFTExample:
|
|
45
|
+
messages: list[SFTMessage]
|
|
46
|
+
tools: list[SFTToolDefinition] = field(default_factory=list)
|
|
47
|
+
tool_choice: Any | None = None
|
|
48
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
49
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _parse_tool_arguments(value: Any) -> Any:
|
|
53
|
+
if isinstance(value, str):
|
|
54
|
+
try:
|
|
55
|
+
return json.loads(value)
|
|
56
|
+
except json.JSONDecodeError:
|
|
57
|
+
return value
|
|
58
|
+
return value
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _coerce_tool_definition(raw: Any, *, index: int) -> SFTToolDefinition:
|
|
62
|
+
if not isinstance(raw, dict):
|
|
63
|
+
raise SFTDataError(f"tool {index} is not an object")
|
|
64
|
+
name = raw.get("name")
|
|
65
|
+
if not isinstance(name, str) or not name.strip():
|
|
66
|
+
raise SFTDataError(f"tool {index} missing name")
|
|
67
|
+
description = raw.get("description")
|
|
68
|
+
if description is not None and not isinstance(description, str):
|
|
69
|
+
raise SFTDataError(f"tool {index} description must be a string if present")
|
|
70
|
+
parameters = raw.get("parameters")
|
|
71
|
+
if parameters is not None and not isinstance(parameters, dict):
|
|
72
|
+
raise SFTDataError(f"tool {index} parameters must be an object if present")
|
|
73
|
+
return SFTToolDefinition(
|
|
74
|
+
name=name, description=description, parameters=parameters, raw=dict(raw)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _coerce_tool_call(raw: Any, *, index: int) -> SFTToolCall:
|
|
79
|
+
if not isinstance(raw, dict):
|
|
80
|
+
raise SFTDataError(f"tool_call {index} is not an object")
|
|
81
|
+
|
|
82
|
+
call_id = raw.get("id")
|
|
83
|
+
call_type = raw.get("type")
|
|
84
|
+
|
|
85
|
+
fn_payload: dict[str, Any] | None = None
|
|
86
|
+
name: str | None = None
|
|
87
|
+
arguments: Any = None
|
|
88
|
+
|
|
89
|
+
if isinstance(raw.get("function"), dict):
|
|
90
|
+
fn_payload = raw["function"]
|
|
91
|
+
name = fn_payload.get("name") if isinstance(fn_payload.get("name"), str) else None
|
|
92
|
+
arguments = fn_payload.get("arguments")
|
|
93
|
+
if name is None:
|
|
94
|
+
maybe_name = raw.get("name")
|
|
95
|
+
if isinstance(maybe_name, str):
|
|
96
|
+
name = maybe_name
|
|
97
|
+
arguments = raw.get("arguments")
|
|
98
|
+
|
|
99
|
+
if not isinstance(name, str) or not name.strip():
|
|
100
|
+
raise SFTDataError(f"tool_call {index} missing function name")
|
|
101
|
+
|
|
102
|
+
parsed_arguments = _parse_tool_arguments(arguments)
|
|
103
|
+
|
|
104
|
+
normalized_id = None
|
|
105
|
+
if call_id is not None:
|
|
106
|
+
normalized_id = str(call_id)
|
|
107
|
+
normalized_type = None
|
|
108
|
+
if call_type is not None:
|
|
109
|
+
normalized_type = str(call_type)
|
|
110
|
+
|
|
111
|
+
return SFTToolCall(
|
|
112
|
+
name=name,
|
|
113
|
+
arguments=parsed_arguments,
|
|
114
|
+
call_id=normalized_id,
|
|
115
|
+
type=normalized_type,
|
|
116
|
+
raw=dict(raw),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _coerce_message(raw: Any, *, index: int) -> SFTMessage:
|
|
121
|
+
if not isinstance(raw, dict):
|
|
122
|
+
raise SFTDataError(f"message {index} is not an object")
|
|
123
|
+
role = raw.get("role")
|
|
124
|
+
if not isinstance(role, str) or not role.strip():
|
|
125
|
+
raise SFTDataError(f"message {index} has invalid role")
|
|
126
|
+
|
|
127
|
+
content = raw.get("content")
|
|
128
|
+
if content is not None and not isinstance(content, str | list | dict):
|
|
129
|
+
raise SFTDataError(f"message {index} has unsupported content type {type(content).__name__}")
|
|
130
|
+
|
|
131
|
+
raw_tool_calls = raw.get("tool_calls")
|
|
132
|
+
tool_calls: list[SFTToolCall] = []
|
|
133
|
+
if raw_tool_calls is not None:
|
|
134
|
+
if not isinstance(raw_tool_calls, list | tuple):
|
|
135
|
+
raise SFTDataError(f"message {index} tool_calls must be a list")
|
|
136
|
+
for call_index, call in enumerate(raw_tool_calls):
|
|
137
|
+
tool_calls.append(_coerce_tool_call(call, index=call_index))
|
|
138
|
+
|
|
139
|
+
tool_call_id = raw.get("tool_call_id")
|
|
140
|
+
if tool_call_id is not None and not isinstance(tool_call_id, str):
|
|
141
|
+
tool_call_id = str(tool_call_id)
|
|
142
|
+
|
|
143
|
+
name = raw.get("name")
|
|
144
|
+
if name is not None and not isinstance(name, str):
|
|
145
|
+
raise SFTDataError(f"message {index} name must be a string if present")
|
|
146
|
+
|
|
147
|
+
extra = {
|
|
148
|
+
key: value
|
|
149
|
+
for key, value in raw.items()
|
|
150
|
+
if key not in {"role", "content", "tool_calls", "tool_call_id", "name"}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
return SFTMessage(
|
|
154
|
+
role=role,
|
|
155
|
+
content=content,
|
|
156
|
+
tool_calls=tool_calls,
|
|
157
|
+
tool_call_id=tool_call_id,
|
|
158
|
+
name=name,
|
|
159
|
+
extra=extra,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def coerce_example(raw: Any, *, min_messages: int = 1) -> SFTExample:
|
|
164
|
+
if not isinstance(raw, dict):
|
|
165
|
+
raise SFTDataError("record is not an object")
|
|
166
|
+
|
|
167
|
+
messages_raw = raw.get("messages")
|
|
168
|
+
if not isinstance(messages_raw, Sequence):
|
|
169
|
+
raise SFTDataError("missing messages[] list")
|
|
170
|
+
if len(messages_raw) < min_messages:
|
|
171
|
+
raise SFTDataError(f"missing messages[] with at least {min_messages} turns")
|
|
172
|
+
|
|
173
|
+
messages = [_coerce_message(msg, index=i) for i, msg in enumerate(messages_raw)]
|
|
174
|
+
|
|
175
|
+
tools: list[SFTToolDefinition] = []
|
|
176
|
+
if "tools" in raw and raw["tools"] is not None:
|
|
177
|
+
tools_raw = raw["tools"]
|
|
178
|
+
if not isinstance(tools_raw, Sequence):
|
|
179
|
+
raise SFTDataError("tools must be provided as a list when present")
|
|
180
|
+
for tool_index, tool in enumerate(tools_raw):
|
|
181
|
+
tools.append(_coerce_tool_definition(tool, index=tool_index))
|
|
182
|
+
|
|
183
|
+
tool_choice = raw.get("tool_choice")
|
|
184
|
+
|
|
185
|
+
metadata_field = raw.get("metadata")
|
|
186
|
+
metadata: dict[str, Any] = {}
|
|
187
|
+
if metadata_field is not None:
|
|
188
|
+
if not isinstance(metadata_field, dict):
|
|
189
|
+
raise SFTDataError("metadata must be an object if present")
|
|
190
|
+
metadata = dict(metadata_field)
|
|
191
|
+
|
|
192
|
+
extra = {
|
|
193
|
+
key: value
|
|
194
|
+
for key, value in raw.items()
|
|
195
|
+
if key not in {"messages", "tools", "tool_choice", "metadata"}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
return SFTExample(
|
|
199
|
+
messages=messages,
|
|
200
|
+
tools=tools,
|
|
201
|
+
tool_choice=tool_choice,
|
|
202
|
+
metadata=metadata,
|
|
203
|
+
extra=extra,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def parse_jsonl_line(line: str, *, min_messages: int = 1) -> SFTExample:
|
|
208
|
+
record = json.loads(line)
|
|
209
|
+
return coerce_example(record, min_messages=min_messages)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def iter_sft_examples(
|
|
213
|
+
source: Iterable[str], *, min_messages: int = 1, skip_empty: bool = True
|
|
214
|
+
) -> Iterator[SFTExample]:
|
|
215
|
+
for line in source:
|
|
216
|
+
if skip_empty and not line.strip():
|
|
217
|
+
continue
|
|
218
|
+
yield parse_jsonl_line(line, min_messages=min_messages)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def collect_sft_jsonl_errors(
|
|
222
|
+
path: Path,
|
|
223
|
+
*,
|
|
224
|
+
min_messages: int = 1,
|
|
225
|
+
max_lines: int | None = None,
|
|
226
|
+
max_errors: int | None = None,
|
|
227
|
+
) -> list[str]:
|
|
228
|
+
errors: list[str] = []
|
|
229
|
+
lines_checked = 0
|
|
230
|
+
|
|
231
|
+
with path.open("r", encoding="utf-8") as fh:
|
|
232
|
+
for lineno, raw_line in enumerate(fh, start=1):
|
|
233
|
+
if max_lines is not None and lines_checked >= max_lines:
|
|
234
|
+
break
|
|
235
|
+
stripped = raw_line.strip()
|
|
236
|
+
if not stripped:
|
|
237
|
+
continue
|
|
238
|
+
lines_checked += 1
|
|
239
|
+
try:
|
|
240
|
+
parse_jsonl_line(stripped, min_messages=min_messages)
|
|
241
|
+
except json.JSONDecodeError as exc:
|
|
242
|
+
errors.append(f"Line {lineno}: invalid JSON ({exc.msg})")
|
|
243
|
+
except SFTDataError as exc:
|
|
244
|
+
errors.append(f"Line {lineno}: {exc}")
|
|
245
|
+
if max_errors is not None and len(errors) >= max_errors:
|
|
246
|
+
break
|
|
247
|
+
if lines_checked == 0 and (max_errors is None or len(errors) < max_errors):
|
|
248
|
+
errors.append("File contains no SFT examples")
|
|
249
|
+
return errors
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def validate_jsonl_or_raise(
|
|
253
|
+
path: Path,
|
|
254
|
+
*,
|
|
255
|
+
min_messages: int = 1,
|
|
256
|
+
max_lines: int | None = None,
|
|
257
|
+
max_errors: int | None = None,
|
|
258
|
+
error_factory: type[Exception] = ValueError,
|
|
259
|
+
) -> None:
|
|
260
|
+
if not path.exists():
|
|
261
|
+
raise FileNotFoundError(str(path))
|
|
262
|
+
|
|
263
|
+
issues = collect_sft_jsonl_errors(
|
|
264
|
+
path,
|
|
265
|
+
min_messages=min_messages,
|
|
266
|
+
max_lines=max_lines,
|
|
267
|
+
max_errors=max_errors,
|
|
268
|
+
)
|
|
269
|
+
if issues:
|
|
270
|
+
truncated = max_errors is not None and len(issues) >= max_errors
|
|
271
|
+
suffix = "" if not truncated else f" (showing first {max_errors} issues)"
|
|
272
|
+
details = "\n - ".join(issues)
|
|
273
|
+
raise error_factory(f"{path}: Dataset validation failed{suffix}:\n - {details}")
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def load_jsonl(path: Path, *, min_messages: int = 1) -> list[SFTExample]:
|
|
277
|
+
if not path.exists():
|
|
278
|
+
raise FileNotFoundError(str(path))
|
|
279
|
+
with path.open("r", encoding="utf-8") as fh:
|
|
280
|
+
return list(iter_sft_examples(fh, min_messages=min_messages))
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
__all__ = [
|
|
284
|
+
"SFTDataError",
|
|
285
|
+
"SFTExample",
|
|
286
|
+
"SFTMessage",
|
|
287
|
+
"SFTToolCall",
|
|
288
|
+
"SFTToolDefinition",
|
|
289
|
+
"collect_sft_jsonl_errors",
|
|
290
|
+
"coerce_example",
|
|
291
|
+
"iter_sft_examples",
|
|
292
|
+
"load_jsonl",
|
|
293
|
+
"parse_jsonl_line",
|
|
294
|
+
"validate_jsonl_or_raise",
|
|
295
|
+
]
|
synth_ai/learning/sse.py
CHANGED
|
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import time
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from contextlib import suppress
|
|
6
7
|
|
|
7
8
|
import aiohttp
|
|
8
9
|
|
|
@@ -18,7 +19,7 @@ async def stream_events(
|
|
|
18
19
|
job_id: str,
|
|
19
20
|
*,
|
|
20
21
|
seconds: int = 60,
|
|
21
|
-
on_event:
|
|
22
|
+
on_event: Callable[[dict], None] | None = None,
|
|
22
23
|
) -> None:
|
|
23
24
|
if seconds <= 0:
|
|
24
25
|
return
|
|
@@ -29,28 +30,28 @@ async def stream_events(
|
|
|
29
30
|
]
|
|
30
31
|
for url in candidates:
|
|
31
32
|
try:
|
|
32
|
-
async with
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
async with (
|
|
34
|
+
aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session,
|
|
35
|
+
session.get(url, headers=headers) as resp,
|
|
36
|
+
):
|
|
37
|
+
if resp.status != 200:
|
|
38
|
+
continue
|
|
39
|
+
start_t = time.time()
|
|
40
|
+
async for raw in resp.content:
|
|
41
|
+
line = raw.decode(errors="ignore").strip()
|
|
42
|
+
if not line or line.startswith(":"):
|
|
35
43
|
continue
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
obj
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
if on_event:
|
|
49
|
-
try:
|
|
50
|
-
on_event(obj)
|
|
51
|
-
except Exception:
|
|
52
|
-
pass
|
|
53
|
-
if (time.time() - start_t) >= seconds:
|
|
54
|
-
return
|
|
44
|
+
if not line.startswith("data:"):
|
|
45
|
+
continue
|
|
46
|
+
data = line[5:].strip()
|
|
47
|
+
try:
|
|
48
|
+
obj = json.loads(data)
|
|
49
|
+
except Exception:
|
|
50
|
+
continue
|
|
51
|
+
if on_event:
|
|
52
|
+
with suppress(Exception):
|
|
53
|
+
on_event(obj)
|
|
54
|
+
if (time.time() - start_t) >= seconds:
|
|
55
|
+
return
|
|
55
56
|
except Exception:
|
|
56
57
|
continue
|