freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,636 @@
|
|
|
1
|
+
"""CLI command handlers for the managed Flash service.
|
|
2
|
+
|
|
3
|
+
Every run-lifecycle command is a thin HTTP call to the Flash control plane —
|
|
4
|
+
users authenticate with their freesolo API key (`flash login` verifies it against
|
|
5
|
+
the freesolo backend), never with provider credentials. Config parsing/validation
|
|
6
|
+
and `--dry-run` stay fully local.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from flash import __version__
|
|
18
|
+
from flash._logging import get_logger
|
|
19
|
+
from flash.catalog import public_model_rows
|
|
20
|
+
from flash.client import (
|
|
21
|
+
ApiClient,
|
|
22
|
+
ClientError,
|
|
23
|
+
client_from_config,
|
|
24
|
+
save_credentials,
|
|
25
|
+
verify_freesolo_key,
|
|
26
|
+
)
|
|
27
|
+
from flash.client.config import load_credentials
|
|
28
|
+
from flash.client.runtime_secrets import runtime_secrets_from_local_env
|
|
29
|
+
from flash.client.specs import spec_payload
|
|
30
|
+
from flash.cost.spec import runconfig_from_spec
|
|
31
|
+
from flash.runner import TERMINAL_STATES, new_run_id
|
|
32
|
+
from flash.schema import ConfigError, spec_from_file
|
|
33
|
+
|
|
34
|
+
from . import render
|
|
35
|
+
from .training_doc import TRAINING_MD
|
|
36
|
+
|
|
37
|
+
logger = get_logger("flash.cli.main")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Exceptions that represent expected user/config errors: report them as a clean one-line
|
|
41
|
+
# message instead of a Python traceback (use --debug to see the full trace).
|
|
42
|
+
_USER_ERRORS = (
|
|
43
|
+
ConfigError,
|
|
44
|
+
ClientError,
|
|
45
|
+
FileNotFoundError,
|
|
46
|
+
ValueError,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Run states after which nothing more will happen (polling can stop).
|
|
50
|
+
_CLI_DONE_STATES = TERMINAL_STATES | {"deployed"}
|
|
51
|
+
_OK_STATES = {"done", "dry_run", "deployed"}
|
|
52
|
+
_SPINNER_FRAMES = "|/-\\"
|
|
53
|
+
_SPINNER_TICK_SECONDS = 0.1
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class _LogFollowSpinner:
|
|
57
|
+
def __init__(self, run_id: str):
|
|
58
|
+
self._run_id = run_id
|
|
59
|
+
self._frame = 0
|
|
60
|
+
self._last_len = 0
|
|
61
|
+
self._active = False
|
|
62
|
+
self._enabled = sys.stderr.isatty()
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def enabled(self) -> bool:
|
|
66
|
+
return self._enabled
|
|
67
|
+
|
|
68
|
+
def render(self, state: str) -> None:
|
|
69
|
+
if not self._enabled:
|
|
70
|
+
return
|
|
71
|
+
frame = _SPINNER_FRAMES[self._frame % len(_SPINNER_FRAMES)]
|
|
72
|
+
self._frame += 1
|
|
73
|
+
message = f"{frame} following logs for {self._run_id} ({state})"
|
|
74
|
+
padding = " " * max(0, self._last_len - len(message))
|
|
75
|
+
sys.stderr.write(f"\r{message}{padding}")
|
|
76
|
+
sys.stderr.flush()
|
|
77
|
+
self._last_len = len(message)
|
|
78
|
+
self._active = True
|
|
79
|
+
|
|
80
|
+
def clear(self) -> None:
|
|
81
|
+
if not (self._enabled and self._active):
|
|
82
|
+
return
|
|
83
|
+
sys.stderr.write(f"\r{' ' * self._last_len}\r")
|
|
84
|
+
sys.stderr.flush()
|
|
85
|
+
self._active = False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _sleep_with_spinner(interval: float, spinner: _LogFollowSpinner, state: str) -> None:
|
|
89
|
+
if interval <= 0:
|
|
90
|
+
return
|
|
91
|
+
if not spinner.enabled:
|
|
92
|
+
time.sleep(interval)
|
|
93
|
+
return
|
|
94
|
+
ticks = max(1, int(interval / _SPINNER_TICK_SECONDS))
|
|
95
|
+
sleep_for = interval / ticks
|
|
96
|
+
for _ in range(ticks):
|
|
97
|
+
spinner.render(state)
|
|
98
|
+
time.sleep(sleep_for)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def cmd_version(args) -> int:
|
|
102
|
+
if render.styled():
|
|
103
|
+
print(render.version(__version__))
|
|
104
|
+
else:
|
|
105
|
+
print(f"flash {__version__}")
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def cmd_login(args) -> int:
|
|
110
|
+
# Login is handled by the freesolo backend (not the flash control plane): the user
|
|
111
|
+
# supplies the freesolo API key they created at freesolo.co/sign-in, and we verify it against
|
|
112
|
+
# freesolo before storing it. The same key authenticates flash's control plane.
|
|
113
|
+
try:
|
|
114
|
+
env_api_key = os.environ.get("FREESOLO_API_KEY")
|
|
115
|
+
api_key = args.api_key or env_api_key
|
|
116
|
+
if not api_key:
|
|
117
|
+
raise ClientError(
|
|
118
|
+
"no API key provided: pass `--api-key <key>` or set FREESOLO_API_KEY. "
|
|
119
|
+
"Create or copy a key at https://freesolo.co/sign-in."
|
|
120
|
+
)
|
|
121
|
+
verify_freesolo_key(api_key, base_url=getattr(args, "freesolo_url", None))
|
|
122
|
+
except ClientError as exc:
|
|
123
|
+
# Login failed (no key, a rejected key, or an unreachable backend): say so plainly
|
|
124
|
+
# and point the user back at `flash login` to try again. `--debug` still surfaces
|
|
125
|
+
# the full traceback via the top-level handler.
|
|
126
|
+
if getattr(args, "debug", False):
|
|
127
|
+
raise
|
|
128
|
+
print(render.login_failed(str(exc)), file=sys.stderr)
|
|
129
|
+
return 1
|
|
130
|
+
api_url = args.api_url or load_credentials()[0]
|
|
131
|
+
# save_credentials clears the stored url when it's the default, so logging into the
|
|
132
|
+
# default plane also drops a stale custom url from a previous custom-URL login.
|
|
133
|
+
_ = save_credentials(api_key, api_url=api_url)
|
|
134
|
+
if args.api_key and env_api_key and env_api_key != args.api_key:
|
|
135
|
+
print(
|
|
136
|
+
"warning: FREESOLO_API_KEY is set and will override this saved login for future "
|
|
137
|
+
"commands; unset FREESOLO_API_KEY to use the saved key.",
|
|
138
|
+
file=sys.stderr,
|
|
139
|
+
)
|
|
140
|
+
# Show who they are right away (the same identity `flash whoami` prints) so they don't
|
|
141
|
+
# have to run a second command. Never echo the key itself. The identity lookup is
|
|
142
|
+
# best-effort: the key is already verified and stored, so a momentary control-plane
|
|
143
|
+
# hiccup must not turn a successful login into a failure.
|
|
144
|
+
print(render.login_ok(_identity_or_none(api_key, api_url)))
|
|
145
|
+
return 0
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# A control-plane hiccup must not make a successful login appear to hang while we fetch a
|
|
149
|
+
# nonessential card, so the best-effort identity lookup uses a short timeout.
|
|
150
|
+
_IDENTITY_LOOKUP_TIMEOUT_S = 5.0
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _identity_or_none(api_key: str, api_url: str) -> dict | None:
|
|
154
|
+
# Use the key/url we just verified and stored, not `client_from_config()`: an ambient
|
|
155
|
+
# FREESOLO_API_KEY would otherwise win over the file and render the wrong identity.
|
|
156
|
+
try:
|
|
157
|
+
return ApiClient(api_url, api_key, timeout=_IDENTITY_LOOKUP_TIMEOUT_S).me()
|
|
158
|
+
except (ClientError, OSError, ValueError):
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def cmd_whoami(args) -> int:
|
|
163
|
+
print(render.whoami(client_from_config().me()))
|
|
164
|
+
return 0
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
_STARTER_ENV_PY = '''\
|
|
168
|
+
"""Starter Freesolo environment.
|
|
169
|
+
|
|
170
|
+
Edit datasets/train.jsonl and the reward code, then upload with
|
|
171
|
+
`flash env push --name my-env .`.
|
|
172
|
+
|
|
173
|
+
A managed run should use the returned [environment] id from
|
|
174
|
+
`flash env push --name my-env .`.
|
|
175
|
+
|
|
176
|
+
This starter keeps a tiny smoke-test dataset in datasets/train.jsonl. Replace it
|
|
177
|
+
with your real training rows before a real run.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
from __future__ import annotations
|
|
181
|
+
|
|
182
|
+
import json
|
|
183
|
+
from pathlib import Path
|
|
184
|
+
|
|
185
|
+
from freesolo.datasets.types import TaskExample
|
|
186
|
+
from freesolo.environments import EnvironmentSingleTurn, RewardResult
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
DEFAULT_DATASET_PATH = Path(__file__).parent / "datasets" / "train.jsonl"
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def load_jsonl(path: str | Path):
|
|
193
|
+
rows = []
|
|
194
|
+
with Path(path).open() as f:
|
|
195
|
+
for line in f:
|
|
196
|
+
line = line.strip()
|
|
197
|
+
if line:
|
|
198
|
+
rows.append(json.loads(line))
|
|
199
|
+
return rows
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def exact_match_reward(example: TaskExample, response_text: str) -> RewardResult:
|
|
203
|
+
expected = str(example.output or "").strip()
|
|
204
|
+
score = 1.0 if expected and expected in response_text else 0.0
|
|
205
|
+
return RewardResult(score=score, threshold=1.0)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class StarterEnv(EnvironmentSingleTurn):
|
|
209
|
+
dataset = load_jsonl(DEFAULT_DATASET_PATH)
|
|
210
|
+
|
|
211
|
+
def build_prompt_messages(self, example: TaskExample, prompt_text: str):
|
|
212
|
+
return [{"role": "user", "content": example.input}]
|
|
213
|
+
|
|
214
|
+
def score_response(self, example: TaskExample, response_text: str) -> RewardResult:
|
|
215
|
+
return exact_match_reward(example, response_text)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def load_environment(dataset_path: str | None = None, **kwargs) -> StarterEnv:
|
|
219
|
+
env = StarterEnv()
|
|
220
|
+
if dataset_path:
|
|
221
|
+
env.dataset = load_jsonl(dataset_path)
|
|
222
|
+
return env
|
|
223
|
+
'''
|
|
224
|
+
|
|
225
|
+
_STARTER_DATASET_JSONL = """\
|
|
226
|
+
{"input":"What is 2 + 2?","output":"4"}
|
|
227
|
+
{"input":"What is 3 + 5?","output":"8"}
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def cmd_env_setup(args) -> int:
|
|
232
|
+
Path("configs").mkdir(exist_ok=True)
|
|
233
|
+
Path("datasets").mkdir(exist_ok=True)
|
|
234
|
+
dataset = Path("datasets/train.jsonl")
|
|
235
|
+
if not dataset.exists():
|
|
236
|
+
dataset.write_text(_STARTER_DATASET_JSONL)
|
|
237
|
+
starter_env = Path("environment.py")
|
|
238
|
+
if not starter_env.exists():
|
|
239
|
+
starter_env.write_text(_STARTER_ENV_PY)
|
|
240
|
+
env_comment = (
|
|
241
|
+
"# Environment: upload this project folder with\n"
|
|
242
|
+
"# `flash env push --name my-env .`, then paste the returned id below.\n"
|
|
243
|
+
"# If the environment reads secrets with os.environ, list only the env var names here.\n"
|
|
244
|
+
"# Values are read from your shell or .env at submit time and are not stored in the spec.\n"
|
|
245
|
+
"[environment]\n"
|
|
246
|
+
'id = ""\n\n'
|
|
247
|
+
'# secrets = ["SERPAPI_API_KEY"]\n\n'
|
|
248
|
+
)
|
|
249
|
+
rl = Path("configs/rl.toml")
|
|
250
|
+
if not rl.exists():
|
|
251
|
+
rl.write_text(
|
|
252
|
+
'model = "Qwen/Qwen3.5-4B"\n'
|
|
253
|
+
'algorithm = "grpo"\n\n'
|
|
254
|
+
f"{env_comment}"
|
|
255
|
+
"[train]\n"
|
|
256
|
+
"steps = 150\n"
|
|
257
|
+
"lora_rank = 32\n"
|
|
258
|
+
"seeds = [0]\n"
|
|
259
|
+
"# GPU and the HF artifact repo are managed automatically by the platform: the GPU is\n"
|
|
260
|
+
"# the cheapest fitting class across providers, and each run gets its own artifact repo.\n"
|
|
261
|
+
)
|
|
262
|
+
sft = Path("configs/sft.toml")
|
|
263
|
+
if not sft.exists():
|
|
264
|
+
sft.write_text(
|
|
265
|
+
'model = "Qwen/Qwen3.5-4B"\n'
|
|
266
|
+
'algorithm = "sft"\n\n'
|
|
267
|
+
f"{env_comment}"
|
|
268
|
+
"[train]\n"
|
|
269
|
+
"epochs = 1\n"
|
|
270
|
+
"lora_rank = 32\n"
|
|
271
|
+
"seeds = [0]\n"
|
|
272
|
+
"# GPU and the HF artifact repo are managed automatically by the platform: the GPU is\n"
|
|
273
|
+
"# the cheapest fitting class across providers, and each run gets its own artifact repo.\n"
|
|
274
|
+
)
|
|
275
|
+
# TRAINING.md is the playbook for the AI agent driving these runs: how to design the
|
|
276
|
+
# reward, what to read, and how to decide a run actually improved (not just finished).
|
|
277
|
+
training = Path("TRAINING.md")
|
|
278
|
+
if not training.exists():
|
|
279
|
+
# Explicit UTF-8: TRAINING_MD has non-ASCII (em dashes, ·, √, ≥, ≈), which would
|
|
280
|
+
# raise UnicodeEncodeError under a non-UTF-8 locale with write_text's default.
|
|
281
|
+
training.write_text(TRAINING_MD, encoding="utf-8")
|
|
282
|
+
scaffolded = [
|
|
283
|
+
"environment.py",
|
|
284
|
+
"datasets/train.jsonl",
|
|
285
|
+
"configs/rl.toml",
|
|
286
|
+
"configs/sft.toml",
|
|
287
|
+
"TRAINING.md",
|
|
288
|
+
]
|
|
289
|
+
if render.styled():
|
|
290
|
+
print(render.env_setup(scaffolded))
|
|
291
|
+
return 0
|
|
292
|
+
print(f"ensured {', '.join(scaffolded)}")
|
|
293
|
+
return 0
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def cmd_models(args) -> int:
|
|
297
|
+
rows = public_model_rows()
|
|
298
|
+
if render.styled():
|
|
299
|
+
print(render.models_table(rows))
|
|
300
|
+
return 0
|
|
301
|
+
for row in rows:
|
|
302
|
+
print(row["id"])
|
|
303
|
+
return 0
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def cmd_gpus(args) -> int:
|
|
307
|
+
"""List RunPod GPU classes, VRAM, and $/hr."""
|
|
308
|
+
from flash.providers.base import GPU_INFO
|
|
309
|
+
from flash.providers.runpod.pricing import static_rates as runpod_static_rates
|
|
310
|
+
|
|
311
|
+
runpod_rates = runpod_static_rates()
|
|
312
|
+
infos = sorted(
|
|
313
|
+
(info for info in GPU_INFO.values() if info.enum_member), key=lambda g: g.hourly_usd
|
|
314
|
+
)
|
|
315
|
+
tip = (
|
|
316
|
+
"Tip: GPU class selection is fully automatic — the submit-time allocator always picks the\n"
|
|
317
|
+
"cheapest validated RunPod class that fits the model, so you don't pin a GPU type."
|
|
318
|
+
)
|
|
319
|
+
if render.styled():
|
|
320
|
+
rows = [(info.name, info.vram_gb, runpod_rates.get(info.name)) for info in infos]
|
|
321
|
+
print(render.gpus_table(rows, tip))
|
|
322
|
+
return 0
|
|
323
|
+
|
|
324
|
+
def fmt_rate(v: float | None) -> str:
|
|
325
|
+
return f"{v:>10.2f}" if v else f"{'-':>10}"
|
|
326
|
+
|
|
327
|
+
print(f"{'gpu':<16}{'vram':>6}{'runpod$/hr':>11}")
|
|
328
|
+
for info in infos:
|
|
329
|
+
runpod_rate = runpod_rates.get(info.name)
|
|
330
|
+
print(f"{info.name:<16}{info.vram_gb:>5}G{fmt_rate(runpod_rate):>11}")
|
|
331
|
+
print(f"\n{tip}")
|
|
332
|
+
return 0
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def cmd_env_list(args) -> int:
|
|
336
|
+
from flash.envs.registry import list_installed_environments
|
|
337
|
+
|
|
338
|
+
installed = list_installed_environments()
|
|
339
|
+
paths: list[str] = []
|
|
340
|
+
if Path("environment.py").is_file():
|
|
341
|
+
paths.append(".")
|
|
342
|
+
local = Path("environments")
|
|
343
|
+
if local.is_dir():
|
|
344
|
+
# Prefer publishing folders. Single-file modules remain supported for small smoke tests.
|
|
345
|
+
for p in local.iterdir():
|
|
346
|
+
if p.name.startswith("__"):
|
|
347
|
+
continue
|
|
348
|
+
if p.is_dir():
|
|
349
|
+
stem = p.name.replace("-", "_")
|
|
350
|
+
module = p / f"{stem}.py"
|
|
351
|
+
canonical = p / "environment.py"
|
|
352
|
+
if canonical.is_file() or module.is_file():
|
|
353
|
+
paths.append(f"environments/{p.name}")
|
|
354
|
+
elif p.suffix == ".py":
|
|
355
|
+
paths.append(f"environments/{p.name}")
|
|
356
|
+
# Decide the rendering up front so the themed panel and the legacy lines never both print.
|
|
357
|
+
if render.styled():
|
|
358
|
+
print(render.env_list(list(installed), sorted(paths)))
|
|
359
|
+
return 0
|
|
360
|
+
if installed:
|
|
361
|
+
print("installed environments:")
|
|
362
|
+
for env_id in installed:
|
|
363
|
+
print(f" {env_id}")
|
|
364
|
+
if paths:
|
|
365
|
+
print("local env sources (publish with `flash env push --name <name> <path>`):")
|
|
366
|
+
for path in sorted(paths):
|
|
367
|
+
print(f" {path}")
|
|
368
|
+
return 0
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _cmd_train_cost(args) -> int:
|
|
372
|
+
"""`flash train --cost`: print the pre-flight USD cost for the config and exit (no submit).
|
|
373
|
+
|
|
374
|
+
Catalog-only and deterministic; an uncapped SFT run tries to count the env's train split, and
|
|
375
|
+
falls back to a default example count (with a warning) when the environment isn't
|
|
376
|
+
importable here."""
|
|
377
|
+
from flash.cost import estimate_cost
|
|
378
|
+
|
|
379
|
+
spec = spec_from_file(
|
|
380
|
+
args.config,
|
|
381
|
+
run_id=None,
|
|
382
|
+
overrides=args.overrides,
|
|
383
|
+
extra_configs=args.extra_configs,
|
|
384
|
+
)
|
|
385
|
+
estimate = estimate_cost(runconfig_from_spec(spec))
|
|
386
|
+
if render.styled():
|
|
387
|
+
print(render.cost_panel(estimate))
|
|
388
|
+
else:
|
|
389
|
+
print(estimate.breakdown())
|
|
390
|
+
return 0
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def cmd_train(args) -> int:
|
|
394
|
+
if getattr(args, "cost", False):
|
|
395
|
+
return _cmd_train_cost(args)
|
|
396
|
+
spec = spec_from_file(
|
|
397
|
+
args.config,
|
|
398
|
+
run_id=new_run_id() if args.dry_run else None,
|
|
399
|
+
overrides=args.overrides,
|
|
400
|
+
extra_configs=args.extra_configs,
|
|
401
|
+
)
|
|
402
|
+
if args.dry_run:
|
|
403
|
+
# Fully local: validate the id-based config without credentials, a server, or a GPU.
|
|
404
|
+
payload = {"run_id": spec.run_id, "state": "dry_run", "spec": spec.to_dict()}
|
|
405
|
+
if render.styled():
|
|
406
|
+
print(
|
|
407
|
+
render.object_panel("train", payload, "dry run — validated locally, not submitted")
|
|
408
|
+
)
|
|
409
|
+
else:
|
|
410
|
+
print(json.dumps(payload, indent=2))
|
|
411
|
+
return 0
|
|
412
|
+
client = client_from_config()
|
|
413
|
+
status = client.create_run(
|
|
414
|
+
spec_payload(spec),
|
|
415
|
+
runtime_secrets=runtime_secrets_from_local_env(args.config, keys=spec.environment.secrets),
|
|
416
|
+
)
|
|
417
|
+
run_id = status["run_id"]
|
|
418
|
+
logger.info(
|
|
419
|
+
"submitted run %s: model=%s algorithm=%s gpu=%s seeds=%s",
|
|
420
|
+
run_id,
|
|
421
|
+
spec.model,
|
|
422
|
+
spec.algorithm,
|
|
423
|
+
spec.gpu.type,
|
|
424
|
+
list(spec.train.seeds),
|
|
425
|
+
)
|
|
426
|
+
if args.background:
|
|
427
|
+
if render.styled():
|
|
428
|
+
print(render.object_panel("train", status, "submitted (running in background)"))
|
|
429
|
+
else:
|
|
430
|
+
print(json.dumps(status, indent=2))
|
|
431
|
+
return 0
|
|
432
|
+
if render.styled():
|
|
433
|
+
print(render.submitted(run_id), file=sys.stderr)
|
|
434
|
+
else:
|
|
435
|
+
print(
|
|
436
|
+
f"run {run_id} submitted; following logs "
|
|
437
|
+
f"(Ctrl-C detaches, `flash status {run_id} --follow` resumes)",
|
|
438
|
+
file=sys.stderr,
|
|
439
|
+
)
|
|
440
|
+
return _follow_run(client, run_id)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _poll_logs(client: ApiClient, run_id: str, interval: float) -> str:
|
|
444
|
+
"""Stream offset-paged logs until the run reaches a terminal state; return that state."""
|
|
445
|
+
offset = 0
|
|
446
|
+
spinner = _LogFollowSpinner(run_id)
|
|
447
|
+
try:
|
|
448
|
+
while True:
|
|
449
|
+
page = client.get_logs(run_id, offset=offset)
|
|
450
|
+
if page["logs"]:
|
|
451
|
+
spinner.clear()
|
|
452
|
+
print(page["logs"], end="", flush=True)
|
|
453
|
+
offset = page["offset"]
|
|
454
|
+
if page["state"] in _CLI_DONE_STATES:
|
|
455
|
+
spinner.clear()
|
|
456
|
+
return page["state"]
|
|
457
|
+
_sleep_with_spinner(interval, spinner, page["state"])
|
|
458
|
+
finally:
|
|
459
|
+
spinner.clear()
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def _follow_run(client: ApiClient, run_id: str) -> int:
|
|
463
|
+
"""Poll logs until the run reaches a terminal state, then print the final status."""
|
|
464
|
+
state = _poll_logs(client, run_id, interval=2.0)
|
|
465
|
+
status = client.get_run(run_id)
|
|
466
|
+
if render.styled():
|
|
467
|
+
print(render.run_status(status))
|
|
468
|
+
else:
|
|
469
|
+
print(json.dumps(status, indent=2))
|
|
470
|
+
return 0 if state in _OK_STATES else 1
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def cmd_status(args) -> int:
|
|
474
|
+
client = client_from_config()
|
|
475
|
+
if getattr(args, "follow", False):
|
|
476
|
+
return _follow_run(client, args.run_id)
|
|
477
|
+
if getattr(args, "logs", False):
|
|
478
|
+
logs = client.get_logs(args.run_id).get("logs", "")
|
|
479
|
+
printed_any = False
|
|
480
|
+
if logs:
|
|
481
|
+
print(logs, end="")
|
|
482
|
+
if not logs.endswith("\n"):
|
|
483
|
+
print()
|
|
484
|
+
printed_any = True
|
|
485
|
+
# Always append the real train-subprocess output (the orchestrator log can't carry it);
|
|
486
|
+
# the server fetches console_/error_<phase>.txt from HF with the operator token.
|
|
487
|
+
for name, text in (client.get_worker_output(args.run_id) or {}).items():
|
|
488
|
+
if not text:
|
|
489
|
+
continue
|
|
490
|
+
# Separate sections with a blank line, but NOT before the first thing printed (an empty
|
|
491
|
+
# orchestrator log would otherwise leave a leading blank line above the first section).
|
|
492
|
+
sep = "\n" if printed_any else ""
|
|
493
|
+
print(f"{sep}----- {name} -----")
|
|
494
|
+
print(text, end="" if text.endswith("\n") else "\n")
|
|
495
|
+
printed_any = True
|
|
496
|
+
status = client.get_run(args.run_id)
|
|
497
|
+
if render.styled():
|
|
498
|
+
print(render.run_status(status))
|
|
499
|
+
else:
|
|
500
|
+
print(json.dumps(status, indent=2))
|
|
501
|
+
return 0
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def cmd_runs(args) -> int:
|
|
505
|
+
runs = client_from_config().list_runs()
|
|
506
|
+
if not runs:
|
|
507
|
+
if render.styled():
|
|
508
|
+
print(render.empty("runs", "0 runs", "no runs yet — submit one with `flash train`"))
|
|
509
|
+
else:
|
|
510
|
+
print("no runs yet")
|
|
511
|
+
return 0
|
|
512
|
+
if render.styled():
|
|
513
|
+
print(render.runs_table(runs))
|
|
514
|
+
return 0
|
|
515
|
+
print(f"{'RUN_ID':<32} {'STATE':<11} {'ALGO':<5} {'COST($)':>8} {'GPU':<22} MODEL")
|
|
516
|
+
for r in sorted(runs, key=lambda r: r.get("updated_at", 0), reverse=True):
|
|
517
|
+
spec = r.get("spec") or {}
|
|
518
|
+
model = spec.get("model", "")
|
|
519
|
+
algorithm = str(spec.get("algorithm") or "-").upper()
|
|
520
|
+
remote = r.get("remote") or {}
|
|
521
|
+
# the remote handle knows what actually ran; the spec is the parse-time pick
|
|
522
|
+
provider = remote.get("provider") or (
|
|
523
|
+
"runpod" if remote else (spec.get("gpu") or {}).get("provider", "")
|
|
524
|
+
)
|
|
525
|
+
gpu = remote.get("gpu") or (spec.get("gpu") or {}).get("type", "")
|
|
526
|
+
where = f"{gpu}@{provider}" if provider else gpu
|
|
527
|
+
print(
|
|
528
|
+
f"{r['run_id']:<32} {r['state']:<11} {algorithm:<5} "
|
|
529
|
+
f"{r.get('cost_usd', 0.0):>8.4f} {where:<22} {model}"
|
|
530
|
+
)
|
|
531
|
+
return 0
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def cmd_cancel(args) -> int:
|
|
535
|
+
status = client_from_config().cancel_run(args.run_id)
|
|
536
|
+
payload = {"run_id": args.run_id, "state": status["state"]}
|
|
537
|
+
if render.styled():
|
|
538
|
+
print(render.object_panel("cancel", payload))
|
|
539
|
+
else:
|
|
540
|
+
print(json.dumps(payload, indent=2))
|
|
541
|
+
return 0
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def cmd_checkpoints(args) -> int:
|
|
545
|
+
checkpoints = client_from_config().checkpoints(args.run_id)
|
|
546
|
+
if not checkpoints:
|
|
547
|
+
print(
|
|
548
|
+
f"no deployable checkpoints for {args.run_id} yet "
|
|
549
|
+
"(RL streams one per save interval; SFT-only runs have none).",
|
|
550
|
+
file=sys.stderr,
|
|
551
|
+
)
|
|
552
|
+
return 0
|
|
553
|
+
for c in checkpoints:
|
|
554
|
+
print(f"step {c['step']:>6} {c['repo_id']}:{c['subfolder']}")
|
|
555
|
+
print(
|
|
556
|
+
f"\ndeploy one with `flash deploy {args.run_id} --step <STEP>`.",
|
|
557
|
+
file=sys.stderr,
|
|
558
|
+
)
|
|
559
|
+
return 0
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def cmd_deploy(args) -> int:
|
|
563
|
+
dep = client_from_config().deploy(
|
|
564
|
+
args.run_id,
|
|
565
|
+
dry_run=args.dry_run,
|
|
566
|
+
step=getattr(args, "step", None),
|
|
567
|
+
)
|
|
568
|
+
if render.styled():
|
|
569
|
+
print(render.object_panel("deploy", dep))
|
|
570
|
+
else:
|
|
571
|
+
print(json.dumps(dep, indent=2))
|
|
572
|
+
print(
|
|
573
|
+
"note: serving is billed per token only; use "
|
|
574
|
+
f"`flash undeploy {args.run_id}` to deregister the adapter.",
|
|
575
|
+
file=sys.stderr,
|
|
576
|
+
)
|
|
577
|
+
return 0
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def cmd_undeploy(args) -> int:
|
|
581
|
+
result = client_from_config().undeploy(args.run_id)
|
|
582
|
+
if render.styled():
|
|
583
|
+
print(render.object_panel("undeploy", result))
|
|
584
|
+
else:
|
|
585
|
+
print(json.dumps(result, indent=2))
|
|
586
|
+
return 0
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def cmd_deployments(args) -> int:
|
|
590
|
+
rows = client_from_config().deployments()
|
|
591
|
+
if not rows:
|
|
592
|
+
if render.styled():
|
|
593
|
+
print(render.empty("deployments", "0 active", "no active deployments"))
|
|
594
|
+
else:
|
|
595
|
+
print("no active deployments")
|
|
596
|
+
return 0
|
|
597
|
+
if render.styled():
|
|
598
|
+
print(render.deployments_table(rows))
|
|
599
|
+
return 0
|
|
600
|
+
print(f"{'RUN_ID':<32} {'GPU':<9} ENDPOINT")
|
|
601
|
+
for r in rows:
|
|
602
|
+
d = r.get("deployment") or {}
|
|
603
|
+
print(f"{r['run_id']:<32} {d.get('gpu', '?'):<9} {d.get('endpoint_name', '')}")
|
|
604
|
+
return 0
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def cmd_chat(args) -> int:
|
|
608
|
+
client = client_from_config()
|
|
609
|
+
messages = [{"role": "user", "content": args.message}]
|
|
610
|
+
# A faint speaker label on a TTY; the reply text itself stays plain so a piped transcript
|
|
611
|
+
# is byte-for-byte the model's words.
|
|
612
|
+
if render.styled():
|
|
613
|
+
print(render.chat_label())
|
|
614
|
+
stream = getattr(client, "chat_stream", None)
|
|
615
|
+
if stream is not None:
|
|
616
|
+
wrote = False
|
|
617
|
+
for chunk in stream(
|
|
618
|
+
args.run_id,
|
|
619
|
+
messages=messages,
|
|
620
|
+
temperature=args.temperature,
|
|
621
|
+
max_tokens=args.max_tokens,
|
|
622
|
+
):
|
|
623
|
+
print(chunk, end="", flush=True)
|
|
624
|
+
wrote = True
|
|
625
|
+
if wrote:
|
|
626
|
+
print()
|
|
627
|
+
return 0
|
|
628
|
+
|
|
629
|
+
resp = client.chat(
|
|
630
|
+
args.run_id,
|
|
631
|
+
messages=messages,
|
|
632
|
+
temperature=args.temperature,
|
|
633
|
+
max_tokens=args.max_tokens,
|
|
634
|
+
)
|
|
635
|
+
print(resp["choices"][0]["message"]["content"])
|
|
636
|
+
return 0
|