freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/cost/facts.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Static lookup facts for the cost model: GPU price/VRAM/compute + cheapest-fit
|
|
2
|
+
selection, model size/quant, and reward-grader latency. Pure tables + accessors."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from flash.catalog import MODELS
|
|
7
|
+
from flash.providers.base import GPU_INFO, GpuClass, providers_for
|
|
8
|
+
|
|
9
|
+
# ===== GPU facts =====
|
|
10
|
+
GPU_COMPUTE_TFLOPS: dict[str, float] = {
|
|
11
|
+
"L4": 60.0,
|
|
12
|
+
"RTX 4090": 165.0,
|
|
13
|
+
"RTX 5090": 210.0,
|
|
14
|
+
"RTX A6000": 155.0,
|
|
15
|
+
"A40": 150.0,
|
|
16
|
+
"RTX 6000 Ada": 182.0,
|
|
17
|
+
"A100 PCIe": 312.0,
|
|
18
|
+
"A100 SXM": 312.0,
|
|
19
|
+
"H100": 990.0,
|
|
20
|
+
"RTX Pro 6000": 250.0,
|
|
21
|
+
}
|
|
22
|
+
_DEFAULT_TFLOPS = 100.0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def gpu_tflops(name: str) -> float:
|
|
26
|
+
"""Peak bf16 tensor TFLOPS for a managed GPU class."""
|
|
27
|
+
return GPU_COMPUTE_TFLOPS.get(name, _DEFAULT_TFLOPS)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def gpu_hourly_usd(name: str, provider: str | None = None) -> float:
|
|
31
|
+
"""Representative $/hr for a class, on ``provider`` when given.
|
|
32
|
+
|
|
33
|
+
The nominal ``GpuClass.hourly_usd`` is the RunPod rate, which is WRONG for a provider-specific
|
|
34
|
+
quote (e.g. a Lambda RTX A6000 is $1.09/hr, not RunPod's $0.49). When ``provider`` is
|
|
35
|
+
``lambda``/``hyperstack`` and the class is offered there, price it through that provider's
|
|
36
|
+
pricing module (live with a static fallback); otherwise (runpod/auto/None) use the nominal rate.
|
|
37
|
+
"""
|
|
38
|
+
info = GPU_INFO.get(name)
|
|
39
|
+
if info is None:
|
|
40
|
+
raise KeyError(f"unknown GPU class {name!r}")
|
|
41
|
+
p = (provider or "").strip().lower()
|
|
42
|
+
if p == "lambda" and info.lambda_name:
|
|
43
|
+
from flash.providers.lambdalabs.pricing import hourly_rate
|
|
44
|
+
|
|
45
|
+
return hourly_rate(name)
|
|
46
|
+
if p == "hyperstack" and info.hyperstack_name:
|
|
47
|
+
from flash.providers.hyperstack.pricing import hourly_rate
|
|
48
|
+
|
|
49
|
+
return hourly_rate(name)
|
|
50
|
+
return info.hourly_usd
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def gpu_vram_gb(name: str) -> int:
|
|
54
|
+
info = GPU_INFO.get(name)
|
|
55
|
+
if info is None:
|
|
56
|
+
raise KeyError(f"unknown GPU class {name!r}")
|
|
57
|
+
return info.vram_gb
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def pick_gpu(required_vram_gb: int, *, provider: str | None = None) -> str:
|
|
61
|
+
"""Cheapest GPU class that fits ``required_vram_gb``, ranked by static $/hr.
|
|
62
|
+
|
|
63
|
+
No pin; every fitting class is eligible, validated or not. NOTE this is intentionally
|
|
64
|
+
gate-free: the submit-time allocator restricts to the validated pool, so the
|
|
65
|
+
actually-provisioned class can be pricier than the one priced here. ``provider`` restricts
|
|
66
|
+
candidates to what it can provision.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def _selectable(g: GpuClass) -> bool:
|
|
70
|
+
return provider in (None, "auto") or provider in providers_for(g.name)
|
|
71
|
+
|
|
72
|
+
candidates = [g for g in GPU_INFO.values() if g.vram_gb >= required_vram_gb and _selectable(g)]
|
|
73
|
+
if not candidates:
|
|
74
|
+
raise ValueError(f"no GPU class fits >= {required_vram_gb} GB")
|
|
75
|
+
# Rank by the rate on the REQUESTED provider so a provider-specific quote picks that provider's
|
|
76
|
+
# cheapest fit (not the cheapest by the RunPod nominal rate).
|
|
77
|
+
best = min(candidates, key=lambda g: (gpu_hourly_usd(g.name, provider=provider), g.vram_gb, g.name))
|
|
78
|
+
return best.name
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ===== Model-size facts (catalog-only; five dense text models, no MoE/open-model sizing) =====
|
|
82
|
+
def total_params_b(model_id: str) -> float:
|
|
83
|
+
"""Total parameter count (billions) for a catalog model -- the curated ``params_b`` stat."""
|
|
84
|
+
info = MODELS.get(model_id)
|
|
85
|
+
if info is None:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"unknown model {model_id!r}; cost estimation supports catalog models only "
|
|
88
|
+
f"({', '.join(MODELS)})"
|
|
89
|
+
)
|
|
90
|
+
return info.params_b
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def model_quant(model_id: str) -> str:
|
|
94
|
+
"""Quantization of the catalog entry; ``"bf16"`` for the whole catalog today (bf16 default)."""
|
|
95
|
+
info = MODELS.get(model_id)
|
|
96
|
+
return (info.quant or "bf16") if info is not None else "bf16"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def download_weight_gb(model_id: str) -> float:
|
|
100
|
+
"""GB pulled from the HF hub at cold start (full bf16 checkpoint, 2 bytes/param)."""
|
|
101
|
+
return total_params_b(model_id) * 2.0
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ===== Reward-grader latency (GRPO) =====
|
|
105
|
+
# A single average grader latency (s/completion) for every env. Graders span ~0.01s (regex/math)
|
|
106
|
+
# to ~3s (LLM judge/code); ~1s is a middle-of-the-road default (a run can override it).
|
|
107
|
+
AVG_REWARD_SECONDS_PER_COMPLETION = 1.0
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def reward_seconds_per_completion(override: float | None = None) -> float:
|
|
111
|
+
"""Per-completion reward latency (s): the explicit override, else the single average."""
|
|
112
|
+
if override is not None:
|
|
113
|
+
return max(0.0, override)
|
|
114
|
+
return AVG_REWARD_SECONDS_PER_COMPLETION
|
flash/cost/spec.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Map a parsed training ``JobSpec`` to a cost ``RunConfig`` / step count / estimate.
|
|
2
|
+
|
|
3
|
+
Used by ``flash train --cost`` for a pre-flight quote. The control plane bills completed runs
|
|
4
|
+
from their final recorded ``cost_usd`` instead of charging this estimate at submit time."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from flash.cost.analytical import estimate_cost
|
|
9
|
+
from flash.cost.types import CostEstimate, RunConfig
|
|
10
|
+
|
|
11
|
+
# Fallback SFT dataset size when an uncapped run's env can't be counted locally. Most Freesolo
|
|
12
|
+
# training datasets land in the
|
|
13
|
+
# low-thousands of rows; this is a representative middle estimate so the quote is in the right
|
|
14
|
+
# ballpark rather than hard-failing.
|
|
15
|
+
DEFAULT_UNCOUNTED_SFT_EXAMPLES = 1000
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def count_env_examples(env_id: str, params: dict | None = None) -> int | None:
|
|
19
|
+
"""Training rows in ``env_id``'s dataset (the worker's train split), or ``None`` if it can't
|
|
20
|
+
be loaded. Best-effort -- prices an uncapped SFT run on the real dataset size, not a guess.
|
|
21
|
+
|
|
22
|
+
Loading may need network access for managed Freesolo environments. If the environment
|
|
23
|
+
cannot be loaded in this interpreter, this returns ``None`` and the caller falls back to a
|
|
24
|
+
default count instead of hard-failing."""
|
|
25
|
+
if not env_id:
|
|
26
|
+
return None
|
|
27
|
+
try:
|
|
28
|
+
from flash.envs import load_environment
|
|
29
|
+
|
|
30
|
+
rows = load_environment(env_id, params or {}).dataset()
|
|
31
|
+
except Exception:
|
|
32
|
+
return None
|
|
33
|
+
return len(rows) if rows is not None else None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def spec_steps(spec) -> int:
|
|
37
|
+
"""Per-seed optimizer steps implied by a train spec (mirrors the worker). GRPO: ``train.steps``
|
|
38
|
+
(else recipe default). SFT: ``epochs x ceil(num_examples / realized_batch)`` capped by
|
|
39
|
+
``max_steps``, where ``num_examples`` is ``max_examples`` if pinned else the real env size."""
|
|
40
|
+
from flash.catalog import vocab_size_for
|
|
41
|
+
from flash.engine.recipe import RECIPE
|
|
42
|
+
from flash.engine.vram import resolve_params_b, sft_logits_fused, sft_realized_batch
|
|
43
|
+
|
|
44
|
+
t = spec.train
|
|
45
|
+
if spec.algorithm == "grpo":
|
|
46
|
+
if t.steps is not None:
|
|
47
|
+
return max(1, int(t.steps))
|
|
48
|
+
return RECIPE.rl.num_steps
|
|
49
|
+
# --- SFT ---
|
|
50
|
+
cap = int(t.max_steps) if t.max_steps else 0 # SFT-only optimizer-step cap (0 = uncapped)
|
|
51
|
+
epochs = int(t.epochs) if t.epochs is not None else RECIPE.sft.num_epochs
|
|
52
|
+
requested_batch = int(t.batch_size) if t.batch_size is not None else RECIPE.sft.effective_batch
|
|
53
|
+
# Mirror the worker's per-device micro-batch EXACTLY, incl. the big-vocab logits cap: when the
|
|
54
|
+
# fused CE is OFF the worker vocab-sizes the micro-batch (engine.worker), which (with CEIL'd
|
|
55
|
+
# grad-accum) can change the realized global batch and thus the step count. Feed the same
|
|
56
|
+
# seq/vocab/fused so the priced step count matches what actually runs.
|
|
57
|
+
sft_seq = (
|
|
58
|
+
int(t.max_length)
|
|
59
|
+
if t.max_length is not None
|
|
60
|
+
else (RECIPE.sft.max_seq_len_thinking if spec.thinking else RECIPE.sft.max_seq_len)
|
|
61
|
+
)
|
|
62
|
+
# Resolve params_b via the shared helper (catalog stat else HF safetensors for an open model) —
|
|
63
|
+
# the SAME resolution the worker's run_sft uses. The fused-CE decision (and thus the big-vocab
|
|
64
|
+
# micro-batch cap) hinges on the >=3B threshold, so an uncataloged >=3B model must not be priced
|
|
65
|
+
# as <3B (which would flip fused off, change the realized batch via the cap, and misprice the
|
|
66
|
+
# step count). Best-effort: no network -> None -> the prior <3B (cap-on) behavior.
|
|
67
|
+
sft_fused = sft_logits_fused(resolve_params_b(spec.model), sft_seq)
|
|
68
|
+
batch = sft_realized_batch(
|
|
69
|
+
requested_batch, seq_len=sft_seq, vocab=vocab_size_for(spec.model), fused=sft_fused
|
|
70
|
+
)
|
|
71
|
+
# max_examples is a CAP; 0 (like None) means "no cap" (worker trains the full dataset), so
|
|
72
|
+
# don't let max_examples=0 price a single step.
|
|
73
|
+
pinned_examples = int(t.max_examples) if t.max_examples else 0
|
|
74
|
+
if pinned_examples > 0:
|
|
75
|
+
examples = pinned_examples
|
|
76
|
+
else:
|
|
77
|
+
# No cap: the worker trains the FULL env dataset, so price its real size when we can
|
|
78
|
+
# count it. A managed Freesolo environment may not be reachable in this interpreter, so
|
|
79
|
+
# counting can return None. Fall back to a representative default instead of hard-failing.
|
|
80
|
+
examples = count_env_examples(spec.environment.id, spec.environment.params)
|
|
81
|
+
if examples is None:
|
|
82
|
+
examples = DEFAULT_UNCOUNTED_SFT_EXAMPLES
|
|
83
|
+
n = max(1, -(-examples // batch) * epochs) # epochs x ceil(examples / realized_batch)
|
|
84
|
+
return min(n, cap) if cap > 0 else n
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def runconfig_from_spec(spec) -> RunConfig:
|
|
88
|
+
"""Map a parsed ``JobSpec`` to a cost ``RunConfig``. Each seed is its own job that re-pays the
|
|
89
|
+
cold start, so steps and setup repeats scale by the seed count. The estimate doesn't pin a
|
|
90
|
+
GPU -- it does its own cheapest-fit (provider="auto")."""
|
|
91
|
+
t, g = spec.train, spec.gpu
|
|
92
|
+
is_grpo = spec.algorithm == "grpo"
|
|
93
|
+
seeds = max(1, len(t.seeds or (0,)))
|
|
94
|
+
return RunConfig(
|
|
95
|
+
model_id=spec.model,
|
|
96
|
+
method=spec.algorithm,
|
|
97
|
+
steps=spec_steps(spec) * seeds,
|
|
98
|
+
setup_repeats=seeds,
|
|
99
|
+
seq_len=t.max_length,
|
|
100
|
+
completion_len=t.max_tokens if is_grpo else None,
|
|
101
|
+
batch_size=t.batch_size,
|
|
102
|
+
group_size=t.group_size if is_grpo else None,
|
|
103
|
+
lora_rank=t.lora_rank,
|
|
104
|
+
thinking=spec.thinking,
|
|
105
|
+
provider="auto",
|
|
106
|
+
max_wall_seconds=g.max_wall_seconds,
|
|
107
|
+
environment=spec.environment.id or None,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def estimate_for_spec(spec) -> CostEstimate:
|
|
112
|
+
"""The pre-flight ``CostEstimate`` for a parsed training ``JobSpec``."""
|
|
113
|
+
return estimate_cost(runconfig_from_spec(spec))
|
flash/cost/types.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""The estimator's I/O types: ``RunConfig`` (input) and ``CostEstimate`` (result)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
|
|
7
|
+
from flash.catalog import normalize_algorithm
|
|
8
|
+
from flash.engine.recipe import RECIPE
|
|
9
|
+
from flash.providers import PROVIDER_NAMES
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class RunConfig:
|
|
14
|
+
"""One training run to price. ``None`` knobs resolve to recipe defaults."""
|
|
15
|
+
|
|
16
|
+
model_id: str
|
|
17
|
+
method: str # "sft" | "grpo"
|
|
18
|
+
steps: int
|
|
19
|
+
|
|
20
|
+
# Cold-start setups the bill covers: a multi-seed run reprovisions (and re-pays boot) per
|
|
21
|
+
# seed, so this is the seed count.
|
|
22
|
+
setup_repeats: int = 1
|
|
23
|
+
|
|
24
|
+
# Engine context length (forwarded as [train].max_length, NOT prompt length). When unset the
|
|
25
|
+
# GRPO default mirrors the worker's max(1024, max_prompt_len + completion); see normalized().
|
|
26
|
+
seq_len: int | None = None
|
|
27
|
+
completion_len: int | None = None # GRPO only (max_tokens)
|
|
28
|
+
batch_size: int | None = None # SFT effective batch / GRPO prompts_per_step
|
|
29
|
+
group_size: int | None = None # GRPO completions per prompt (G)
|
|
30
|
+
lora_rank: int | None = None
|
|
31
|
+
thinking: bool = False
|
|
32
|
+
# GRPO only: seconds to score one completion. None -> the single average grader latency.
|
|
33
|
+
reward_seconds_per_completion: float | None = None
|
|
34
|
+
|
|
35
|
+
max_wall_seconds: int | None = None # per-seed wall cap (spec gpu.max_wall_seconds); None = 24h
|
|
36
|
+
provider: str = "auto"
|
|
37
|
+
environment: str | None = None # Freesolo environment id; descriptive only
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
object.__setattr__(self, "method", normalize_algorithm(self.method))
|
|
41
|
+
# Normalize like the allocator (case/whitespace, empty -> "auto") and reject an unknown
|
|
42
|
+
# substrate up front (else it filters out every candidate -> confusing "no GPU fits").
|
|
43
|
+
prov = (self.provider or "auto").strip().lower() or "auto"
|
|
44
|
+
if prov not in ("auto", *PROVIDER_NAMES):
|
|
45
|
+
raise ValueError(f"unknown provider {self.provider!r} (auto, {', '.join(PROVIDER_NAMES)})")
|
|
46
|
+
object.__setattr__(self, "provider", prov)
|
|
47
|
+
if self.steps < 1:
|
|
48
|
+
raise ValueError(f"steps must be >= 1, got {self.steps}")
|
|
49
|
+
if self.setup_repeats < 1:
|
|
50
|
+
raise ValueError(f"setup_repeats must be >= 1, got {self.setup_repeats}")
|
|
51
|
+
# Steps are split evenly across seeds, so a non-divisible split would price fractional
|
|
52
|
+
# steps per seed (impossible in a real run).
|
|
53
|
+
if self.steps % self.setup_repeats != 0:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"steps ({self.steps}) must be a multiple of setup_repeats ({self.setup_repeats})"
|
|
56
|
+
)
|
|
57
|
+
# Reject 0/negative positive-only knobs (bogus quote). max_wall_seconds is NOT here: the
|
|
58
|
+
# runner floors it to max(60, ...) and estimate_cost mirrors that, so a non-positive cap
|
|
59
|
+
# is accepted (floored to 60s), not rejected.
|
|
60
|
+
for _name in ("seq_len", "batch_size", "group_size", "completion_len", "lora_rank"):
|
|
61
|
+
_val = getattr(self, _name)
|
|
62
|
+
if _val is not None and _val < 1:
|
|
63
|
+
raise ValueError(f"{_name} must be >= 1, got {_val}")
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def is_grpo(self) -> bool:
|
|
67
|
+
return self.method == "grpo"
|
|
68
|
+
|
|
69
|
+
def normalized(self) -> RunConfig:
|
|
70
|
+
"""A copy with every ``None`` knob filled from the recipe for this method."""
|
|
71
|
+
lora = self.lora_rank if self.lora_rank is not None else RECIPE.lora.rank
|
|
72
|
+
if self.is_grpo:
|
|
73
|
+
comp = self.completion_len
|
|
74
|
+
if comp is None:
|
|
75
|
+
comp = (
|
|
76
|
+
RECIPE.rl.max_completion_len_thinking
|
|
77
|
+
if self.thinking
|
|
78
|
+
else RECIPE.rl.max_completion_len
|
|
79
|
+
)
|
|
80
|
+
# Explicit pin wins; else mirror the allocator's GRPO sizing of an unset max_length:
|
|
81
|
+
# max(1024, max_prompt_len + completion), not bare max_prompt_len (which under-sizes).
|
|
82
|
+
seq = (
|
|
83
|
+
self.seq_len
|
|
84
|
+
if self.seq_len is not None
|
|
85
|
+
else max(1024, RECIPE.rl.max_prompt_len + int(comp))
|
|
86
|
+
)
|
|
87
|
+
batch = self.batch_size if self.batch_size is not None else RECIPE.rl.prompts_per_step
|
|
88
|
+
group = self.group_size if self.group_size is not None else RECIPE.rl.group_size
|
|
89
|
+
else:
|
|
90
|
+
seq = self.seq_len
|
|
91
|
+
if seq is None:
|
|
92
|
+
seq = RECIPE.sft.max_seq_len_thinking if self.thinking else RECIPE.sft.max_seq_len
|
|
93
|
+
comp = None
|
|
94
|
+
batch = self.batch_size if self.batch_size is not None else RECIPE.sft.effective_batch
|
|
95
|
+
group = None
|
|
96
|
+
return replace(self, seq_len=seq, completion_len=comp, batch_size=batch, group_size=group, lora_rank=lora)
|
|
97
|
+
|
|
98
|
+
def train_knobs(self) -> dict[str, int]:
|
|
99
|
+
"""The knob dict ``model_required_vram_gb`` consumes. Only an EXPLICIT batch_size is
|
|
100
|
+
forwarded -- an omitted SFT batch sizes as the worker's micro-batch (4), not the recipe's
|
|
101
|
+
effective batch (32), which would over-provision."""
|
|
102
|
+
n = self.normalized()
|
|
103
|
+
knobs: dict[str, int] = {"lora_rank": n.lora_rank}
|
|
104
|
+
if self.batch_size is not None:
|
|
105
|
+
knobs["batch_size"] = self.batch_size
|
|
106
|
+
if n.seq_len is not None:
|
|
107
|
+
knobs["max_length"] = n.seq_len
|
|
108
|
+
if n.completion_len is not None:
|
|
109
|
+
knobs["max_tokens"] = n.completion_len
|
|
110
|
+
if n.group_size is not None:
|
|
111
|
+
knobs["group_size"] = n.group_size
|
|
112
|
+
return knobs
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class CostEstimate:
|
|
117
|
+
"""A pre-flight estimate. ``total_usd`` = ``wall_clock_hours * gpu_hourly_usd``, no multiplier."""
|
|
118
|
+
|
|
119
|
+
model_id: str
|
|
120
|
+
method: str
|
|
121
|
+
steps: int
|
|
122
|
+
gpu: str
|
|
123
|
+
provider: str
|
|
124
|
+
gpu_vram_gb: int
|
|
125
|
+
required_vram_gb: int
|
|
126
|
+
gpu_hourly_usd: float
|
|
127
|
+
setup_seconds: float # cold start: boot + deps + model load (+ vLLM init for GRPO)
|
|
128
|
+
seconds_per_step: float
|
|
129
|
+
train_seconds: float # steps * seconds_per_step (post wall-clock cap)
|
|
130
|
+
wall_clock_seconds: float
|
|
131
|
+
wall_capped: bool
|
|
132
|
+
total_usd: float
|
|
133
|
+
notes: tuple[str, ...] = ()
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def wall_clock_hours(self) -> float:
|
|
137
|
+
return self.wall_clock_seconds / 3600.0
|
|
138
|
+
|
|
139
|
+
def breakdown(self) -> str:
|
|
140
|
+
"""Multi-line itemized breakdown for CLI output."""
|
|
141
|
+
lines = [
|
|
142
|
+
f"Run : {self.model_id} [{self.method.upper()}, {self.steps} steps]",
|
|
143
|
+
f"GPU : {self.gpu} on {self.provider} "
|
|
144
|
+
f"({self.gpu_vram_gb} GB; run needs >= {self.required_vram_gb} GB) "
|
|
145
|
+
f"@ ${self.gpu_hourly_usd:.2f}/hr",
|
|
146
|
+
f"Setup : {self.setup_seconds / 60:.1f} min (cold start: boot + deps + model load"
|
|
147
|
+
+ (" + vLLM init" if self.method == "grpo" else "")
|
|
148
|
+
+ ")",
|
|
149
|
+
f"Per step : {self.seconds_per_step:.2f} s",
|
|
150
|
+
f"Train : {self.train_seconds / 60:.1f} min"
|
|
151
|
+
+ (" [CAPPED at the wall-clock limit]" if self.wall_capped else ""),
|
|
152
|
+
f"Wall clock : {self.wall_clock_hours:.2f} h",
|
|
153
|
+
f"TOTAL : ${self.total_usd:.2f}",
|
|
154
|
+
]
|
|
155
|
+
if self.notes:
|
|
156
|
+
lines.append("Notes :")
|
|
157
|
+
lines.extend(f" - {n}" for n in self.notes)
|
|
158
|
+
return "\n".join(lines)
|
flash/engine/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Cost accounting + the standard run-metrics record for Flash runs.
|
|
2
|
+
|
|
3
|
+
GPU cost = gpu_hours * hourly_rate (per-second billing on RunPod; artifacts go via HF).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from dataclasses import asdict, dataclass, field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class RunMetrics:
|
|
14
|
+
"""Standard metrics record written per phase/seed."""
|
|
15
|
+
|
|
16
|
+
arm: str = "runpod" # compute substrate
|
|
17
|
+
phase: str = "" # "sft" | "rl"
|
|
18
|
+
seed: int = 0
|
|
19
|
+
model_id: str = ""
|
|
20
|
+
# Speed
|
|
21
|
+
wall_seconds: float = 0.0
|
|
22
|
+
setup_seconds: float = 0.0 # cold start / provisioning + model load
|
|
23
|
+
train_throughput_toks_per_s: float = 0.0
|
|
24
|
+
# Token accounting
|
|
25
|
+
train_tokens: int = 0
|
|
26
|
+
generated_tokens: int = 0 # RL: total sampled completion tokens
|
|
27
|
+
# Misc / friction. cost_usd is computed/stamped downstream by the runner from the
|
|
28
|
+
# provider's $/hr (see runner._persist_metrics), not by the worker.
|
|
29
|
+
notes: dict = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
def to_json(self) -> str:
|
|
32
|
+
return json.dumps(asdict(self), indent=2)
|
|
33
|
+
|
|
34
|
+
def save(self, path: str):
|
|
35
|
+
with open(path, "w") as f:
|
|
36
|
+
f.write(self.to_json())
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Optional chalk GPU kernels (the ``freesolo-chalk`` package).
|
|
2
|
+
|
|
3
|
+
Chalk holds Freesolo's hand-written Triton/CUDA kernels that complement Liger: the RoPE the
|
|
4
|
+
qwen3.5 hybrid arch needs (Liger refuses it), the LoRA-delta matmul, fused MLP, the QKV
|
|
5
|
+
norm+RoPE attention epilogue, embedding gather, and FP8 frozen-base GEMMs.
|
|
6
|
+
|
|
7
|
+
Chalk ships a Liger-style one-call entry point, ``apply_chalk_kernel_to_qwen35(model, ...)``,
|
|
8
|
+
mirroring ``apply_liger_kernel_to_qwen3``: enablement is the call itself (no env flag), each kernel
|
|
9
|
+
is a boolean keyword, and it NEVER raises on a kernel failure (every installer self-tests +
|
|
10
|
+
arch-gates and falls back to the eager/Liger path; a no-op off-GPU). flash applies it
|
|
11
|
+
AUTOMATICALLY — like Liger — after the trainer builds the model, with the **gap-filling** kernels
|
|
12
|
+
Liger leaves on the eager path ON BY DEFAULT: RoPE, the LoRA-delta matmul, and embedding gather.
|
|
13
|
+
The kernels that OVERLAP Liger (fused MLP / SwiGLU — Liger owns MLP) or are situational (the
|
|
14
|
+
eval-only QKV epilogue, the Hopper-only FP8 frozen base) stay OPT-IN.
|
|
15
|
+
|
|
16
|
+
Liger is applied by TRL (``use_liger_kernel``); chalk composes ON TOP of the live Liger modules,
|
|
17
|
+
so flash calls chalk with ``liger=False``. Kernel selection is FIXED (deterministic): the
|
|
18
|
+
gap-fillers run and the overlapping/situational kernels stay off — there is no env override. If
|
|
19
|
+
``freesolo-chalk`` isn't installed (no ``FLASH_CHALK_SPEC``, or on the control plane) the whole
|
|
20
|
+
module degrades to a no-op.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from collections.abc import Mapping
|
|
26
|
+
|
|
27
|
+
from flash._logging import get_logger
|
|
28
|
+
|
|
29
|
+
log = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
# Chalk kernel table: (apply_chalk_kernel_to_qwen35 keyword, enabled). Selection is FIXED — there
|
|
32
|
+
# is no env override; the values here are exactly what runs on every supported run.
|
|
33
|
+
# The GAP-FILLERS that complement Liger are ON — applied automatically like apply_liger_kernel —
|
|
34
|
+
# because each chalk installer self-tests on install and falls back to the eager/Liger path on any
|
|
35
|
+
# failure, so always-applying them is safe:
|
|
36
|
+
# * rope — the RoPE Liger REFUSES on the qwen3.5 hybrid arch (its only real gap)
|
|
37
|
+
# * fused_lora_delta — the LoRA-delta matmul on the trainable path (Liger doesn't touch adapters)
|
|
38
|
+
# * fused_embedding — the embedding gather (Liger doesn't touch it)
|
|
39
|
+
# The OVERLAPPING / situational kernels stay OFF: the fused MLP overlaps Liger's SwiGLU (Liger owns
|
|
40
|
+
# MLP), the attn epilogue is eval-only (needs q/k/v out of LORA_TARGETS), and the FP8 frozen base is
|
|
41
|
+
# Hopper sm_90+ only. The keyword is exactly chalk's apply_chalk_kernel_to_qwen35 kwarg.
|
|
42
|
+
_KERNELS: list[tuple[str, bool]] = [
|
|
43
|
+
("rope", True),
|
|
44
|
+
("fused_lora_delta", True),
|
|
45
|
+
("fused_embedding", True),
|
|
46
|
+
("fused_mlp", False), # off (Liger owns MLP/SwiGLU)
|
|
47
|
+
("attn_epilogue", False), # off (eval-only; needs q/k/v out of LoRA)
|
|
48
|
+
("fp8_frozen_base", False), # off (Hopper sm_90+ only)
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _enabled_kwargs() -> dict[str, bool]:
|
|
53
|
+
"""The fixed ``apply_chalk_kernel_to_qwen35`` boolean kwargs (gap-fillers on, the rest off)."""
|
|
54
|
+
return dict(_KERNELS)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def active_kernels(report: Mapping[str, object] | None) -> list[str]:
|
|
58
|
+
"""The chalk kernels that actually ENGAGED (truthy, non-error result) in an apply report.
|
|
59
|
+
|
|
60
|
+
For a metrics note recording which kernels ran (so chalk engagement is verifiable without the
|
|
61
|
+
console). Excludes ``liger`` (TRL applies Liger; chalk's report carries it as False here).
|
|
62
|
+
"""
|
|
63
|
+
return sorted(
|
|
64
|
+
k
|
|
65
|
+
for k, v in (report or {}).items()
|
|
66
|
+
if k != "liger" and v not in (False, None) and not (isinstance(v, dict) and "error" in v)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def install_chalk_kernels(model=None) -> dict:
|
|
71
|
+
"""Apply chalk's gap-filling kernels to ``model`` — ON by default (like Liger).
|
|
72
|
+
|
|
73
|
+
Uses chalk's Liger-style entry point ``apply_chalk_kernel_to_qwen35(model, liger=False, ...)``:
|
|
74
|
+
Liger is already applied by TRL (``use_liger_kernel``), so chalk composes on top of the live
|
|
75
|
+
Liger modules. Each kernel is a fixed boolean (gap-fillers on, the rest off). Returns chalk's
|
|
76
|
+
per-kernel report, or ``{}`` when there is no model yet or freesolo-chalk isn't installed.
|
|
77
|
+
|
|
78
|
+
chalk's apply patches the LIVE module, so the worker calls this AFTER TRL builds the trainer
|
|
79
|
+
(``model=trainer.model``); ``model is None`` is a safe no-op kept for defensive callers.
|
|
80
|
+
"""
|
|
81
|
+
if model is None:
|
|
82
|
+
# chalk's apply patches the materialized module -> nothing to do before the model is built.
|
|
83
|
+
return {}
|
|
84
|
+
|
|
85
|
+
kwargs = _enabled_kwargs()
|
|
86
|
+
try:
|
|
87
|
+
from chalk.transformers import apply_chalk_kernel_to_qwen35
|
|
88
|
+
except ImportError:
|
|
89
|
+
# chalk is installed by default (PyPI; chalk_extra_pip), so this only fires if an install
|
|
90
|
+
# was disabled/failed. Always safe: the kernels degrade to the eager/Liger path. Only the
|
|
91
|
+
# post-build call reaches this import (the pre-build pass returns early), so it logs at most
|
|
92
|
+
# once per run — no per-process dedup needed.
|
|
93
|
+
log.info(
|
|
94
|
+
"freesolo-chalk is not installed on this worker (set FLASH_CHALK_SPEC to an installable "
|
|
95
|
+
"spec, or check the default PyPI install); chalk kernels off, using eager/Liger."
|
|
96
|
+
)
|
|
97
|
+
return {}
|
|
98
|
+
except Exception as e:
|
|
99
|
+
# A partially-installed / version-incompatible chalk can raise non-ImportError errors at
|
|
100
|
+
# import time (e.g. a Triton/torch mismatch). This hook must never abort training.
|
|
101
|
+
log.warning("chalk import failed (ignored, kernels disabled): %s", e)
|
|
102
|
+
return {}
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
# liger=False: TRL already applied Liger (use_liger_kernel); chalk composes on the live
|
|
106
|
+
# Liger modules. apply_chalk_kernel_to_qwen35 never raises on a per-kernel failure, but
|
|
107
|
+
# guard the call itself so a chalk API/version skew can never abort training.
|
|
108
|
+
report = apply_chalk_kernel_to_qwen35(model, liger=False, **kwargs)
|
|
109
|
+
except Exception as e: # never block training on the optional kernel stack
|
|
110
|
+
log.warning("chalk apply failed (ignored, kernels disabled): %s", e)
|
|
111
|
+
return {}
|
|
112
|
+
|
|
113
|
+
active = active_kernels(report)
|
|
114
|
+
if active:
|
|
115
|
+
log.info("chalk kernels active: %s", ", ".join(active))
|
|
116
|
+
return report or {}
|