freesolo-flash 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/PKG-INFO +3 -1
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/__init__.py +1 -1
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/catalog.py +8 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/cli/main/__init__.py +5 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/cli/main/commands.py +21 -2
- freesolo_flash-0.2.4/flash/cost/__init__.py +16 -0
- freesolo_flash-0.2.4/flash/cost/analytical.py +160 -0
- freesolo_flash-0.2.4/flash/cost/facts.py +126 -0
- freesolo_flash-0.2.4/flash/cost/spec.py +87 -0
- freesolo_flash-0.2.4/flash/cost/types.py +158 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/vram.py +19 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/worker/__init__.py +5 -7
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/base.py +10 -5
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/schema/__init__.py +2 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/schema/fields.py +8 -2
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/server/app.py +51 -3
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/server/auth.py +7 -2
- freesolo_flash-0.2.4/flash/server/billing.py +128 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/pyproject.toml +15 -2
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_allocator.py +1 -1
- freesolo_flash-0.2.4/tests/test_cli_estimate.py +224 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_client_server_integration.py +8 -0
- freesolo_flash-0.2.4/tests/test_cost_analytical.py +224 -0
- freesolo_flash-0.2.4/tests/test_cost_equation.py +47 -0
- freesolo_flash-0.2.4/tests/test_cost_estimate.py +77 -0
- freesolo_flash-0.2.4/tests/test_cost_hardware.py +131 -0
- freesolo_flash-0.2.4/tests/test_cost_models.py +36 -0
- freesolo_flash-0.2.4/tests/test_cost_rewards.py +65 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_open_model_policy.py +4 -4
- freesolo_flash-0.2.4/tests/test_server_billing.py +392 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/uv.lock +44 -40
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.dockerignore +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.env.example +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.github/workflows/ci.yml +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.github/workflows/main-source-guard.yml +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.github/workflows/publish-image.yml +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.github/workflows/publish.yml +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.github/workflows/worker-image.yml +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/.gitignore +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/Dockerfile +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/Dockerfile.worker +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/LICENSE +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/README.md +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/docker/make_rp_handler.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/_fileio.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/_logging.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/cli/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/cli/main/__main__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/cli/main/envpush.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/client/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/client/config.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/client/http.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/client/specs.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/accounting.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/chalk_kernels.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/multiturn_rollout.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/recipe.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/worker/__main__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/worker/lora.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/engine/worker/perf.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/envs/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/envs/adapter/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/envs/adapter/rubric.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/envs/base.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/envs/registry.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/mcp/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/mcp/server.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/_auth.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/_http.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/_poll.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/allocator.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/preflight.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/api.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/auth.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/gpus.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/jobs.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/preflight.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/pricing.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/train/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/train/deps.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/runpod/train/endpoints.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/_bootstrap.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/api.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/auth.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/gpus.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/jobs/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/jobs/builders.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/preflight.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/pricing.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/providers/vast/train.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/py.typed +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/runner/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/runner/deploy.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/runner/lifecycle.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/serve/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/serve/deploy.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/server/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/server/__main__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/server/db.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/server/envs.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/flash/spec.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/_helpers/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/_helpers/runner.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/_helpers/specs.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/_helpers/vast.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/conftest.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/fixtures/math_eval.jsonl +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/fixtures/math_train.jsonl +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/live/__init__.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/live/conftest.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/live/test_runpod_live.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/live/test_vast_live.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_agent_slm_cli_contract.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_algorithms.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_backend_jobspec_contract.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_cancel_remote.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_catalog_consistency.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_chalk_kernels.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_cli_commands.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_cli_errors.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_cli_managed.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_client.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_config_overrides.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_disk_gb.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_endpoint_name.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_env_install.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_env_publish.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_env_push.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_envs_coverage.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_flash_mvp.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_flash_worker.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_gpus.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_grpo_params.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_jobs.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_logging.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_login_perms.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_metrics_schema_agent_contract.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_multiturn_rollout.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_orchestrator_flash.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_preflight.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_pricing_cache.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_provider_routing.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_providers_symmetry.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_runmgmt.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_runpod_api_delete.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_serve.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_serve_modes.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_server_api.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_spec_and_validation.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_thinking_config.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_vast_api.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_vast_offers.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_vast_runner.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_verifiers.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_version.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_wandb_naming.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_worker_dryrun.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_worker_hardexit.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_worker_stack.py +0 -0
- {freesolo_flash-0.2.2 → freesolo_flash-0.2.4}/tests/test_worker_thinking.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: freesolo-flash
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: Flash — managed LoRA post-training (SFT/GRPO) for verifiers environments, driven by the `flash` CLI
|
|
5
5
|
Project-URL: Homepage, https://github.com/freesolo-co/flash
|
|
6
6
|
Project-URL: Repository, https://github.com/freesolo-co/flash
|
|
@@ -27,12 +27,14 @@ Requires-Dist: trl<1.7,>=1.6; extra == 'gpu'
|
|
|
27
27
|
Requires-Dist: verifiers>=0.1.10; extra == 'gpu'
|
|
28
28
|
Requires-Dist: vllm==0.19.1; extra == 'gpu'
|
|
29
29
|
Provides-Extra: server
|
|
30
|
+
Requires-Dist: datasets>=2.19; extra == 'server'
|
|
30
31
|
Requires-Dist: fastapi; extra == 'server'
|
|
31
32
|
Requires-Dist: httpx>=0.27; extra == 'server'
|
|
32
33
|
Requires-Dist: huggingface-hub>=0.34; extra == 'server'
|
|
33
34
|
Requires-Dist: prime>=0.6.3; extra == 'server'
|
|
34
35
|
Requires-Dist: runpod-flash; extra == 'server'
|
|
35
36
|
Requires-Dist: uvicorn; extra == 'server'
|
|
37
|
+
Requires-Dist: verifiers>=0.1.10; extra == 'server'
|
|
36
38
|
Description-Content-Type: text/markdown
|
|
37
39
|
|
|
38
40
|
# Flash
|
|
@@ -64,6 +64,9 @@ class ModelInfo:
|
|
|
64
64
|
# the raw tokenizer count). Drives the GRPO fp32-logits memory term and the per-device
|
|
65
65
|
# completion cap. Curated per model below; defaults to the open-model fallback.
|
|
66
66
|
vocab_size: int = _DEFAULT_VOCAB_SIZE
|
|
67
|
+
# Total parameters in billions — the numeric model size the cost estimator reads directly
|
|
68
|
+
# (no parsing of the ``params`` display string). Curated per catalog model below.
|
|
69
|
+
params_b: float = 0.0
|
|
67
70
|
|
|
68
71
|
def to_dict(self) -> dict[str, Any]:
|
|
69
72
|
return asdict(self)
|
|
@@ -79,6 +82,7 @@ MODELS: dict[str, ModelInfo] = {
|
|
|
79
82
|
id="openbmb/MiniCPM5-1B",
|
|
80
83
|
display_name="MiniCPM5 1B",
|
|
81
84
|
params="1.2B dense (Llama arch)",
|
|
85
|
+
params_b=1.2,
|
|
82
86
|
vocab_size=130_560,
|
|
83
87
|
algos=("sft", "grpo"),
|
|
84
88
|
min_vram_gb=12,
|
|
@@ -95,6 +99,7 @@ MODELS: dict[str, ModelInfo] = {
|
|
|
95
99
|
id="Qwen/Qwen3.5-0.8B",
|
|
96
100
|
display_name="Qwen3.5 0.8B",
|
|
97
101
|
params="0.9B (text-only fine-tune)",
|
|
102
|
+
params_b=0.9,
|
|
98
103
|
vocab_size=248_320,
|
|
99
104
|
algos=("sft", "grpo"),
|
|
100
105
|
min_vram_gb=12,
|
|
@@ -106,6 +111,7 @@ MODELS: dict[str, ModelInfo] = {
|
|
|
106
111
|
id="Qwen/Qwen3.5-2B",
|
|
107
112
|
display_name="Qwen3.5 2B",
|
|
108
113
|
params="2.3B (text-only fine-tune)",
|
|
114
|
+
params_b=2.3,
|
|
109
115
|
vocab_size=248_320,
|
|
110
116
|
algos=("sft", "grpo"),
|
|
111
117
|
min_vram_gb=16,
|
|
@@ -116,6 +122,7 @@ MODELS: dict[str, ModelInfo] = {
|
|
|
116
122
|
id="Qwen/Qwen3.5-4B",
|
|
117
123
|
display_name="Qwen3.5 4B",
|
|
118
124
|
params="4.7B (text-only fine-tune)",
|
|
125
|
+
params_b=4.7,
|
|
119
126
|
vocab_size=248_320,
|
|
120
127
|
algos=("sft", "grpo"),
|
|
121
128
|
min_vram_gb=32,
|
|
@@ -128,6 +135,7 @@ MODELS: dict[str, ModelInfo] = {
|
|
|
128
135
|
id="Qwen/Qwen3.5-9B",
|
|
129
136
|
display_name="Qwen3.5 9B",
|
|
130
137
|
params="9.7B (text-only fine-tune)",
|
|
138
|
+
params_b=9.7,
|
|
131
139
|
vocab_size=248_320,
|
|
132
140
|
algos=("sft", "grpo"),
|
|
133
141
|
min_vram_gb=16,
|
|
@@ -137,6 +137,11 @@ def main(argv: list[str] | None = None) -> int:
|
|
|
137
137
|
help="override a config value; repeatable",
|
|
138
138
|
)
|
|
139
139
|
train.add_argument("--dry-run", action="store_true")
|
|
140
|
+
train.add_argument(
|
|
141
|
+
"--cost",
|
|
142
|
+
action="store_true",
|
|
143
|
+
help="print the pre-flight USD cost for the config and exit (no submit)",
|
|
144
|
+
)
|
|
140
145
|
train.add_argument(
|
|
141
146
|
"--background",
|
|
142
147
|
action="store_true",
|
|
@@ -26,6 +26,7 @@ from flash.client import (
|
|
|
26
26
|
)
|
|
27
27
|
from flash.client.config import load_credentials
|
|
28
28
|
from flash.client.specs import spec_payload
|
|
29
|
+
from flash.cost.spec import runconfig_from_spec
|
|
29
30
|
from flash.runner import TERMINAL_STATES, new_run_id
|
|
30
31
|
from flash.schema import ConfigError, spec_from_file
|
|
31
32
|
|
|
@@ -262,12 +263,30 @@ def cmd_env_list(args) -> int:
|
|
|
262
263
|
return 0
|
|
263
264
|
|
|
264
265
|
|
|
266
|
+
def _cmd_train_cost(args) -> int:
|
|
267
|
+
"""`flash train --cost`: print the pre-flight USD cost for the config and exit (no submit).
|
|
268
|
+
|
|
269
|
+
Catalog-only and deterministic; an uncapped SFT run loads the env to count its train split."""
|
|
270
|
+
from flash.cost import estimate_cost
|
|
271
|
+
|
|
272
|
+
spec = spec_from_file(
|
|
273
|
+
args.config,
|
|
274
|
+
run_id=None,
|
|
275
|
+
overrides=args.overrides,
|
|
276
|
+
extra_configs=args.extra_configs,
|
|
277
|
+
)
|
|
278
|
+
print(estimate_cost(runconfig_from_spec(spec)).breakdown())
|
|
279
|
+
return 0
|
|
280
|
+
|
|
281
|
+
|
|
265
282
|
def cmd_train(args) -> int:
|
|
283
|
+
if getattr(args, "cost", False):
|
|
284
|
+
return _cmd_train_cost(args)
|
|
266
285
|
spec = spec_from_file(
|
|
267
286
|
args.config,
|
|
268
287
|
run_id=new_run_id() if args.dry_run else None,
|
|
269
|
-
overrides=
|
|
270
|
-
extra_configs=
|
|
288
|
+
overrides=args.overrides,
|
|
289
|
+
extra_configs=args.extra_configs,
|
|
271
290
|
)
|
|
272
291
|
if args.dry_run:
|
|
273
292
|
# Fully local: validate the id-based config without credentials, a server, or a GPU.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Flash training-cost estimator: a deterministic, equation-based pre-flight estimate
|
|
2
|
+
(``estimate_cost``) of cost = wall-clock hours x market $/hr. No output multiplier."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from .analytical import estimate_cost
|
|
7
|
+
from .spec import estimate_for_spec, runconfig_from_spec
|
|
8
|
+
from .types import CostEstimate, RunConfig
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CostEstimate",
|
|
12
|
+
"RunConfig",
|
|
13
|
+
"estimate_cost",
|
|
14
|
+
"estimate_for_spec",
|
|
15
|
+
"runconfig_from_spec",
|
|
16
|
+
]
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""The analytical cost model: total = wall-clock hours x GPU $/hr, where wall = cold-start
|
|
2
|
+
setup + steps x per-step time (a FLOPs/MFU estimate). GRPO splits each step into a vLLM
|
|
3
|
+
rollout + reward grading + policy/reference update."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
from flash.providers.allocator import required_vram_gb, vram_headroom
|
|
10
|
+
|
|
11
|
+
from .facts import (
|
|
12
|
+
download_weight_gb,
|
|
13
|
+
gpu_tflops,
|
|
14
|
+
gpu_vram_gb,
|
|
15
|
+
model_quant,
|
|
16
|
+
pick_gpu,
|
|
17
|
+
realized_hourly_usd,
|
|
18
|
+
reward_seconds_per_completion,
|
|
19
|
+
total_params_b,
|
|
20
|
+
)
|
|
21
|
+
from .types import CostEstimate, RunConfig
|
|
22
|
+
|
|
23
|
+
# FLOPs per token per active-parameter.
|
|
24
|
+
SFT_FLOPS_PER_TOKEN_PER_PARAM = 6.0 # forward (2) + backward (4)
|
|
25
|
+
GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM = 2.0 # autoregressive rollout forward
|
|
26
|
+
GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM = 8.0 # policy fwd+bwd (6) + frozen-ref fwd (2)
|
|
27
|
+
|
|
28
|
+
# Model-FLOPs utilization (fraction of peak sustained), calibrated against real RunPod/Vast
|
|
29
|
+
# wall clock. LoRA + small batches sit well below dense-pretraining MFU.
|
|
30
|
+
MFU_TRAIN = 0.35 # GRPO policy/reference update
|
|
31
|
+
MFU_SFT_TRAIN = 0.25 # SFT fwd/bwd (smaller effective batch, long sequences)
|
|
32
|
+
MFU_DECODE = 0.12 # batched vLLM rollout (decode is memory-bandwidth-bound)
|
|
33
|
+
|
|
34
|
+
# Reward grading is CONCURRENT: a step's completions score in parallel slots, so the reward
|
|
35
|
+
# wall is ceil(completions / slots) waves x latency, not completions x latency.
|
|
36
|
+
REWARD_CONCURRENCY = 16.0
|
|
37
|
+
|
|
38
|
+
# Cold-start overhead (seconds): container boot + deps + model download (+ vLLM init for GRPO).
|
|
39
|
+
WORKER_BOOT_S = 180.0
|
|
40
|
+
DEPS_INSTALL_S = 120.0
|
|
41
|
+
VLLM_INIT_S = 120.0
|
|
42
|
+
DOWNLOAD_RATE_GBPS = 0.4 # effective HF snapshot download (hf_transfer)
|
|
43
|
+
|
|
44
|
+
DEFAULT_WALL_CAP_S = 24 * 3600 # spec gpu.max_wall_seconds default
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _fmt_duration(seconds: float) -> str:
|
|
48
|
+
"""Human duration for notes: seconds < 1m, minutes < 1h, else whole/1-decimal hours."""
|
|
49
|
+
if seconds < 60:
|
|
50
|
+
return f"{seconds:.0f}s"
|
|
51
|
+
if seconds < 3600:
|
|
52
|
+
return f"{seconds / 60:.0f}m"
|
|
53
|
+
hours = seconds / 3600
|
|
54
|
+
return f"{hours:.0f}h" if abs(hours - round(hours)) < 1e-9 else f"{hours:.1f}h"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def setup_seconds(config: RunConfig) -> float:
|
|
58
|
+
"""Cold-start wall time billed before the first optimizer step."""
|
|
59
|
+
s = WORKER_BOOT_S + DEPS_INSTALL_S + download_weight_gb(config.model_id) / DOWNLOAD_RATE_GBPS
|
|
60
|
+
if config.is_grpo:
|
|
61
|
+
s += VLLM_INIT_S
|
|
62
|
+
return s
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def seconds_per_step(config: RunConfig, gpu: str) -> float:
|
|
66
|
+
"""Steady-state wall time for one optimizer step on ``gpu``."""
|
|
67
|
+
n = config.normalized()
|
|
68
|
+
params = total_params_b(n.model_id) * 1e9
|
|
69
|
+
peak = gpu_tflops(gpu) * 1e12 # FLOP/s
|
|
70
|
+
|
|
71
|
+
if not n.is_grpo:
|
|
72
|
+
flops = SFT_FLOPS_PER_TOKEN_PER_PARAM * params * (n.batch_size * n.seq_len)
|
|
73
|
+
return flops / (peak * MFU_SFT_TRAIN)
|
|
74
|
+
|
|
75
|
+
# GRPO step = rollout (G completions/prompt) + concurrent reward grading + policy/ref update.
|
|
76
|
+
completions = n.batch_size * n.group_size
|
|
77
|
+
gen_tokens = completions * n.completion_len
|
|
78
|
+
gen_s = (GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_DECODE)
|
|
79
|
+
update_s = (GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_TRAIN)
|
|
80
|
+
latency = reward_seconds_per_completion(n.reward_seconds_per_completion)
|
|
81
|
+
reward_s = math.ceil(completions / REWARD_CONCURRENCY) * latency # ceil: a partial wave still costs one latency
|
|
82
|
+
return gen_s + reward_s + update_s
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def select_gpu(config: RunConfig) -> tuple[str, int]:
|
|
86
|
+
"""(chosen GPU class, required VRAM GB): the cheapest fitting class, like the allocator
|
|
87
|
+
(no pin, no validation gate). Catalog sizing is offline/deterministic."""
|
|
88
|
+
total_params_b(config.model_id) # catalog-only: reject a non-catalog model before any (HF) sizing
|
|
89
|
+
need = required_vram_gb(
|
|
90
|
+
config.model_id,
|
|
91
|
+
config.method,
|
|
92
|
+
train=config.train_knobs(),
|
|
93
|
+
thinking=config.thinking,
|
|
94
|
+
)
|
|
95
|
+
gpu = pick_gpu(need, provider=config.provider)
|
|
96
|
+
return gpu, need
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _notes(config: RunConfig, raw_train_s: float, wall_capped: bool, cap_s: float) -> tuple[str, ...]:
|
|
100
|
+
n = config.normalized()
|
|
101
|
+
notes: list[str] = []
|
|
102
|
+
if (quant := model_quant(n.model_id)) != "bf16":
|
|
103
|
+
notes.append(f"{quant}: smaller VRAM footprint -> cheaper GPU class fits")
|
|
104
|
+
if n.is_grpo:
|
|
105
|
+
comps = n.batch_size * n.group_size
|
|
106
|
+
rsec = reward_seconds_per_completion(n.reward_seconds_per_completion)
|
|
107
|
+
notes.append(
|
|
108
|
+
f"GRPO step = vLLM rollout of {n.batch_size}x{n.group_size}={comps} completions "
|
|
109
|
+
f"@ {n.completion_len} tok + reward ({rsec:.2f}s/completion"
|
|
110
|
+
+ (f", env {n.environment}" if n.environment else "")
|
|
111
|
+
+ ") + policy+reference update"
|
|
112
|
+
)
|
|
113
|
+
notes.append(f"GPU sized with {vram_headroom() - 1:.0%} VRAM headroom; market (spot/queue) $/hr")
|
|
114
|
+
if wall_capped:
|
|
115
|
+
per_seed = "" if config.setup_repeats == 1 else "per-seed "
|
|
116
|
+
notes.append(
|
|
117
|
+
f"training clamped to fit the {_fmt_duration(cap_s)} {per_seed}wall cap "
|
|
118
|
+
f"(after setup; uncapped: {_fmt_duration(raw_train_s)})"
|
|
119
|
+
)
|
|
120
|
+
return tuple(notes)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def estimate_cost(config: RunConfig, *, wall_cap_s: float = DEFAULT_WALL_CAP_S) -> CostEstimate:
|
|
124
|
+
"""Deterministic pre-flight cost estimate -- the analytical ground truth."""
|
|
125
|
+
gpu, need = select_gpu(config)
|
|
126
|
+
hourly = realized_hourly_usd(gpu)
|
|
127
|
+
# Mirror the runner's max(60, max_wall_seconds) floor so a sub-60s cap isn't underpriced.
|
|
128
|
+
cap_s = max(60.0, float(config.max_wall_seconds)) if config.max_wall_seconds is not None else wall_cap_s
|
|
129
|
+
|
|
130
|
+
# Each seed is its own job (own cold start + own wall cap): price one seed, clamp, x seeds.
|
|
131
|
+
seeds = config.setup_repeats
|
|
132
|
+
setup_per_seed = setup_seconds(config)
|
|
133
|
+
sps = seconds_per_step(config, gpu)
|
|
134
|
+
raw_train_per_seed = (config.steps / seeds) * sps
|
|
135
|
+
|
|
136
|
+
# The cap is on total per-seed wall; setup is billed too, so clamp training to fit it.
|
|
137
|
+
wall_capped = (setup_per_seed + raw_train_per_seed) > cap_s
|
|
138
|
+
setup_per_seed = min(setup_per_seed, cap_s)
|
|
139
|
+
train_per_seed = max(0.0, cap_s - setup_per_seed) if wall_capped else raw_train_per_seed
|
|
140
|
+
|
|
141
|
+
setup, train = setup_per_seed * seeds, train_per_seed * seeds
|
|
142
|
+
wall = setup + train
|
|
143
|
+
|
|
144
|
+
return CostEstimate(
|
|
145
|
+
model_id=config.model_id,
|
|
146
|
+
method=config.method,
|
|
147
|
+
steps=config.steps,
|
|
148
|
+
gpu=gpu,
|
|
149
|
+
provider=config.provider,
|
|
150
|
+
gpu_vram_gb=gpu_vram_gb(gpu),
|
|
151
|
+
required_vram_gb=need,
|
|
152
|
+
gpu_hourly_usd=hourly,
|
|
153
|
+
setup_seconds=setup,
|
|
154
|
+
seconds_per_step=sps,
|
|
155
|
+
train_seconds=train,
|
|
156
|
+
wall_clock_seconds=wall,
|
|
157
|
+
wall_capped=wall_capped,
|
|
158
|
+
total_usd=wall / 3600.0 * hourly,
|
|
159
|
+
notes=_notes(config, raw_train_per_seed, wall_capped, cap_s),
|
|
160
|
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Static lookup facts for the cost model: GPU price/VRAM/compute + cheapest-fit
|
|
2
|
+
selection, model size/quant, and reward-grader latency. Pure tables + accessors."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from flash.catalog import MODELS
|
|
7
|
+
from flash.providers.base import GPU_INFO, GpuClass, providers_for
|
|
8
|
+
|
|
9
|
+
# ===== GPU facts =====
|
|
10
|
+
GPU_COMPUTE_TFLOPS: dict[str, float] = {
|
|
11
|
+
"RTX A4000": 77.0,
|
|
12
|
+
"RTX 2000 Ada": 89.0,
|
|
13
|
+
"RTX A4500": 89.0,
|
|
14
|
+
"RTX 4000 Ada": 90.0,
|
|
15
|
+
"RTX A5000": 89.0,
|
|
16
|
+
"RTX 3090": 71.0,
|
|
17
|
+
"L4": 60.0,
|
|
18
|
+
"RTX Pro 4000": 95.0,
|
|
19
|
+
"RTX 4090": 165.0,
|
|
20
|
+
"RTX 5090": 210.0,
|
|
21
|
+
"RTX A6000": 155.0,
|
|
22
|
+
"A40": 150.0,
|
|
23
|
+
"RTX 6000 Ada": 182.0,
|
|
24
|
+
"L40S": 181.0,
|
|
25
|
+
"A100 SXM 40GB": 312.0,
|
|
26
|
+
"A100 PCIe": 312.0,
|
|
27
|
+
"A100 SXM": 312.0,
|
|
28
|
+
"H100 NVL": 835.0,
|
|
29
|
+
"H100": 990.0,
|
|
30
|
+
"RTX Pro 6000": 250.0,
|
|
31
|
+
"RTX Pro 6000 WK": 250.0,
|
|
32
|
+
}
|
|
33
|
+
_DEFAULT_TFLOPS = 100.0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def gpu_tflops(name: str) -> float:
|
|
37
|
+
"""Peak bf16 tensor TFLOPS for a managed GPU class."""
|
|
38
|
+
return GPU_COMPUTE_TFLOPS.get(name, _DEFAULT_TFLOPS)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def gpu_hourly_usd(name: str) -> float:
|
|
42
|
+
"""Static fallback (on-demand list) $/hr for a class."""
|
|
43
|
+
info = GPU_INFO.get(name)
|
|
44
|
+
if info is None:
|
|
45
|
+
raise KeyError(f"unknown GPU class {name!r}")
|
|
46
|
+
return info.hourly_usd
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Realized (spot/queue) $/hr per class -- the discount below on-demand list (RTX 5090 lists
|
|
50
|
+
# $0.99, bills ~$0.87). ``realized_hourly_usd`` CLAMPS to the list price so it can never
|
|
51
|
+
# over-quote; a class with no clean observed rate falls back to list.
|
|
52
|
+
REALIZED_HOURLY_USD: dict[str, float] = {
|
|
53
|
+
"RTX 3090": 0.239,
|
|
54
|
+
"RTX 4090": 0.426,
|
|
55
|
+
"RTX 5090": 0.871,
|
|
56
|
+
"RTX A5000": 0.304,
|
|
57
|
+
"RTX 6000 Ada": 0.601,
|
|
58
|
+
"A100 PCIe": 1.035,
|
|
59
|
+
"A100 SXM": 1.133,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def realized_hourly_usd(name: str) -> float:
|
|
64
|
+
"""Market (spot/queue) $/hr, clamped to the list price; the list price when not observed."""
|
|
65
|
+
list_price = gpu_hourly_usd(name)
|
|
66
|
+
return min(REALIZED_HOURLY_USD.get(name, list_price), list_price)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def gpu_vram_gb(name: str) -> int:
|
|
70
|
+
info = GPU_INFO.get(name)
|
|
71
|
+
if info is None:
|
|
72
|
+
raise KeyError(f"unknown GPU class {name!r}")
|
|
73
|
+
return info.vram_gb
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def pick_gpu(required_vram_gb: int, *, provider: str | None = None) -> str:
|
|
77
|
+
"""Cheapest GPU class that fits ``required_vram_gb``, ranked by the REALIZED (market) $/hr it
|
|
78
|
+
is BILLED at (ties: vram, name) -- so selection is consistent with the bill and approximates
|
|
79
|
+
the allocator, which provisions the cheapest live offer. No pin and no validation gate -- every
|
|
80
|
+
fitting class is eligible. ``provider`` restricts candidates to what it can provision.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _selectable(g: GpuClass) -> bool:
|
|
84
|
+
return provider in (None, "auto") or provider in providers_for(g.name)
|
|
85
|
+
|
|
86
|
+
candidates = [g for g in GPU_INFO.values() if g.vram_gb >= required_vram_gb and _selectable(g)]
|
|
87
|
+
if not candidates:
|
|
88
|
+
raise ValueError(f"no GPU class fits >= {required_vram_gb} GB")
|
|
89
|
+
best = min(candidates, key=lambda g: (realized_hourly_usd(g.name), g.vram_gb, g.name))
|
|
90
|
+
return best.name
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ===== Model-size facts (catalog-only; five dense text models, no MoE/open-model sizing) =====
|
|
94
|
+
def total_params_b(model_id: str) -> float:
|
|
95
|
+
"""Total parameter count (billions) for a catalog model -- the curated ``params_b`` stat."""
|
|
96
|
+
info = MODELS.get(model_id)
|
|
97
|
+
if info is None:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"unknown model {model_id!r}; cost estimation supports catalog models only "
|
|
100
|
+
f"({', '.join(MODELS)})"
|
|
101
|
+
)
|
|
102
|
+
return info.params_b
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def model_quant(model_id: str) -> str:
|
|
106
|
+
"""Quantization of the catalog entry (``"bf16"`` or ``"4bit-qlora"``); bf16 default."""
|
|
107
|
+
info = MODELS.get(model_id)
|
|
108
|
+
return (info.quant or "bf16") if info is not None else "bf16"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def download_weight_gb(model_id: str) -> float:
|
|
112
|
+
"""GB pulled from the HF hub at cold start (full bf16 checkpoint, 2 bytes/param)."""
|
|
113
|
+
return total_params_b(model_id) * 2.0
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ===== Reward-grader latency (GRPO) =====
|
|
117
|
+
# A single average grader latency (s/completion) for every env. Graders span ~0.01s (regex/math)
|
|
118
|
+
# to ~3s (LLM judge/code); ~1s is a middle-of-the-road default (a run can override it).
|
|
119
|
+
AVG_REWARD_SECONDS_PER_COMPLETION = 1.0
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def reward_seconds_per_completion(override: float | None = None) -> float:
|
|
123
|
+
"""Per-completion reward latency (s): the explicit override, else the single average."""
|
|
124
|
+
if override is not None:
|
|
125
|
+
return max(0.0, override)
|
|
126
|
+
return AVG_REWARD_SECONDS_PER_COMPLETION
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Map a parsed training ``JobSpec`` to a cost ``RunConfig`` / step count / estimate.
|
|
2
|
+
|
|
3
|
+
Shared by ``flash train --cost`` and the control plane's submit-time charge, so both price the
|
|
4
|
+
same work on the same catalog-only, cheapest-fit basis."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from flash.cost.analytical import estimate_cost
|
|
9
|
+
from flash.cost.types import CostEstimate, RunConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def count_env_examples(env_id: str, params: dict | None = None) -> int | None:
|
|
13
|
+
"""Training rows in ``env_id``'s dataset (the worker's train split), or ``None`` if it can't
|
|
14
|
+
be loaded. Best-effort -- prices an uncapped SFT run on the real dataset size, not a guess."""
|
|
15
|
+
if not env_id:
|
|
16
|
+
return None
|
|
17
|
+
try:
|
|
18
|
+
from flash.envs import load_environment
|
|
19
|
+
|
|
20
|
+
rows = load_environment(env_id, params or {}).dataset("train")
|
|
21
|
+
except Exception:
|
|
22
|
+
return None
|
|
23
|
+
return len(rows) if rows is not None else None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def spec_steps(spec) -> int:
|
|
27
|
+
"""Per-seed optimizer steps implied by a train spec (mirrors the worker). GRPO: ``train.steps``
|
|
28
|
+
(else recipe default). SFT: ``epochs x ceil(num_examples / realized_batch)`` capped by
|
|
29
|
+
``max_steps``, where ``num_examples`` is ``max_examples`` if pinned else the real env size."""
|
|
30
|
+
from flash.engine.recipe import RECIPE
|
|
31
|
+
from flash.engine.vram import sft_realized_batch
|
|
32
|
+
|
|
33
|
+
t = spec.train
|
|
34
|
+
if spec.algorithm == "grpo":
|
|
35
|
+
if t.steps is not None:
|
|
36
|
+
return max(1, int(t.steps))
|
|
37
|
+
return RECIPE.rl.num_steps
|
|
38
|
+
# --- SFT ---
|
|
39
|
+
cap = int(t.max_steps) if t.max_steps else 0 # SFT-only optimizer-step cap (0 = uncapped)
|
|
40
|
+
epochs = int(t.epochs) if t.epochs is not None else RECIPE.sft.num_epochs
|
|
41
|
+
requested_batch = int(t.batch_size) if t.batch_size is not None else RECIPE.sft.effective_batch
|
|
42
|
+
batch = sft_realized_batch(requested_batch)
|
|
43
|
+
# max_examples is a CAP; 0 (like None) means "no cap" (worker trains the full dataset), so
|
|
44
|
+
# don't let max_examples=0 price a single step.
|
|
45
|
+
pinned_examples = int(t.max_examples) if t.max_examples else 0
|
|
46
|
+
if pinned_examples > 0:
|
|
47
|
+
examples = pinned_examples
|
|
48
|
+
else:
|
|
49
|
+
# No cap: the worker trains the FULL env dataset, so price its real size.
|
|
50
|
+
examples = count_env_examples(spec.environment.id, spec.environment.params)
|
|
51
|
+
if examples is None:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"could not load environment {spec.environment.id!r} to count its training "
|
|
54
|
+
f"examples for the cost; install it (`slm env install {spec.environment.id}`) "
|
|
55
|
+
"or pin [train].max_examples"
|
|
56
|
+
)
|
|
57
|
+
n = max(1, -(-examples // batch) * epochs) # epochs x ceil(examples / realized_batch)
|
|
58
|
+
return min(n, cap) if cap > 0 else n
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def runconfig_from_spec(spec) -> RunConfig:
|
|
62
|
+
"""Map a parsed ``JobSpec`` to a cost ``RunConfig``. Each seed is its own job that re-pays the
|
|
63
|
+
cold start, so steps and setup repeats scale by the seed count. The estimate doesn't pin a
|
|
64
|
+
GPU -- it does its own cheapest-fit (provider="auto")."""
|
|
65
|
+
t, g = spec.train, spec.gpu
|
|
66
|
+
is_grpo = spec.algorithm == "grpo"
|
|
67
|
+
seeds = max(1, len(t.seeds or (0,)))
|
|
68
|
+
return RunConfig(
|
|
69
|
+
model_id=spec.model,
|
|
70
|
+
method=spec.algorithm,
|
|
71
|
+
steps=spec_steps(spec) * seeds,
|
|
72
|
+
setup_repeats=seeds,
|
|
73
|
+
seq_len=t.max_length,
|
|
74
|
+
completion_len=t.max_tokens if is_grpo else None,
|
|
75
|
+
batch_size=t.batch_size,
|
|
76
|
+
group_size=t.group_size if is_grpo else None,
|
|
77
|
+
lora_rank=t.lora_rank,
|
|
78
|
+
thinking=spec.thinking,
|
|
79
|
+
provider="auto",
|
|
80
|
+
max_wall_seconds=g.max_wall_seconds,
|
|
81
|
+
environment=spec.environment.id or None,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def estimate_for_spec(spec) -> CostEstimate:
|
|
86
|
+
"""The pre-flight ``CostEstimate`` for a parsed training ``JobSpec``."""
|
|
87
|
+
return estimate_cost(runconfig_from_spec(spec))
|