freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Rubric/reward/introspection helpers for the verifiers adapter.
|
|
2
|
+
|
|
3
|
+
Leaf helpers split out of ``flash.envs.adapter``: rubric flattening, judge discovery,
|
|
4
|
+
reward-func invocation, eval-metric guarding, and env-shape introspection. None reference
|
|
5
|
+
the rest of the adapter package (no import cycle); the package ``__init__`` re-exports them.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import inspect
|
|
12
|
+
|
|
13
|
+
# The judge-related kwarg names a reward func may declare, sourced from a JudgeRubric.
|
|
14
|
+
# Single source of truth for both ``_judge_kwargs`` and ``_AVAILABLE_REWARD_KWARGS``.
|
|
15
|
+
_JUDGE_KWARG_NAMES = ("judge", "judge_client", "judge_model", "judge_prompt")
|
|
16
|
+
|
|
17
|
+
# The kwargs this adapter can supply to a reward func. The non-judge keys are exactly the
|
|
18
|
+
# ones built into the ``available`` dict in VerifiersEnvironment._reward_available; the judge
|
|
19
|
+
# keys come from ``_judge_kwargs``. Deriving the frozenset from these shared names avoids the
|
|
20
|
+
# manual "keep in sync" coupling (adding a kwarg below without updating the set would
|
|
21
|
+
# re-trigger the false "requires unavailable arg" failure).
|
|
22
|
+
_BASE_REWARD_KWARG_NAMES = (
|
|
23
|
+
"completion",
|
|
24
|
+
"prompt",
|
|
25
|
+
"answer",
|
|
26
|
+
"info",
|
|
27
|
+
"state",
|
|
28
|
+
"parser",
|
|
29
|
+
"task",
|
|
30
|
+
)
|
|
31
|
+
_AVAILABLE_REWARD_KWARGS = frozenset(_BASE_REWARD_KWARG_NAMES + _JUDGE_KWARG_NAMES)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _reward_requires_unavailable_args(func) -> str | None:
|
|
35
|
+
"""Name of a required arg this adapter cannot supply, or None.
|
|
36
|
+
|
|
37
|
+
Group/batch reward funcs declare plural required params (``completions``,
|
|
38
|
+
``prompts``, ``answers``, ...). The worker scores one completion at a time and has no
|
|
39
|
+
batch, so such a func would be called without its required argument and silently score
|
|
40
|
+
0.0 — train/eval on an all-zero signal. Detect it so the caller can fail fast."""
|
|
41
|
+
try:
|
|
42
|
+
params = inspect.signature(func).parameters.values()
|
|
43
|
+
except (TypeError, ValueError):
|
|
44
|
+
return None # builtins/uninspectable: _invoke_reward passes everything
|
|
45
|
+
for p in params:
|
|
46
|
+
if p.kind in (p.VAR_KEYWORD, p.VAR_POSITIONAL):
|
|
47
|
+
continue
|
|
48
|
+
if p.default is inspect.Parameter.empty and p.name not in _AVAILABLE_REWARD_KWARGS:
|
|
49
|
+
return p.name
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _run_async(coro):
|
|
54
|
+
"""Run an awaitable to completion from sync code, even inside a running loop."""
|
|
55
|
+
try:
|
|
56
|
+
asyncio.get_running_loop()
|
|
57
|
+
except RuntimeError:
|
|
58
|
+
return asyncio.run(coro)
|
|
59
|
+
# Already inside a loop (rare for the worker): run in a fresh loop on a thread.
|
|
60
|
+
import concurrent.futures
|
|
61
|
+
|
|
62
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as ex:
|
|
63
|
+
return ex.submit(lambda: asyncio.run(coro)).result()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _call_dataset_getter(obj, method_name: str, *, seed: int, n: int = -1):
|
|
67
|
+
"""Call a verifiers dataset getter, binding (n, seed) when it declares them.
|
|
68
|
+
|
|
69
|
+
verifiers exposes get_dataset/get_eval_dataset as get_X(n=-1, seed=0); some published envs
|
|
70
|
+
declare them WITHOUT defaults, so a no-arg call raised TypeError, swallowed into an empty
|
|
71
|
+
dataset (a paid run over no data). Bind ``n`` (default -1 = all rows — the adapter does its
|
|
72
|
+
own fixed subset selection; callers pass a positive cap to avoid materializing a huge split)
|
|
73
|
+
and the seed when the signature declares them; a genuine failure propagates (fail loudly)
|
|
74
|
+
instead of silently emptying the split."""
|
|
75
|
+
fn = getattr(obj, method_name, None)
|
|
76
|
+
if not callable(fn):
|
|
77
|
+
return None
|
|
78
|
+
try:
|
|
79
|
+
param_names = set(inspect.signature(fn).parameters)
|
|
80
|
+
except (TypeError, ValueError):
|
|
81
|
+
param_names = set()
|
|
82
|
+
kwargs = {}
|
|
83
|
+
if "n" in param_names:
|
|
84
|
+
kwargs["n"] = n
|
|
85
|
+
if "seed" in param_names:
|
|
86
|
+
kwargs["seed"] = seed
|
|
87
|
+
return fn(**kwargs)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _rows_to_list(ds) -> list[dict]:
|
|
91
|
+
if ds is None:
|
|
92
|
+
return []
|
|
93
|
+
try:
|
|
94
|
+
return [dict(r) for r in ds]
|
|
95
|
+
except Exception:
|
|
96
|
+
return list(ds)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _flatten_rubric(rubric) -> list[tuple]:
|
|
100
|
+
"""Collect ``(func, weight)`` pairs from a rubric, recursing into ``RubricGroup``.
|
|
101
|
+
|
|
102
|
+
verifiers composes rubrics (e.g. a ``RubricGroup`` wrapping a ``MathRubric`` plus a
|
|
103
|
+
``MultiTurnMonitorRubric``); the real reward funcs live on the *nested* rubrics while the
|
|
104
|
+
group's own ``funcs`` is empty. Flattening finds them all.
|
|
105
|
+
"""
|
|
106
|
+
funcs = list(getattr(rubric, "funcs", None) or getattr(rubric, "reward_funcs", None) or [])
|
|
107
|
+
weights = list(
|
|
108
|
+
getattr(rubric, "weights", None) or getattr(rubric, "reward_weights", None) or []
|
|
109
|
+
)
|
|
110
|
+
if len(weights) < len(funcs):
|
|
111
|
+
weights += [1.0] * (len(funcs) - len(weights))
|
|
112
|
+
pairs = list(zip(funcs, weights, strict=False))
|
|
113
|
+
for sub in getattr(rubric, "rubrics", None) or []:
|
|
114
|
+
pairs.extend(_flatten_rubric(sub))
|
|
115
|
+
return pairs
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _find_judge_rubric(rubric):
|
|
119
|
+
"""Return the first ``JudgeRubric`` in a rubric tree (or None), for judge-arg injection."""
|
|
120
|
+
if rubric is None:
|
|
121
|
+
return None
|
|
122
|
+
try:
|
|
123
|
+
import verifiers as vf
|
|
124
|
+
|
|
125
|
+
judge_cls = getattr(vf, "JudgeRubric", None)
|
|
126
|
+
except ImportError:
|
|
127
|
+
judge_cls = None
|
|
128
|
+
if judge_cls is not None and isinstance(rubric, judge_cls):
|
|
129
|
+
return rubric
|
|
130
|
+
# Duck-type fallback: anything exposing a `judge` method + a judge_client attr.
|
|
131
|
+
if callable(getattr(rubric, "judge", None)) and hasattr(rubric, "judge_client"):
|
|
132
|
+
return rubric
|
|
133
|
+
for sub in getattr(rubric, "rubrics", None) or []:
|
|
134
|
+
found = _find_judge_rubric(sub)
|
|
135
|
+
if found is not None:
|
|
136
|
+
return found
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _judge_kwargs(judge_rubric) -> dict:
|
|
141
|
+
"""The judge-related kwargs a reward func may declare, sourced from a JudgeRubric."""
|
|
142
|
+
if judge_rubric is None:
|
|
143
|
+
return {}
|
|
144
|
+
return {name: getattr(judge_rubric, name, None) for name in _JUDGE_KWARG_NAMES}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _invoke_reward(func, available: dict) -> float:
|
|
148
|
+
"""Call a verifiers reward func passing only the kwargs it declares; await if async.
|
|
149
|
+
|
|
150
|
+
Exceptions PROPAGATE. ``scores_breakdown`` invokes this for *weighted* reward funcs, so an
|
|
151
|
+
exception here is a real (weighted) reward func genuinely failing (e.g. a JudgeRubric judge
|
|
152
|
+
raising on an API/rate-limit error, or a parse error on row data). Swallowing it as 0.0
|
|
153
|
+
would silently train/score on an all-zero signal and waste a paid run, so we fail loudly
|
|
154
|
+
instead. (Unweighted monitor funcs are skipped entirely by ``scores_breakdown``.)
|
|
155
|
+
"""
|
|
156
|
+
try:
|
|
157
|
+
params = inspect.signature(func).parameters
|
|
158
|
+
if any(p.kind == p.VAR_KEYWORD for p in params.values()):
|
|
159
|
+
kwargs = dict(available)
|
|
160
|
+
else:
|
|
161
|
+
kwargs = {k: v for k, v in available.items() if k in params}
|
|
162
|
+
except (TypeError, ValueError):
|
|
163
|
+
kwargs = dict(available)
|
|
164
|
+
result = func(**kwargs)
|
|
165
|
+
if inspect.isawaitable(result):
|
|
166
|
+
result = _run_async(result)
|
|
167
|
+
return float(result or 0.0)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _substring_answer_score(completion: str, example: dict) -> float:
|
|
171
|
+
"""Fallback reward for an env with no rubric: 1.0 iff the example ``answer`` is a
|
|
172
|
+
non-empty substring of the completion, else 0.0."""
|
|
173
|
+
answer = str(example.get("answer") or "")
|
|
174
|
+
return 1.0 if answer and answer in (completion or "") else 0.0
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _unique_key(name: str, existing) -> str:
|
|
178
|
+
"""``name``, or the first ``name_1``/``name_2``/… not already a key of ``existing``.
|
|
179
|
+
|
|
180
|
+
Probe for an unused exact key so two rubric funcs that share a name both survive — a
|
|
181
|
+
prefix/length heuristic can recompute a suffix that collides with an already-recorded
|
|
182
|
+
key (e.g. ``score`` vs ``score_detail``) and silently overwrite a scorer.
|
|
183
|
+
"""
|
|
184
|
+
if name not in existing:
|
|
185
|
+
return name
|
|
186
|
+
i = 1
|
|
187
|
+
while f"{name}_{i}" in existing:
|
|
188
|
+
i += 1
|
|
189
|
+
return f"{name}_{i}"
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _is_multi_turn(vf_env) -> bool:
|
|
193
|
+
"""True for a tool/multi-turn verifiers env (NOT a plain SingleTurnEnv)."""
|
|
194
|
+
try:
|
|
195
|
+
import verifiers as vf
|
|
196
|
+
except ImportError:
|
|
197
|
+
return False
|
|
198
|
+
tool = getattr(vf, "ToolEnv", None)
|
|
199
|
+
multi = getattr(vf, "MultiTurnEnv", None)
|
|
200
|
+
single = getattr(vf, "SingleTurnEnv", None)
|
|
201
|
+
if tool is not None and isinstance(vf_env, tool):
|
|
202
|
+
return True
|
|
203
|
+
if multi is not None and isinstance(vf_env, multi):
|
|
204
|
+
# SingleTurnEnv subclasses MultiTurnEnv in verifiers; exempt it.
|
|
205
|
+
return single is None or not isinstance(vf_env, single)
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _is_tool_env(vf_env) -> bool:
|
|
210
|
+
"""True for a verifiers ``ToolEnv`` or any subclass (Stateful/Sandbox/Python).
|
|
211
|
+
|
|
212
|
+
Tool envs expose Python tool callables; the worker hands those to TRL's
|
|
213
|
+
``GRPOTrainer(tools=...)`` so TRL drives the tool-call loop natively (it owns generation,
|
|
214
|
+
tool execution, and assistant-only token masking). A *pure* ``MultiTurnEnv`` (env turns are
|
|
215
|
+
arbitrary content, e.g. a simulated user) is multi-turn but NOT a tool env, and takes the
|
|
216
|
+
``rollout_func`` path instead."""
|
|
217
|
+
try:
|
|
218
|
+
import verifiers as vf
|
|
219
|
+
except ImportError:
|
|
220
|
+
return False
|
|
221
|
+
tool = getattr(vf, "ToolEnv", None)
|
|
222
|
+
return tool is not None and isinstance(vf_env, tool)
|
flash/envs/base.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Small, serializable environment interface for SFT/RL jobs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Protocol
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Environment(Protocol):
|
|
10
|
+
id: str
|
|
11
|
+
|
|
12
|
+
def dataset(self) -> list[dict]:
|
|
13
|
+
"""Return the training rows (the only split used; eval is on the serving side)."""
|
|
14
|
+
|
|
15
|
+
def prompt_messages(self, example: dict) -> list[dict]:
|
|
16
|
+
"""Chat messages fed to the model for one example."""
|
|
17
|
+
|
|
18
|
+
def sft_completion(self, example: dict) -> list[dict]:
|
|
19
|
+
"""Gold completion messages appended after the prompt for one SFT example — a multi-turn
|
|
20
|
+
trajectory or a single assistant turn."""
|
|
21
|
+
|
|
22
|
+
def reward(self, completion: str, example: dict, state: dict | None = None) -> float:
|
|
23
|
+
"""Scalar RL reward for a completion."""
|
|
24
|
+
|
|
25
|
+
def grade(self, completion: str, example: dict, state: dict | None = None) -> bool:
|
|
26
|
+
"""Boolean correctness scorer the reward can build on."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class BaseEnvironment:
|
|
31
|
+
id: str
|
|
32
|
+
|
|
33
|
+
def dataset(self) -> list[dict]:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
def prompt_messages(self, example: dict) -> list[dict]:
|
|
37
|
+
return [{"role": "user", "content": str(example.get("input") or "")}]
|
|
38
|
+
|
|
39
|
+
def sft_completion(self, example: dict) -> list[dict]:
|
|
40
|
+
# Single-turn default: one target assistant turn from the record's scalar ``output``.
|
|
41
|
+
# FreesoloEnvironment overrides this to support multi-turn target trajectories via the
|
|
42
|
+
# freesolo-sdk (``Environment.sft_completion`` -> ``datasets.target_messages``).
|
|
43
|
+
return [{"role": "assistant", "content": str(example.get("output") or "")}]
|
|
44
|
+
|
|
45
|
+
def reward(self, completion: str, example: dict, state: dict | None = None) -> float:
|
|
46
|
+
return 1.0 if self.grade(completion, example, state) else 0.0
|
|
47
|
+
|
|
48
|
+
def grade(self, completion: str, example: dict, state: dict | None = None) -> bool:
|
|
49
|
+
gold = str(example.get("output") or "").strip()
|
|
50
|
+
# A missing/empty output must NOT grade every completion correct (`"" in x` is
|
|
51
|
+
# always True) — treat it as unscorable -> incorrect.
|
|
52
|
+
return bool(gold) and gold in (completion or "")
|
flash/envs/registry.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Environment registry used by specs, worker, CLI, and server.
|
|
2
|
+
|
|
3
|
+
Every managed run names a Freesolo SDK environment by Hub slug.
|
|
4
|
+
The canonical generated environment entrypoint is ``environment.py:load_environment``.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
from .._fileio import read_json_or_empty, secure_json_write
|
|
13
|
+
from .base import Environment
|
|
14
|
+
|
|
15
|
+
# Manifest of local Freesolo environment ids (written by `flash env install`).
|
|
16
|
+
INSTALLED_MANIFEST = Path(
|
|
17
|
+
os.environ.get("FLASH_ENVS_MANIFEST", str(Path.home() / ".flash" / "envs.json"))
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_installed_manifest() -> dict:
|
|
22
|
+
return read_json_or_empty(INSTALLED_MANIFEST)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def list_installed_environments() -> list[str]:
|
|
26
|
+
"""Freesolo environment ids recorded via `flash env install`."""
|
|
27
|
+
return sorted(load_installed_manifest())
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def record_installed_env(env_id: str, package: str, extras: dict | None = None) -> None:
|
|
31
|
+
manifest = load_installed_manifest()
|
|
32
|
+
manifest[env_id] = {"package": package, **(extras or {})}
|
|
33
|
+
# The manifest can hold a credentialed --extra-index-url, so write it with private perms.
|
|
34
|
+
secure_json_write(INSTALLED_MANIFEST, manifest)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def worker_pip_for_env(env_id: str) -> list[str]:
|
|
38
|
+
"""Pip deps the GPU worker needs to run a Freesolo environment."""
|
|
39
|
+
return ["freesolo"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_environment(
|
|
43
|
+
env_id: str, params: dict | None = None, resolved_sha: str | None = None
|
|
44
|
+
) -> Environment:
|
|
45
|
+
"""Load a Freesolo SDK environment and wrap it in Flash's protocol.
|
|
46
|
+
|
|
47
|
+
``resolved_sha`` is the optional resolve-once hint (the control-plane-pinned commit sha for the
|
|
48
|
+
env's GitHub ref). None/"" preserves today's behavior — the adapter resolves the ref itself.
|
|
49
|
+
"""
|
|
50
|
+
params = params or {}
|
|
51
|
+
from .adapter import load_freesolo_environment
|
|
52
|
+
|
|
53
|
+
if not env_id:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"no environment specified: set [environment] id to the id returned by "
|
|
56
|
+
"`flash env push --name <name>` (for example 'your-name/your-env')"
|
|
57
|
+
)
|
|
58
|
+
# User [environment.params] are freeform and forwarded verbatim to the SDK loader. The
|
|
59
|
+
# control-plane resolve-once pin is passed out-of-band as a POSITIONAL-ONLY argument, so a user
|
|
60
|
+
# param of ANY name (even "pinned_sha"/"resolved_sha") lands in **params and reaches the SDK
|
|
61
|
+
# unchanged — it can never bind to or disable the pin. None/"" keeps today's behavior.
|
|
62
|
+
return load_freesolo_environment(env_id, resolved_sha or None, **params)
|
flash/mcp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""MCP integration package."""
|
flash/mcp/server.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Minimal stdio MCP-style bridge for coding agents.
|
|
2
|
+
|
|
3
|
+
This intentionally avoids a hard dependency on a specific MCP SDK while exposing
|
|
4
|
+
the stable JSON tools that agents need. Requests are newline-delimited JSON:
|
|
5
|
+
{"tool": "list_models", "args": {...}}.
|
|
6
|
+
|
|
7
|
+
Run-lifecycle tools call the managed Flash control plane with the same stored
|
|
8
|
+
credentials as the CLI (`flash login`); dry-run validation stays local.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import sys
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
|
|
17
|
+
from flash.catalog import public_model_rows
|
|
18
|
+
from flash.client import client_from_config
|
|
19
|
+
from flash.client.runtime_secrets import runtime_secrets_from_local_env
|
|
20
|
+
from flash.client.specs import spec_payload
|
|
21
|
+
from flash.schema import spec_from_dict
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def list_models(args: dict) -> dict:
|
|
25
|
+
return {"models": public_model_rows()}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create_train_run(args: dict) -> dict:
|
|
29
|
+
spec = spec_from_dict(args, run_id=args.get("run_id"))
|
|
30
|
+
if args.get("dry_run"):
|
|
31
|
+
# Fully local: validate without credentials, a server, or a GPU.
|
|
32
|
+
return {"run_id": spec.run_id, "state": "dry_run", "spec": spec.to_dict()}
|
|
33
|
+
return client_from_config().create_run(
|
|
34
|
+
spec_payload(spec),
|
|
35
|
+
runtime_secrets=runtime_secrets_from_local_env(keys=spec.environment.secrets),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_run_status(args: dict) -> dict:
|
|
40
|
+
return client_from_config().get_run(args["run_id"])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_run_logs(args: dict) -> dict:
|
|
44
|
+
page = client_from_config().get_logs(args["run_id"], offset=int(args.get("offset", 0)))
|
|
45
|
+
return {"run_id": args["run_id"], **{k: page[k] for k in ("logs", "offset", "state")}}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def deploy_adapter_tool(args: dict) -> dict:
|
|
49
|
+
return client_from_config().deploy(
|
|
50
|
+
args["run_id"],
|
|
51
|
+
dry_run=bool(args.get("dry_run", False)),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
TOOLS: dict[str, Callable[[dict], dict]] = {
|
|
56
|
+
"list_models": list_models,
|
|
57
|
+
"create_training_run": create_train_run,
|
|
58
|
+
"get_run_status": get_run_status,
|
|
59
|
+
"get_run_logs": get_run_logs,
|
|
60
|
+
"deploy_adapter": deploy_adapter_tool,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def handle(payload: dict) -> dict:
|
|
65
|
+
tool = payload.get("tool")
|
|
66
|
+
if tool not in TOOLS:
|
|
67
|
+
raise ValueError(f"unknown tool {tool!r}; choose one of {sorted(TOOLS)}")
|
|
68
|
+
return TOOLS[tool](payload.get("args") or {})
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def main() -> int:
|
|
72
|
+
for line in sys.stdin:
|
|
73
|
+
try:
|
|
74
|
+
line = line.strip()
|
|
75
|
+
if not line:
|
|
76
|
+
continue
|
|
77
|
+
response = {"ok": True, "result": handle(json.loads(line))}
|
|
78
|
+
except Exception as exc:
|
|
79
|
+
response = {"ok": False, "error": str(exc)}
|
|
80
|
+
print(json.dumps(response), flush=True)
|
|
81
|
+
return 0
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == "__main__":
|
|
85
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Pluggable GPU substrates.
|
|
2
|
+
|
|
3
|
+
The training worker (``flash.engine.worker``) reads a JobSpec from the environment, pulls code
|
|
4
|
+
from the HF dataset repo, and streams artifacts/heartbeats/metrics back to it. The provider
|
|
5
|
+
owns pricing, provisioning, polling, cancellation, and teardown.
|
|
6
|
+
|
|
7
|
+
runpod serverless Flash endpoints (always on)
|
|
8
|
+
lambda Lambda Cloud GPU instances (instance-based complement; iff LAMBDA_API_KEY set)
|
|
9
|
+
hyperstack Hyperstack GPU VMs (instance-based complement; iff HYPERSTACK_API_KEY set)
|
|
10
|
+
|
|
11
|
+
This module is the registry: ``get_provider(name)`` / ``PROVIDER_NAMES``.
|
|
12
|
+
``allocator.allocate`` iterates the active provider list below; ``available_providers`` narrows
|
|
13
|
+
it to the ones configured on THIS control plane (Lambda/Hyperstack are opt-in via their operator
|
|
14
|
+
keys, so a box without them silently behaves exactly as the RunPod-only setup).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from functools import cache
|
|
20
|
+
|
|
21
|
+
from flash.providers.base import Provider
|
|
22
|
+
|
|
23
|
+
# Active provider order is also the tie-break preference (RunPod wins price ties, then Lambda).
|
|
24
|
+
PROVIDER_NAMES: tuple[str, ...] = ("runpod", "lambda", "hyperstack")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_provider(name: str) -> Provider:
|
|
28
|
+
"""The ``Provider`` singleton for a registered name (raises on unknown)."""
|
|
29
|
+
# Normalize BEFORE the cache so "RunPod"/"runpod"/" runpod " share one cache entry.
|
|
30
|
+
return _get_provider((name or "").strip().lower())
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@cache
|
|
34
|
+
def _get_provider(key: str) -> Provider:
|
|
35
|
+
if key == "runpod":
|
|
36
|
+
from flash.providers.runpod import PROVIDER
|
|
37
|
+
|
|
38
|
+
return PROVIDER
|
|
39
|
+
if key == "lambda":
|
|
40
|
+
from flash.providers.lambdalabs import PROVIDER
|
|
41
|
+
|
|
42
|
+
return PROVIDER
|
|
43
|
+
if key == "hyperstack":
|
|
44
|
+
from flash.providers.hyperstack import PROVIDER
|
|
45
|
+
|
|
46
|
+
return PROVIDER
|
|
47
|
+
raise KeyError(f"unknown provider {key!r} (known: {', '.join(PROVIDER_NAMES)})")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def available_providers() -> tuple[str, ...]:
|
|
51
|
+
"""Provider NAMES usable from this control plane right now: a provider is available when it
|
|
52
|
+
``is_configured()`` (creds present). RunPod is always on; Lambda/Hyperstack join only when
|
|
53
|
+
their operator keys are present."""
|
|
54
|
+
return tuple(n for n in PROVIDER_NAMES if get_provider(n).is_configured())
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def configured_providers() -> list[Provider]:
|
|
58
|
+
"""The ``Provider`` objects available right now (see ``available_providers``)."""
|
|
59
|
+
return [get_provider(n) for n in available_providers()]
|
flash/providers/_auth.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Shared operator-credential helpers for the GPU providers.
|
|
2
|
+
|
|
3
|
+
Every provider authenticates the same way: a single API key read ONLY from an
|
|
4
|
+
environment variable on the control-plane host (never config files, never shipped to
|
|
5
|
+
workers). The per-provider ``auth.py`` modules wrap these with their own env-var name
|
|
6
|
+
and error message.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_provider_key(env_var: str) -> str | None:
|
|
15
|
+
"""Provider API key from ``env_var`` (operator configuration), or None."""
|
|
16
|
+
return os.environ.get(env_var) or None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def ensure_provider_auth(env_var: str, missing_message: str) -> str:
|
|
20
|
+
"""Return the provider key from ``env_var``; raise ``missing_message`` if unset."""
|
|
21
|
+
key = load_provider_key(env_var)
|
|
22
|
+
if not key:
|
|
23
|
+
raise RuntimeError(missing_message)
|
|
24
|
+
return key
|