freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,18 @@
1
+ """Static per-GPU hourly rates for RunPod-provisionable Flash classes."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def static_rates() -> dict[str, float]:
7
+ """Friendly GPU name -> static $/hr snapshot."""
8
+ from flash.providers.base import GPU_INFO
9
+
10
+ return {name: info.hourly_usd for name, info in GPU_INFO.items() if info.enum_member}
11
+
12
+
13
+ def hourly_rate(gpu_name: str) -> float:
14
+ """Static $/hr for one friendly GPU name."""
15
+ from flash.providers.base import canonical_gpu
16
+
17
+ name = canonical_gpu(gpu_name)
18
+ return static_rates()[name]
@@ -0,0 +1,79 @@
1
+ """Client for the shared RunPod endpoint-slot quota (freesolo backend).
2
+
3
+ The control plane provisions one live RunPod endpoint per run and must keep its in-flight
4
+ endpoints under RunPod's account-wide cap. When an operator internal key is configured, the cap
5
+ is enforced CROSS-PROCESS via the freesolo backend's ``runpod_endpoint_slots`` store
6
+ (``POST /api/runpod/internal/slots/*``): an advisory-locked atomic claim, so >1 control-plane
7
+ replica can never together exceed the cap, and a startup reconcile recovers the true in-use
8
+ count after a crash. Without an internal key (local/dev single process) the caller falls back to
9
+ an in-process semaphore — see ``endpoints.py``.
10
+
11
+ This module is the thin network boundary only; the queue/fallback policy lives in ``endpoints``.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ import socket
19
+ import urllib.error
20
+ import urllib.request
21
+
22
+ from flash.server.auth import INTERNAL_KEY_ENV, freesolo_base_url
23
+
24
+ _TIMEOUT_S = 10.0
25
+ _CLAIM_PATH = "/api/runpod/internal/slots/claim"
26
+ _RELEASE_PATH = "/api/runpod/internal/slots/release"
27
+ _RECONCILE_PATH = "/api/runpod/internal/slots/reconcile"
28
+
29
+
30
+ class SlotStoreError(RuntimeError):
31
+ """The shared slot store could not be reached or returned an error."""
32
+
33
+
34
+ def internal_key() -> str | None:
35
+ """The operator internal key, or None when unset (local/dev: use the in-process semaphore)."""
36
+ return os.environ.get(INTERNAL_KEY_ENV, "").strip() or None
37
+
38
+
39
+ def claimed_by_ident() -> str:
40
+ """A best-effort host:pid tag so a slot row records which replica holds it (debugging only)."""
41
+ return f"{socket.gethostname()}:{os.getpid()}"
42
+
43
+
44
+ def _post(path: str, body: dict) -> dict:
45
+ key = internal_key()
46
+ if not key:
47
+ raise SlotStoreError(f"{INTERNAL_KEY_ENV} is not configured")
48
+ req = urllib.request.Request(
49
+ f"{freesolo_base_url()}{path}",
50
+ data=json.dumps(body).encode("utf-8"),
51
+ method="POST",
52
+ headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"},
53
+ )
54
+ try:
55
+ with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp:
56
+ raw = resp.read()
57
+ except (urllib.error.HTTPError, urllib.error.URLError, OSError) as exc:
58
+ raise SlotStoreError(str(exc)) from exc
59
+ try:
60
+ return json.loads(raw or b"{}")
61
+ except ValueError as exc:
62
+ raise SlotStoreError(f"bad slot-store response: {exc}") from exc
63
+
64
+
65
+ def claim(name: str, *, cap: int, claimed_by: str | None = None) -> tuple[bool, int]:
66
+ """Atomically claim a slot. Returns ``(claimed, in_use)``; ``claimed`` is False (no error)
67
+ when the cap is full, so the caller queues and retries rather than oversubscribing RunPod."""
68
+ out = _post(_CLAIM_PATH, {"name": name, "cap": cap, "claimedBy": claimed_by})
69
+ return bool(out.get("claimed")), int(out.get("inUse") or 0)
70
+
71
+
72
+ def release(name: str) -> bool:
73
+ """Release a slot (idempotent server-side). Returns whether a row was actually removed."""
74
+ return bool(_post(_RELEASE_PATH, {"name": name}).get("released"))
75
+
76
+
77
+ def reconcile(live_names: list[str]) -> dict:
78
+ """Reclaim slots whose endpoint is no longer live. Returns ``{"inUse", "reclaimed"}``."""
79
+ return _post(_RECONCILE_PATH, {"liveNames": list(live_names)})
@@ -0,0 +1,150 @@
1
+ """RunPod Flash fine-tuning endpoints (queue-based, one dedicated GPU per run).
2
+
3
+ Flash provisions a dedicated RunPod GPU (RTX 4090 / 5090, no Docker), installs
4
+ ``WORKER_DEPS``, runs the handler, returns the metrics dict, and scales to zero.
5
+
6
+ Flash's live ("ad-hoc") provisioning does not bundle local project code, so the
7
+ handler fetches the ``flash`` package from the HF dataset repo (uploaded by
8
+ ``upload_code`` before submit), adds it to ``PYTHONPATH``, and runs
9
+ ``flash.engine.worker`` to train. The worker streams the adapter + checkpoints to
10
+ the same HF repo for serving and preemption-resilient resume.
11
+
12
+ This is a package: the worker dependency stack + per-run env / chalk selection live in
13
+ ``.deps`` (the leaf), the endpoint lifecycle + worker handler in ``.endpoints``; this
14
+ ``__init__`` owns code upload + submit and re-exports the package's public surface so the
15
+ import path ``flash.providers.runpod.train`` is unchanged.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import asyncio
21
+ import inspect
22
+ import os
23
+
24
+ # Re-export the package's public surface so ``from flash.providers.runpod.train import <name>``
25
+ # keeps working unchanged for callers and tests.
26
+ from flash.providers.runpod.train.deps import ( # noqa: F401
27
+ DEFAULT_CHALK_SPEC,
28
+ DEFAULT_EXECUTION_TIMEOUT_MS,
29
+ WORKER_DEPS,
30
+ WORKER_IMAGE,
31
+ WORKER_SYSTEM_DEPS,
32
+ _effective_worker_env,
33
+ build_worker_env,
34
+ chalk_extra_pip,
35
+ logger,
36
+ resolve_worker_deps,
37
+ strip_runpod_volume_env,
38
+ worker_image_for_gpu,
39
+ )
40
+ from flash.providers.runpod.train.endpoints import ( # noqa: F401
41
+ _ENDPOINT_CACHE,
42
+ FLASH_SDK_LOCK,
43
+ _patch_runpod_backoff,
44
+ _run_suffix,
45
+ _select_endpoint_resources,
46
+ _train_body,
47
+ endpoint_name,
48
+ get_train_endpoint,
49
+ isolate_flash_state,
50
+ min_cuda_for,
51
+ stop_endpoint,
52
+ terminate_endpoint,
53
+ )
54
+ from flash.spec import JobSpec
55
+
56
+
57
+ def upload_code(repo: str | None = None) -> str:
58
+ """Upload the ``flash`` package to the run's HF artifact repo.
59
+
60
+ ``repo`` is the per-run artifact repo (``spec.train.hf_repo``); the worker fetches
61
+ ``code/**`` from the same repo it is given in the submit payload, so the code must land in
62
+ that per-run repo.
63
+
64
+ The worker downloads ``code/**`` to ``/runcode``. There are no built-in example
65
+ environments to ship; Freesolo SDK support is installed through
66
+ ``registry.worker_pip_for_env`` and environment ids are resolved by the adapter at load time.
67
+
68
+ Only the ``flash`` package is uploaded, NOT the client's project tree. Managed runs must
69
+ reference a published Freesolo environment by ``id`` (``flash env push`` to publish a local
70
+ env first).
71
+ """
72
+ from huggingface_hub import HfApi
73
+
74
+ import flash
75
+
76
+ if not repo:
77
+ raise RuntimeError(
78
+ "hf_repo must be set (the run's [train] hf_repo: HF dataset repo for code + artifacts)"
79
+ )
80
+ token = os.environ.get("HF_TOKEN")
81
+ # ``realpath`` collapses any symlink in the package path so the upload reads the REAL installed
82
+ # tree, not a link target a redeploy may have re-pointed (e.g. a /current -> /releases/<sha>
83
+ # symlink layout). This is the package the worker re-imports, so what we upload == what runs.
84
+ pkg_dir = os.path.realpath(os.path.dirname(os.path.abspath(flash.__file__)))
85
+ api = HfApi(token=token)
86
+ # Run artifact repos are always private (they carry run code, adapters, and metrics).
87
+ api.create_repo(repo, repo_type="dataset", exist_ok=True, private=True)
88
+ # create_repo(exist_ok=True) is a no-op on an EXISTING repo, so `private=True` above does NOT
89
+ # change the visibility of a repo that was created earlier as public. Force private explicitly
90
+ # so a reused/public artifact repo can't leak run code/adapters/metrics under the always-private
91
+ # invariant. (Idempotent: a no-op on a repo that is already private.)
92
+ api.update_repo_settings(repo_id=repo, repo_type="dataset", private=True)
93
+ api.upload_folder(
94
+ folder_path=pkg_dir,
95
+ path_in_repo="code/flash",
96
+ repo_id=repo,
97
+ repo_type="dataset",
98
+ ignore_patterns=["__pycache__/*", "*.pyc"],
99
+ # Exact-mirror code/flash so the worker never re-imports an orphaned/renamed module a prior
100
+ # additive upload left behind. delete_patterns are relative to path_in_repo, so "**" is
101
+ # scoped to code/flash (only orphans there are purged; unchanged files are kept).
102
+ delete_patterns=["**"],
103
+ )
104
+ return repo
105
+
106
+
107
+ def submit_train(
108
+ spec: JobSpec, seed: int, log=None, runtime_secrets: dict[str, str] | None = None
109
+ ) -> dict:
110
+ """Provision a dedicated GPU via Flash, run training, return the metrics dict."""
111
+ timeout_s = max(60, int(spec.gpu.max_wall_seconds))
112
+ from flash.envs.registry import worker_pip_for_env
113
+
114
+ handler = get_train_endpoint(
115
+ spec.gpu.type,
116
+ execution_timeout_ms=timeout_s * 1000,
117
+ name_suffix=_run_suffix(spec.run_id),
118
+ disk_gb=spec.gpu.disk_gb,
119
+ spec=spec,
120
+ )
121
+ payload = {
122
+ "hf_repo": spec.train.hf_repo,
123
+ "job_spec_json": spec.to_json(),
124
+ "phase": spec.phase,
125
+ "seed": int(seed),
126
+ "env": build_worker_env(spec, seed, runtime_secrets=runtime_secrets),
127
+ # extra_pip is installed by the worker for EVERY job (baked-image RunPod _train_body and
128
+ # Vast bootstrap both pip-install it), so it's where the chalk spec must go to reach a
129
+ # default run — see chalk_extra_pip().
130
+ "extra_pip": (list(spec.environment.pip) or worker_pip_for_env(spec.environment.id))
131
+ + chalk_extra_pip(spec),
132
+ }
133
+ if log is not None:
134
+ print(
135
+ f"submitting Flash job: gpu={spec.gpu.type} phase={spec.phase} "
136
+ f"seed={seed} model={spec.model}",
137
+ file=log,
138
+ flush=True,
139
+ )
140
+
141
+ async def _call():
142
+ res = handler(payload)
143
+ if inspect.isawaitable(res):
144
+ res = await res
145
+ return res
146
+
147
+ out = asyncio.run(_call())
148
+ if not isinstance(out, dict):
149
+ raise RuntimeError(f"flash job returned no metrics: {out!r}")
150
+ return out
@@ -0,0 +1,395 @@
1
+ """Worker dependency stack + per-run env / chalk-kernel selection (leaf module).
2
+
3
+ The substrate-neutral training dependency list (``WORKER_DEPS``), the prebuilt worker
4
+ image, the per-run worker env builder, and the chalk-kernel install-selection helpers.
5
+ This is the leaf of the ``train`` package: it imports nothing else from the package and
6
+ defines the shared ``logger`` the other submodules import.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+
13
+ from flash._logging import get_logger
14
+ from flash.client.runtime_secrets import DEFAULT_RUNTIME_SECRET_KEYS
15
+ from flash.providers.base import get_gpu_info
16
+ from flash.spec import JobSpec
17
+
18
+ # Literal name (NOT __name__) so the logger stays "flash.providers.runpod.train" after the
19
+ # module split into a package — callers/tests assert that exact name.
20
+ logger = get_logger("flash.providers.runpod.train")
21
+
22
+
23
+ # Worker stack: trl 1.6 (colocate default; adds the GRPO `tools=` / `rollout_func`
24
+ # multi-turn hooks used for Freesolo EnvironmentMultiTurn training), vllm 0.19.1
25
+ # (Qwen3.5/3.6 archs, native RL APIs, transformers-5
26
+ # compatible metadata), transformers 5.x (qwen3_5/qwen3_5_moe model types),
27
+ # bitsandbytes (the 8-bit paged AdamW optimizer state — LoRA+ coexists with it).
28
+ # trl 1.6 requires transformers>=4.56,
29
+ # satisfied by the 5.6+ pin; GRPOConfig is field-compatible with the 1.5 usage here.
30
+ # Resolver/driver notes: vllm 0.17/0.18 hard-pin transformers<5 (uv refuses the
31
+ # combo), so the first transformers-5-compatible vllm line is 0.19.1. vllm >=0.20
32
+ # pins torch 2.11 whose default pypi wheels are CUDA-13 builds — RunPod 4090/5090
33
+ # hosts filtered at min_cuda 12.8 often run 12.8/12.9 drivers where cu13 torch sees
34
+ # NO GPU (observed: "cuda not available" + vLLM "cumem allocator not supported").
35
+ # vllm 0.19.1 pins torch 2.10 (cu128 default) which matches those drivers.
36
+ # trl's *optional* [vllm] extra caps at 0.18, but we install plain trl, so the only
37
+ # constraint that matters is runtime API compat — validated per-model on real
38
+ # RTX 4090/5090 workers before promotion to default (see bench/results/phase1).
39
+ WORKER_DEPS = [
40
+ "torch==2.10.0",
41
+ "transformers>=5.6,<5.13",
42
+ "trl>=1.6,<1.7",
43
+ "peft>=0.19",
44
+ "vllm==0.19.1",
45
+ "bitsandbytes>=0.49",
46
+ "datasets>=4.7,<6",
47
+ # >=0.2.49: first version exposing Environment.sft_completion + datasets.target_messages,
48
+ # which the worker's SFT/multi-turn path now calls (flash #162 multi-turn SFT).
49
+ "freesolo>=0.2.49",
50
+ "huggingface_hub>=0.25",
51
+ "accelerate>=1.4",
52
+ # NB: the HF `kernels` Hub package is intentionally NOT pinned here — the versions
53
+ # compatible with torch2.10 break transformers 5.6-5.10's hub_kernels integration at IMPORT
54
+ # (LayerRepository now requires a version; transformers passes none -> ValueError on every
55
+ # `import transformers`). FlashAttention via the Hub is therefore disabled; attention uses
56
+ # SDPA (already a flash/efficient backend on Ampere/Ada) + the Liger fused kernels below,
57
+ # which are the dominant LoRA speedup anyway. (FA via a pinned flash-attn wheel is a future
58
+ # per-arch experiment, kept out of the default deps to avoid a fragile cold-start install.)
59
+ # Liger fused Triton kernels (pure Triton -> JITs on every arch incl. Blackwell): fused
60
+ # linear cross-entropy for SFT (use_liger_kernel) and the chunked GRPO loss
61
+ # (use_liger_loss) — the big large-vocab (Qwen3.5 ~248k) memory/throughput win.
62
+ "wandb>=0.17",
63
+ "liger-kernel>=0.5",
64
+ # Fused Triton kernels for Gated-DeltaNet (Qwen3.5/3.6 family): without this, transformers
65
+ # falls back to a pure-PyTorch delta rule that is dramatically slower + memory-heavier (measured
66
+ # A/B, H100 SXM, Qwen3.5 hidden-2560 LoRA: seq4096 435->105 ms/step & 9.9->6.1 GB = 4.2x/1.6x;
67
+ # seq16384 3106->247 ms & 32->17 GB = 12.6x/1.9x; forward loss matches to 1.8e-4). Installed
68
+ # from git: the PyPI ``flash-linear-attention`` wheel is a broken stub missing ``fla.modules``.
69
+ # PINNED to a specific commit (not the moving default branch) so cold-start installs are
70
+ # reproducible and a breaking upstream change can't silently land on a run; bump intentionally
71
+ # after validating. Keep this SHA in lockstep with Dockerfile.worker's fla install.
72
+ "flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention.git@f0e213dbd8b5fb90c3c7eca869ac1706d5377139",
73
+ # fla's gated chunk_bwd is INCORRECT on Hopper (H100) with Triton>=3.4 (fla #640); its
74
+ # ``tilelang`` backend is the correct path there, so we KEEP fla on every arch (the worker's
75
+ # _ensure_fla_fastpath_on_hopper ensures tilelang is live on sm90 before any model import).
76
+ # PINNED (like the fla SHA) so cold-start installs / image rebuilds are reproducible and a
77
+ # breaking upstream tilelang can't silently land on the Hopper correctness path. Keep this
78
+ # version in lockstep with Dockerfile.worker + perf.py's runtime reinstall.
79
+ "tilelang==0.1.11",
80
+ "apache-tvm-ffi==0.1.11", # pin: 0.1.12 double-registers TVM-FFI -> `import tilelang` aborts
81
+ # NB: ``causal_conv1d`` (the conv kernel that, with fla's cu_seqlens, lets GatedDeltaNet hybrids
82
+ # PACK boundary-correctly — engine.worker.packing) is NOT pip-listed here: it's a CUDA-extension
83
+ # build, so Dockerfile.worker compiles it best-effort with TORCH_CUDA_ARCH_LIST set (a plain pip
84
+ # entry would try to compile at the no-GPU image-build step for the wrong/native arch). If it's
85
+ # absent the GDN packing path stays off (gdn_packing_available) and Qwen3.5/3.6 train unpacked.
86
+ # NB: freesolo-chalk (custom Triton/CUDA kernels that complement Liger) is NOT in this base dep
87
+ # list, but its gap-fillers are default-on (flash/engine/chalk_kernels.py), so the submit path
88
+ # (chalk_extra_pip) appends the version-pinned ``freesolo-chalk`` (DEFAULT_CHALK_SPEC) from PyPI
89
+ # to the worker's `extra_pip` for EVERY job by default — installed + applied automatically, like
90
+ # Liger. Override the install SOURCE with FLASH_CHALK_SPEC (an exact version / git URL / wheel).
91
+ # The worker pip-installs extra_pip for EVERY job through the baked-image RunPod _train_body.
92
+ ]
93
+ # NOTE on download speed: Flash's runtime already ships hf_transfer and exports
94
+ # HF_HUB_ENABLE_HF_TRANSFER=1 on workers (measured: Qwen3-4B's ~8 GB pulled in 6.3 s,
95
+ # NIC-saturated — bench/results/phase6). Adding hf_transfer here is redundant; don't.
96
+ WORKER_SYSTEM_DEPS = ["build-essential"] # Triton/Inductor need a C compiler
97
+
98
+ # The prebuilt worker image (full training stack baked in; built by Dockerfile.worker /
99
+ # .github/workflows/worker-image.yml). PUBLIC under the org namespace, so no registry login is
100
+ # ever needed. Must be published to GHCR + made public before the paths that use it can pull it.
101
+ # * RunPod baked-image submit (jobs.deploy_train_endpoint / build_function_input): the default —
102
+ # a self-contained serverless worker whose rp_handler runs _train_body. FLASH_WORKER_IMAGE
103
+ # overrides the tag (e.g. a hotfix); the boot-install fallback is only reachable if BOTH are
104
+ # cleared (not a normal configuration).
105
+ # * RunPod Flash live-endpoint (train.endpoints.get_train_endpoint): does NOT use this baked image
106
+ # by default — RunPod Flash needs its serverless runtime baked in, which WORKER_IMAGE lacks, so
107
+ # it boot-installs resolve_worker_deps() on Flash's default template instead. It uses an image
108
+ # only when the operator sets FLASH_WORKER_IMAGE to a RunPod-serverless-compatible one.
109
+ # So FLASH_WORKER_IMAGE IS an operator override (consulted by every RunPod path).
110
+ WORKER_IMAGE = "ghcr.io/freesolo-co/flash-worker:cu128"
111
+ WORKER_IMAGE_TEMPLATE_ENV = "FLASH_WORKER_IMAGE_TEMPLATE"
112
+ WORKER_IMAGE_PER_SM_ENV = "FLASH_WORKER_IMAGE_PER_SM"
113
+
114
+
115
+ def _truthy(value: str | None) -> bool:
116
+ return (value or "").strip().lower() in {"1", "true", "yes", "on"}
117
+
118
+
119
+ def _append_tag_suffix(image: str, suffix: str) -> str:
120
+ slash = image.rfind("/")
121
+ colon = image.rfind(":")
122
+ if colon > slash:
123
+ return f"{image[:colon]}:{image[colon + 1:]}-{suffix}"
124
+ return f"{image}-{suffix}"
125
+
126
+
127
+ def worker_image_for_gpu(friendly_gpu: str | None, *, allow_default: bool = True) -> str | None:
128
+ """Return the RunPod worker image for a GPU class.
129
+
130
+ ``FLASH_WORKER_IMAGE`` remains the absolute override. Per-SM warmed images are opt-in through
131
+ either ``FLASH_WORKER_IMAGE_TEMPLATE`` (for custom tag layouts) or ``FLASH_WORKER_IMAGE_PER_SM``
132
+ (which appends ``-smXX`` to ``WORKER_IMAGE``'s tag). Without either opt-in, the current base
133
+ image is returned for durable jobs and ``None`` is returned for live endpoints that request no
134
+ default image.
135
+ """
136
+ override = os.environ.get("FLASH_WORKER_IMAGE", "").strip()
137
+ if override:
138
+ return override
139
+ # per-sm / template images are durable RunPod queue-worker images (CMD runs rp_handler.py), NOT
140
+ # RunPod Flash live-function images, so they must never leak into the live path. only the
141
+ # absolute FLASH_WORKER_IMAGE override (handled above) is treated as live-compatible.
142
+ if friendly_gpu and allow_default:
143
+ info = get_gpu_info(friendly_gpu)
144
+ template = os.environ.get(WORKER_IMAGE_TEMPLATE_ENV, "").strip()
145
+ if template:
146
+ return template.format(
147
+ base_image=WORKER_IMAGE,
148
+ gpu=info.name,
149
+ gpu_short=info.short,
150
+ sm=info.sm,
151
+ sm_num=info.sm.removeprefix("sm"),
152
+ )
153
+ if _truthy(os.environ.get(WORKER_IMAGE_PER_SM_ENV)):
154
+ return _append_tag_suffix(WORKER_IMAGE, info.sm)
155
+ return WORKER_IMAGE if allow_default else None
156
+
157
+
158
+ def resolve_worker_deps() -> list[str]:
159
+ """The dependency list Flash installs on the GPU worker for this run.
160
+
161
+ The pinned ``WORKER_DEPS`` is authoritative — flash is fully managed, no per-run override.
162
+
163
+ fla is kept on ALL arches (including Hopper): the worker's _ensure_fla_fastpath_on_hopper
164
+ ensures fla's correct ``tilelang`` backend is live on sm90 before any model import (fla #640's
165
+ Triton>=3.4 miscompute is a tilelang-backend fix, not a reason to drop fla). This makes Hopper
166
+ GDN training ~4-13x faster + ~2x less memory than the pure-PyTorch delta fallback.
167
+ """
168
+ return list(WORKER_DEPS)
169
+
170
+
171
+ def _effective_worker_env(spec=None) -> dict[str, str]:
172
+ """The env the WORKER process will actually see, for chalk-selection decisions.
173
+
174
+ chalk install-on-call is selected by ``FLASH_*`` flags read on the worker from its own process
175
+ env, which ``build_worker_env`` builds as the control-plane ``os.environ`` allowlist with the
176
+ run's ``[worker_env]`` overrides merged ON TOP (per-run ``spec.worker_env`` wins). A run that
177
+ opts into chalk via its ``[worker_env]`` block therefore sets the flag the worker reads — so the
178
+ SAME merge must decide whether chalk is selected and whether its spec is added to ``extra_pip``;
179
+ reading bare ``os.environ`` here would miss a per-run ``[worker_env]`` opt-in and the kernels
180
+ would never install for that run.
181
+
182
+ Returns ``os.environ`` overlaid with ``spec.worker_env`` (string-coerced). ``spec=None`` (no
183
+ per-run env) collapses to plain ``os.environ``.
184
+ """
185
+ eff: dict[str, str] = dict(os.environ)
186
+ for k, v in (getattr(spec, "worker_env", None) or {}).items():
187
+ eff[str(k)] = str(v)
188
+ return eff
189
+
190
+
191
+ # Default chalk install spec when FLASH_CHALK_SPEC is unset. VERSION-PINNED (bounded range, like the
192
+ # rest of WORKER_DEPS) so a default run is reproducible and a breaking freesolo-chalk release can't
193
+ # silently land on production jobs — 0.1.x patches are allowed, 0.2 is not. Bump intentionally after
194
+ # validating a new line; an operator can pin exactly via FLASH_CHALK_SPEC=freesolo-chalk==X.Y.Z.
195
+ DEFAULT_CHALK_SPEC = "freesolo-chalk>=0.1.0,<0.2.0"
196
+
197
+
198
+ def chalk_extra_pip(spec=None) -> list[str]:
199
+ """Chalk pip spec(s) to ADD to the worker's ``extra_pip`` when a chalk kernel is selected.
200
+
201
+ This is the install hook that runs for DEFAULT remote jobs: the baked-image RunPod path
202
+ (``_train_body`` -> ``pip install *extra_pip``) consumes the payload's ``extra_pip``
203
+ regardless of ``WORKER_IMAGE`` — unlike ``resolve_worker_deps``, which the durable
204
+ ``build_function_input`` baked-image path skips.
205
+
206
+ Selection (and the ``FLASH_CHALK_SPEC`` lookup) is resolved against the EFFECTIVE worker env —
207
+ the run's ``[worker_env]`` merged over ``os.environ`` — so it matches exactly what the worker
208
+ process will see (``build_worker_env``) and a per-run ``[worker_env]`` opt-in installs chalk.
209
+
210
+ Chalk's gap-fillers are default-on, so chalk is selected for a normal run even with no FLASH_*
211
+ flags set. freesolo-chalk is published on PyPI, so it auto-installs by DEFAULT (just like Liger):
212
+ when chalk is selected and ``FLASH_CHALK_SPEC`` is unset we add the version-pinned
213
+ :data:`DEFAULT_CHALK_SPEC`. Set ``FLASH_CHALK_SPEC`` to override the source (an exact version, a
214
+ git URL, or a wheel/path), or disable every kernel (``FLASH_<K>=0``) to skip the install.
215
+ """
216
+ # PyPI default (version-pinned for reproducibility) — chalk is published, so a normal run
217
+ # installs + applies it automatically. An explicit FLASH_CHALK_SPEC overrides the source.
218
+ spec_str = _effective_worker_env(spec).get("FLASH_CHALK_SPEC", "").strip() or DEFAULT_CHALK_SPEC
219
+ import shlex
220
+
221
+ return [d for d in shlex.split(spec_str) if d.strip()]
222
+
223
+
224
+ DEFAULT_EXECUTION_TIMEOUT_MS = 6 * 3600 * 1000 # 6h RunPod worker execution cap
225
+
226
+
227
+ _RUNTIME_SECRET_KEYS = DEFAULT_RUNTIME_SECRET_KEYS
228
+
229
+ # RunPod serverless mounts a network volume at this FIXED path (can't mount over ~/.cache), so the
230
+ # redirect is conditional per-run env, not a static image ENV.
231
+ _WEIGHT_CACHE_MOUNT = "/runpod-volume"
232
+
233
+
234
+ def weight_cache_env(mount: str = _WEIGHT_CACHE_MOUNT) -> dict[str, str]:
235
+ """Worker env that points the HF cache at the persistent volume mount.
236
+
237
+ Only used when a weight-cache volume is attached (jobs.weight_cache_endpoint_kwargs). ``HF_HOME``
238
+ is the whole feature — the model download becomes a one-time cost per region instead of per run
239
+ (hf_transfer already saturates the NIC; this just makes it land somewhere persistent).
240
+
241
+ DELIBERATELY HF-only. The volume is SHARED platform-wide (multi-tenant), and HF snapshots are
242
+ inert DATA. We do NOT redirect the executable kernel-JIT caches (Triton/Inductor/tilelang/
243
+ torch-extensions) onto it: those are compiled artifacts the worker *executes*, so sharing them
244
+ across tenants on one volume would let a buggy/hostile run's environment code poison a later
245
+ unrelated run in the same region. JIT caches stay per-worker/ephemeral (the ~10-15 min first-use
246
+ compile is paid per cold worker, as before).
247
+
248
+ Concurrent writes: two cold runs landing in the same region before it is warm both
249
+ snapshot_download the model onto this shared mount. huggingface_hub guards this with a per-blob
250
+ file lock + atomic rename (content-addressed blobs), so the worst case is duplicated download
251
+ work, not a corrupt snapshot. The preload step (flash/providers/runpod/preload.py) pre-warms each
252
+ region precisely so real runs hit a populated cache and the cold-concurrent case is rare.
253
+ """
254
+ return {"HF_HOME": f"{mount}/hf-cache"}
255
+
256
+
257
+ def drop_unmounted_cache_env(env: dict, mount: str = _WEIGHT_CACHE_MOUNT) -> dict:
258
+ """Strip any ``mount``-rooted cache vars when the volume isn't actually mounted (mutates+returns).
259
+
260
+ Defense-in-depth for the cold/no-volume fallback: if the cache attach degraded to ``{}`` (an SDK
261
+ error) or the worker simply has no volume, ``HF_HOME`` would otherwise point at a non-existent
262
+ ``/runpod-volume`` path. Dropping it lets HF fall back to the default ephemeral cache (a correct
263
+ cold run) instead of writing under a missing/ephemeral mount. Reads the real filesystem
264
+ (``os.path.isdir``) but takes the env as an argument, so tests drive it by monkeypatching isdir.
265
+ """
266
+ if os.path.isdir(mount):
267
+ return env
268
+ for k in [k for k, v in env.items() if str(v).startswith(mount)]:
269
+ env.pop(k, None)
270
+ return env
271
+
272
+
273
+ def strip_runpod_volume_env(env: dict, mount: str = _WEIGHT_CACHE_MOUNT) -> dict:
274
+ """Remove the RunPod weight-cache redirect from an env bound for a NON-RunPod worker (mutates).
275
+
276
+ Instance providers (Lambda/Hyperstack) reuse this module's shared ``build_worker_env``, which
277
+ redirects ``HF_HOME`` onto the RunPod network-volume mount (``/runpod-volume``) whenever the run
278
+ carries a weight-cache volume. That mount exists ONLY on RunPod serverless — on a rented instance
279
+ it is a nonexistent path with no cross-run persistence — so the instance submit path strips any
280
+ ``/runpod-volume``-rooted cache var here (unconditionally: instance providers never mount it). A
281
+ user ``[worker_env]`` ``HF_HOME`` override is not ``/runpod-volume``-rooted, so it is preserved.
282
+ """
283
+ for k in [k for k, v in env.items() if str(v).startswith(mount)]:
284
+ env.pop(k, None)
285
+ return env
286
+
287
+
288
+ def build_worker_env(
289
+ spec: JobSpec,
290
+ seed: int,
291
+ runtime_secrets: dict[str, str] | None = None,
292
+ ) -> dict:
293
+ """Per-run env passed to the worker (platform creds + recipe overrides).
294
+
295
+ Provider and artifact credentials still come from the control-plane process environment.
296
+ User runtime secrets (W&B plus [environment].secrets) are injected from ``runtime_secrets``
297
+ below so the control plane never stores user-owned secret values in the spec.
298
+ """
299
+ # CUDA allocator conf. Colocate (TRL trainer + vLLM on one GPU) fragments over a long run, so
300
+ # expandable_segments (which reclaims fragmentation) is the right default — EXCEPT under GRPO
301
+ # vLLM sleep mode, whose CuMemAllocator memory pool is incompatible with expandable_segments
302
+ # (vLLM asserts and the run crashes at engine init). RL is fully managed — no sleep/alloc knob:
303
+ # we set the sleep-SAFE non-expandable conf for RL here (the launcher can't yet know the model's
304
+ # resolved sleep decision), and the worker upgrades RL runs to expandable_segments when it
305
+ # resolves sleep OFF for the model/context (engine.worker.finalize_alloc_conf_for_sleep, same
306
+ # deterministic _memory_mode gate run_rl uses). SFT has no vLLM engine, so expandable is safe.
307
+ # torch >= 2.10 renamed the env to PYTORCH_ALLOC_CONF — set BOTH names for either stack.
308
+ _is_rl = str(getattr(spec, "algorithm", "")).lower() not in ("sft",)
309
+ _alloc_conf = (
310
+ "garbage_collection_threshold:0.8,max_split_size_mb:256"
311
+ if _is_rl
312
+ else "expandable_segments:True"
313
+ )
314
+ env: dict[str, str] = {
315
+ "RUN_ID": spec.run_id,
316
+ # Compute substrate, read back by engine.worker for the RunMetrics record.
317
+ "FLASH_ARM": "runpod",
318
+ "BENCH_HF_MODEL": spec.model,
319
+ "PYTORCH_CUDA_ALLOC_CONF": _alloc_conf,
320
+ "PYTORCH_ALLOC_CONF": _alloc_conf,
321
+ }
322
+ # HF artifact creds + managed environment hub creds + optional reward-judge creds: a Freesolo
323
+ # environment whose reward calls an LLM judge (e.g. OpenRouter gpt-oss-120b) needs the API key ON THE WORKER,
324
+ # where the reward runs. FLASH_JUDGE_MODEL is the judge model id the optimizer-authored env
325
+ # reads (agents/common/prompt.py) to pick the JudgeRubric client model; forward the operator's
326
+ # control-plane override so SFT-eval/GRPO-reward/rejection-sampling judges don't silently fall
327
+ # back to the env's generated default. Forward any that the operator has set; absent ones are
328
+ # simply not passed (the env then uses its own default model).
329
+ for key in (
330
+ "HF_TOKEN",
331
+ "GITHUB_TOKEN",
332
+ "OPENROUTER_API_KEY",
333
+ "OPENAI_API_KEY",
334
+ "FLASH_JUDGE_MODEL",
335
+ ):
336
+ if os.environ.get(key):
337
+ env[key] = os.environ[key]
338
+ # Seed the worker's own HF_REPO env from the run's [train] hf_repo (adapter/checkpoint/
339
+ # code storage + heartbeats). The worker reads HF_REPO from its own process env; that env
340
+ # is now sourced from the spec, not the operator's HF_REPO.
341
+ env["HF_REPO"] = spec.train.hf_repo
342
+ # When the shared weight-cache volume is attached, redirect HF_HOME onto the mount so model
343
+ # weights persist across runs (HF blobs only — NOT the executable kernel-JIT caches; see
344
+ # weight_cache_env). Gated on a volume being assigned: without one the mount doesn't exist, so
345
+ # pointing HF_HOME there would just break the worker (the worker also self-corrects at runtime
346
+ # if /runpod-volume isn't mounted — see _train_body). A per-run [worker_env] override still wins
347
+ # (merged last, below).
348
+ # CONFIDENTIALITY: HF_HOME here is process-global, so the run's environment/reward code's runtime
349
+ # HF downloads ALSO land on this SHARED mount — the catalog gate (runner._assign_weight_cache_volume)
350
+ # scopes only the spec model, not assets fetched at execution time with the forwarded HF_TOKEN. See
351
+ # the TRUST MODEL note in flash/runner/__init__.py; proper base-model-only scoping (explicit
352
+ # cache_dir for prefetch + ephemeral HF cache for env code, or a read-only mount) is a worker-side
353
+ # follow-up.
354
+ if getattr(spec.gpu, "network_volume", None):
355
+ env.update(weight_cache_env())
356
+ if spec.train.steps is not None:
357
+ env["RL_STEPS"] = str(spec.train.steps)
358
+ if spec.train.epochs is not None:
359
+ env["SFT_EPOCHS"] = str(spec.train.epochs)
360
+ # Forward the worker-side knobs the worker / vLLM actually read. flash is fully
361
+ # managed: there are no per-run env tuning knobs — the only per-run config is the spec's
362
+ # structured fields, and the worker hardcodes the vLLM-util / quant / heartbeat defaults.
363
+ for k in (
364
+ "SFT_PER_DEVICE_BS",
365
+ "VLLM_USE_V1",
366
+ # Upload the worker console (which optimizations engaged) on SUCCESS too, not just on crash.
367
+ # run_mode() in _train_body reads this from the `env` dict it builds (os.environ updated with
368
+ # this forwarded input_data["env"] allowlist), NOT from its own process os.environ — so a
369
+ # control-plane `FLASH_UPLOAD_CONSOLE=1` only reaches run_mode if it's forwarded here.
370
+ "FLASH_UPLOAD_CONSOLE",
371
+ # The chalk install SOURCE (an exact version / git URL / wheel). Kernel SELECTION is fixed
372
+ # in engine.chalk_kernels (no env flags); this only points install_chalk_kernels at a
373
+ # specific chalk build, and is also consumed at submit time to add chalk to extra_pip.
374
+ "FLASH_CHALK_SPEC",
375
+ ):
376
+ # Forward when SET, even if empty: an explicit "" is a meaningful override.
377
+ if os.environ.get(k) is not None:
378
+ env[k] = os.environ[k]
379
+ # Per-run worker_env overrides win over the global os.environ allowlist: this is what lets
380
+ # ONE run differ (e.g. a per-run optimizer or LoRA-init A/B) while every other concurrent run
381
+ # keeps the global default. Run-IDENTITY keys are control-plane-owned and excluded: the poller,
382
+ # deploy, and artifact paths all key off spec.run_id / spec.train.hf_repo, so letting a
383
+ # [worker_env] override RUN_ID/HF_REPO would make the worker upload under a different repo/prefix
384
+ # and orphan the artifacts (the poller would never find DONE/metrics, deploy can't locate the
385
+ # adapter). FLASH_ARM identifies the substrate.
386
+ _RESERVED_WORKER_ENV = {"RUN_ID", "HF_REPO", "FLASH_ARM"}
387
+ for k, v in (getattr(spec, "worker_env", None) or {}).items():
388
+ if str(k).upper() in _RESERVED_WORKER_ENV:
389
+ continue # control plane owns run identity; a per-run override would orphan artifacts
390
+ env[str(k)] = str(v)
391
+ allowed_runtime_secrets = set(_RUNTIME_SECRET_KEYS) | set(spec.environment.secrets)
392
+ for k, v in (runtime_secrets or {}).items():
393
+ if k in allowed_runtime_secrets and v:
394
+ env[k] = str(v)
395
+ return env