freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
flash/cost/facts.py ADDED
@@ -0,0 +1,114 @@
1
+ """Static lookup facts for the cost model: GPU price/VRAM/compute + cheapest-fit
2
+ selection, model size/quant, and reward-grader latency. Pure tables + accessors."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from flash.catalog import MODELS
7
+ from flash.providers.base import GPU_INFO, GpuClass, providers_for
8
+
9
+ # ===== GPU facts =====
10
+ GPU_COMPUTE_TFLOPS: dict[str, float] = {
11
+ "L4": 60.0,
12
+ "RTX 4090": 165.0,
13
+ "RTX 5090": 210.0,
14
+ "RTX A6000": 155.0,
15
+ "A40": 150.0,
16
+ "RTX 6000 Ada": 182.0,
17
+ "A100 PCIe": 312.0,
18
+ "A100 SXM": 312.0,
19
+ "H100": 990.0,
20
+ "RTX Pro 6000": 250.0,
21
+ }
22
+ _DEFAULT_TFLOPS = 100.0
23
+
24
+
25
+ def gpu_tflops(name: str) -> float:
26
+ """Peak bf16 tensor TFLOPS for a managed GPU class."""
27
+ return GPU_COMPUTE_TFLOPS.get(name, _DEFAULT_TFLOPS)
28
+
29
+
30
+ def gpu_hourly_usd(name: str, provider: str | None = None) -> float:
31
+ """Representative $/hr for a class, on ``provider`` when given.
32
+
33
+ The nominal ``GpuClass.hourly_usd`` is the RunPod rate, which is WRONG for a provider-specific
34
+ quote (e.g. a Lambda RTX A6000 is $1.09/hr, not RunPod's $0.49). When ``provider`` is
35
+ ``lambda``/``hyperstack`` and the class is offered there, price it through that provider's
36
+ pricing module (live with a static fallback); otherwise (runpod/auto/None) use the nominal rate.
37
+ """
38
+ info = GPU_INFO.get(name)
39
+ if info is None:
40
+ raise KeyError(f"unknown GPU class {name!r}")
41
+ p = (provider or "").strip().lower()
42
+ if p == "lambda" and info.lambda_name:
43
+ from flash.providers.lambdalabs.pricing import hourly_rate
44
+
45
+ return hourly_rate(name)
46
+ if p == "hyperstack" and info.hyperstack_name:
47
+ from flash.providers.hyperstack.pricing import hourly_rate
48
+
49
+ return hourly_rate(name)
50
+ return info.hourly_usd
51
+
52
+
53
+ def gpu_vram_gb(name: str) -> int:
54
+ info = GPU_INFO.get(name)
55
+ if info is None:
56
+ raise KeyError(f"unknown GPU class {name!r}")
57
+ return info.vram_gb
58
+
59
+
60
+ def pick_gpu(required_vram_gb: int, *, provider: str | None = None) -> str:
61
+ """Cheapest GPU class that fits ``required_vram_gb``, ranked by static $/hr.
62
+
63
+ No pin; every fitting class is eligible, validated or not. NOTE this is intentionally
64
+ gate-free: the submit-time allocator restricts to the validated pool, so the
65
+ actually-provisioned class can be pricier than the one priced here. ``provider`` restricts
66
+ candidates to what it can provision.
67
+ """
68
+
69
+ def _selectable(g: GpuClass) -> bool:
70
+ return provider in (None, "auto") or provider in providers_for(g.name)
71
+
72
+ candidates = [g for g in GPU_INFO.values() if g.vram_gb >= required_vram_gb and _selectable(g)]
73
+ if not candidates:
74
+ raise ValueError(f"no GPU class fits >= {required_vram_gb} GB")
75
+ # Rank by the rate on the REQUESTED provider so a provider-specific quote picks that provider's
76
+ # cheapest fit (not the cheapest by the RunPod nominal rate).
77
+ best = min(candidates, key=lambda g: (gpu_hourly_usd(g.name, provider=provider), g.vram_gb, g.name))
78
+ return best.name
79
+
80
+
81
+ # ===== Model-size facts (catalog-only; five dense text models, no MoE/open-model sizing) =====
82
+ def total_params_b(model_id: str) -> float:
83
+ """Total parameter count (billions) for a catalog model -- the curated ``params_b`` stat."""
84
+ info = MODELS.get(model_id)
85
+ if info is None:
86
+ raise ValueError(
87
+ f"unknown model {model_id!r}; cost estimation supports catalog models only "
88
+ f"({', '.join(MODELS)})"
89
+ )
90
+ return info.params_b
91
+
92
+
93
+ def model_quant(model_id: str) -> str:
94
+ """Quantization of the catalog entry; ``"bf16"`` for the whole catalog today (bf16 default)."""
95
+ info = MODELS.get(model_id)
96
+ return (info.quant or "bf16") if info is not None else "bf16"
97
+
98
+
99
+ def download_weight_gb(model_id: str) -> float:
100
+ """GB pulled from the HF hub at cold start (full bf16 checkpoint, 2 bytes/param)."""
101
+ return total_params_b(model_id) * 2.0
102
+
103
+
104
+ # ===== Reward-grader latency (GRPO) =====
105
+ # A single average grader latency (s/completion) for every env. Graders span ~0.01s (regex/math)
106
+ # to ~3s (LLM judge/code); ~1s is a middle-of-the-road default (a run can override it).
107
+ AVG_REWARD_SECONDS_PER_COMPLETION = 1.0
108
+
109
+
110
+ def reward_seconds_per_completion(override: float | None = None) -> float:
111
+ """Per-completion reward latency (s): the explicit override, else the single average."""
112
+ if override is not None:
113
+ return max(0.0, override)
114
+ return AVG_REWARD_SECONDS_PER_COMPLETION
flash/cost/spec.py ADDED
@@ -0,0 +1,113 @@
1
+ """Map a parsed training ``JobSpec`` to a cost ``RunConfig`` / step count / estimate.
2
+
3
+ Used by ``flash train --cost`` for a pre-flight quote. The control plane bills completed runs
4
+ from their final recorded ``cost_usd`` instead of charging this estimate at submit time."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from flash.cost.analytical import estimate_cost
9
+ from flash.cost.types import CostEstimate, RunConfig
10
+
11
+ # Fallback SFT dataset size when an uncapped run's env can't be counted locally. Most Freesolo
12
+ # training datasets land in the
13
+ # low-thousands of rows; this is a representative middle estimate so the quote is in the right
14
+ # ballpark rather than hard-failing.
15
+ DEFAULT_UNCOUNTED_SFT_EXAMPLES = 1000
16
+
17
+
18
+ def count_env_examples(env_id: str, params: dict | None = None) -> int | None:
19
+ """Training rows in ``env_id``'s dataset (the worker's train split), or ``None`` if it can't
20
+ be loaded. Best-effort -- prices an uncapped SFT run on the real dataset size, not a guess.
21
+
22
+ Loading may need network access for managed Freesolo environments. If the environment
23
+ cannot be loaded in this interpreter, this returns ``None`` and the caller falls back to a
24
+ default count instead of hard-failing."""
25
+ if not env_id:
26
+ return None
27
+ try:
28
+ from flash.envs import load_environment
29
+
30
+ rows = load_environment(env_id, params or {}).dataset()
31
+ except Exception:
32
+ return None
33
+ return len(rows) if rows is not None else None
34
+
35
+
36
+ def spec_steps(spec) -> int:
37
+ """Per-seed optimizer steps implied by a train spec (mirrors the worker). GRPO: ``train.steps``
38
+ (else recipe default). SFT: ``epochs x ceil(num_examples / realized_batch)`` capped by
39
+ ``max_steps``, where ``num_examples`` is ``max_examples`` if pinned else the real env size."""
40
+ from flash.catalog import vocab_size_for
41
+ from flash.engine.recipe import RECIPE
42
+ from flash.engine.vram import resolve_params_b, sft_logits_fused, sft_realized_batch
43
+
44
+ t = spec.train
45
+ if spec.algorithm == "grpo":
46
+ if t.steps is not None:
47
+ return max(1, int(t.steps))
48
+ return RECIPE.rl.num_steps
49
+ # --- SFT ---
50
+ cap = int(t.max_steps) if t.max_steps else 0 # SFT-only optimizer-step cap (0 = uncapped)
51
+ epochs = int(t.epochs) if t.epochs is not None else RECIPE.sft.num_epochs
52
+ requested_batch = int(t.batch_size) if t.batch_size is not None else RECIPE.sft.effective_batch
53
+ # Mirror the worker's per-device micro-batch EXACTLY, incl. the big-vocab logits cap: when the
54
+ # fused CE is OFF the worker vocab-sizes the micro-batch (engine.worker), which (with CEIL'd
55
+ # grad-accum) can change the realized global batch and thus the step count. Feed the same
56
+ # seq/vocab/fused so the priced step count matches what actually runs.
57
+ sft_seq = (
58
+ int(t.max_length)
59
+ if t.max_length is not None
60
+ else (RECIPE.sft.max_seq_len_thinking if spec.thinking else RECIPE.sft.max_seq_len)
61
+ )
62
+ # Resolve params_b via the shared helper (catalog stat else HF safetensors for an open model) —
63
+ # the SAME resolution the worker's run_sft uses. The fused-CE decision (and thus the big-vocab
64
+ # micro-batch cap) hinges on the >=3B threshold, so an uncataloged >=3B model must not be priced
65
+ # as <3B (which would flip fused off, change the realized batch via the cap, and misprice the
66
+ # step count). Best-effort: no network -> None -> the prior <3B (cap-on) behavior.
67
+ sft_fused = sft_logits_fused(resolve_params_b(spec.model), sft_seq)
68
+ batch = sft_realized_batch(
69
+ requested_batch, seq_len=sft_seq, vocab=vocab_size_for(spec.model), fused=sft_fused
70
+ )
71
+ # max_examples is a CAP; 0 (like None) means "no cap" (worker trains the full dataset), so
72
+ # don't let max_examples=0 price a single step.
73
+ pinned_examples = int(t.max_examples) if t.max_examples else 0
74
+ if pinned_examples > 0:
75
+ examples = pinned_examples
76
+ else:
77
+ # No cap: the worker trains the FULL env dataset, so price its real size when we can
78
+ # count it. A managed Freesolo environment may not be reachable in this interpreter, so
79
+ # counting can return None. Fall back to a representative default instead of hard-failing.
80
+ examples = count_env_examples(spec.environment.id, spec.environment.params)
81
+ if examples is None:
82
+ examples = DEFAULT_UNCOUNTED_SFT_EXAMPLES
83
+ n = max(1, -(-examples // batch) * epochs) # epochs x ceil(examples / realized_batch)
84
+ return min(n, cap) if cap > 0 else n
85
+
86
+
87
+ def runconfig_from_spec(spec) -> RunConfig:
88
+ """Map a parsed ``JobSpec`` to a cost ``RunConfig``. Each seed is its own job that re-pays the
89
+ cold start, so steps and setup repeats scale by the seed count. The estimate doesn't pin a
90
+ GPU -- it does its own cheapest-fit (provider="auto")."""
91
+ t, g = spec.train, spec.gpu
92
+ is_grpo = spec.algorithm == "grpo"
93
+ seeds = max(1, len(t.seeds or (0,)))
94
+ return RunConfig(
95
+ model_id=spec.model,
96
+ method=spec.algorithm,
97
+ steps=spec_steps(spec) * seeds,
98
+ setup_repeats=seeds,
99
+ seq_len=t.max_length,
100
+ completion_len=t.max_tokens if is_grpo else None,
101
+ batch_size=t.batch_size,
102
+ group_size=t.group_size if is_grpo else None,
103
+ lora_rank=t.lora_rank,
104
+ thinking=spec.thinking,
105
+ provider="auto",
106
+ max_wall_seconds=g.max_wall_seconds,
107
+ environment=spec.environment.id or None,
108
+ )
109
+
110
+
111
+ def estimate_for_spec(spec) -> CostEstimate:
112
+ """The pre-flight ``CostEstimate`` for a parsed training ``JobSpec``."""
113
+ return estimate_cost(runconfig_from_spec(spec))
flash/cost/types.py ADDED
@@ -0,0 +1,158 @@
1
+ """The estimator's I/O types: ``RunConfig`` (input) and ``CostEstimate`` (result)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, replace
6
+
7
+ from flash.catalog import normalize_algorithm
8
+ from flash.engine.recipe import RECIPE
9
+ from flash.providers import PROVIDER_NAMES
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class RunConfig:
14
+ """One training run to price. ``None`` knobs resolve to recipe defaults."""
15
+
16
+ model_id: str
17
+ method: str # "sft" | "grpo"
18
+ steps: int
19
+
20
+ # Cold-start setups the bill covers: a multi-seed run reprovisions (and re-pays boot) per
21
+ # seed, so this is the seed count.
22
+ setup_repeats: int = 1
23
+
24
+ # Engine context length (forwarded as [train].max_length, NOT prompt length). When unset the
25
+ # GRPO default mirrors the worker's max(1024, max_prompt_len + completion); see normalized().
26
+ seq_len: int | None = None
27
+ completion_len: int | None = None # GRPO only (max_tokens)
28
+ batch_size: int | None = None # SFT effective batch / GRPO prompts_per_step
29
+ group_size: int | None = None # GRPO completions per prompt (G)
30
+ lora_rank: int | None = None
31
+ thinking: bool = False
32
+ # GRPO only: seconds to score one completion. None -> the single average grader latency.
33
+ reward_seconds_per_completion: float | None = None
34
+
35
+ max_wall_seconds: int | None = None # per-seed wall cap (spec gpu.max_wall_seconds); None = 24h
36
+ provider: str = "auto"
37
+ environment: str | None = None # Freesolo environment id; descriptive only
38
+
39
+ def __post_init__(self) -> None:
40
+ object.__setattr__(self, "method", normalize_algorithm(self.method))
41
+ # Normalize like the allocator (case/whitespace, empty -> "auto") and reject an unknown
42
+ # substrate up front (else it filters out every candidate -> confusing "no GPU fits").
43
+ prov = (self.provider or "auto").strip().lower() or "auto"
44
+ if prov not in ("auto", *PROVIDER_NAMES):
45
+ raise ValueError(f"unknown provider {self.provider!r} (auto, {', '.join(PROVIDER_NAMES)})")
46
+ object.__setattr__(self, "provider", prov)
47
+ if self.steps < 1:
48
+ raise ValueError(f"steps must be >= 1, got {self.steps}")
49
+ if self.setup_repeats < 1:
50
+ raise ValueError(f"setup_repeats must be >= 1, got {self.setup_repeats}")
51
+ # Steps are split evenly across seeds, so a non-divisible split would price fractional
52
+ # steps per seed (impossible in a real run).
53
+ if self.steps % self.setup_repeats != 0:
54
+ raise ValueError(
55
+ f"steps ({self.steps}) must be a multiple of setup_repeats ({self.setup_repeats})"
56
+ )
57
+ # Reject 0/negative positive-only knobs (bogus quote). max_wall_seconds is NOT here: the
58
+ # runner floors it to max(60, ...) and estimate_cost mirrors that, so a non-positive cap
59
+ # is accepted (floored to 60s), not rejected.
60
+ for _name in ("seq_len", "batch_size", "group_size", "completion_len", "lora_rank"):
61
+ _val = getattr(self, _name)
62
+ if _val is not None and _val < 1:
63
+ raise ValueError(f"{_name} must be >= 1, got {_val}")
64
+
65
+ @property
66
+ def is_grpo(self) -> bool:
67
+ return self.method == "grpo"
68
+
69
+ def normalized(self) -> RunConfig:
70
+ """A copy with every ``None`` knob filled from the recipe for this method."""
71
+ lora = self.lora_rank if self.lora_rank is not None else RECIPE.lora.rank
72
+ if self.is_grpo:
73
+ comp = self.completion_len
74
+ if comp is None:
75
+ comp = (
76
+ RECIPE.rl.max_completion_len_thinking
77
+ if self.thinking
78
+ else RECIPE.rl.max_completion_len
79
+ )
80
+ # Explicit pin wins; else mirror the allocator's GRPO sizing of an unset max_length:
81
+ # max(1024, max_prompt_len + completion), not bare max_prompt_len (which under-sizes).
82
+ seq = (
83
+ self.seq_len
84
+ if self.seq_len is not None
85
+ else max(1024, RECIPE.rl.max_prompt_len + int(comp))
86
+ )
87
+ batch = self.batch_size if self.batch_size is not None else RECIPE.rl.prompts_per_step
88
+ group = self.group_size if self.group_size is not None else RECIPE.rl.group_size
89
+ else:
90
+ seq = self.seq_len
91
+ if seq is None:
92
+ seq = RECIPE.sft.max_seq_len_thinking if self.thinking else RECIPE.sft.max_seq_len
93
+ comp = None
94
+ batch = self.batch_size if self.batch_size is not None else RECIPE.sft.effective_batch
95
+ group = None
96
+ return replace(self, seq_len=seq, completion_len=comp, batch_size=batch, group_size=group, lora_rank=lora)
97
+
98
+ def train_knobs(self) -> dict[str, int]:
99
+ """The knob dict ``model_required_vram_gb`` consumes. Only an EXPLICIT batch_size is
100
+ forwarded -- an omitted SFT batch sizes as the worker's micro-batch (4), not the recipe's
101
+ effective batch (32), which would over-provision."""
102
+ n = self.normalized()
103
+ knobs: dict[str, int] = {"lora_rank": n.lora_rank}
104
+ if self.batch_size is not None:
105
+ knobs["batch_size"] = self.batch_size
106
+ if n.seq_len is not None:
107
+ knobs["max_length"] = n.seq_len
108
+ if n.completion_len is not None:
109
+ knobs["max_tokens"] = n.completion_len
110
+ if n.group_size is not None:
111
+ knobs["group_size"] = n.group_size
112
+ return knobs
113
+
114
+
115
+ @dataclass(frozen=True)
116
+ class CostEstimate:
117
+ """A pre-flight estimate. ``total_usd`` = ``wall_clock_hours * gpu_hourly_usd``, no multiplier."""
118
+
119
+ model_id: str
120
+ method: str
121
+ steps: int
122
+ gpu: str
123
+ provider: str
124
+ gpu_vram_gb: int
125
+ required_vram_gb: int
126
+ gpu_hourly_usd: float
127
+ setup_seconds: float # cold start: boot + deps + model load (+ vLLM init for GRPO)
128
+ seconds_per_step: float
129
+ train_seconds: float # steps * seconds_per_step (post wall-clock cap)
130
+ wall_clock_seconds: float
131
+ wall_capped: bool
132
+ total_usd: float
133
+ notes: tuple[str, ...] = ()
134
+
135
+ @property
136
+ def wall_clock_hours(self) -> float:
137
+ return self.wall_clock_seconds / 3600.0
138
+
139
+ def breakdown(self) -> str:
140
+ """Multi-line itemized breakdown for CLI output."""
141
+ lines = [
142
+ f"Run : {self.model_id} [{self.method.upper()}, {self.steps} steps]",
143
+ f"GPU : {self.gpu} on {self.provider} "
144
+ f"({self.gpu_vram_gb} GB; run needs >= {self.required_vram_gb} GB) "
145
+ f"@ ${self.gpu_hourly_usd:.2f}/hr",
146
+ f"Setup : {self.setup_seconds / 60:.1f} min (cold start: boot + deps + model load"
147
+ + (" + vLLM init" if self.method == "grpo" else "")
148
+ + ")",
149
+ f"Per step : {self.seconds_per_step:.2f} s",
150
+ f"Train : {self.train_seconds / 60:.1f} min"
151
+ + (" [CAPPED at the wall-clock limit]" if self.wall_capped else ""),
152
+ f"Wall clock : {self.wall_clock_hours:.2f} h",
153
+ f"TOTAL : ${self.total_usd:.2f}",
154
+ ]
155
+ if self.notes:
156
+ lines.append("Notes :")
157
+ lines.extend(f" - {n}" for n in self.notes)
158
+ return "\n".join(lines)
@@ -0,0 +1,6 @@
1
+ """Fine-tuning internals for the Flash package.
2
+
3
+ This subpackage holds the shared recipe, data loaders, graders, run accounting,
4
+ and the on-GPU worker entrypoint. The RunPod provider invokes ``flash.engine.worker``
5
+ on the provisioned GPU.
6
+ """
@@ -0,0 +1,36 @@
1
+ """Cost accounting + the standard run-metrics record for Flash runs.
2
+
3
+ GPU cost = gpu_hours * hourly_rate (per-second billing on RunPod; artifacts go via HF).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from dataclasses import asdict, dataclass, field
10
+
11
+
12
+ @dataclass
13
+ class RunMetrics:
14
+ """Standard metrics record written per phase/seed."""
15
+
16
+ arm: str = "runpod" # compute substrate
17
+ phase: str = "" # "sft" | "rl"
18
+ seed: int = 0
19
+ model_id: str = ""
20
+ # Speed
21
+ wall_seconds: float = 0.0
22
+ setup_seconds: float = 0.0 # cold start / provisioning + model load
23
+ train_throughput_toks_per_s: float = 0.0
24
+ # Token accounting
25
+ train_tokens: int = 0
26
+ generated_tokens: int = 0 # RL: total sampled completion tokens
27
+ # Misc / friction. cost_usd is computed/stamped downstream by the runner from the
28
+ # provider's $/hr (see runner._persist_metrics), not by the worker.
29
+ notes: dict = field(default_factory=dict)
30
+
31
+ def to_json(self) -> str:
32
+ return json.dumps(asdict(self), indent=2)
33
+
34
+ def save(self, path: str):
35
+ with open(path, "w") as f:
36
+ f.write(self.to_json())
@@ -0,0 +1,116 @@
1
+ """Optional chalk GPU kernels (the ``freesolo-chalk`` package).
2
+
3
+ Chalk holds Freesolo's hand-written Triton/CUDA kernels that complement Liger: the RoPE the
4
+ qwen3.5 hybrid arch needs (Liger refuses it), the LoRA-delta matmul, fused MLP, the QKV
5
+ norm+RoPE attention epilogue, embedding gather, and FP8 frozen-base GEMMs.
6
+
7
+ Chalk ships a Liger-style one-call entry point, ``apply_chalk_kernel_to_qwen35(model, ...)``,
8
+ mirroring ``apply_liger_kernel_to_qwen3``: enablement is the call itself (no env flag), each kernel
9
+ is a boolean keyword, and it NEVER raises on a kernel failure (every installer self-tests +
10
+ arch-gates and falls back to the eager/Liger path; a no-op off-GPU). flash applies it
11
+ AUTOMATICALLY — like Liger — after the trainer builds the model, with the **gap-filling** kernels
12
+ Liger leaves on the eager path ON BY DEFAULT: RoPE, the LoRA-delta matmul, and embedding gather.
13
+ The kernels that OVERLAP Liger (fused MLP / SwiGLU — Liger owns MLP) or are situational (the
14
+ eval-only QKV epilogue, the Hopper-only FP8 frozen base) stay OPT-IN.
15
+
16
+ Liger is applied by TRL (``use_liger_kernel``); chalk composes ON TOP of the live Liger modules,
17
+ so flash calls chalk with ``liger=False``. Kernel selection is FIXED (deterministic): the
18
+ gap-fillers run and the overlapping/situational kernels stay off — there is no env override. If
19
+ ``freesolo-chalk`` isn't installed (no ``FLASH_CHALK_SPEC``, or on the control plane) the whole
20
+ module degrades to a no-op.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from collections.abc import Mapping
26
+
27
+ from flash._logging import get_logger
28
+
29
+ log = get_logger(__name__)
30
+
31
+ # Chalk kernel table: (apply_chalk_kernel_to_qwen35 keyword, enabled). Selection is FIXED — there
32
+ # is no env override; the values here are exactly what runs on every supported run.
33
+ # The GAP-FILLERS that complement Liger are ON — applied automatically like apply_liger_kernel —
34
+ # because each chalk installer self-tests on install and falls back to the eager/Liger path on any
35
+ # failure, so always-applying them is safe:
36
+ # * rope — the RoPE Liger REFUSES on the qwen3.5 hybrid arch (its only real gap)
37
+ # * fused_lora_delta — the LoRA-delta matmul on the trainable path (Liger doesn't touch adapters)
38
+ # * fused_embedding — the embedding gather (Liger doesn't touch it)
39
+ # The OVERLAPPING / situational kernels stay OFF: the fused MLP overlaps Liger's SwiGLU (Liger owns
40
+ # MLP), the attn epilogue is eval-only (needs q/k/v out of LORA_TARGETS), and the FP8 frozen base is
41
+ # Hopper sm_90+ only. The keyword is exactly chalk's apply_chalk_kernel_to_qwen35 kwarg.
42
+ _KERNELS: list[tuple[str, bool]] = [
43
+ ("rope", True),
44
+ ("fused_lora_delta", True),
45
+ ("fused_embedding", True),
46
+ ("fused_mlp", False), # off (Liger owns MLP/SwiGLU)
47
+ ("attn_epilogue", False), # off (eval-only; needs q/k/v out of LoRA)
48
+ ("fp8_frozen_base", False), # off (Hopper sm_90+ only)
49
+ ]
50
+
51
+
52
+ def _enabled_kwargs() -> dict[str, bool]:
53
+ """The fixed ``apply_chalk_kernel_to_qwen35`` boolean kwargs (gap-fillers on, the rest off)."""
54
+ return dict(_KERNELS)
55
+
56
+
57
+ def active_kernels(report: Mapping[str, object] | None) -> list[str]:
58
+ """The chalk kernels that actually ENGAGED (truthy, non-error result) in an apply report.
59
+
60
+ For a metrics note recording which kernels ran (so chalk engagement is verifiable without the
61
+ console). Excludes ``liger`` (TRL applies Liger; chalk's report carries it as False here).
62
+ """
63
+ return sorted(
64
+ k
65
+ for k, v in (report or {}).items()
66
+ if k != "liger" and v not in (False, None) and not (isinstance(v, dict) and "error" in v)
67
+ )
68
+
69
+
70
+ def install_chalk_kernels(model=None) -> dict:
71
+ """Apply chalk's gap-filling kernels to ``model`` — ON by default (like Liger).
72
+
73
+ Uses chalk's Liger-style entry point ``apply_chalk_kernel_to_qwen35(model, liger=False, ...)``:
74
+ Liger is already applied by TRL (``use_liger_kernel``), so chalk composes on top of the live
75
+ Liger modules. Each kernel is a fixed boolean (gap-fillers on, the rest off). Returns chalk's
76
+ per-kernel report, or ``{}`` when there is no model yet or freesolo-chalk isn't installed.
77
+
78
+ chalk's apply patches the LIVE module, so the worker calls this AFTER TRL builds the trainer
79
+ (``model=trainer.model``); ``model is None`` is a safe no-op kept for defensive callers.
80
+ """
81
+ if model is None:
82
+ # chalk's apply patches the materialized module -> nothing to do before the model is built.
83
+ return {}
84
+
85
+ kwargs = _enabled_kwargs()
86
+ try:
87
+ from chalk.transformers import apply_chalk_kernel_to_qwen35
88
+ except ImportError:
89
+ # chalk is installed by default (PyPI; chalk_extra_pip), so this only fires if an install
90
+ # was disabled/failed. Always safe: the kernels degrade to the eager/Liger path. Only the
91
+ # post-build call reaches this import (the pre-build pass returns early), so it logs at most
92
+ # once per run — no per-process dedup needed.
93
+ log.info(
94
+ "freesolo-chalk is not installed on this worker (set FLASH_CHALK_SPEC to an installable "
95
+ "spec, or check the default PyPI install); chalk kernels off, using eager/Liger."
96
+ )
97
+ return {}
98
+ except Exception as e:
99
+ # A partially-installed / version-incompatible chalk can raise non-ImportError errors at
100
+ # import time (e.g. a Triton/torch mismatch). This hook must never abort training.
101
+ log.warning("chalk import failed (ignored, kernels disabled): %s", e)
102
+ return {}
103
+
104
+ try:
105
+ # liger=False: TRL already applied Liger (use_liger_kernel); chalk composes on the live
106
+ # Liger modules. apply_chalk_kernel_to_qwen35 never raises on a per-kernel failure, but
107
+ # guard the call itself so a chalk API/version skew can never abort training.
108
+ report = apply_chalk_kernel_to_qwen35(model, liger=False, **kwargs)
109
+ except Exception as e: # never block training on the optional kernel stack
110
+ log.warning("chalk apply failed (ignored, kernels disabled): %s", e)
111
+ return {}
112
+
113
+ active = active_kernels(report)
114
+ if active:
115
+ log.info("chalk kernels active: %s", ", ".join(active))
116
+ return report or {}