freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,636 @@
1
+ """CLI command handlers for the managed Flash service.
2
+
3
+ Every run-lifecycle command is a thin HTTP call to the Flash control plane —
4
+ users authenticate with their freesolo API key (`flash login` verifies it against
5
+ the freesolo backend), never with provider credentials. Config parsing/validation
6
+ and `--dry-run` stay fully local.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+
17
+ from flash import __version__
18
+ from flash._logging import get_logger
19
+ from flash.catalog import public_model_rows
20
+ from flash.client import (
21
+ ApiClient,
22
+ ClientError,
23
+ client_from_config,
24
+ save_credentials,
25
+ verify_freesolo_key,
26
+ )
27
+ from flash.client.config import load_credentials
28
+ from flash.client.runtime_secrets import runtime_secrets_from_local_env
29
+ from flash.client.specs import spec_payload
30
+ from flash.cost.spec import runconfig_from_spec
31
+ from flash.runner import TERMINAL_STATES, new_run_id
32
+ from flash.schema import ConfigError, spec_from_file
33
+
34
+ from . import render
35
+ from .training_doc import TRAINING_MD
36
+
37
+ logger = get_logger("flash.cli.main")
38
+
39
+
40
+ # Exceptions that represent expected user/config errors: report them as a clean one-line
41
+ # message instead of a Python traceback (use --debug to see the full trace).
42
+ _USER_ERRORS = (
43
+ ConfigError,
44
+ ClientError,
45
+ FileNotFoundError,
46
+ ValueError,
47
+ )
48
+
49
+ # Run states after which nothing more will happen (polling can stop).
50
+ _CLI_DONE_STATES = TERMINAL_STATES | {"deployed"}
51
+ _OK_STATES = {"done", "dry_run", "deployed"}
52
+ _SPINNER_FRAMES = "|/-\\"
53
+ _SPINNER_TICK_SECONDS = 0.1
54
+
55
+
56
+ class _LogFollowSpinner:
57
+ def __init__(self, run_id: str):
58
+ self._run_id = run_id
59
+ self._frame = 0
60
+ self._last_len = 0
61
+ self._active = False
62
+ self._enabled = sys.stderr.isatty()
63
+
64
+ @property
65
+ def enabled(self) -> bool:
66
+ return self._enabled
67
+
68
+ def render(self, state: str) -> None:
69
+ if not self._enabled:
70
+ return
71
+ frame = _SPINNER_FRAMES[self._frame % len(_SPINNER_FRAMES)]
72
+ self._frame += 1
73
+ message = f"{frame} following logs for {self._run_id} ({state})"
74
+ padding = " " * max(0, self._last_len - len(message))
75
+ sys.stderr.write(f"\r{message}{padding}")
76
+ sys.stderr.flush()
77
+ self._last_len = len(message)
78
+ self._active = True
79
+
80
+ def clear(self) -> None:
81
+ if not (self._enabled and self._active):
82
+ return
83
+ sys.stderr.write(f"\r{' ' * self._last_len}\r")
84
+ sys.stderr.flush()
85
+ self._active = False
86
+
87
+
88
+ def _sleep_with_spinner(interval: float, spinner: _LogFollowSpinner, state: str) -> None:
89
+ if interval <= 0:
90
+ return
91
+ if not spinner.enabled:
92
+ time.sleep(interval)
93
+ return
94
+ ticks = max(1, int(interval / _SPINNER_TICK_SECONDS))
95
+ sleep_for = interval / ticks
96
+ for _ in range(ticks):
97
+ spinner.render(state)
98
+ time.sleep(sleep_for)
99
+
100
+
101
+ def cmd_version(args) -> int:
102
+ if render.styled():
103
+ print(render.version(__version__))
104
+ else:
105
+ print(f"flash {__version__}")
106
+ return 0
107
+
108
+
109
+ def cmd_login(args) -> int:
110
+ # Login is handled by the freesolo backend (not the flash control plane): the user
111
+ # supplies the freesolo API key they created at freesolo.co/sign-in, and we verify it against
112
+ # freesolo before storing it. The same key authenticates flash's control plane.
113
+ try:
114
+ env_api_key = os.environ.get("FREESOLO_API_KEY")
115
+ api_key = args.api_key or env_api_key
116
+ if not api_key:
117
+ raise ClientError(
118
+ "no API key provided: pass `--api-key <key>` or set FREESOLO_API_KEY. "
119
+ "Create or copy a key at https://freesolo.co/sign-in."
120
+ )
121
+ verify_freesolo_key(api_key, base_url=getattr(args, "freesolo_url", None))
122
+ except ClientError as exc:
123
+ # Login failed (no key, a rejected key, or an unreachable backend): say so plainly
124
+ # and point the user back at `flash login` to try again. `--debug` still surfaces
125
+ # the full traceback via the top-level handler.
126
+ if getattr(args, "debug", False):
127
+ raise
128
+ print(render.login_failed(str(exc)), file=sys.stderr)
129
+ return 1
130
+ api_url = args.api_url or load_credentials()[0]
131
+ # save_credentials clears the stored url when it's the default, so logging into the
132
+ # default plane also drops a stale custom url from a previous custom-URL login.
133
+ _ = save_credentials(api_key, api_url=api_url)
134
+ if args.api_key and env_api_key and env_api_key != args.api_key:
135
+ print(
136
+ "warning: FREESOLO_API_KEY is set and will override this saved login for future "
137
+ "commands; unset FREESOLO_API_KEY to use the saved key.",
138
+ file=sys.stderr,
139
+ )
140
+ # Show who they are right away (the same identity `flash whoami` prints) so they don't
141
+ # have to run a second command. Never echo the key itself. The identity lookup is
142
+ # best-effort: the key is already verified and stored, so a momentary control-plane
143
+ # hiccup must not turn a successful login into a failure.
144
+ print(render.login_ok(_identity_or_none(api_key, api_url)))
145
+ return 0
146
+
147
+
148
+ # A control-plane hiccup must not make a successful login appear to hang while we fetch a
149
+ # nonessential card, so the best-effort identity lookup uses a short timeout.
150
+ _IDENTITY_LOOKUP_TIMEOUT_S = 5.0
151
+
152
+
153
+ def _identity_or_none(api_key: str, api_url: str) -> dict | None:
154
+ # Use the key/url we just verified and stored, not `client_from_config()`: an ambient
155
+ # FREESOLO_API_KEY would otherwise win over the file and render the wrong identity.
156
+ try:
157
+ return ApiClient(api_url, api_key, timeout=_IDENTITY_LOOKUP_TIMEOUT_S).me()
158
+ except (ClientError, OSError, ValueError):
159
+ return None
160
+
161
+
162
+ def cmd_whoami(args) -> int:
163
+ print(render.whoami(client_from_config().me()))
164
+ return 0
165
+
166
+
167
+ _STARTER_ENV_PY = '''\
168
+ """Starter Freesolo environment.
169
+
170
+ Edit datasets/train.jsonl and the reward code, then upload with
171
+ `flash env push --name my-env .`.
172
+
173
+ A managed run should use the returned [environment] id from
174
+ `flash env push --name my-env .`.
175
+
176
+ This starter keeps a tiny smoke-test dataset in datasets/train.jsonl. Replace it
177
+ with your real training rows before a real run.
178
+ """
179
+
180
+ from __future__ import annotations
181
+
182
+ import json
183
+ from pathlib import Path
184
+
185
+ from freesolo.datasets.types import TaskExample
186
+ from freesolo.environments import EnvironmentSingleTurn, RewardResult
187
+
188
+
189
+ DEFAULT_DATASET_PATH = Path(__file__).parent / "datasets" / "train.jsonl"
190
+
191
+
192
+ def load_jsonl(path: str | Path):
193
+ rows = []
194
+ with Path(path).open() as f:
195
+ for line in f:
196
+ line = line.strip()
197
+ if line:
198
+ rows.append(json.loads(line))
199
+ return rows
200
+
201
+
202
+ def exact_match_reward(example: TaskExample, response_text: str) -> RewardResult:
203
+ expected = str(example.output or "").strip()
204
+ score = 1.0 if expected and expected in response_text else 0.0
205
+ return RewardResult(score=score, threshold=1.0)
206
+
207
+
208
+ class StarterEnv(EnvironmentSingleTurn):
209
+ dataset = load_jsonl(DEFAULT_DATASET_PATH)
210
+
211
+ def build_prompt_messages(self, example: TaskExample, prompt_text: str):
212
+ return [{"role": "user", "content": example.input}]
213
+
214
+ def score_response(self, example: TaskExample, response_text: str) -> RewardResult:
215
+ return exact_match_reward(example, response_text)
216
+
217
+
218
+ def load_environment(dataset_path: str | None = None, **kwargs) -> StarterEnv:
219
+ env = StarterEnv()
220
+ if dataset_path:
221
+ env.dataset = load_jsonl(dataset_path)
222
+ return env
223
+ '''
224
+
225
+ _STARTER_DATASET_JSONL = """\
226
+ {"input":"What is 2 + 2?","output":"4"}
227
+ {"input":"What is 3 + 5?","output":"8"}
228
+ """
229
+
230
+
231
+ def cmd_env_setup(args) -> int:
232
+ Path("configs").mkdir(exist_ok=True)
233
+ Path("datasets").mkdir(exist_ok=True)
234
+ dataset = Path("datasets/train.jsonl")
235
+ if not dataset.exists():
236
+ dataset.write_text(_STARTER_DATASET_JSONL)
237
+ starter_env = Path("environment.py")
238
+ if not starter_env.exists():
239
+ starter_env.write_text(_STARTER_ENV_PY)
240
+ env_comment = (
241
+ "# Environment: upload this project folder with\n"
242
+ "# `flash env push --name my-env .`, then paste the returned id below.\n"
243
+ "# If the environment reads secrets with os.environ, list only the env var names here.\n"
244
+ "# Values are read from your shell or .env at submit time and are not stored in the spec.\n"
245
+ "[environment]\n"
246
+ 'id = ""\n\n'
247
+ '# secrets = ["SERPAPI_API_KEY"]\n\n'
248
+ )
249
+ rl = Path("configs/rl.toml")
250
+ if not rl.exists():
251
+ rl.write_text(
252
+ 'model = "Qwen/Qwen3.5-4B"\n'
253
+ 'algorithm = "grpo"\n\n'
254
+ f"{env_comment}"
255
+ "[train]\n"
256
+ "steps = 150\n"
257
+ "lora_rank = 32\n"
258
+ "seeds = [0]\n"
259
+ "# GPU and the HF artifact repo are managed automatically by the platform: the GPU is\n"
260
+ "# the cheapest fitting class across providers, and each run gets its own artifact repo.\n"
261
+ )
262
+ sft = Path("configs/sft.toml")
263
+ if not sft.exists():
264
+ sft.write_text(
265
+ 'model = "Qwen/Qwen3.5-4B"\n'
266
+ 'algorithm = "sft"\n\n'
267
+ f"{env_comment}"
268
+ "[train]\n"
269
+ "epochs = 1\n"
270
+ "lora_rank = 32\n"
271
+ "seeds = [0]\n"
272
+ "# GPU and the HF artifact repo are managed automatically by the platform: the GPU is\n"
273
+ "# the cheapest fitting class across providers, and each run gets its own artifact repo.\n"
274
+ )
275
+ # TRAINING.md is the playbook for the AI agent driving these runs: how to design the
276
+ # reward, what to read, and how to decide a run actually improved (not just finished).
277
+ training = Path("TRAINING.md")
278
+ if not training.exists():
279
+ # Explicit UTF-8: TRAINING_MD has non-ASCII (em dashes, ·, √, ≥, ≈), which would
280
+ # raise UnicodeEncodeError under a non-UTF-8 locale with write_text's default.
281
+ training.write_text(TRAINING_MD, encoding="utf-8")
282
+ scaffolded = [
283
+ "environment.py",
284
+ "datasets/train.jsonl",
285
+ "configs/rl.toml",
286
+ "configs/sft.toml",
287
+ "TRAINING.md",
288
+ ]
289
+ if render.styled():
290
+ print(render.env_setup(scaffolded))
291
+ return 0
292
+ print(f"ensured {', '.join(scaffolded)}")
293
+ return 0
294
+
295
+
296
+ def cmd_models(args) -> int:
297
+ rows = public_model_rows()
298
+ if render.styled():
299
+ print(render.models_table(rows))
300
+ return 0
301
+ for row in rows:
302
+ print(row["id"])
303
+ return 0
304
+
305
+
306
+ def cmd_gpus(args) -> int:
307
+ """List RunPod GPU classes, VRAM, and $/hr."""
308
+ from flash.providers.base import GPU_INFO
309
+ from flash.providers.runpod.pricing import static_rates as runpod_static_rates
310
+
311
+ runpod_rates = runpod_static_rates()
312
+ infos = sorted(
313
+ (info for info in GPU_INFO.values() if info.enum_member), key=lambda g: g.hourly_usd
314
+ )
315
+ tip = (
316
+ "Tip: GPU class selection is fully automatic — the submit-time allocator always picks the\n"
317
+ "cheapest validated RunPod class that fits the model, so you don't pin a GPU type."
318
+ )
319
+ if render.styled():
320
+ rows = [(info.name, info.vram_gb, runpod_rates.get(info.name)) for info in infos]
321
+ print(render.gpus_table(rows, tip))
322
+ return 0
323
+
324
+ def fmt_rate(v: float | None) -> str:
325
+ return f"{v:>10.2f}" if v else f"{'-':>10}"
326
+
327
+ print(f"{'gpu':<16}{'vram':>6}{'runpod$/hr':>11}")
328
+ for info in infos:
329
+ runpod_rate = runpod_rates.get(info.name)
330
+ print(f"{info.name:<16}{info.vram_gb:>5}G{fmt_rate(runpod_rate):>11}")
331
+ print(f"\n{tip}")
332
+ return 0
333
+
334
+
335
+ def cmd_env_list(args) -> int:
336
+ from flash.envs.registry import list_installed_environments
337
+
338
+ installed = list_installed_environments()
339
+ paths: list[str] = []
340
+ if Path("environment.py").is_file():
341
+ paths.append(".")
342
+ local = Path("environments")
343
+ if local.is_dir():
344
+ # Prefer publishing folders. Single-file modules remain supported for small smoke tests.
345
+ for p in local.iterdir():
346
+ if p.name.startswith("__"):
347
+ continue
348
+ if p.is_dir():
349
+ stem = p.name.replace("-", "_")
350
+ module = p / f"{stem}.py"
351
+ canonical = p / "environment.py"
352
+ if canonical.is_file() or module.is_file():
353
+ paths.append(f"environments/{p.name}")
354
+ elif p.suffix == ".py":
355
+ paths.append(f"environments/{p.name}")
356
+ # Decide the rendering up front so the themed panel and the legacy lines never both print.
357
+ if render.styled():
358
+ print(render.env_list(list(installed), sorted(paths)))
359
+ return 0
360
+ if installed:
361
+ print("installed environments:")
362
+ for env_id in installed:
363
+ print(f" {env_id}")
364
+ if paths:
365
+ print("local env sources (publish with `flash env push --name <name> <path>`):")
366
+ for path in sorted(paths):
367
+ print(f" {path}")
368
+ return 0
369
+
370
+
371
+ def _cmd_train_cost(args) -> int:
372
+ """`flash train --cost`: print the pre-flight USD cost for the config and exit (no submit).
373
+
374
+ Catalog-only and deterministic; an uncapped SFT run tries to count the env's train split, and
375
+ falls back to a default example count (with a warning) when the environment isn't
376
+ importable here."""
377
+ from flash.cost import estimate_cost
378
+
379
+ spec = spec_from_file(
380
+ args.config,
381
+ run_id=None,
382
+ overrides=args.overrides,
383
+ extra_configs=args.extra_configs,
384
+ )
385
+ estimate = estimate_cost(runconfig_from_spec(spec))
386
+ if render.styled():
387
+ print(render.cost_panel(estimate))
388
+ else:
389
+ print(estimate.breakdown())
390
+ return 0
391
+
392
+
393
+ def cmd_train(args) -> int:
394
+ if getattr(args, "cost", False):
395
+ return _cmd_train_cost(args)
396
+ spec = spec_from_file(
397
+ args.config,
398
+ run_id=new_run_id() if args.dry_run else None,
399
+ overrides=args.overrides,
400
+ extra_configs=args.extra_configs,
401
+ )
402
+ if args.dry_run:
403
+ # Fully local: validate the id-based config without credentials, a server, or a GPU.
404
+ payload = {"run_id": spec.run_id, "state": "dry_run", "spec": spec.to_dict()}
405
+ if render.styled():
406
+ print(
407
+ render.object_panel("train", payload, "dry run — validated locally, not submitted")
408
+ )
409
+ else:
410
+ print(json.dumps(payload, indent=2))
411
+ return 0
412
+ client = client_from_config()
413
+ status = client.create_run(
414
+ spec_payload(spec),
415
+ runtime_secrets=runtime_secrets_from_local_env(args.config, keys=spec.environment.secrets),
416
+ )
417
+ run_id = status["run_id"]
418
+ logger.info(
419
+ "submitted run %s: model=%s algorithm=%s gpu=%s seeds=%s",
420
+ run_id,
421
+ spec.model,
422
+ spec.algorithm,
423
+ spec.gpu.type,
424
+ list(spec.train.seeds),
425
+ )
426
+ if args.background:
427
+ if render.styled():
428
+ print(render.object_panel("train", status, "submitted (running in background)"))
429
+ else:
430
+ print(json.dumps(status, indent=2))
431
+ return 0
432
+ if render.styled():
433
+ print(render.submitted(run_id), file=sys.stderr)
434
+ else:
435
+ print(
436
+ f"run {run_id} submitted; following logs "
437
+ f"(Ctrl-C detaches, `flash status {run_id} --follow` resumes)",
438
+ file=sys.stderr,
439
+ )
440
+ return _follow_run(client, run_id)
441
+
442
+
443
+ def _poll_logs(client: ApiClient, run_id: str, interval: float) -> str:
444
+ """Stream offset-paged logs until the run reaches a terminal state; return that state."""
445
+ offset = 0
446
+ spinner = _LogFollowSpinner(run_id)
447
+ try:
448
+ while True:
449
+ page = client.get_logs(run_id, offset=offset)
450
+ if page["logs"]:
451
+ spinner.clear()
452
+ print(page["logs"], end="", flush=True)
453
+ offset = page["offset"]
454
+ if page["state"] in _CLI_DONE_STATES:
455
+ spinner.clear()
456
+ return page["state"]
457
+ _sleep_with_spinner(interval, spinner, page["state"])
458
+ finally:
459
+ spinner.clear()
460
+
461
+
462
+ def _follow_run(client: ApiClient, run_id: str) -> int:
463
+ """Poll logs until the run reaches a terminal state, then print the final status."""
464
+ state = _poll_logs(client, run_id, interval=2.0)
465
+ status = client.get_run(run_id)
466
+ if render.styled():
467
+ print(render.run_status(status))
468
+ else:
469
+ print(json.dumps(status, indent=2))
470
+ return 0 if state in _OK_STATES else 1
471
+
472
+
473
+ def cmd_status(args) -> int:
474
+ client = client_from_config()
475
+ if getattr(args, "follow", False):
476
+ return _follow_run(client, args.run_id)
477
+ if getattr(args, "logs", False):
478
+ logs = client.get_logs(args.run_id).get("logs", "")
479
+ printed_any = False
480
+ if logs:
481
+ print(logs, end="")
482
+ if not logs.endswith("\n"):
483
+ print()
484
+ printed_any = True
485
+ # Always append the real train-subprocess output (the orchestrator log can't carry it);
486
+ # the server fetches console_/error_<phase>.txt from HF with the operator token.
487
+ for name, text in (client.get_worker_output(args.run_id) or {}).items():
488
+ if not text:
489
+ continue
490
+ # Separate sections with a blank line, but NOT before the first thing printed (an empty
491
+ # orchestrator log would otherwise leave a leading blank line above the first section).
492
+ sep = "\n" if printed_any else ""
493
+ print(f"{sep}----- {name} -----")
494
+ print(text, end="" if text.endswith("\n") else "\n")
495
+ printed_any = True
496
+ status = client.get_run(args.run_id)
497
+ if render.styled():
498
+ print(render.run_status(status))
499
+ else:
500
+ print(json.dumps(status, indent=2))
501
+ return 0
502
+
503
+
504
+ def cmd_runs(args) -> int:
505
+ runs = client_from_config().list_runs()
506
+ if not runs:
507
+ if render.styled():
508
+ print(render.empty("runs", "0 runs", "no runs yet — submit one with `flash train`"))
509
+ else:
510
+ print("no runs yet")
511
+ return 0
512
+ if render.styled():
513
+ print(render.runs_table(runs))
514
+ return 0
515
+ print(f"{'RUN_ID':<32} {'STATE':<11} {'ALGO':<5} {'COST($)':>8} {'GPU':<22} MODEL")
516
+ for r in sorted(runs, key=lambda r: r.get("updated_at", 0), reverse=True):
517
+ spec = r.get("spec") or {}
518
+ model = spec.get("model", "")
519
+ algorithm = str(spec.get("algorithm") or "-").upper()
520
+ remote = r.get("remote") or {}
521
+ # the remote handle knows what actually ran; the spec is the parse-time pick
522
+ provider = remote.get("provider") or (
523
+ "runpod" if remote else (spec.get("gpu") or {}).get("provider", "")
524
+ )
525
+ gpu = remote.get("gpu") or (spec.get("gpu") or {}).get("type", "")
526
+ where = f"{gpu}@{provider}" if provider else gpu
527
+ print(
528
+ f"{r['run_id']:<32} {r['state']:<11} {algorithm:<5} "
529
+ f"{r.get('cost_usd', 0.0):>8.4f} {where:<22} {model}"
530
+ )
531
+ return 0
532
+
533
+
534
+ def cmd_cancel(args) -> int:
535
+ status = client_from_config().cancel_run(args.run_id)
536
+ payload = {"run_id": args.run_id, "state": status["state"]}
537
+ if render.styled():
538
+ print(render.object_panel("cancel", payload))
539
+ else:
540
+ print(json.dumps(payload, indent=2))
541
+ return 0
542
+
543
+
544
+ def cmd_checkpoints(args) -> int:
545
+ checkpoints = client_from_config().checkpoints(args.run_id)
546
+ if not checkpoints:
547
+ print(
548
+ f"no deployable checkpoints for {args.run_id} yet "
549
+ "(RL streams one per save interval; SFT-only runs have none).",
550
+ file=sys.stderr,
551
+ )
552
+ return 0
553
+ for c in checkpoints:
554
+ print(f"step {c['step']:>6} {c['repo_id']}:{c['subfolder']}")
555
+ print(
556
+ f"\ndeploy one with `flash deploy {args.run_id} --step <STEP>`.",
557
+ file=sys.stderr,
558
+ )
559
+ return 0
560
+
561
+
562
+ def cmd_deploy(args) -> int:
563
+ dep = client_from_config().deploy(
564
+ args.run_id,
565
+ dry_run=args.dry_run,
566
+ step=getattr(args, "step", None),
567
+ )
568
+ if render.styled():
569
+ print(render.object_panel("deploy", dep))
570
+ else:
571
+ print(json.dumps(dep, indent=2))
572
+ print(
573
+ "note: serving is billed per token only; use "
574
+ f"`flash undeploy {args.run_id}` to deregister the adapter.",
575
+ file=sys.stderr,
576
+ )
577
+ return 0
578
+
579
+
580
+ def cmd_undeploy(args) -> int:
581
+ result = client_from_config().undeploy(args.run_id)
582
+ if render.styled():
583
+ print(render.object_panel("undeploy", result))
584
+ else:
585
+ print(json.dumps(result, indent=2))
586
+ return 0
587
+
588
+
589
+ def cmd_deployments(args) -> int:
590
+ rows = client_from_config().deployments()
591
+ if not rows:
592
+ if render.styled():
593
+ print(render.empty("deployments", "0 active", "no active deployments"))
594
+ else:
595
+ print("no active deployments")
596
+ return 0
597
+ if render.styled():
598
+ print(render.deployments_table(rows))
599
+ return 0
600
+ print(f"{'RUN_ID':<32} {'GPU':<9} ENDPOINT")
601
+ for r in rows:
602
+ d = r.get("deployment") or {}
603
+ print(f"{r['run_id']:<32} {d.get('gpu', '?'):<9} {d.get('endpoint_name', '')}")
604
+ return 0
605
+
606
+
607
+ def cmd_chat(args) -> int:
608
+ client = client_from_config()
609
+ messages = [{"role": "user", "content": args.message}]
610
+ # A faint speaker label on a TTY; the reply text itself stays plain so a piped transcript
611
+ # is byte-for-byte the model's words.
612
+ if render.styled():
613
+ print(render.chat_label())
614
+ stream = getattr(client, "chat_stream", None)
615
+ if stream is not None:
616
+ wrote = False
617
+ for chunk in stream(
618
+ args.run_id,
619
+ messages=messages,
620
+ temperature=args.temperature,
621
+ max_tokens=args.max_tokens,
622
+ ):
623
+ print(chunk, end="", flush=True)
624
+ wrote = True
625
+ if wrote:
626
+ print()
627
+ return 0
628
+
629
+ resp = client.chat(
630
+ args.run_id,
631
+ messages=messages,
632
+ temperature=args.temperature,
633
+ max_tokens=args.max_tokens,
634
+ )
635
+ print(resp["choices"][0]["message"]["content"])
636
+ return 0