synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (327) hide show
  1. examples/__init__.py +16 -0
  2. examples/crafter_debug_render.py +8 -11
  3. examples/qwen_coder/README.md +102 -0
  4. examples/qwen_coder/_shared.py +113 -0
  5. examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
  6. examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
  7. examples/qwen_coder/configs/coder_lora_small.toml +58 -0
  8. examples/qwen_coder/generate_dataset.py +98 -0
  9. examples/qwen_coder/infer_ft_smoke.py +64 -0
  10. examples/qwen_coder/infer_prod_proxy.py +73 -0
  11. examples/qwen_coder/infer_via_synth.py +87 -0
  12. examples/qwen_coder/scripts/infer_coder.sh +18 -0
  13. examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
  14. examples/qwen_coder/sft_full_17b.py +103 -0
  15. examples/qwen_coder/sft_lora_30b.py +110 -0
  16. examples/qwen_coder/subset_jsonl.py +38 -0
  17. examples/qwen_coder/validate_jsonl.py +59 -0
  18. examples/rl/run_eval.py +36 -37
  19. examples/rl/run_rl_and_save.py +5 -5
  20. examples/rl/task_app/math_single_step.py +65 -43
  21. examples/rl/task_app/math_task_app.py +3 -3
  22. examples/sft/README.md +139 -0
  23. examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
  24. examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
  25. examples/sft/evaluate.py +117 -0
  26. examples/sft/export_dataset.py +117 -0
  27. examples/sft/generate_traces.py +162 -0
  28. examples/swe/__init__.py +12 -0
  29. examples/swe/task_app/README.md +105 -0
  30. examples/swe/task_app/__init__.py +2 -0
  31. examples/swe/task_app/grpo_swe_mini.py +571 -0
  32. examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
  33. examples/swe/task_app/hosted/README.md +173 -0
  34. examples/swe/task_app/hosted/__init__.py +5 -0
  35. examples/swe/task_app/hosted/branching.py +143 -0
  36. examples/swe/task_app/hosted/environment_routes.py +1289 -0
  37. examples/swe/task_app/hosted/envs/__init__.py +1 -0
  38. examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
  39. examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
  40. examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
  41. examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
  42. examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
  43. examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
  44. examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
  45. examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
  46. examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
  47. examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
  48. examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
  49. examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
  50. examples/swe/task_app/hosted/hosted_app.py +204 -0
  51. examples/swe/task_app/hosted/inference/__init__.py +5 -0
  52. examples/swe/task_app/hosted/inference/openai_client.py +618 -0
  53. examples/swe/task_app/hosted/main.py +100 -0
  54. examples/swe/task_app/hosted/policy_routes.py +1079 -0
  55. examples/swe/task_app/hosted/registry.py +195 -0
  56. examples/swe/task_app/hosted/rollout.py +1869 -0
  57. examples/swe/task_app/hosted/storage/__init__.py +5 -0
  58. examples/swe/task_app/hosted/storage/volume.py +211 -0
  59. examples/swe/task_app/hosted/test_agents.py +161 -0
  60. examples/swe/task_app/hosted/test_service.py +137 -0
  61. examples/swe/task_app/hosted/utils.py +62 -0
  62. examples/vlm/README.md +68 -0
  63. examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
  64. examples/vlm/crafter_image_only_agent.py +207 -0
  65. examples/vlm/crafter_openai_vlm_agent.py +277 -0
  66. examples/vlm/filter_image_rows.py +63 -0
  67. examples/vlm/run_crafter_vlm_benchmark.py +316 -0
  68. examples/warming_up_to_rl/analyze_trace_db.py +5 -5
  69. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
  70. examples/warming_up_to_rl/export_trace_sft.py +78 -21
  71. examples/warming_up_to_rl/groq_test.py +4 -4
  72. examples/warming_up_to_rl/manage_secrets.py +13 -18
  73. examples/warming_up_to_rl/run_eval.py +42 -44
  74. examples/warming_up_to_rl/run_fft_and_save.py +11 -16
  75. examples/warming_up_to_rl/run_local_rollout.py +1 -3
  76. examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
  77. examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
  78. examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
  79. examples/warming_up_to_rl/run_rl_and_save.py +5 -6
  80. examples/warming_up_to_rl/run_rollout_remote.py +8 -10
  81. examples/warming_up_to_rl/task_app/README.md +6 -2
  82. examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
  83. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
  98. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
  99. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
  100. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
  101. synth/__init__.py +14 -0
  102. synth_ai/__init__.py +26 -4
  103. synth_ai/api/models/supported.py +376 -0
  104. synth_ai/api/train/builders.py +128 -21
  105. synth_ai/api/train/cli.py +80 -64
  106. synth_ai/api/train/config_finder.py +7 -2
  107. synth_ai/api/train/env_resolver.py +1 -1
  108. synth_ai/api/train/pollers.py +2 -1
  109. synth_ai/api/train/supported_algos.py +139 -0
  110. synth_ai/api/train/task_app.py +1 -2
  111. synth_ai/api/train/utils.py +13 -44
  112. synth_ai/cli/__init__.py +8 -0
  113. synth_ai/cli/_modal_wrapper.py +28 -0
  114. synth_ai/cli/_typer_patch.py +49 -0
  115. synth_ai/cli/balance.py +1 -2
  116. synth_ai/cli/calc.py +1 -1
  117. synth_ai/cli/demo.py +2 -1
  118. synth_ai/cli/recent.py +2 -2
  119. synth_ai/cli/rl_demo.py +2 -1
  120. synth_ai/cli/root.py +11 -13
  121. synth_ai/cli/status.py +2 -2
  122. synth_ai/cli/task_apps.py +529 -179
  123. synth_ai/cli/traces.py +6 -4
  124. synth_ai/cli/watch.py +12 -18
  125. synth_ai/demo_registry.py +1 -1
  126. synth_ai/demos/core/cli.py +36 -43
  127. synth_ai/demos/demo_task_apps/__init__.py +3 -3
  128. synth_ai/demos/demo_task_apps/core.py +17 -25
  129. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
  130. synth_ai/demos/demo_task_apps/math/app.py +2 -1
  131. synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
  132. synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
  133. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  134. synth_ai/environments/examples/crafter_classic/environment.py +76 -1
  135. synth_ai/environments/reproducibility/tree.py +2 -5
  136. synth_ai/environments/service/app.py +11 -12
  137. synth_ai/environments/service/core_routes.py +4 -7
  138. synth_ai/environments/stateful/engine.py +1 -1
  139. synth_ai/environments/tasks/core.py +1 -0
  140. synth_ai/environments/tasks/filters.py +5 -6
  141. synth_ai/environments/tasks/utils.py +4 -5
  142. synth_ai/handshake.py +9 -9
  143. synth_ai/http.py +1 -1
  144. synth_ai/http_client.py +18 -10
  145. synth_ai/inference/client.py +15 -5
  146. synth_ai/jobs/client.py +78 -83
  147. synth_ai/learning/__init__.py +41 -6
  148. synth_ai/learning/algorithms.py +14 -0
  149. synth_ai/learning/client.py +91 -24
  150. synth_ai/learning/config.py +2 -38
  151. synth_ai/learning/ft_client.py +4 -59
  152. synth_ai/learning/health.py +5 -6
  153. synth_ai/learning/jobs.py +31 -47
  154. synth_ai/{rl → learning/rl}/__init__.py +14 -4
  155. synth_ai/learning/rl/client.py +267 -0
  156. synth_ai/learning/rl/config.py +31 -0
  157. synth_ai/{rl → learning/rl}/contracts.py +5 -8
  158. synth_ai/{rl → learning/rl}/env_keys.py +39 -15
  159. synth_ai/learning/rl/secrets.py +13 -0
  160. synth_ai/learning/rl_client.py +2 -281
  161. synth_ai/learning/sft/__init__.py +29 -0
  162. synth_ai/learning/sft/client.py +68 -0
  163. synth_ai/learning/sft/config.py +270 -0
  164. synth_ai/learning/sft/data.py +295 -0
  165. synth_ai/learning/sse.py +25 -24
  166. synth_ai/learning/validators.py +25 -28
  167. synth_ai/lm/__init__.py +21 -47
  168. synth_ai/main.py +6 -0
  169. synth_ai/task/__init__.py +25 -27
  170. synth_ai/task/apps/__init__.py +7 -8
  171. synth_ai/task/auth.py +8 -8
  172. synth_ai/task/client.py +14 -14
  173. synth_ai/task/contracts.py +36 -35
  174. synth_ai/task/datasets.py +6 -5
  175. synth_ai/task/errors.py +10 -10
  176. synth_ai/task/health.py +17 -9
  177. synth_ai/task/json.py +58 -23
  178. synth_ai/task/proxy.py +13 -9
  179. synth_ai/task/rubrics.py +16 -15
  180. synth_ai/task/server.py +12 -12
  181. synth_ai/task/tracing_utils.py +4 -4
  182. synth_ai/task/vendors.py +5 -6
  183. synth_ai/tracing_v3/__init__.py +2 -0
  184. synth_ai/tracing_v3/abstractions.py +21 -4
  185. synth_ai/tracing_v3/decorators.py +18 -16
  186. synth_ai/tracing_v3/hooks.py +5 -5
  187. synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
  188. synth_ai/tracing_v3/session_tracer.py +40 -14
  189. synth_ai/tracing_v3/storage/base.py +85 -0
  190. synth_ai/tracing_v3/storage/config.py +21 -8
  191. synth_ai/tracing_v3/storage/factory.py +10 -7
  192. synth_ai/tracing_v3/storage/utils.py +4 -2
  193. synth_ai/tracing_v3/turso/daemon.py +7 -2
  194. synth_ai/tracing_v3/turso/models.py +2 -2
  195. synth_ai/tracing_v3/turso/native_manager.py +1173 -0
  196. synth_ai/tracing_v3/utils.py +4 -4
  197. synth_ai/v0/api/__init__.py +8 -0
  198. synth_ai/v0/api/models/__init__.py +8 -0
  199. synth_ai/v0/api/models/supported.py +8 -0
  200. synth_ai/v0/config/__init__.py +15 -0
  201. synth_ai/v0/config/base_url.py +12 -0
  202. synth_ai/v0/lm/__init__.py +51 -0
  203. synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
  204. synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
  205. synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
  206. synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
  207. synth_ai/{lm → v0/lm}/config.py +6 -1
  208. synth_ai/{lm → v0/lm}/core/all.py +9 -9
  209. synth_ai/{lm → v0/lm}/core/main.py +6 -6
  210. synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
  211. synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
  212. synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
  213. synth_ai/{lm → v0/lm}/overrides.py +2 -2
  214. synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
  215. synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
  216. synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
  217. synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
  218. synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
  219. synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
  220. synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
  221. synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
  222. synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
  223. synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
  224. synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
  225. synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
  226. synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
  227. synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
  228. synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
  229. synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
  230. synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
  231. synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
  232. synth_ai/v0/tracing_v3/__init__.py +10 -0
  233. synth_ai/v0/tracing_v3/abstractions.py +3 -0
  234. synth_ai/v0/tracing_v3/decorators.py +3 -0
  235. synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
  236. synth_ai/v0/tracing_v3/session_tracer.py +3 -0
  237. synth_ai-0.2.9.dev9.dist-info/METADATA +191 -0
  238. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/RECORD +268 -238
  239. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/top_level.txt +1 -0
  240. examples/common_old/backend.py +0 -20
  241. examples/evals_old/README.md +0 -98
  242. examples/evals_old/__init__.py +0 -6
  243. examples/evals_old/compare_models.py +0 -1038
  244. examples/evals_old/example_log.md +0 -145
  245. examples/evals_old/run_demo.sh +0 -126
  246. examples/evals_old/trace_analysis.py +0 -270
  247. examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
  248. examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
  249. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
  250. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
  251. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
  252. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
  253. examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
  254. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
  255. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
  256. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
  257. examples/finetuning_old/synth_qwen_v1/README.md +0 -68
  258. examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
  259. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
  260. examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
  261. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
  262. examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
  263. examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
  264. examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
  265. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
  266. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
  267. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
  268. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
  269. examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
  270. examples/finetuning_old/synth_qwen_v1/util.py +0 -152
  271. examples/rl_old/task_app.py +0 -1131
  272. examples/warming_up_to_rl/old/event_rewards.md +0 -234
  273. examples/warming_up_to_rl/old/notes.md +0 -73
  274. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
  275. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
  276. synth_ai/experimental/synth_oss.py +0 -445
  277. synth_ai/learning/filtering.py +0 -0
  278. synth_ai/learning/offline/dpo.py +0 -0
  279. synth_ai/learning/offline/providers.py +0 -7
  280. synth_ai/learning/offline/sft.py +0 -0
  281. synth_ai/learning/offline/shared.py +0 -0
  282. synth_ai/learning/online/grpo.py +0 -0
  283. synth_ai/learning/online/irft.py +0 -0
  284. synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
  285. synth_ai/learning/prompts/gepa.py +0 -0
  286. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
  287. synth_ai/learning/prompts/mipro.py +0 -289
  288. synth_ai/learning/prompts/random_search.py +0 -249
  289. synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
  290. synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
  291. synth_ai/rl/secrets.py +0 -19
  292. synth_ai/scripts/verify_rewards.py +0 -100
  293. synth_ai/tracing/__init__.py +0 -30
  294. synth_ai/tracing_v1/__init__.py +0 -33
  295. synth_ai/tracing_v3/turso/__init__.py +0 -25
  296. synth_ai/tracing_v3/turso/manager.py +0 -838
  297. synth_ai/zyk/__init__.py +0 -30
  298. synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
  299. /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
  300. /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
  301. /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
  302. /synth_ai/{lm → v0/lm}/constants.py +0 -0
  303. /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
  304. /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
  305. /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
  306. /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
  307. /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
  308. /synth_ai/{lm → v0/lm}/injection.py +0 -0
  309. /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
  310. /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
  311. /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
  312. /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
  313. /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
  314. /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
  315. /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
  316. /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
  317. /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
  318. /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
  319. /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
  320. /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
  321. /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
  322. /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
  323. /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
  324. /synth_ai/{lm → v0/lm}/warmup.py +0 -0
  325. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/WHEEL +0 -0
  326. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/entry_points.txt +0 -0
  327. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev9.dist-info}/licenses/LICENSE +0 -0
@@ -5,8 +5,19 @@ from pathlib import Path
5
5
  from typing import Any
6
6
 
7
7
  import click
8
+ from synth_ai.api.models.supported import (
9
+ UnsupportedModelError,
10
+ ensure_allowed_model,
11
+ normalize_model_identifier,
12
+ )
13
+ from synth_ai.learning.sft.config import prepare_sft_job_payload
8
14
 
9
- from .utils import ensure_api_base, load_toml, TrainError
15
+ from .supported_algos import (
16
+ AlgorithmValidationError,
17
+ ensure_model_supported_for_algorithm,
18
+ validate_algorithm_config,
19
+ )
20
+ from .utils import TrainError, ensure_api_base, load_toml
10
21
 
11
22
 
12
23
  @dataclass(slots=True)
@@ -29,29 +40,78 @@ def build_rl_payload(
29
40
  task_url: str,
30
41
  overrides: dict[str, Any],
31
42
  idempotency: str | None,
43
+ allow_experimental: bool | None = None,
32
44
  ) -> RLBuildResult:
33
45
  data = load_toml(config_path)
46
+ try:
47
+ spec = validate_algorithm_config(data.get("algorithm"), expected_family="rl")
48
+ except AlgorithmValidationError as exc:
49
+ raise click.ClickException(str(exc)) from exc
34
50
  services = data.get("services") if isinstance(data.get("services"), dict) else {}
35
51
  model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
36
52
 
37
53
  final_task_url = (
38
- overrides.get("task_url") or task_url or services.get("task_url") or ""
54
+ overrides.get("task_url")
55
+ or task_url
56
+ or (services.get("task_url") if isinstance(services, dict) else None)
57
+ or ""
39
58
  ).strip()
40
59
  if not final_task_url:
41
60
  raise click.ClickException(
42
61
  "Task app URL required (provide --task-url or set services.task_url in TOML)"
43
62
  )
44
63
 
45
- model_source = (model_cfg.get("source") or "").strip()
46
- model_base = (model_cfg.get("base") or "").strip()
64
+ raw_source = model_cfg.get("source") if isinstance(model_cfg, dict) else ""
65
+ model_source = str(raw_source or "").strip()
66
+ raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
67
+ model_base = str(raw_base or "").strip()
47
68
  override_model = (overrides.get("model") or "").strip()
48
69
  if override_model:
49
70
  model_source = override_model
50
71
  model_base = ""
51
72
  if bool(model_source) == bool(model_base):
73
+ details = (
74
+ f"Config: {config_path}\n"
75
+ f"[model].source={model_source!r} | [model].base={model_base!r}"
76
+ )
77
+ hint = (
78
+ "Set exactly one: [model].base for a base model (e.g. 'Qwen/Qwen3-1.7B') "
79
+ "or [model].source for a fine-tuned model id. Also remove any conflicting "
80
+ "'[policy].model' entries."
81
+ )
52
82
  raise click.ClickException(
53
- "Model section must specify exactly one of [model].source or [model].base"
83
+ "Invalid model config: exactly one of [model].source or [model].base is required.\n"
84
+ + details
85
+ + "\nHint: "
86
+ + hint
87
+ )
88
+
89
+ try:
90
+ if model_source:
91
+ model_source = normalize_model_identifier(model_source)
92
+ if model_base:
93
+ model_base = normalize_model_identifier(model_base, allow_finetuned_prefixes=False)
94
+ except UnsupportedModelError as exc:
95
+ raise click.ClickException(str(exc)) from exc
96
+
97
+ base_model_for_training: str | None = None
98
+ if model_source:
99
+ base_model_for_training = ensure_allowed_model(
100
+ model_source,
101
+ allow_finetuned_prefixes=True,
102
+ allow_experimental=allow_experimental,
103
+ )
104
+ elif model_base:
105
+ base_model_for_training = ensure_allowed_model(
106
+ model_base,
107
+ allow_finetuned_prefixes=False,
108
+ allow_experimental=allow_experimental,
54
109
  )
110
+ if base_model_for_training:
111
+ try:
112
+ ensure_model_supported_for_algorithm(base_model_for_training, spec)
113
+ except AlgorithmValidationError as exc:
114
+ raise click.ClickException(str(exc)) from exc
55
115
 
56
116
  # Force TOML services.task_url to the effective endpoint to avoid split URLs
57
117
  try:
@@ -87,15 +147,24 @@ def build_sft_payload(
87
147
  *,
88
148
  config_path: Path,
89
149
  dataset_override: Path | None,
150
+ allow_experimental: bool | None,
90
151
  ) -> SFTBuildResult:
91
152
  data = load_toml(config_path)
153
+ try:
154
+ spec = validate_algorithm_config(data.get("algorithm"), expected_family="sft")
155
+ except AlgorithmValidationError as exc:
156
+ raise TrainError(str(exc)) from exc
92
157
  job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
93
158
  data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
94
159
  hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
95
160
  train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
96
161
  compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
97
162
 
98
- raw_dataset = dataset_override or job_cfg.get("data") or job_cfg.get("data_path")
163
+ raw_dataset = (
164
+ dataset_override
165
+ or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
166
+ or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
167
+ )
99
168
  if not raw_dataset:
100
169
  raise TrainError("Dataset not specified; pass --dataset or set [job].data")
101
170
  dataset_path = Path(raw_dataset)
@@ -108,7 +177,9 @@ def build_sft_payload(
108
177
 
109
178
  validation_path = (
110
179
  data_cfg.get("validation_path")
111
- if isinstance(data_cfg.get("validation_path"), str)
180
+ if isinstance(data_cfg, dict)
181
+ else None
182
+ if isinstance(data_cfg, dict) and isinstance(data_cfg.get("validation_path"), str)
112
183
  else None
113
184
  )
114
185
  validation_file = None
@@ -122,7 +193,7 @@ def build_sft_payload(
122
193
  validation_file = vpath
123
194
 
124
195
  hp_block: dict[str, Any] = {
125
- "n_epochs": int(hp_cfg.get("n_epochs", 1)),
196
+ "n_epochs": int(hp_cfg.get("n_epochs", 1) if isinstance(hp_cfg, dict) else 1),
126
197
  }
127
198
  for key in (
128
199
  "batch_size",
@@ -134,27 +205,35 @@ def build_sft_payload(
134
205
  "warmup_ratio",
135
206
  "train_kind",
136
207
  ):
137
- if key in hp_cfg:
208
+ if isinstance(hp_cfg, dict) and key in hp_cfg:
138
209
  hp_block[key] = hp_cfg[key]
139
- if isinstance(hp_cfg.get("parallelism"), dict):
210
+ if isinstance(hp_cfg, dict) and isinstance(hp_cfg.get("parallelism"), dict):
140
211
  hp_block["parallelism"] = hp_cfg["parallelism"]
141
212
 
142
213
  compute_block = {
143
- k: compute_cfg[k] for k in ("gpu_type", "gpu_count", "nodes") if k in compute_cfg
214
+ k: compute_cfg[k]
215
+ for k in ("gpu_type", "gpu_count", "nodes")
216
+ if isinstance(compute_cfg, dict) and k in compute_cfg
144
217
  }
145
218
 
146
219
  effective = {
147
220
  "compute": compute_block,
148
221
  "data": {
149
222
  "topology": data_cfg.get("topology", {})
150
- if isinstance(data_cfg.get("topology"), dict)
223
+ if isinstance(data_cfg, dict) and isinstance(data_cfg.get("topology"), dict)
151
224
  else {}
152
225
  },
153
- "training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
226
+ "training": {
227
+ k: v
228
+ for k, v in (train_cfg.items() if isinstance(train_cfg, dict) else [])
229
+ if k in ("mode", "use_qlora")
230
+ },
154
231
  }
155
232
 
156
233
  validation_cfg = (
157
- train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
234
+ train_cfg.get("validation")
235
+ if isinstance(train_cfg, dict) and isinstance(train_cfg.get("validation"), dict)
236
+ else None
158
237
  )
159
238
  if isinstance(validation_cfg, dict):
160
239
  hp_block.update(
@@ -170,13 +249,41 @@ def build_sft_payload(
170
249
  "enabled": bool(validation_cfg.get("enabled", True))
171
250
  }
172
251
 
173
- payload = {
174
- "model": job_cfg.get("model") or data.get("model"),
175
- "training_file_id": None, # populated after upload
176
- "training_type": "sft_offline",
177
- "hyperparameters": hp_block,
178
- "metadata": {"effective_config": effective},
179
- }
252
+ raw_model = str(
253
+ job_cfg.get("model") if isinstance(job_cfg, dict) else None or data.get("model") or ""
254
+ ).strip()
255
+ if not raw_model:
256
+ raise TrainError("Model not specified; set [job].model or [model].base in the config")
257
+
258
+ try:
259
+ base_model = ensure_allowed_model(
260
+ raw_model,
261
+ allow_finetuned_prefixes=False,
262
+ allow_experimental=allow_experimental,
263
+ )
264
+ except UnsupportedModelError as exc:
265
+ raise TrainError(str(exc)) from exc
266
+ try:
267
+ ensure_model_supported_for_algorithm(base_model, spec)
268
+ except AlgorithmValidationError as exc:
269
+ raise TrainError(str(exc)) from exc
270
+
271
+ try:
272
+ payload = prepare_sft_job_payload(
273
+ model=raw_model,
274
+ training_file=None,
275
+ hyperparameters=hp_block,
276
+ metadata={"effective_config": effective},
277
+ training_type="sft_offline",
278
+ training_file_field="training_file_id",
279
+ require_training_file=False,
280
+ include_training_file_when_none=True,
281
+ allow_finetuned_prefixes=False,
282
+ )
283
+ except UnsupportedModelError as exc:
284
+ raise TrainError(str(exc)) from exc
285
+ except ValueError as exc:
286
+ raise TrainError(str(exc)) from exc
180
287
 
181
288
  return SFTBuildResult(payload=payload, train_file=dataset_path, validation_file=validation_file)
182
289
 
synth_ai/api/train/cli.py CHANGED
@@ -2,21 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  from pathlib import Path
5
- from typing import Any, Dict
5
+ from typing import Any
6
6
 
7
7
  import click
8
+ from synth_ai.config.base_url import get_backend_from_env
8
9
 
9
- from .builders import RLBuildResult, SFTBuildResult, build_rl_payload, build_sft_payload
10
+ from .builders import build_rl_payload, build_sft_payload
10
11
  from .config_finder import discover_configs, prompt_for_config
11
12
  from .env_resolver import KeySpec, resolve_env
12
13
  from .pollers import RLJobPoller, SFTJobPoller
13
14
  from .task_app import check_task_app_health
14
15
  from .utils import (
15
- TrainError,
16
16
  REPO_ROOT,
17
+ TrainError,
17
18
  ensure_api_base,
18
- http_post,
19
19
  http_get,
20
+ http_post,
20
21
  limit_jsonl_examples,
21
22
  mask_value,
22
23
  post_multipart,
@@ -24,7 +25,6 @@ from .utils import (
24
25
  sleep,
25
26
  validate_sft_jsonl,
26
27
  )
27
- from synth_ai.config.base_url import get_backend_from_env
28
28
 
29
29
 
30
30
  def _discover_dataset_candidates(config_path: Path, limit: int = 50) -> list[Path]:
@@ -130,8 +130,23 @@ def _default_backend() -> str:
130
130
  )
131
131
  @click.option("--backend", default=_default_backend, help="Backend base URL")
132
132
  @click.option("--model", default=None, help="Override model identifier")
133
+ @click.option(
134
+ "--allow-experimental",
135
+ "allow_experimental",
136
+ is_flag=True,
137
+ flag_value=True,
138
+ default=None,
139
+ help="Allow experimental models (overrides SDK_EXPERIMENTAL env)",
140
+ )
141
+ @click.option(
142
+ "--no-allow-experimental",
143
+ "allow_experimental",
144
+ is_flag=True,
145
+ flag_value=False,
146
+ help="Disallow experimental models (overrides SDK_EXPERIMENTAL env)",
147
+ )
133
148
  @click.option("--idempotency", default=None, help="Idempotency-Key header for job creation")
134
- @click.option("--dry-run", is_flag=True, help="Preview payload without submitting")
149
+ @click.option("--dry-run", is_flag=True, hidden=True, help="Deprecated: no-op")
135
150
  @click.option("--poll/--no-poll", default=True, help="Poll job status until terminal state")
136
151
  @click.option(
137
152
  "--poll-timeout", default=3600.0, type=float, help="Maximum seconds to poll before timing out"
@@ -152,6 +167,7 @@ def train_command(
152
167
  dataset_path: str | None,
153
168
  backend: str,
154
169
  model: str | None,
170
+ allow_experimental: bool | None,
155
171
  idempotency: str | None,
156
172
  dry_run: bool,
157
173
  poll: bool,
@@ -165,7 +181,9 @@ def train_command(
165
181
  list(config_paths), requested_type=train_type if train_type != "auto" else None
166
182
  )
167
183
  selection = prompt_for_config(
168
- candidates, requested_type=train_type if train_type != "auto" else None
184
+ candidates,
185
+ requested_type=train_type if train_type != "auto" else None,
186
+ allow_autoselect=bool(config_paths),
169
187
  )
170
188
 
171
189
  effective_type = train_type if train_type != "auto" else selection.train_type
@@ -243,6 +261,7 @@ def train_command(
243
261
  task_url_override=task_url,
244
262
  model_override=model,
245
263
  idempotency=idempotency,
264
+ allow_experimental=allow_experimental,
246
265
  dry_run=dry_run,
247
266
  poll=poll,
248
267
  poll_timeout=poll_timeout,
@@ -255,6 +274,7 @@ def train_command(
255
274
  backend_base=backend_base,
256
275
  synth_key=synth_key,
257
276
  dataset_override=dataset_override_path,
277
+ allow_experimental=allow_experimental,
258
278
  dry_run=dry_run,
259
279
  poll=poll,
260
280
  poll_timeout=poll_timeout,
@@ -303,7 +323,7 @@ def _wait_for_training_file(
303
323
  error_body = resp.json()
304
324
  except Exception:
305
325
  error_body = resp.text[:400]
306
- click.echo(f"\n[ERROR] Authentication failed when checking training file:")
326
+ click.echo("\n[ERROR] Authentication failed when checking training file:")
307
327
  click.echo(f" URL: {url}")
308
328
  click.echo(f" Status: {resp.status_code}")
309
329
  click.echo(f" Response: {error_body}")
@@ -339,12 +359,13 @@ def handle_rl(
339
359
  task_url_override: str | None,
340
360
  model_override: str | None,
341
361
  idempotency: str | None,
362
+ allow_experimental: bool | None,
342
363
  dry_run: bool,
343
364
  poll: bool,
344
365
  poll_timeout: float,
345
366
  poll_interval: float,
346
367
  ) -> None:
347
- overrides: Dict[str, Any] = {
368
+ overrides: dict[str, Any] = {
348
369
  "backend": backend_base,
349
370
  "task_url": task_url_override,
350
371
  "model": model_override,
@@ -354,6 +375,7 @@ def handle_rl(
354
375
  task_url=task_url_override or os.environ.get("TASK_APP_URL", ""),
355
376
  overrides=overrides,
356
377
  idempotency=idempotency,
378
+ allow_experimental=allow_experimental,
357
379
  )
358
380
 
359
381
  # Backend-side verification: try ALL org environment keys against /health and /task_info
@@ -371,7 +393,7 @@ def handle_rl(
371
393
  raise click.ClickException(
372
394
  f"Task app verification call failed: {type(_ve).__name__}: {_ve}"
373
395
  ) from _ve
374
- if vresp.status_code >= 400:
396
+ if vresp.status_code is not None and vresp.status_code >= 400:
375
397
  click.echo("Task app verification error:\n" + preview_json(vjs, limit=800))
376
398
  raise click.ClickException(f"Verification failed with status {vresp.status_code}")
377
399
  if not bool(vjs.get("any_ok")):
@@ -407,9 +429,6 @@ def handle_rl(
407
429
 
408
430
  click.echo(f"POST {create_url}")
409
431
  click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
410
- if dry_run:
411
- click.echo("Dry run enabled; skipping submission")
412
- return
413
432
 
414
433
  resp = http_post(create_url, headers=headers, json_body=build.payload)
415
434
  try:
@@ -439,6 +458,7 @@ def handle_sft(
439
458
  backend_base: str,
440
459
  synth_key: str,
441
460
  dataset_override: Path | None,
461
+ allow_experimental: bool | None,
442
462
  dry_run: bool,
443
463
  poll: bool,
444
464
  poll_timeout: float,
@@ -449,7 +469,11 @@ def handle_sft(
449
469
 
450
470
  while True:
451
471
  try:
452
- build = build_sft_payload(config_path=cfg_path, dataset_override=dataset_path)
472
+ build = build_sft_payload(
473
+ config_path=cfg_path,
474
+ dataset_override=dataset_path,
475
+ allow_experimental=allow_experimental,
476
+ )
453
477
  break
454
478
  except TrainError as exc:
455
479
  click.echo(str(exc))
@@ -472,54 +496,49 @@ def handle_sft(
472
496
  validate_sft_jsonl(build.validation_file)
473
497
 
474
498
  upload_url = f"{backend_base}/learning/files"
475
- click.echo(f"\n=== Uploading Training Data ===")
499
+ click.echo("\n=== Uploading Training Data ===")
476
500
  click.echo(f"Dataset: {build.train_file}")
477
501
  click.echo(f"Destination: {upload_url}")
478
- if dry_run:
479
- click.echo("Dry run: skipping upload")
480
- train_file_id = "dry-run-train"
481
- val_file_id = None
482
- else:
483
- resp = post_multipart(
484
- upload_url, api_key=synth_key, file_field="file", file_path=build.train_file
502
+ resp = post_multipart(
503
+ upload_url, api_key=synth_key, file_field="file", file_path=build.train_file
504
+ )
505
+ js = (
506
+ resp.json()
507
+ if resp.headers.get("content-type", "").startswith("application/json")
508
+ else {}
509
+ )
510
+ if resp.status_code is not None and resp.status_code >= 400 or "id" not in js:
511
+ click.echo("\n[ERROR] Training file upload failed:")
512
+ click.echo(f" URL: {upload_url}")
513
+ click.echo(f" Status: {resp.status_code}")
514
+ click.echo(f" Response: {js or resp.text[:400]}")
515
+ click.echo(f" File: {build.train_file}")
516
+ raise click.ClickException(
517
+ f"Training file upload failed with status {resp.status_code}"
518
+ )
519
+ train_file_id = js["id"]
520
+ click.echo(f"✓ Training file uploaded (id={train_file_id})")
521
+ val_file_id = None
522
+ if build.validation_file:
523
+ click.echo(f"Uploading validation dataset: {build.validation_file}")
524
+ vresp = post_multipart(
525
+ upload_url,
526
+ api_key=synth_key,
527
+ file_field="file",
528
+ file_path=build.validation_file,
485
529
  )
486
- js = (
487
- resp.json()
488
- if resp.headers.get("content-type", "").startswith("application/json")
530
+ vjs = (
531
+ vresp.json()
532
+ if vresp.headers.get("content-type", "").startswith("application/json")
489
533
  else {}
490
534
  )
491
- if resp.status_code >= 400 or "id" not in js:
492
- click.echo(f"\n[ERROR] Training file upload failed:")
493
- click.echo(f" URL: {upload_url}")
494
- click.echo(f" Status: {resp.status_code}")
495
- click.echo(f" Response: {js or resp.text[:400]}")
496
- click.echo(f" File: {build.train_file}")
497
- raise click.ClickException(
498
- f"Training file upload failed with status {resp.status_code}"
499
- )
500
- train_file_id = js["id"]
501
- click.echo(f"✓ Training file uploaded (id={train_file_id})")
502
- val_file_id = None
503
- if build.validation_file:
504
- click.echo(f"Uploading validation dataset: {build.validation_file}")
505
- vresp = post_multipart(
506
- upload_url,
507
- api_key=synth_key,
508
- file_field="file",
509
- file_path=build.validation_file,
510
- )
511
- vjs = (
512
- vresp.json()
513
- if vresp.headers.get("content-type", "").startswith("application/json")
514
- else {}
535
+ if vresp.status_code is not None and vresp.status_code < 400 and "id" in vjs:
536
+ val_file_id = vjs["id"]
537
+ click.echo(f" Validation file uploaded (id={val_file_id})")
538
+ else:
539
+ click.echo(
540
+ f"[WARN] Validation upload failed ({vresp.status_code}): {vjs or vresp.text[:200]}"
515
541
  )
516
- if vresp.status_code < 400 and "id" in vjs:
517
- val_file_id = vjs["id"]
518
- click.echo(f"✓ Validation file uploaded (id={val_file_id})")
519
- else:
520
- click.echo(
521
- f"[WARN] Validation upload failed ({vresp.status_code}): {vjs or vresp.text[:200]}"
522
- )
523
542
  payload = dict(build.payload)
524
543
  payload["training_file_id"] = train_file_id
525
544
  if val_file_id:
@@ -527,18 +546,15 @@ def handle_sft(
527
546
  "data", {}
528
547
  )["validation_files"] = [val_file_id]
529
548
 
530
- click.echo(f"\n=== Checking File Processing Status ===")
549
+ click.echo("\n=== Checking File Processing Status ===")
531
550
  try:
532
551
  _wait_for_training_file(backend_base, synth_key, train_file_id)
533
552
  except click.ClickException as exc:
534
553
  raise click.ClickException(f"Training file {train_file_id} not ready: {exc}") from exc
535
554
 
536
- click.echo(f"\n=== Creating Training Job ===")
555
+ click.echo("\n=== Creating Training Job ===")
537
556
  click.echo("Job payload preview:")
538
557
  click.echo(preview_json(payload, limit=800))
539
- if dry_run:
540
- click.echo("Dry run: skipping job submission")
541
- return
542
558
 
543
559
  create_url = f"{backend_base}/learning/jobs"
544
560
  headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
@@ -550,7 +566,7 @@ def handle_sft(
550
566
  else {}
551
567
  )
552
568
  if resp.status_code not in (200, 201):
553
- click.echo(f"\n[ERROR] Job creation failed:")
569
+ click.echo("\n[ERROR] Job creation failed:")
554
570
  click.echo(f" URL: {create_url}")
555
571
  click.echo(f" Status: {resp.status_code}")
556
572
  click.echo(f" Response: {preview_json(js, limit=600)}")
@@ -560,14 +576,14 @@ def handle_sft(
560
576
  raise click.ClickException("Response missing job id")
561
577
  click.echo(f"✓ Job created (id={job_id})")
562
578
 
563
- click.echo(f"\n=== Starting Training Job ===")
579
+ click.echo("\n=== Starting Training Job ===")
564
580
  start_url = f"{backend_base}/learning/jobs/{job_id}/start"
565
581
  click.echo(f"POST {start_url}")
566
582
  start_resp = http_post(start_url, headers=headers, json_body={})
567
583
  if start_resp.status_code not in (200, 201):
568
584
  click.echo(f"[WARN] Job start returned status {start_resp.status_code}")
569
585
  else:
570
- click.echo(f"✓ Job started")
586
+ click.echo("✓ Job started")
571
587
 
572
588
  if not poll:
573
589
  click.echo(f"Started job {job_id} (polling disabled)")
@@ -2,9 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  import os
5
+ from collections.abc import Iterable
5
6
  from dataclasses import dataclass
6
7
  from pathlib import Path
7
- from typing import Iterable
8
8
 
9
9
  import click
10
10
 
@@ -173,7 +173,7 @@ def discover_configs(explicit: list[str], *, requested_type: str | None) -> list
173
173
 
174
174
 
175
175
  def prompt_for_config(
176
- candidates: list[ConfigCandidate], *, requested_type: str | None
176
+ candidates: list[ConfigCandidate], *, requested_type: str | None, allow_autoselect: bool = False
177
177
  ) -> ConfigCandidate:
178
178
  if not candidates:
179
179
  raise click.ClickException("No training configs found. Pass --config explicitly.")
@@ -182,6 +182,11 @@ def prompt_for_config(
182
182
  last_config = _load_last_config()
183
183
  default_idx = 1
184
184
 
185
+ if allow_autoselect and len(candidates) == 1:
186
+ chosen = candidates[0]
187
+ _save_last_config(chosen.path)
188
+ return chosen
189
+
185
190
  if last_config:
186
191
  for idx, cand in enumerate(candidates):
187
192
  if cand.path.resolve() == last_config:
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ from collections.abc import Callable, Iterable, MutableMapping
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
- from typing import Callable, Iterable, MutableMapping
7
7
 
8
8
  import click
9
9
 
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections.abc import Mapping
3
4
  from dataclasses import dataclass
4
5
  from datetime import datetime
5
- from typing import Any, Mapping
6
+ from typing import Any
6
7
 
7
8
  import click
8
9