freesolo-flash 0.2.3__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/PKG-INFO +3 -1
  2. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/__init__.py +1 -1
  3. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/catalog.py +8 -0
  4. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/cli/main/__init__.py +5 -0
  5. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/cli/main/commands.py +21 -2
  6. freesolo_flash-0.2.4/flash/cost/__init__.py +16 -0
  7. freesolo_flash-0.2.4/flash/cost/analytical.py +160 -0
  8. freesolo_flash-0.2.4/flash/cost/facts.py +126 -0
  9. freesolo_flash-0.2.4/flash/cost/spec.py +87 -0
  10. freesolo_flash-0.2.4/flash/cost/types.py +158 -0
  11. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/vram.py +19 -0
  12. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/worker/__init__.py +5 -7
  13. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/base.py +10 -5
  14. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/schema/__init__.py +2 -0
  15. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/server/app.py +51 -3
  16. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/server/auth.py +7 -2
  17. freesolo_flash-0.2.4/flash/server/billing.py +128 -0
  18. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/pyproject.toml +15 -2
  19. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_allocator.py +1 -1
  20. freesolo_flash-0.2.4/tests/test_cli_estimate.py +224 -0
  21. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_client_server_integration.py +8 -0
  22. freesolo_flash-0.2.4/tests/test_cost_analytical.py +224 -0
  23. freesolo_flash-0.2.4/tests/test_cost_equation.py +47 -0
  24. freesolo_flash-0.2.4/tests/test_cost_estimate.py +77 -0
  25. freesolo_flash-0.2.4/tests/test_cost_hardware.py +131 -0
  26. freesolo_flash-0.2.4/tests/test_cost_models.py +36 -0
  27. freesolo_flash-0.2.4/tests/test_cost_rewards.py +65 -0
  28. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_open_model_policy.py +4 -4
  29. freesolo_flash-0.2.4/tests/test_server_billing.py +392 -0
  30. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/uv.lock +44 -40
  31. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.dockerignore +0 -0
  32. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.env.example +0 -0
  33. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.github/workflows/ci.yml +0 -0
  34. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.github/workflows/main-source-guard.yml +0 -0
  35. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.github/workflows/publish-image.yml +0 -0
  36. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.github/workflows/publish.yml +0 -0
  37. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.github/workflows/worker-image.yml +0 -0
  38. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/.gitignore +0 -0
  39. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/Dockerfile +0 -0
  40. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/Dockerfile.worker +0 -0
  41. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/LICENSE +0 -0
  42. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/README.md +0 -0
  43. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/docker/make_rp_handler.py +0 -0
  44. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/_fileio.py +0 -0
  45. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/_logging.py +0 -0
  46. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/cli/__init__.py +0 -0
  47. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/cli/main/__main__.py +0 -0
  48. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/cli/main/envpush.py +0 -0
  49. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/client/__init__.py +0 -0
  50. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/client/config.py +0 -0
  51. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/client/http.py +0 -0
  52. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/client/specs.py +0 -0
  53. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/__init__.py +0 -0
  54. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/accounting.py +0 -0
  55. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/chalk_kernels.py +0 -0
  56. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/multiturn_rollout.py +0 -0
  57. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/recipe.py +0 -0
  58. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/worker/__main__.py +0 -0
  59. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/worker/lora.py +0 -0
  60. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/engine/worker/perf.py +0 -0
  61. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/envs/__init__.py +0 -0
  62. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/envs/adapter/__init__.py +0 -0
  63. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/envs/adapter/rubric.py +0 -0
  64. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/envs/base.py +0 -0
  65. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/envs/registry.py +0 -0
  66. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/mcp/__init__.py +0 -0
  67. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/mcp/server.py +0 -0
  68. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/__init__.py +0 -0
  69. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/_auth.py +0 -0
  70. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/_http.py +0 -0
  71. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/_poll.py +0 -0
  72. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/allocator.py +0 -0
  73. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/preflight.py +0 -0
  74. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/__init__.py +0 -0
  75. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/api.py +0 -0
  76. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/auth.py +0 -0
  77. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/gpus.py +0 -0
  78. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/jobs.py +0 -0
  79. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/preflight.py +0 -0
  80. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/pricing.py +0 -0
  81. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/train/__init__.py +0 -0
  82. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/train/deps.py +0 -0
  83. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/runpod/train/endpoints.py +0 -0
  84. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/__init__.py +0 -0
  85. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/_bootstrap.py +0 -0
  86. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/api.py +0 -0
  87. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/auth.py +0 -0
  88. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/gpus.py +0 -0
  89. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/jobs/__init__.py +0 -0
  90. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/jobs/builders.py +0 -0
  91. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/preflight.py +0 -0
  92. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/pricing.py +0 -0
  93. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/providers/vast/train.py +0 -0
  94. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/py.typed +0 -0
  95. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/runner/__init__.py +0 -0
  96. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/runner/deploy.py +0 -0
  97. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/runner/lifecycle.py +0 -0
  98. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/schema/fields.py +0 -0
  99. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/serve/__init__.py +0 -0
  100. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/serve/deploy.py +0 -0
  101. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/server/__init__.py +0 -0
  102. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/server/__main__.py +0 -0
  103. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/server/db.py +0 -0
  104. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/server/envs.py +0 -0
  105. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/flash/spec.py +0 -0
  106. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/__init__.py +0 -0
  107. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/_helpers/__init__.py +0 -0
  108. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/_helpers/runner.py +0 -0
  109. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/_helpers/specs.py +0 -0
  110. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/_helpers/vast.py +0 -0
  111. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/conftest.py +0 -0
  112. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/fixtures/math_eval.jsonl +0 -0
  113. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/fixtures/math_train.jsonl +0 -0
  114. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/live/__init__.py +0 -0
  115. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/live/conftest.py +0 -0
  116. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/live/test_runpod_live.py +0 -0
  117. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/live/test_vast_live.py +0 -0
  118. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_agent_slm_cli_contract.py +0 -0
  119. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_algorithms.py +0 -0
  120. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_backend_jobspec_contract.py +0 -0
  121. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_cancel_remote.py +0 -0
  122. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_catalog_consistency.py +0 -0
  123. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_chalk_kernels.py +0 -0
  124. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_cli_commands.py +0 -0
  125. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_cli_errors.py +0 -0
  126. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_cli_managed.py +0 -0
  127. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_client.py +0 -0
  128. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_config_overrides.py +0 -0
  129. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_disk_gb.py +0 -0
  130. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_endpoint_name.py +0 -0
  131. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_env_install.py +0 -0
  132. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_env_publish.py +0 -0
  133. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_env_push.py +0 -0
  134. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_envs_coverage.py +0 -0
  135. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_flash_mvp.py +0 -0
  136. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_flash_worker.py +0 -0
  137. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_gpus.py +0 -0
  138. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_grpo_params.py +0 -0
  139. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_jobs.py +0 -0
  140. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_logging.py +0 -0
  141. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_login_perms.py +0 -0
  142. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_metrics_schema_agent_contract.py +0 -0
  143. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_multiturn_rollout.py +0 -0
  144. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_orchestrator_flash.py +0 -0
  145. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_preflight.py +0 -0
  146. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_pricing_cache.py +0 -0
  147. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_provider_routing.py +0 -0
  148. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_providers_symmetry.py +0 -0
  149. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_runmgmt.py +0 -0
  150. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_runpod_api_delete.py +0 -0
  151. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_serve.py +0 -0
  152. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_serve_modes.py +0 -0
  153. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_server_api.py +0 -0
  154. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_spec_and_validation.py +0 -0
  155. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_thinking_config.py +0 -0
  156. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_vast_api.py +0 -0
  157. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_vast_offers.py +0 -0
  158. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_vast_runner.py +0 -0
  159. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_verifiers.py +0 -0
  160. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_version.py +0 -0
  161. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_wandb_naming.py +0 -0
  162. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_worker_dryrun.py +0 -0
  163. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_worker_hardexit.py +0 -0
  164. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_worker_stack.py +0 -0
  165. {freesolo_flash-0.2.3 → freesolo_flash-0.2.4}/tests/test_worker_thinking.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: freesolo-flash
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: Flash — managed LoRA post-training (SFT/GRPO) for verifiers environments, driven by the `flash` CLI
5
5
  Project-URL: Homepage, https://github.com/freesolo-co/flash
6
6
  Project-URL: Repository, https://github.com/freesolo-co/flash
@@ -27,12 +27,14 @@ Requires-Dist: trl<1.7,>=1.6; extra == 'gpu'
27
27
  Requires-Dist: verifiers>=0.1.10; extra == 'gpu'
28
28
  Requires-Dist: vllm==0.19.1; extra == 'gpu'
29
29
  Provides-Extra: server
30
+ Requires-Dist: datasets>=2.19; extra == 'server'
30
31
  Requires-Dist: fastapi; extra == 'server'
31
32
  Requires-Dist: httpx>=0.27; extra == 'server'
32
33
  Requires-Dist: huggingface-hub>=0.34; extra == 'server'
33
34
  Requires-Dist: prime>=0.6.3; extra == 'server'
34
35
  Requires-Dist: runpod-flash; extra == 'server'
35
36
  Requires-Dist: uvicorn; extra == 'server'
37
+ Requires-Dist: verifiers>=0.1.10; extra == 'server'
36
38
  Description-Content-Type: text/markdown
37
39
 
38
40
  # Flash
@@ -8,4 +8,4 @@ GPU (RunPod or Vast.ai) behind the scenes.
8
8
 
9
9
  __all__ = ["__version__"]
10
10
 
11
- __version__ = "0.2.3"
11
+ __version__ = "0.2.4"
@@ -64,6 +64,9 @@ class ModelInfo:
64
64
  # the raw tokenizer count). Drives the GRPO fp32-logits memory term and the per-device
65
65
  # completion cap. Curated per model below; defaults to the open-model fallback.
66
66
  vocab_size: int = _DEFAULT_VOCAB_SIZE
67
+ # Total parameters in billions — the numeric model size the cost estimator reads directly
68
+ # (no parsing of the ``params`` display string). Curated per catalog model below.
69
+ params_b: float = 0.0
67
70
 
68
71
  def to_dict(self) -> dict[str, Any]:
69
72
  return asdict(self)
@@ -79,6 +82,7 @@ MODELS: dict[str, ModelInfo] = {
79
82
  id="openbmb/MiniCPM5-1B",
80
83
  display_name="MiniCPM5 1B",
81
84
  params="1.2B dense (Llama arch)",
85
+ params_b=1.2,
82
86
  vocab_size=130_560,
83
87
  algos=("sft", "grpo"),
84
88
  min_vram_gb=12,
@@ -95,6 +99,7 @@ MODELS: dict[str, ModelInfo] = {
95
99
  id="Qwen/Qwen3.5-0.8B",
96
100
  display_name="Qwen3.5 0.8B",
97
101
  params="0.9B (text-only fine-tune)",
102
+ params_b=0.9,
98
103
  vocab_size=248_320,
99
104
  algos=("sft", "grpo"),
100
105
  min_vram_gb=12,
@@ -106,6 +111,7 @@ MODELS: dict[str, ModelInfo] = {
106
111
  id="Qwen/Qwen3.5-2B",
107
112
  display_name="Qwen3.5 2B",
108
113
  params="2.3B (text-only fine-tune)",
114
+ params_b=2.3,
109
115
  vocab_size=248_320,
110
116
  algos=("sft", "grpo"),
111
117
  min_vram_gb=16,
@@ -116,6 +122,7 @@ MODELS: dict[str, ModelInfo] = {
116
122
  id="Qwen/Qwen3.5-4B",
117
123
  display_name="Qwen3.5 4B",
118
124
  params="4.7B (text-only fine-tune)",
125
+ params_b=4.7,
119
126
  vocab_size=248_320,
120
127
  algos=("sft", "grpo"),
121
128
  min_vram_gb=32,
@@ -128,6 +135,7 @@ MODELS: dict[str, ModelInfo] = {
128
135
  id="Qwen/Qwen3.5-9B",
129
136
  display_name="Qwen3.5 9B",
130
137
  params="9.7B (text-only fine-tune)",
138
+ params_b=9.7,
131
139
  vocab_size=248_320,
132
140
  algos=("sft", "grpo"),
133
141
  min_vram_gb=16,
@@ -137,6 +137,11 @@ def main(argv: list[str] | None = None) -> int:
137
137
  help="override a config value; repeatable",
138
138
  )
139
139
  train.add_argument("--dry-run", action="store_true")
140
+ train.add_argument(
141
+ "--cost",
142
+ action="store_true",
143
+ help="print the pre-flight USD cost for the config and exit (no submit)",
144
+ )
140
145
  train.add_argument(
141
146
  "--background",
142
147
  action="store_true",
@@ -26,6 +26,7 @@ from flash.client import (
26
26
  )
27
27
  from flash.client.config import load_credentials
28
28
  from flash.client.specs import spec_payload
29
+ from flash.cost.spec import runconfig_from_spec
29
30
  from flash.runner import TERMINAL_STATES, new_run_id
30
31
  from flash.schema import ConfigError, spec_from_file
31
32
 
@@ -262,12 +263,30 @@ def cmd_env_list(args) -> int:
262
263
  return 0
263
264
 
264
265
 
266
+ def _cmd_train_cost(args) -> int:
267
+ """`flash train --cost`: print the pre-flight USD cost for the config and exit (no submit).
268
+
269
+ Catalog-only and deterministic; an uncapped SFT run loads the env to count its train split."""
270
+ from flash.cost import estimate_cost
271
+
272
+ spec = spec_from_file(
273
+ args.config,
274
+ run_id=None,
275
+ overrides=args.overrides,
276
+ extra_configs=args.extra_configs,
277
+ )
278
+ print(estimate_cost(runconfig_from_spec(spec)).breakdown())
279
+ return 0
280
+
281
+
265
282
  def cmd_train(args) -> int:
283
+ if getattr(args, "cost", False):
284
+ return _cmd_train_cost(args)
266
285
  spec = spec_from_file(
267
286
  args.config,
268
287
  run_id=new_run_id() if args.dry_run else None,
269
- overrides=getattr(args, "overrides", None),
270
- extra_configs=getattr(args, "extra_configs", None),
288
+ overrides=args.overrides,
289
+ extra_configs=args.extra_configs,
271
290
  )
272
291
  if args.dry_run:
273
292
  # Fully local: validate the id-based config without credentials, a server, or a GPU.
@@ -0,0 +1,16 @@
1
+ """Flash training-cost estimator: a deterministic, equation-based pre-flight estimate
2
+ (``estimate_cost``) of cost = wall-clock hours x market $/hr. No output multiplier."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from .analytical import estimate_cost
7
+ from .spec import estimate_for_spec, runconfig_from_spec
8
+ from .types import CostEstimate, RunConfig
9
+
10
+ __all__ = [
11
+ "CostEstimate",
12
+ "RunConfig",
13
+ "estimate_cost",
14
+ "estimate_for_spec",
15
+ "runconfig_from_spec",
16
+ ]
@@ -0,0 +1,160 @@
1
+ """The analytical cost model: total = wall-clock hours x GPU $/hr, where wall = cold-start
2
+ setup + steps x per-step time (a FLOPs/MFU estimate). GRPO splits each step into a vLLM
3
+ rollout + reward grading + policy/reference update."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+
9
+ from flash.providers.allocator import required_vram_gb, vram_headroom
10
+
11
+ from .facts import (
12
+ download_weight_gb,
13
+ gpu_tflops,
14
+ gpu_vram_gb,
15
+ model_quant,
16
+ pick_gpu,
17
+ realized_hourly_usd,
18
+ reward_seconds_per_completion,
19
+ total_params_b,
20
+ )
21
+ from .types import CostEstimate, RunConfig
22
+
23
+ # FLOPs per token per active-parameter.
24
+ SFT_FLOPS_PER_TOKEN_PER_PARAM = 6.0 # forward (2) + backward (4)
25
+ GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM = 2.0 # autoregressive rollout forward
26
+ GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM = 8.0 # policy fwd+bwd (6) + frozen-ref fwd (2)
27
+
28
+ # Model-FLOPs utilization (fraction of peak sustained), calibrated against real RunPod/Vast
29
+ # wall clock. LoRA + small batches sit well below dense-pretraining MFU.
30
+ MFU_TRAIN = 0.35 # GRPO policy/reference update
31
+ MFU_SFT_TRAIN = 0.25 # SFT fwd/bwd (smaller effective batch, long sequences)
32
+ MFU_DECODE = 0.12 # batched vLLM rollout (decode is memory-bandwidth-bound)
33
+
34
+ # Reward grading is CONCURRENT: a step's completions score in parallel slots, so the reward
35
+ # wall is ceil(completions / slots) waves x latency, not completions x latency.
36
+ REWARD_CONCURRENCY = 16.0
37
+
38
+ # Cold-start overhead (seconds): container boot + deps + model download (+ vLLM init for GRPO).
39
+ WORKER_BOOT_S = 180.0
40
+ DEPS_INSTALL_S = 120.0
41
+ VLLM_INIT_S = 120.0
42
+ DOWNLOAD_RATE_GBPS = 0.4 # effective HF snapshot download (hf_transfer)
43
+
44
+ DEFAULT_WALL_CAP_S = 24 * 3600 # spec gpu.max_wall_seconds default
45
+
46
+
47
+ def _fmt_duration(seconds: float) -> str:
48
+ """Human duration for notes: seconds < 1m, minutes < 1h, else whole/1-decimal hours."""
49
+ if seconds < 60:
50
+ return f"{seconds:.0f}s"
51
+ if seconds < 3600:
52
+ return f"{seconds / 60:.0f}m"
53
+ hours = seconds / 3600
54
+ return f"{hours:.0f}h" if abs(hours - round(hours)) < 1e-9 else f"{hours:.1f}h"
55
+
56
+
57
+ def setup_seconds(config: RunConfig) -> float:
58
+ """Cold-start wall time billed before the first optimizer step."""
59
+ s = WORKER_BOOT_S + DEPS_INSTALL_S + download_weight_gb(config.model_id) / DOWNLOAD_RATE_GBPS
60
+ if config.is_grpo:
61
+ s += VLLM_INIT_S
62
+ return s
63
+
64
+
65
+ def seconds_per_step(config: RunConfig, gpu: str) -> float:
66
+ """Steady-state wall time for one optimizer step on ``gpu``."""
67
+ n = config.normalized()
68
+ params = total_params_b(n.model_id) * 1e9
69
+ peak = gpu_tflops(gpu) * 1e12 # FLOP/s
70
+
71
+ if not n.is_grpo:
72
+ flops = SFT_FLOPS_PER_TOKEN_PER_PARAM * params * (n.batch_size * n.seq_len)
73
+ return flops / (peak * MFU_SFT_TRAIN)
74
+
75
+ # GRPO step = rollout (G completions/prompt) + concurrent reward grading + policy/ref update.
76
+ completions = n.batch_size * n.group_size
77
+ gen_tokens = completions * n.completion_len
78
+ gen_s = (GRPO_GEN_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_DECODE)
79
+ update_s = (GRPO_UPDATE_FLOPS_PER_TOKEN_PER_PARAM * params * gen_tokens) / (peak * MFU_TRAIN)
80
+ latency = reward_seconds_per_completion(n.reward_seconds_per_completion)
81
+ reward_s = math.ceil(completions / REWARD_CONCURRENCY) * latency # ceil: a partial wave still costs one latency
82
+ return gen_s + reward_s + update_s
83
+
84
+
85
+ def select_gpu(config: RunConfig) -> tuple[str, int]:
86
+ """(chosen GPU class, required VRAM GB): the cheapest fitting class, like the allocator
87
+ (no pin, no validation gate). Catalog sizing is offline/deterministic."""
88
+ total_params_b(config.model_id) # catalog-only: reject a non-catalog model before any (HF) sizing
89
+ need = required_vram_gb(
90
+ config.model_id,
91
+ config.method,
92
+ train=config.train_knobs(),
93
+ thinking=config.thinking,
94
+ )
95
+ gpu = pick_gpu(need, provider=config.provider)
96
+ return gpu, need
97
+
98
+
99
+ def _notes(config: RunConfig, raw_train_s: float, wall_capped: bool, cap_s: float) -> tuple[str, ...]:
100
+ n = config.normalized()
101
+ notes: list[str] = []
102
+ if (quant := model_quant(n.model_id)) != "bf16":
103
+ notes.append(f"{quant}: smaller VRAM footprint -> cheaper GPU class fits")
104
+ if n.is_grpo:
105
+ comps = n.batch_size * n.group_size
106
+ rsec = reward_seconds_per_completion(n.reward_seconds_per_completion)
107
+ notes.append(
108
+ f"GRPO step = vLLM rollout of {n.batch_size}x{n.group_size}={comps} completions "
109
+ f"@ {n.completion_len} tok + reward ({rsec:.2f}s/completion"
110
+ + (f", env {n.environment}" if n.environment else "")
111
+ + ") + policy+reference update"
112
+ )
113
+ notes.append(f"GPU sized with {vram_headroom() - 1:.0%} VRAM headroom; market (spot/queue) $/hr")
114
+ if wall_capped:
115
+ per_seed = "" if config.setup_repeats == 1 else "per-seed "
116
+ notes.append(
117
+ f"training clamped to fit the {_fmt_duration(cap_s)} {per_seed}wall cap "
118
+ f"(after setup; uncapped: {_fmt_duration(raw_train_s)})"
119
+ )
120
+ return tuple(notes)
121
+
122
+
123
+ def estimate_cost(config: RunConfig, *, wall_cap_s: float = DEFAULT_WALL_CAP_S) -> CostEstimate:
124
+ """Deterministic pre-flight cost estimate -- the analytical ground truth."""
125
+ gpu, need = select_gpu(config)
126
+ hourly = realized_hourly_usd(gpu)
127
+ # Mirror the runner's max(60, max_wall_seconds) floor so a sub-60s cap isn't underpriced.
128
+ cap_s = max(60.0, float(config.max_wall_seconds)) if config.max_wall_seconds is not None else wall_cap_s
129
+
130
+ # Each seed is its own job (own cold start + own wall cap): price one seed, clamp, x seeds.
131
+ seeds = config.setup_repeats
132
+ setup_per_seed = setup_seconds(config)
133
+ sps = seconds_per_step(config, gpu)
134
+ raw_train_per_seed = (config.steps / seeds) * sps
135
+
136
+ # The cap is on total per-seed wall; setup is billed too, so clamp training to fit it.
137
+ wall_capped = (setup_per_seed + raw_train_per_seed) > cap_s
138
+ setup_per_seed = min(setup_per_seed, cap_s)
139
+ train_per_seed = max(0.0, cap_s - setup_per_seed) if wall_capped else raw_train_per_seed
140
+
141
+ setup, train = setup_per_seed * seeds, train_per_seed * seeds
142
+ wall = setup + train
143
+
144
+ return CostEstimate(
145
+ model_id=config.model_id,
146
+ method=config.method,
147
+ steps=config.steps,
148
+ gpu=gpu,
149
+ provider=config.provider,
150
+ gpu_vram_gb=gpu_vram_gb(gpu),
151
+ required_vram_gb=need,
152
+ gpu_hourly_usd=hourly,
153
+ setup_seconds=setup,
154
+ seconds_per_step=sps,
155
+ train_seconds=train,
156
+ wall_clock_seconds=wall,
157
+ wall_capped=wall_capped,
158
+ total_usd=wall / 3600.0 * hourly,
159
+ notes=_notes(config, raw_train_per_seed, wall_capped, cap_s),
160
+ )
@@ -0,0 +1,126 @@
1
+ """Static lookup facts for the cost model: GPU price/VRAM/compute + cheapest-fit
2
+ selection, model size/quant, and reward-grader latency. Pure tables + accessors."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from flash.catalog import MODELS
7
+ from flash.providers.base import GPU_INFO, GpuClass, providers_for
8
+
9
+ # ===== GPU facts =====
10
+ GPU_COMPUTE_TFLOPS: dict[str, float] = {
11
+ "RTX A4000": 77.0,
12
+ "RTX 2000 Ada": 89.0,
13
+ "RTX A4500": 89.0,
14
+ "RTX 4000 Ada": 90.0,
15
+ "RTX A5000": 89.0,
16
+ "RTX 3090": 71.0,
17
+ "L4": 60.0,
18
+ "RTX Pro 4000": 95.0,
19
+ "RTX 4090": 165.0,
20
+ "RTX 5090": 210.0,
21
+ "RTX A6000": 155.0,
22
+ "A40": 150.0,
23
+ "RTX 6000 Ada": 182.0,
24
+ "L40S": 181.0,
25
+ "A100 SXM 40GB": 312.0,
26
+ "A100 PCIe": 312.0,
27
+ "A100 SXM": 312.0,
28
+ "H100 NVL": 835.0,
29
+ "H100": 990.0,
30
+ "RTX Pro 6000": 250.0,
31
+ "RTX Pro 6000 WK": 250.0,
32
+ }
33
+ _DEFAULT_TFLOPS = 100.0
34
+
35
+
36
+ def gpu_tflops(name: str) -> float:
37
+ """Peak bf16 tensor TFLOPS for a managed GPU class."""
38
+ return GPU_COMPUTE_TFLOPS.get(name, _DEFAULT_TFLOPS)
39
+
40
+
41
+ def gpu_hourly_usd(name: str) -> float:
42
+ """Static fallback (on-demand list) $/hr for a class."""
43
+ info = GPU_INFO.get(name)
44
+ if info is None:
45
+ raise KeyError(f"unknown GPU class {name!r}")
46
+ return info.hourly_usd
47
+
48
+
49
+ # Realized (spot/queue) $/hr per class -- the discount below on-demand list (RTX 5090 lists
50
+ # $0.99, bills ~$0.87). ``realized_hourly_usd`` CLAMPS to the list price so it can never
51
+ # over-quote; a class with no clean observed rate falls back to list.
52
+ REALIZED_HOURLY_USD: dict[str, float] = {
53
+ "RTX 3090": 0.239,
54
+ "RTX 4090": 0.426,
55
+ "RTX 5090": 0.871,
56
+ "RTX A5000": 0.304,
57
+ "RTX 6000 Ada": 0.601,
58
+ "A100 PCIe": 1.035,
59
+ "A100 SXM": 1.133,
60
+ }
61
+
62
+
63
+ def realized_hourly_usd(name: str) -> float:
64
+ """Market (spot/queue) $/hr, clamped to the list price; the list price when not observed."""
65
+ list_price = gpu_hourly_usd(name)
66
+ return min(REALIZED_HOURLY_USD.get(name, list_price), list_price)
67
+
68
+
69
+ def gpu_vram_gb(name: str) -> int:
70
+ info = GPU_INFO.get(name)
71
+ if info is None:
72
+ raise KeyError(f"unknown GPU class {name!r}")
73
+ return info.vram_gb
74
+
75
+
76
+ def pick_gpu(required_vram_gb: int, *, provider: str | None = None) -> str:
77
+ """Cheapest GPU class that fits ``required_vram_gb``, ranked by the REALIZED (market) $/hr it
78
+ is BILLED at (ties: vram, name) -- so selection is consistent with the bill and approximates
79
+ the allocator, which provisions the cheapest live offer. No pin and no validation gate -- every
80
+ fitting class is eligible. ``provider`` restricts candidates to what it can provision.
81
+ """
82
+
83
+ def _selectable(g: GpuClass) -> bool:
84
+ return provider in (None, "auto") or provider in providers_for(g.name)
85
+
86
+ candidates = [g for g in GPU_INFO.values() if g.vram_gb >= required_vram_gb and _selectable(g)]
87
+ if not candidates:
88
+ raise ValueError(f"no GPU class fits >= {required_vram_gb} GB")
89
+ best = min(candidates, key=lambda g: (realized_hourly_usd(g.name), g.vram_gb, g.name))
90
+ return best.name
91
+
92
+
93
+ # ===== Model-size facts (catalog-only; five dense text models, no MoE/open-model sizing) =====
94
+ def total_params_b(model_id: str) -> float:
95
+ """Total parameter count (billions) for a catalog model -- the curated ``params_b`` stat."""
96
+ info = MODELS.get(model_id)
97
+ if info is None:
98
+ raise ValueError(
99
+ f"unknown model {model_id!r}; cost estimation supports catalog models only "
100
+ f"({', '.join(MODELS)})"
101
+ )
102
+ return info.params_b
103
+
104
+
105
+ def model_quant(model_id: str) -> str:
106
+ """Quantization of the catalog entry (``"bf16"`` or ``"4bit-qlora"``); bf16 default."""
107
+ info = MODELS.get(model_id)
108
+ return (info.quant or "bf16") if info is not None else "bf16"
109
+
110
+
111
+ def download_weight_gb(model_id: str) -> float:
112
+ """GB pulled from the HF hub at cold start (full bf16 checkpoint, 2 bytes/param)."""
113
+ return total_params_b(model_id) * 2.0
114
+
115
+
116
+ # ===== Reward-grader latency (GRPO) =====
117
+ # A single average grader latency (s/completion) for every env. Graders span ~0.01s (regex/math)
118
+ # to ~3s (LLM judge/code); ~1s is a middle-of-the-road default (a run can override it).
119
+ AVG_REWARD_SECONDS_PER_COMPLETION = 1.0
120
+
121
+
122
+ def reward_seconds_per_completion(override: float | None = None) -> float:
123
+ """Per-completion reward latency (s): the explicit override, else the single average."""
124
+ if override is not None:
125
+ return max(0.0, override)
126
+ return AVG_REWARD_SECONDS_PER_COMPLETION
@@ -0,0 +1,87 @@
1
+ """Map a parsed training ``JobSpec`` to a cost ``RunConfig`` / step count / estimate.
2
+
3
+ Shared by ``flash train --cost`` and the control plane's submit-time charge, so both price the
4
+ same work on the same catalog-only, cheapest-fit basis."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from flash.cost.analytical import estimate_cost
9
+ from flash.cost.types import CostEstimate, RunConfig
10
+
11
+
12
+ def count_env_examples(env_id: str, params: dict | None = None) -> int | None:
13
+ """Training rows in ``env_id``'s dataset (the worker's train split), or ``None`` if it can't
14
+ be loaded. Best-effort -- prices an uncapped SFT run on the real dataset size, not a guess."""
15
+ if not env_id:
16
+ return None
17
+ try:
18
+ from flash.envs import load_environment
19
+
20
+ rows = load_environment(env_id, params or {}).dataset("train")
21
+ except Exception:
22
+ return None
23
+ return len(rows) if rows is not None else None
24
+
25
+
26
+ def spec_steps(spec) -> int:
27
+ """Per-seed optimizer steps implied by a train spec (mirrors the worker). GRPO: ``train.steps``
28
+ (else recipe default). SFT: ``epochs x ceil(num_examples / realized_batch)`` capped by
29
+ ``max_steps``, where ``num_examples`` is ``max_examples`` if pinned else the real env size."""
30
+ from flash.engine.recipe import RECIPE
31
+ from flash.engine.vram import sft_realized_batch
32
+
33
+ t = spec.train
34
+ if spec.algorithm == "grpo":
35
+ if t.steps is not None:
36
+ return max(1, int(t.steps))
37
+ return RECIPE.rl.num_steps
38
+ # --- SFT ---
39
+ cap = int(t.max_steps) if t.max_steps else 0 # SFT-only optimizer-step cap (0 = uncapped)
40
+ epochs = int(t.epochs) if t.epochs is not None else RECIPE.sft.num_epochs
41
+ requested_batch = int(t.batch_size) if t.batch_size is not None else RECIPE.sft.effective_batch
42
+ batch = sft_realized_batch(requested_batch)
43
+ # max_examples is a CAP; 0 (like None) means "no cap" (worker trains the full dataset), so
44
+ # don't let max_examples=0 price a single step.
45
+ pinned_examples = int(t.max_examples) if t.max_examples else 0
46
+ if pinned_examples > 0:
47
+ examples = pinned_examples
48
+ else:
49
+ # No cap: the worker trains the FULL env dataset, so price its real size.
50
+ examples = count_env_examples(spec.environment.id, spec.environment.params)
51
+ if examples is None:
52
+ raise ValueError(
53
+ f"could not load environment {spec.environment.id!r} to count its training "
54
+ f"examples for the cost; install it (`slm env install {spec.environment.id}`) "
55
+ "or pin [train].max_examples"
56
+ )
57
+ n = max(1, -(-examples // batch) * epochs) # epochs x ceil(examples / realized_batch)
58
+ return min(n, cap) if cap > 0 else n
59
+
60
+
61
+ def runconfig_from_spec(spec) -> RunConfig:
62
+ """Map a parsed ``JobSpec`` to a cost ``RunConfig``. Each seed is its own job that re-pays the
63
+ cold start, so steps and setup repeats scale by the seed count. The estimate doesn't pin a
64
+ GPU -- it does its own cheapest-fit (provider="auto")."""
65
+ t, g = spec.train, spec.gpu
66
+ is_grpo = spec.algorithm == "grpo"
67
+ seeds = max(1, len(t.seeds or (0,)))
68
+ return RunConfig(
69
+ model_id=spec.model,
70
+ method=spec.algorithm,
71
+ steps=spec_steps(spec) * seeds,
72
+ setup_repeats=seeds,
73
+ seq_len=t.max_length,
74
+ completion_len=t.max_tokens if is_grpo else None,
75
+ batch_size=t.batch_size,
76
+ group_size=t.group_size if is_grpo else None,
77
+ lora_rank=t.lora_rank,
78
+ thinking=spec.thinking,
79
+ provider="auto",
80
+ max_wall_seconds=g.max_wall_seconds,
81
+ environment=spec.environment.id or None,
82
+ )
83
+
84
+
85
+ def estimate_for_spec(spec) -> CostEstimate:
86
+ """The pre-flight ``CostEstimate`` for a parsed training ``JobSpec``."""
87
+ return estimate_cost(runconfig_from_spec(spec))