rnow 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnow/__init__.py +5 -0
- rnow/__main__.py +7 -0
- rnow/cli/__init__.py +6 -0
- rnow/cli/auth.py +67 -0
- rnow/cli/blob.py +98 -0
- rnow/cli/commands.py +2311 -0
- rnow/cli/common.py +28 -0
- rnow/cli/cube.py +255 -0
- rnow/cli/main.py +49 -0
- rnow/cli/test.py +728 -0
- rnow/cli/token_count.py +295 -0
- rnow/core/__init__.py +33 -0
- rnow/core/reward.py +333 -0
- rnow/core/tool.py +494 -0
- rnow/models.py +295 -0
- rnow/templates/deepseek-aha/config.yml +26 -0
- rnow/templates/deepseek-aha/rewards.py +36 -0
- rnow/templates/deepseek-aha/train.jsonl +1000 -0
- rnow/templates/mcp-tavily/config.yml +29 -0
- rnow/templates/mcp-tavily/requirements.txt +1 -0
- rnow/templates/mcp-tavily/rewards.py +25 -0
- rnow/templates/mcp-tavily/train.jsonl +500 -0
- rnow/templates/new/config.yml +26 -0
- rnow/templates/new/requirements.txt +1 -0
- rnow/templates/new/rewards.py +0 -0
- rnow/templates/new/train.jsonl +0 -0
- rnow/templates/rl-nextjs/config.yml +27 -0
- rnow/templates/rl-nextjs/requirements.txt +2 -0
- rnow/templates/rl-nextjs/rewards.py +446 -0
- rnow/templates/rl-nextjs/train.jsonl +1000 -0
- rnow/templates/rl-single/config.yml +27 -0
- rnow/templates/rl-single/requirements.txt +1 -0
- rnow/templates/rl-single/rewards.py +14 -0
- rnow/templates/rl-single/train.jsonl +1000 -0
- rnow/templates/rl-tools/config.yml +27 -0
- rnow/templates/rl-tools/env.py +38 -0
- rnow/templates/rl-tools/requirements.txt +3 -0
- rnow/templates/rl-tools/rewards.py +25 -0
- rnow/templates/rl-tools/train.jsonl +500 -0
- rnow/templates/sft/config.yml +20 -0
- rnow/templates/sft/train.jsonl +100 -0
- rnow/templates/tutorial-reward/config.yml +27 -0
- rnow/templates/tutorial-reward/requirements.txt +1 -0
- rnow/templates/tutorial-reward/rewards.py +15 -0
- rnow/templates/tutorial-reward/train.jsonl +1000 -0
- rnow/templates/tutorial-tool/config.yml +27 -0
- rnow/templates/tutorial-tool/env.py +7 -0
- rnow/templates/tutorial-tool/requirements.txt +3 -0
- rnow/templates/tutorial-tool/rewards.py +7 -0
- rnow/templates/tutorial-tool/train.jsonl +1266 -0
- rnow-0.2.4.dist-info/METADATA +135 -0
- rnow-0.2.4.dist-info/RECORD +56 -0
- rnow-0.2.4.dist-info/WHEEL +5 -0
- rnow-0.2.4.dist-info/entry_points.txt +2 -0
- rnow-0.2.4.dist-info/licenses/LICENSE +21 -0
- rnow-0.2.4.dist-info/top_level.txt +1 -0
rnow/cli/commands.py
ADDED
|
@@ -0,0 +1,2311 @@
|
|
|
1
|
+
# reinforcenow/cli/commands.py
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
import json
|
|
5
|
+
import sys
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
import webbrowser
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import click
|
|
13
|
+
import requests
|
|
14
|
+
import yaml
|
|
15
|
+
from pydantic import ValidationError
|
|
16
|
+
|
|
17
|
+
# ReinforceNow teal: #14B8A6
|
|
18
|
+
TEAL_RGB = (20, 184, 166)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Spinner:
|
|
22
|
+
"""Simple spinner for CLI feedback."""
|
|
23
|
+
|
|
24
|
+
FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
25
|
+
|
|
26
|
+
def __init__(self, message: str = ""):
|
|
27
|
+
self.message = message
|
|
28
|
+
self._stop_event = threading.Event()
|
|
29
|
+
self._thread: threading.Thread | None = None
|
|
30
|
+
|
|
31
|
+
def _spin(self):
|
|
32
|
+
for frame in itertools.cycle(self.FRAMES):
|
|
33
|
+
if self._stop_event.is_set():
|
|
34
|
+
break
|
|
35
|
+
sys.stdout.write(f"\r\033[K{frame} {self.message}")
|
|
36
|
+
sys.stdout.flush()
|
|
37
|
+
time.sleep(0.08)
|
|
38
|
+
sys.stdout.write("\r\033[K")
|
|
39
|
+
sys.stdout.flush()
|
|
40
|
+
|
|
41
|
+
def start(self):
|
|
42
|
+
self._stop_event.clear()
|
|
43
|
+
self._thread = threading.Thread(target=self._spin, daemon=True)
|
|
44
|
+
self._thread.start()
|
|
45
|
+
|
|
46
|
+
def stop(self):
|
|
47
|
+
self._stop_event.set()
|
|
48
|
+
if self._thread:
|
|
49
|
+
self._thread.join(timeout=0.5)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
from rnow import models
|
|
53
|
+
from rnow.cli import auth
|
|
54
|
+
from rnow.cli.blob import MAX_INLINE_BYTES, maybe_upload_to_blob
|
|
55
|
+
from rnow.cli.common import get_active_organization, require_auth
|
|
56
|
+
from rnow.cli.cube import CubeSpinner
|
|
57
|
+
|
|
58
|
+
CONFIG_DOCS_URL = "https://reinforcenow.ai/docs/cli-reference/configuration"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def format_validation_error(e: ValidationError) -> str:
|
|
62
|
+
"""
|
|
63
|
+
Format Pydantic ValidationError into a user-friendly message.
|
|
64
|
+
"""
|
|
65
|
+
lines = ["", click.style("✗ Invalid config.yml", fg="red", bold=True), ""]
|
|
66
|
+
|
|
67
|
+
for error in e.errors():
|
|
68
|
+
loc = ".".join(str(x) for x in error["loc"])
|
|
69
|
+
msg = error["msg"]
|
|
70
|
+
error_type = error["type"]
|
|
71
|
+
|
|
72
|
+
# For root-level validation errors, show a more specific field name
|
|
73
|
+
if not loc and error_type == "value_error":
|
|
74
|
+
if "qlora_rank" in msg:
|
|
75
|
+
loc = "model.qlora_rank"
|
|
76
|
+
elif "batch_size" in msg:
|
|
77
|
+
loc = "data"
|
|
78
|
+
|
|
79
|
+
lines.append(f" Field: {click.style(loc or '(root)', bold=True)}")
|
|
80
|
+
|
|
81
|
+
# Get the input value if available (skip for dict/complex types)
|
|
82
|
+
if "input" in error:
|
|
83
|
+
input_val = error["input"]
|
|
84
|
+
# Skip showing full config dicts
|
|
85
|
+
if isinstance(input_val, dict) and len(input_val) > 3:
|
|
86
|
+
pass # Don't show large dicts
|
|
87
|
+
elif isinstance(input_val, str) and len(input_val) > 50:
|
|
88
|
+
lines.append(f" Got: {repr(input_val[:50] + '...')}")
|
|
89
|
+
else:
|
|
90
|
+
lines.append(f" Got: {repr(input_val)}")
|
|
91
|
+
|
|
92
|
+
# Format the error message nicely
|
|
93
|
+
if error_type == "literal_error":
|
|
94
|
+
# Extract expected values from the message
|
|
95
|
+
lines.append(f" Error: {msg}")
|
|
96
|
+
elif error_type == "extra_forbidden":
|
|
97
|
+
lines.append(" Error: Unknown field (typo?)")
|
|
98
|
+
elif error_type == "missing":
|
|
99
|
+
lines.append(" Error: Required field is missing")
|
|
100
|
+
elif error_type == "greater_than" or error_type == "greater_than_equal":
|
|
101
|
+
lines.append(f" Error: {msg}")
|
|
102
|
+
elif error_type == "less_than_equal":
|
|
103
|
+
lines.append(f" Error: {msg}")
|
|
104
|
+
if "batch_size" in loc:
|
|
105
|
+
lines.append(" Hint: Maximum batch_size is 32")
|
|
106
|
+
elif "group_size" in loc:
|
|
107
|
+
lines.append(" Hint: Maximum group_size is 64")
|
|
108
|
+
elif error_type == "value_error" and "batch_size * group_size" in msg:
|
|
109
|
+
lines.append(f" Error: {msg}")
|
|
110
|
+
lines.append(" Hint: Reduce batch_size or group_size to stay within the 2048 limit")
|
|
111
|
+
elif error_type == "value_error" and "qlora_rank" in msg:
|
|
112
|
+
# Clean up the error message (remove "Value error, " prefix)
|
|
113
|
+
clean_msg = msg.replace("Value error, ", "")
|
|
114
|
+
lines.append(f" Error: {clean_msg}")
|
|
115
|
+
lines.append(
|
|
116
|
+
" Hint: Different models have different max LoRA ranks (32, 64, or 128)"
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
lines.append(f" Error: {msg}")
|
|
120
|
+
|
|
121
|
+
lines.append("")
|
|
122
|
+
|
|
123
|
+
lines.append(f" See: {click.style(CONFIG_DOCS_URL, fg=TEAL_RGB, underline=True)}")
|
|
124
|
+
lines.append("")
|
|
125
|
+
|
|
126
|
+
return "\n".join(lines)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_rewards_referenced_in_jsonl(path: Path) -> set[str]:
|
|
130
|
+
"""
|
|
131
|
+
Extract all reward names referenced in train.jsonl.
|
|
132
|
+
|
|
133
|
+
Scans the entire file to ensure all reward references are captured.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Set of reward names referenced in the 'rewards' field across all samples.
|
|
137
|
+
"""
|
|
138
|
+
rewards = set()
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
with open(path, encoding="utf-8") as f:
|
|
142
|
+
for line in f:
|
|
143
|
+
stripped = line.strip()
|
|
144
|
+
if not stripped:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
record = json.loads(stripped)
|
|
149
|
+
if isinstance(record, dict) and "rewards" in record:
|
|
150
|
+
record_rewards = record["rewards"]
|
|
151
|
+
if isinstance(record_rewards, list):
|
|
152
|
+
rewards.update(record_rewards)
|
|
153
|
+
except json.JSONDecodeError:
|
|
154
|
+
continue
|
|
155
|
+
except Exception:
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
return rewards
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def validate_reward_references(train_jsonl_path: Path, rewards_py_path: Path) -> list[str]:
|
|
162
|
+
"""
|
|
163
|
+
Validate that all reward names in train.jsonl exist in rewards.py.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
List of error messages (empty if valid).
|
|
167
|
+
"""
|
|
168
|
+
errors = []
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
from rnow.core.reward import get_reward_names_from_file
|
|
172
|
+
except ImportError:
|
|
173
|
+
return [] # Skip validation if module not available
|
|
174
|
+
|
|
175
|
+
# Get rewards defined in rewards.py
|
|
176
|
+
defined_rewards = get_reward_names_from_file(rewards_py_path)
|
|
177
|
+
|
|
178
|
+
# Get rewards referenced in train.jsonl
|
|
179
|
+
referenced_rewards = get_rewards_referenced_in_jsonl(train_jsonl_path)
|
|
180
|
+
|
|
181
|
+
# Find missing rewards
|
|
182
|
+
missing_rewards = referenced_rewards - defined_rewards
|
|
183
|
+
|
|
184
|
+
if missing_rewards:
|
|
185
|
+
for reward_name in sorted(missing_rewards):
|
|
186
|
+
errors.append(
|
|
187
|
+
f"Reward '{reward_name}' is referenced in train.jsonl but not defined in rewards.py"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if defined_rewards:
|
|
191
|
+
errors.append(
|
|
192
|
+
f" Available rewards in rewards.py: {', '.join(sorted(defined_rewards))}"
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
errors.append(" No @reward functions found in rewards.py")
|
|
196
|
+
|
|
197
|
+
return errors
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# Import token counting utilities from dedicated module
|
|
201
|
+
from rnow.cli.token_count import (
|
|
202
|
+
get_max_prompt_tokens,
|
|
203
|
+
get_tokenizer_for_model,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def validate_max_tokens_for_context(
|
|
208
|
+
max_tokens: int, max_prompt_tokens: int, context_window: int = models.MAX_CONTEXT_WINDOW
|
|
209
|
+
) -> tuple[str | None, int]:
|
|
210
|
+
"""
|
|
211
|
+
Validate that max_tokens + max_prompt_tokens fits within context window.
|
|
212
|
+
Returns (error_message, recommended_max_tokens). Error is None if valid.
|
|
213
|
+
"""
|
|
214
|
+
total_required = max_tokens + max_prompt_tokens
|
|
215
|
+
available = context_window - max_prompt_tokens
|
|
216
|
+
if total_required > context_window:
|
|
217
|
+
return (
|
|
218
|
+
f"max_tokens ({max_tokens:,}) + prompt ({max_prompt_tokens:,}) = {total_required:,} > context window ({context_window:,})",
|
|
219
|
+
available,
|
|
220
|
+
)
|
|
221
|
+
return None, available
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def get_tools_from_env_py(env_path: Path) -> list[dict]:
|
|
225
|
+
"""
|
|
226
|
+
Extract tool definitions from env.py as structured data.
|
|
227
|
+
Returns list of tool dicts with name, description, and schema.
|
|
228
|
+
"""
|
|
229
|
+
import ast
|
|
230
|
+
|
|
231
|
+
if not env_path.exists():
|
|
232
|
+
return []
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
source = env_path.read_text()
|
|
236
|
+
tree = ast.parse(source)
|
|
237
|
+
except (SyntaxError, OSError):
|
|
238
|
+
return []
|
|
239
|
+
|
|
240
|
+
tools = []
|
|
241
|
+
for node in ast.walk(tree):
|
|
242
|
+
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
|
|
243
|
+
# Check if function has @tool decorator
|
|
244
|
+
is_tool = any(
|
|
245
|
+
(isinstance(d, ast.Name) and d.id == "tool")
|
|
246
|
+
or (
|
|
247
|
+
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "tool"
|
|
248
|
+
)
|
|
249
|
+
for d in node.decorator_list
|
|
250
|
+
)
|
|
251
|
+
if not is_tool:
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
# Extract tool info
|
|
255
|
+
tool = {
|
|
256
|
+
"name": node.name,
|
|
257
|
+
"description": ast.get_docstring(node) or "",
|
|
258
|
+
"schema": {"type": "object", "properties": {}, "required": []},
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
# Add parameters to schema
|
|
262
|
+
for arg in node.args.args + node.args.kwonlyargs:
|
|
263
|
+
if arg.arg not in ("self", "cls"):
|
|
264
|
+
# Try to get type annotation
|
|
265
|
+
param_type = "string" # Default
|
|
266
|
+
if arg.annotation and isinstance(arg.annotation, ast.Name):
|
|
267
|
+
type_name = arg.annotation.id.lower()
|
|
268
|
+
if type_name in ("int", "integer"):
|
|
269
|
+
param_type = "integer"
|
|
270
|
+
elif type_name in ("float", "number"):
|
|
271
|
+
param_type = "number"
|
|
272
|
+
elif type_name in ("bool", "boolean"):
|
|
273
|
+
param_type = "boolean"
|
|
274
|
+
elif type_name in ("list", "array"):
|
|
275
|
+
param_type = "array"
|
|
276
|
+
elif type_name in ("dict", "object"):
|
|
277
|
+
param_type = "object"
|
|
278
|
+
|
|
279
|
+
tool["schema"]["properties"][arg.arg] = {"type": param_type}
|
|
280
|
+
|
|
281
|
+
# Check if it's a required arg (no default)
|
|
282
|
+
if arg in node.args.args:
|
|
283
|
+
idx = node.args.args.index(arg)
|
|
284
|
+
num_defaults = len(node.args.defaults)
|
|
285
|
+
num_args = len(node.args.args)
|
|
286
|
+
if idx < num_args - num_defaults:
|
|
287
|
+
tool["schema"]["required"].append(arg.arg)
|
|
288
|
+
|
|
289
|
+
tools.append(tool)
|
|
290
|
+
|
|
291
|
+
return tools
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def fetch_mcp_tool_schemas(
|
|
295
|
+
mcp_urls: list[str] | str | None, timeout: float = 15.0
|
|
296
|
+
) -> tuple[list[dict], str | None]:
|
|
297
|
+
"""
|
|
298
|
+
Fetch tool schemas from MCP servers.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
mcp_urls: MCP server URL(s)
|
|
302
|
+
timeout: Connection timeout in seconds
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Tuple of (list of tool dicts, error_message or None)
|
|
306
|
+
Returns ([], error_message) if fetch fails.
|
|
307
|
+
"""
|
|
308
|
+
if not mcp_urls:
|
|
309
|
+
return [], None
|
|
310
|
+
|
|
311
|
+
urls = mcp_urls if isinstance(mcp_urls, list) else [mcp_urls]
|
|
312
|
+
all_tools = []
|
|
313
|
+
error_msg = None
|
|
314
|
+
|
|
315
|
+
try:
|
|
316
|
+
from fastmcp import Client
|
|
317
|
+
except ImportError:
|
|
318
|
+
return [], "fastmcp not installed"
|
|
319
|
+
|
|
320
|
+
import asyncio
|
|
321
|
+
|
|
322
|
+
async def fetch_tools():
|
|
323
|
+
nonlocal all_tools, error_msg
|
|
324
|
+
|
|
325
|
+
# Build FastMCP config
|
|
326
|
+
fastmcp_config = {"mcpServers": {}}
|
|
327
|
+
for i, url in enumerate(urls):
|
|
328
|
+
server_name = f"mcp_{i}"
|
|
329
|
+
fastmcp_config["mcpServers"][server_name] = {"url": url}
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
client = Client(fastmcp_config)
|
|
333
|
+
async with client:
|
|
334
|
+
tools = await client.list_tools()
|
|
335
|
+
|
|
336
|
+
for tool in tools:
|
|
337
|
+
name = tool.name
|
|
338
|
+
description = getattr(tool, "description", "") or ""
|
|
339
|
+
|
|
340
|
+
# Get input schema
|
|
341
|
+
input_schema = {}
|
|
342
|
+
if hasattr(tool, "inputSchema"):
|
|
343
|
+
input_schema = tool.inputSchema
|
|
344
|
+
elif hasattr(tool, "input_schema"):
|
|
345
|
+
input_schema = tool.input_schema
|
|
346
|
+
|
|
347
|
+
all_tools.append(
|
|
348
|
+
{
|
|
349
|
+
"name": name,
|
|
350
|
+
"description": description,
|
|
351
|
+
"schema": input_schema,
|
|
352
|
+
}
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
except Exception as e:
|
|
356
|
+
error_msg = str(e)
|
|
357
|
+
return False
|
|
358
|
+
return True
|
|
359
|
+
|
|
360
|
+
# Run async fetch with timeout
|
|
361
|
+
try:
|
|
362
|
+
loop = asyncio.new_event_loop()
|
|
363
|
+
asyncio.set_event_loop(loop)
|
|
364
|
+
try:
|
|
365
|
+
success = loop.run_until_complete(asyncio.wait_for(fetch_tools(), timeout=timeout))
|
|
366
|
+
finally:
|
|
367
|
+
loop.close()
|
|
368
|
+
|
|
369
|
+
if not success:
|
|
370
|
+
return [], error_msg or "connection failed"
|
|
371
|
+
|
|
372
|
+
except asyncio.TimeoutError:
|
|
373
|
+
return [], f"timeout after {timeout}s"
|
|
374
|
+
except Exception as e:
|
|
375
|
+
return [], str(e)
|
|
376
|
+
|
|
377
|
+
return all_tools, None
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def validate_train_jsonl(
|
|
381
|
+
path: Path, dataset_type: models.DatasetType, sample_size: int = 50
|
|
382
|
+
) -> list[str]:
|
|
383
|
+
"""
|
|
384
|
+
Validate train.jsonl format by sampling first N lines.
|
|
385
|
+
Returns a list of error messages (empty if valid).
|
|
386
|
+
"""
|
|
387
|
+
errors = []
|
|
388
|
+
|
|
389
|
+
try:
|
|
390
|
+
with open(path, encoding="utf-8") as f:
|
|
391
|
+
lines_checked = 0
|
|
392
|
+
for line_num, line in enumerate(f, start=1):
|
|
393
|
+
# Skip empty lines
|
|
394
|
+
stripped = line.strip()
|
|
395
|
+
if not stripped:
|
|
396
|
+
continue
|
|
397
|
+
|
|
398
|
+
# Try to parse as JSON
|
|
399
|
+
try:
|
|
400
|
+
record = json.loads(stripped)
|
|
401
|
+
except json.JSONDecodeError as e:
|
|
402
|
+
errors.append(f"Line {line_num}: Invalid JSON - {e.msg}")
|
|
403
|
+
if len(errors) >= 5:
|
|
404
|
+
errors.append("... (stopping after 5 errors)")
|
|
405
|
+
return errors
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
# Check it's a dict
|
|
409
|
+
if not isinstance(record, dict):
|
|
410
|
+
errors.append(
|
|
411
|
+
f"Line {line_num}: Expected JSON object, got {type(record).__name__}"
|
|
412
|
+
)
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
# Check for required 'messages' field
|
|
416
|
+
if "messages" not in record:
|
|
417
|
+
errors.append(f"Line {line_num}: Missing required 'messages' field")
|
|
418
|
+
continue
|
|
419
|
+
|
|
420
|
+
messages = record["messages"]
|
|
421
|
+
if not isinstance(messages, list):
|
|
422
|
+
errors.append(f"Line {line_num}: 'messages' must be a list")
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
if len(messages) == 0:
|
|
426
|
+
errors.append(f"Line {line_num}: 'messages' list is empty")
|
|
427
|
+
continue
|
|
428
|
+
|
|
429
|
+
# Check each message has role and content
|
|
430
|
+
for msg_idx, msg in enumerate(messages):
|
|
431
|
+
if not isinstance(msg, dict):
|
|
432
|
+
errors.append(f"Line {line_num}: Message {msg_idx + 1} must be an object")
|
|
433
|
+
break
|
|
434
|
+
if "role" not in msg:
|
|
435
|
+
errors.append(f"Line {line_num}: Message {msg_idx + 1} missing 'role'")
|
|
436
|
+
break
|
|
437
|
+
if "content" not in msg:
|
|
438
|
+
errors.append(f"Line {line_num}: Message {msg_idx + 1} missing 'content'")
|
|
439
|
+
break
|
|
440
|
+
if msg["role"] not in ("system", "user", "assistant"):
|
|
441
|
+
errors.append(
|
|
442
|
+
f"Line {line_num}: Message {msg_idx + 1} has invalid role '{msg['role']}' (expected: system, user, assistant)"
|
|
443
|
+
)
|
|
444
|
+
break
|
|
445
|
+
|
|
446
|
+
# For RL, check for rewards field
|
|
447
|
+
if dataset_type == models.DatasetType.RL and "rewards" not in record:
|
|
448
|
+
errors.append(
|
|
449
|
+
f"Line {line_num}: Missing required 'rewards' field for RL dataset"
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Validate optional 'tools' field if present
|
|
453
|
+
if "tools" in record:
|
|
454
|
+
tools = record["tools"]
|
|
455
|
+
if not isinstance(tools, list):
|
|
456
|
+
errors.append(f"Line {line_num}: 'tools' must be a list of tool names")
|
|
457
|
+
elif not all(isinstance(t, str) for t in tools):
|
|
458
|
+
errors.append(f"Line {line_num}: 'tools' must contain only strings")
|
|
459
|
+
|
|
460
|
+
lines_checked += 1
|
|
461
|
+
if lines_checked >= sample_size:
|
|
462
|
+
break
|
|
463
|
+
|
|
464
|
+
# Check if file was effectively empty (only whitespace)
|
|
465
|
+
if lines_checked == 0:
|
|
466
|
+
errors.append("File contains no valid JSON lines")
|
|
467
|
+
|
|
468
|
+
except Exception as e:
|
|
469
|
+
errors.append(f"Failed to read file: {e}")
|
|
470
|
+
|
|
471
|
+
return errors
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
from functools import lru_cache
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
@lru_cache(maxsize=256)
|
|
478
|
+
def _pypi_requires_python(project: str, version: str | None = None) -> str | None:
|
|
479
|
+
"""
|
|
480
|
+
Return the `Requires-Python` specifier for a project (and optionally a specific version),
|
|
481
|
+
or None if it can't be determined.
|
|
482
|
+
"""
|
|
483
|
+
import urllib.error
|
|
484
|
+
import urllib.request
|
|
485
|
+
|
|
486
|
+
try:
|
|
487
|
+
if version:
|
|
488
|
+
url = f"https://pypi.org/pypi/{project}/{version}/json"
|
|
489
|
+
else:
|
|
490
|
+
url = f"https://pypi.org/pypi/{project}/json"
|
|
491
|
+
|
|
492
|
+
with urllib.request.urlopen(url, timeout=5) as response:
|
|
493
|
+
data = json.loads(response.read().decode())
|
|
494
|
+
|
|
495
|
+
info = data.get("info", {})
|
|
496
|
+
return info.get("requires_python")
|
|
497
|
+
except (urllib.error.URLError, TimeoutError, ValueError, KeyError):
|
|
498
|
+
return None
|
|
499
|
+
except Exception:
|
|
500
|
+
return None
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def check_pypi_python_compatibility(req, target_python: str = "3.11") -> str | None:
|
|
504
|
+
"""
|
|
505
|
+
Check if a package on PyPI supports the target Python version.
|
|
506
|
+
Returns an error string if clearly incompatible, otherwise None.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
req: A packaging.requirements.Requirement object
|
|
510
|
+
target_python: Target Python version string (e.g., "3.11")
|
|
511
|
+
"""
|
|
512
|
+
try:
|
|
513
|
+
from packaging.specifiers import SpecifierSet
|
|
514
|
+
from packaging.version import Version
|
|
515
|
+
except ImportError:
|
|
516
|
+
return None
|
|
517
|
+
|
|
518
|
+
target_version = Version(target_python)
|
|
519
|
+
|
|
520
|
+
# Try to respect pinned version if present (foo==1.2.3)
|
|
521
|
+
pinned_version = None
|
|
522
|
+
for spec in req.specifier:
|
|
523
|
+
if spec.operator == "==":
|
|
524
|
+
pinned_version = spec.version
|
|
525
|
+
break
|
|
526
|
+
|
|
527
|
+
requires_python = _pypi_requires_python(req.name, pinned_version)
|
|
528
|
+
if not requires_python:
|
|
529
|
+
# Unknown compatibility → don't fail hard
|
|
530
|
+
return None
|
|
531
|
+
|
|
532
|
+
try:
|
|
533
|
+
specifier = SpecifierSet(requires_python)
|
|
534
|
+
if target_version not in specifier:
|
|
535
|
+
if pinned_version:
|
|
536
|
+
return (
|
|
537
|
+
f"Package '{req.name}=={pinned_version}' requires Python "
|
|
538
|
+
f"{requires_python}, which does not include Python {target_python}"
|
|
539
|
+
)
|
|
540
|
+
else:
|
|
541
|
+
return (
|
|
542
|
+
f"Package '{req.name}' requires Python "
|
|
543
|
+
f"{requires_python}, which does not include Python {target_python}"
|
|
544
|
+
)
|
|
545
|
+
except Exception:
|
|
546
|
+
return None
|
|
547
|
+
|
|
548
|
+
return None
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def validate_requirements_txt(path: Path, target_python: str = "3.11") -> list[str]:
|
|
552
|
+
"""
|
|
553
|
+
Validate requirements.txt for format + Python compatibility.
|
|
554
|
+
|
|
555
|
+
Checks:
|
|
556
|
+
1. File is valid requirements.txt format (not TOML/other format)
|
|
557
|
+
2. Each requirement line is parseable
|
|
558
|
+
3. Environment markers that exclude target Python
|
|
559
|
+
4. PyPI Requires-Python metadata for each package
|
|
560
|
+
|
|
561
|
+
Returns a list of error/warning messages (empty if valid).
|
|
562
|
+
"""
|
|
563
|
+
errors = []
|
|
564
|
+
|
|
565
|
+
try:
|
|
566
|
+
content = path.read_text(encoding="utf-8")
|
|
567
|
+
except Exception as e:
|
|
568
|
+
errors.append(f"Failed to read requirements.txt: {e}")
|
|
569
|
+
return errors
|
|
570
|
+
|
|
571
|
+
# Check if it's accidentally a TOML file (common mistake)
|
|
572
|
+
stripped = content.strip()
|
|
573
|
+
if stripped.startswith("[") and "]" in stripped.split("\n")[0]:
|
|
574
|
+
errors.append("requirements.txt appears to be in TOML format, not pip requirements format")
|
|
575
|
+
errors.append(
|
|
576
|
+
"Hint: requirements.txt should have one package per line, e.g., 'requests>=2.28.0'"
|
|
577
|
+
)
|
|
578
|
+
return errors
|
|
579
|
+
|
|
580
|
+
# Try to parse requirements using packaging library
|
|
581
|
+
try:
|
|
582
|
+
from packaging.markers import default_environment
|
|
583
|
+
from packaging.requirements import InvalidRequirement, Requirement
|
|
584
|
+
from packaging.version import Version
|
|
585
|
+
except ImportError:
|
|
586
|
+
# packaging not available, skip detailed validation
|
|
587
|
+
return []
|
|
588
|
+
|
|
589
|
+
try:
|
|
590
|
+
target_version = Version(target_python)
|
|
591
|
+
except Exception:
|
|
592
|
+
errors.append(f"Invalid target Python version: {target_python}")
|
|
593
|
+
return errors
|
|
594
|
+
|
|
595
|
+
# Environment for evaluating markers like `python_version < "3.11"`
|
|
596
|
+
env = default_environment()
|
|
597
|
+
env["python_version"] = f"{target_version.major}.{target_version.minor}"
|
|
598
|
+
env["python_full_version"] = str(target_version)
|
|
599
|
+
|
|
600
|
+
seen_projects: set[str] = set()
|
|
601
|
+
|
|
602
|
+
for lineno, raw_line in enumerate(content.splitlines(), start=1):
|
|
603
|
+
line = raw_line.strip()
|
|
604
|
+
|
|
605
|
+
# Skip empty lines and comments
|
|
606
|
+
if not line or line.startswith("#"):
|
|
607
|
+
continue
|
|
608
|
+
|
|
609
|
+
# Skip options like -e, --index-url, etc.
|
|
610
|
+
if line.startswith("-"):
|
|
611
|
+
continue
|
|
612
|
+
|
|
613
|
+
# Try to parse as a requirement
|
|
614
|
+
try:
|
|
615
|
+
req = Requirement(line)
|
|
616
|
+
except InvalidRequirement as e:
|
|
617
|
+
errors.append(f"Line {lineno}: Invalid requirement '{line}' - {e}")
|
|
618
|
+
continue
|
|
619
|
+
|
|
620
|
+
# 1) Marker-based incompatibility (e.g., `foo; python_version < "3.11"`)
|
|
621
|
+
if req.marker is not None and not req.marker.evaluate(env):
|
|
622
|
+
# This requirement is explicitly excluded for Python 3.11
|
|
623
|
+
errors.append(
|
|
624
|
+
f"Line {lineno}: Requirement '{line}' is excluded for Python "
|
|
625
|
+
f"{target_python} due to marker '{req.marker}'"
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
# 2) PyPI Requires-Python compatibility check (once per project)
|
|
629
|
+
project_key = req.name.lower()
|
|
630
|
+
if project_key not in seen_projects:
|
|
631
|
+
seen_projects.add(project_key)
|
|
632
|
+
compat_msg = check_pypi_python_compatibility(req, target_python)
|
|
633
|
+
if compat_msg:
|
|
634
|
+
errors.append(f"Line {lineno}: {compat_msg}")
|
|
635
|
+
|
|
636
|
+
return errors
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def get_thinking_mode_display(config: models.ProjectConfig) -> str:
|
|
640
|
+
"""Get a human-readable display string for the thinking mode."""
|
|
641
|
+
thinking_mode = config.rollout.thinking_mode if config.rollout else None
|
|
642
|
+
model = config.model.path
|
|
643
|
+
|
|
644
|
+
# GPT-OSS: Reasoning models with levels
|
|
645
|
+
if model in ["openai/gpt-oss-120b", "openai/gpt-oss-20b"]:
|
|
646
|
+
mode_map = {
|
|
647
|
+
"disabled": "Reasoning Off",
|
|
648
|
+
"easy": "Reasoning Low",
|
|
649
|
+
"hard": "Reasoning High",
|
|
650
|
+
}
|
|
651
|
+
return mode_map.get(thinking_mode, "Reasoning Medium")
|
|
652
|
+
|
|
653
|
+
# Hybrid models: Qwen3, DeepSeek
|
|
654
|
+
if model in [
|
|
655
|
+
"Qwen/Qwen3-30B-A3B",
|
|
656
|
+
"Qwen/Qwen3-32B",
|
|
657
|
+
"Qwen/Qwen3-8B",
|
|
658
|
+
"Qwen/Qwen3-30B-A3B-Base",
|
|
659
|
+
"Qwen/Qwen3-8B-Base",
|
|
660
|
+
"deepseek-ai/DeepSeek-V3.1",
|
|
661
|
+
"deepseek-ai/DeepSeek-V3.1-Base",
|
|
662
|
+
]:
|
|
663
|
+
if thinking_mode == "disabled":
|
|
664
|
+
return "Reasoning Off"
|
|
665
|
+
else:
|
|
666
|
+
return "Reasoning On"
|
|
667
|
+
|
|
668
|
+
# Instruct models: no thinking support
|
|
669
|
+
if model in [
|
|
670
|
+
"Qwen/Qwen3-235B-A22B-Instruct-2507",
|
|
671
|
+
"Qwen/Qwen3-30B-A3B-Instruct-2507",
|
|
672
|
+
"Qwen/Qwen3-4B-Instruct-2507",
|
|
673
|
+
"meta-llama/Llama-3.3-70B-Instruct",
|
|
674
|
+
"meta-llama/Llama-3.1-8B-Instruct",
|
|
675
|
+
]:
|
|
676
|
+
return "Reasoning Off"
|
|
677
|
+
|
|
678
|
+
# Base Llama models
|
|
679
|
+
if model in [
|
|
680
|
+
"meta-llama/Llama-3.1-70B",
|
|
681
|
+
"meta-llama/Llama-3.1-8B",
|
|
682
|
+
"meta-llama/Llama-3.2-3B",
|
|
683
|
+
"meta-llama/Llama-3.2-1B",
|
|
684
|
+
]:
|
|
685
|
+
return "Reasoning Off"
|
|
686
|
+
|
|
687
|
+
return "Reasoning Off"
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
# Simple session for API calls
|
|
691
|
+
session = requests.Session()
|
|
692
|
+
session.headers["User-Agent"] = "ReinforceNow-CLI/1.0"
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def api_request(
|
|
696
|
+
method: str, endpoint: str, base_url: str = None, authenticated: bool = True, **kwargs
|
|
697
|
+
):
|
|
698
|
+
"""Make API request."""
|
|
699
|
+
if authenticated:
|
|
700
|
+
require_auth()
|
|
701
|
+
headers = kwargs.pop("headers", {})
|
|
702
|
+
headers.update(auth.get_auth_headers())
|
|
703
|
+
kwargs["headers"] = headers
|
|
704
|
+
|
|
705
|
+
url = f"{base_url or 'https://www.reinforcenow.ai/api'}{endpoint}"
|
|
706
|
+
return getattr(session, method)(url, **kwargs)
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
# ========== Auth Commands ==========
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
@click.command()
|
|
713
|
+
@click.option("--force", "-f", is_flag=True, help="Force new login even if already authenticated")
|
|
714
|
+
@click.pass_context
|
|
715
|
+
def login(ctx, force: bool):
|
|
716
|
+
"""Login to ReinforceNow platform.
|
|
717
|
+
|
|
718
|
+
Uses OAuth device flow for authentication.
|
|
719
|
+
"""
|
|
720
|
+
base_url = ctx.obj.get("api_url", "https://www.reinforcenow.ai/api")
|
|
721
|
+
|
|
722
|
+
if not force and auth.is_authenticated():
|
|
723
|
+
click.echo(click.style("✓ Already authenticated", fg="green"))
|
|
724
|
+
click.echo("Use --force to re-authenticate")
|
|
725
|
+
return
|
|
726
|
+
|
|
727
|
+
# Get device code
|
|
728
|
+
try:
|
|
729
|
+
response = api_request(
|
|
730
|
+
"post", "/auth/device/code", base_url, json={"client_id": "cli"}, authenticated=False
|
|
731
|
+
)
|
|
732
|
+
response.raise_for_status()
|
|
733
|
+
device = models.DeviceCode(**response.json())
|
|
734
|
+
except ValidationError as e:
|
|
735
|
+
raise click.ClickException(f"Invalid response from server: {e}")
|
|
736
|
+
except requests.RequestException as e:
|
|
737
|
+
raise click.ClickException(f"Failed to initiate login: {e}")
|
|
738
|
+
|
|
739
|
+
# Construct the full URL with user_code parameter
|
|
740
|
+
verification_url = f"{device.verification_uri}?user_code={device.user_code}"
|
|
741
|
+
|
|
742
|
+
click.echo(f"\n{click.style('Opening browser:', fg=TEAL_RGB)} {verification_url}")
|
|
743
|
+
click.echo(
|
|
744
|
+
f"{click.style('Enter code:', fg=TEAL_RGB)} {click.style(device.user_code, bold=True)}\n"
|
|
745
|
+
)
|
|
746
|
+
webbrowser.open(verification_url)
|
|
747
|
+
|
|
748
|
+
# Poll for token with spinner
|
|
749
|
+
spinner = Spinner("Waiting for authentication...")
|
|
750
|
+
spinner.start()
|
|
751
|
+
|
|
752
|
+
start = time.time()
|
|
753
|
+
try:
|
|
754
|
+
while time.time() - start < device.expires_in:
|
|
755
|
+
time.sleep(device.interval)
|
|
756
|
+
|
|
757
|
+
try:
|
|
758
|
+
resp = api_request(
|
|
759
|
+
"post",
|
|
760
|
+
"/auth/device/token",
|
|
761
|
+
base_url,
|
|
762
|
+
json={"device_code": device.device_code},
|
|
763
|
+
authenticated=False,
|
|
764
|
+
)
|
|
765
|
+
data = resp.json()
|
|
766
|
+
except requests.RequestException as e:
|
|
767
|
+
spinner.stop()
|
|
768
|
+
raise click.ClickException(f"Network error: {e}")
|
|
769
|
+
|
|
770
|
+
if resp.status_code == 200:
|
|
771
|
+
try:
|
|
772
|
+
token = models.Token(**data)
|
|
773
|
+
except ValidationError as e:
|
|
774
|
+
spinner.stop()
|
|
775
|
+
raise click.ClickException(f"Invalid token response: {e}")
|
|
776
|
+
|
|
777
|
+
# Save credentials
|
|
778
|
+
auth.DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
779
|
+
with open(auth.CREDS_FILE, "w") as f:
|
|
780
|
+
json.dump(
|
|
781
|
+
{"api_key": token.access_token, "organization_id": token.organization_id}, f
|
|
782
|
+
)
|
|
783
|
+
auth.CREDS_FILE.chmod(0o600)
|
|
784
|
+
|
|
785
|
+
spinner.stop()
|
|
786
|
+
click.echo(click.style("✓ Login successful!", fg="green", bold=True))
|
|
787
|
+
return
|
|
788
|
+
|
|
789
|
+
try:
|
|
790
|
+
error = models.TokenError(**data)
|
|
791
|
+
except ValidationError:
|
|
792
|
+
spinner.stop()
|
|
793
|
+
raise click.ClickException(f"Unexpected response: {data}")
|
|
794
|
+
|
|
795
|
+
if error.error != "authorization_pending":
|
|
796
|
+
spinner.stop()
|
|
797
|
+
raise click.ClickException(f"Authentication failed: {error.error}")
|
|
798
|
+
finally:
|
|
799
|
+
spinner.stop()
|
|
800
|
+
|
|
801
|
+
raise click.ClickException("Authentication timed out")
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
@click.command()
|
|
805
|
+
def logout():
|
|
806
|
+
"""Logout from ReinforceNow."""
|
|
807
|
+
auth.logout()
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
@click.command()
|
|
811
|
+
@click.pass_context
|
|
812
|
+
def status(ctx):
|
|
813
|
+
"""Check authentication status and running jobs."""
|
|
814
|
+
if not auth.is_authenticated():
|
|
815
|
+
click.echo(click.style("✗ Not authenticated", fg="red"))
|
|
816
|
+
raise click.ClickException("Run 'rnow login' to authenticate")
|
|
817
|
+
|
|
818
|
+
click.echo(click.style("✓ Authenticated", fg=TEAL_RGB))
|
|
819
|
+
org_id = get_active_organization()
|
|
820
|
+
if org_id:
|
|
821
|
+
click.echo(f"Organization: {org_id}")
|
|
822
|
+
|
|
823
|
+
base_url = ctx.obj.get("api_url", "https://www.reinforcenow.ai/api")
|
|
824
|
+
click.echo()
|
|
825
|
+
click.echo(click.style("Running jobs:", bold=True))
|
|
826
|
+
try:
|
|
827
|
+
response = api_request("get", "/runs?status=running", base_url)
|
|
828
|
+
response.raise_for_status()
|
|
829
|
+
data = response.json()
|
|
830
|
+
running_runs = data.get("data", [])
|
|
831
|
+
if running_runs:
|
|
832
|
+
for run in running_runs:
|
|
833
|
+
run_id = run.get("id", "unknown")
|
|
834
|
+
project = run.get("project", {})
|
|
835
|
+
project_name = project.get("name", "Unknown project")
|
|
836
|
+
phase = run.get("phase", "running")
|
|
837
|
+
click.echo(f" • {click.style(run_id, fg=TEAL_RGB)} - {project_name} ({phase})")
|
|
838
|
+
else:
|
|
839
|
+
click.echo(" No running jobs")
|
|
840
|
+
except requests.RequestException as e:
|
|
841
|
+
click.echo(click.style(f" Error fetching runs: {e}", fg="red"))
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
# ========== Org Commands ==========
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def _interactive_org_selector(organizations: list, active_org_id: str | None) -> str | None:
|
|
848
|
+
"""Interactive organization selector using arrow keys."""
|
|
849
|
+
import sys
|
|
850
|
+
|
|
851
|
+
# Find initial selection index
|
|
852
|
+
selected_idx = 0
|
|
853
|
+
for i, org in enumerate(organizations):
|
|
854
|
+
if org.id == active_org_id:
|
|
855
|
+
selected_idx = i
|
|
856
|
+
break
|
|
857
|
+
|
|
858
|
+
def render():
|
|
859
|
+
lines = []
|
|
860
|
+
lines.append(click.style("Select organization:", bold=True))
|
|
861
|
+
lines.append("")
|
|
862
|
+
for i, org in enumerate(organizations):
|
|
863
|
+
is_selected = i == selected_idx
|
|
864
|
+
is_active = org.id == active_org_id
|
|
865
|
+
marker = "✓ " if is_active else " "
|
|
866
|
+
role = click.style(f" ({org.role.value})", dim=True)
|
|
867
|
+
|
|
868
|
+
if is_selected:
|
|
869
|
+
# Highlight selected row with teal
|
|
870
|
+
prefix = click.style("› ", fg=TEAL_RGB, bold=True)
|
|
871
|
+
name = click.style(f"{marker}{org.name}", fg=TEAL_RGB, bold=True)
|
|
872
|
+
lines.append(f"{prefix}{name}{role}")
|
|
873
|
+
else:
|
|
874
|
+
prefix = " "
|
|
875
|
+
if is_active:
|
|
876
|
+
name = click.style(f"{marker}{org.name}", fg=TEAL_RGB)
|
|
877
|
+
else:
|
|
878
|
+
name = f"{marker}{org.name}"
|
|
879
|
+
lines.append(f"{prefix}{name}{role}")
|
|
880
|
+
lines.append("")
|
|
881
|
+
lines.append(click.style("↑/↓ to move, Enter to select, q to cancel", dim=True))
|
|
882
|
+
return lines
|
|
883
|
+
|
|
884
|
+
try:
|
|
885
|
+
import termios
|
|
886
|
+
import tty
|
|
887
|
+
|
|
888
|
+
fd = sys.stdin.fileno()
|
|
889
|
+
old_settings = termios.tcgetattr(fd)
|
|
890
|
+
|
|
891
|
+
# Hide cursor and render initial output
|
|
892
|
+
sys.stdout.write("\033[?25l")
|
|
893
|
+
sys.stdout.flush()
|
|
894
|
+
lines = render()
|
|
895
|
+
line_count = len(lines)
|
|
896
|
+
click.echo("\n".join(lines))
|
|
897
|
+
|
|
898
|
+
try:
|
|
899
|
+
tty.setraw(fd)
|
|
900
|
+
while True:
|
|
901
|
+
ch = sys.stdin.read(1)
|
|
902
|
+
if ch == "\x1b": # Escape sequence
|
|
903
|
+
ch2 = sys.stdin.read(1)
|
|
904
|
+
if ch2 == "[":
|
|
905
|
+
ch3 = sys.stdin.read(1)
|
|
906
|
+
if ch3 == "A": # Up arrow
|
|
907
|
+
selected_idx = (selected_idx - 1) % len(organizations)
|
|
908
|
+
elif ch3 == "B": # Down arrow
|
|
909
|
+
selected_idx = (selected_idx + 1) % len(organizations)
|
|
910
|
+
elif ch2 == "\x1b" or ch2 == "": # Double escape or timeout
|
|
911
|
+
# Restore terminal and clean up
|
|
912
|
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
913
|
+
sys.stdout.write("\033[?25h\n")
|
|
914
|
+
sys.stdout.flush()
|
|
915
|
+
return None
|
|
916
|
+
elif ch == "k": # Vim up
|
|
917
|
+
selected_idx = (selected_idx - 1) % len(organizations)
|
|
918
|
+
elif ch == "j": # Vim down
|
|
919
|
+
selected_idx = (selected_idx + 1) % len(organizations)
|
|
920
|
+
elif ch == "\r" or ch == "\n": # Enter
|
|
921
|
+
# Restore terminal before returning
|
|
922
|
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
923
|
+
sys.stdout.write("\033[?25h\n")
|
|
924
|
+
sys.stdout.flush()
|
|
925
|
+
return organizations[selected_idx].id
|
|
926
|
+
elif ch == "q" or ch == "\x03": # q or Ctrl+C
|
|
927
|
+
# Restore terminal and clean up
|
|
928
|
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
929
|
+
sys.stdout.write("\033[?25h\n")
|
|
930
|
+
sys.stdout.flush()
|
|
931
|
+
return None
|
|
932
|
+
|
|
933
|
+
# Move cursor up to beginning of our output and clear
|
|
934
|
+
# Need to exit raw mode temporarily for proper output
|
|
935
|
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
936
|
+
sys.stdout.write(f"\033[{line_count}A") # Move up
|
|
937
|
+
sys.stdout.write("\033[J") # Clear from cursor to end of screen
|
|
938
|
+
lines = render()
|
|
939
|
+
sys.stdout.write("\n".join(lines) + "\n")
|
|
940
|
+
sys.stdout.flush()
|
|
941
|
+
tty.setraw(fd)
|
|
942
|
+
finally:
|
|
943
|
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
944
|
+
sys.stdout.write("\033[?25h") # Show cursor
|
|
945
|
+
sys.stdout.flush()
|
|
946
|
+
except (ImportError, termios.error):
|
|
947
|
+
# Fallback for non-Unix systems
|
|
948
|
+
sys.stdout.write("\033[?25h") # Show cursor
|
|
949
|
+
sys.stdout.flush()
|
|
950
|
+
return None
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
@click.command()
|
|
954
|
+
@click.argument("org_id", required=False)
|
|
955
|
+
@click.pass_context
|
|
956
|
+
def orgs(ctx, org_id: str | None):
|
|
957
|
+
"""Select active organization interactively or by ID.
|
|
958
|
+
|
|
959
|
+
Without arguments, shows interactive selector.
|
|
960
|
+
With ORG_ID, sets that organization as active directly.
|
|
961
|
+
"""
|
|
962
|
+
require_auth()
|
|
963
|
+
base_url = ctx.obj.get("api_url", "https://www.reinforcenow.ai/api")
|
|
964
|
+
|
|
965
|
+
# Fetch organizations first (needed for both direct ID and interactive)
|
|
966
|
+
try:
|
|
967
|
+
response = api_request("get", "/auth/organizations", base_url)
|
|
968
|
+
response.raise_for_status()
|
|
969
|
+
orgs_data = models.Organizations(**response.json())
|
|
970
|
+
except ValidationError as e:
|
|
971
|
+
raise click.ClickException(f"Invalid organization data: {e}")
|
|
972
|
+
except requests.RequestException as e:
|
|
973
|
+
raise click.ClickException(f"Failed to fetch organizations: {e}")
|
|
974
|
+
|
|
975
|
+
# If org_id provided, validate and select it directly
|
|
976
|
+
if org_id:
|
|
977
|
+
# Check if org_id exists in user's organizations
|
|
978
|
+
valid_org = next((org for org in orgs_data.organizations if org.id == org_id), None)
|
|
979
|
+
if not valid_org:
|
|
980
|
+
click.echo(click.style(f"✗ Organization not found: {org_id}", fg="red"))
|
|
981
|
+
click.echo()
|
|
982
|
+
click.echo("Available organizations:")
|
|
983
|
+
for org in orgs_data.organizations:
|
|
984
|
+
click.echo(f" • {org.id} ({org.name})")
|
|
985
|
+
raise click.ClickException("Invalid organization ID")
|
|
986
|
+
|
|
987
|
+
auth.set_active_organization(org_id)
|
|
988
|
+
click.echo(click.style(f"✓ Active organization set to: {valid_org.name}", fg=TEAL_RGB))
|
|
989
|
+
return
|
|
990
|
+
|
|
991
|
+
if not orgs_data.organizations:
|
|
992
|
+
click.echo(click.style("No organizations found", fg="yellow"))
|
|
993
|
+
return
|
|
994
|
+
|
|
995
|
+
# Get locally stored active org
|
|
996
|
+
active_org_id = get_active_organization()
|
|
997
|
+
|
|
998
|
+
# Show interactive selector
|
|
999
|
+
selected_org_id = _interactive_org_selector(orgs_data.organizations, active_org_id)
|
|
1000
|
+
|
|
1001
|
+
if selected_org_id and selected_org_id != active_org_id:
|
|
1002
|
+
auth.set_active_organization(selected_org_id)
|
|
1003
|
+
# Find org name for display
|
|
1004
|
+
org_name = next(
|
|
1005
|
+
(org.name for org in orgs_data.organizations if org.id == selected_org_id),
|
|
1006
|
+
selected_org_id,
|
|
1007
|
+
)
|
|
1008
|
+
click.echo()
|
|
1009
|
+
click.echo(click.style(f"✓ Switched to: {org_name}", fg=TEAL_RGB))
|
|
1010
|
+
elif selected_org_id:
|
|
1011
|
+
click.echo()
|
|
1012
|
+
click.echo(click.style("Organization unchanged", dim=True))
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
# ========== Project Commands ==========
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
@click.command()
|
|
1019
|
+
@click.option(
|
|
1020
|
+
"--template",
|
|
1021
|
+
"-t",
|
|
1022
|
+
type=click.Choice(
|
|
1023
|
+
[
|
|
1024
|
+
"start",
|
|
1025
|
+
"new",
|
|
1026
|
+
"blank",
|
|
1027
|
+
"sft",
|
|
1028
|
+
"rl-single",
|
|
1029
|
+
"rl-nextjs",
|
|
1030
|
+
"rl-tools",
|
|
1031
|
+
"mcp-tavily",
|
|
1032
|
+
"deepseek-aha",
|
|
1033
|
+
"tutorial-reward",
|
|
1034
|
+
"tutorial-tool",
|
|
1035
|
+
]
|
|
1036
|
+
),
|
|
1037
|
+
default="start",
|
|
1038
|
+
help="Project template to use",
|
|
1039
|
+
)
|
|
1040
|
+
@click.option("--name", "-n", help="Project name (will prompt if not provided)")
|
|
1041
|
+
def init(template: str, name: str):
|
|
1042
|
+
"""Initialize a new ReinforceNow project."""
|
|
1043
|
+
require_auth()
|
|
1044
|
+
|
|
1045
|
+
import shutil
|
|
1046
|
+
from pathlib import Path
|
|
1047
|
+
|
|
1048
|
+
from prompt_toolkit import prompt as pt_prompt
|
|
1049
|
+
from prompt_toolkit.formatted_text import HTML
|
|
1050
|
+
|
|
1051
|
+
def styled_prompt(question: str, default: str) -> str:
|
|
1052
|
+
"""Next.js style prompt with placeholder that disappears on typing."""
|
|
1053
|
+
result = pt_prompt(
|
|
1054
|
+
HTML(f"<b>{question}</b> <gray>›</gray> "),
|
|
1055
|
+
placeholder=HTML(f"<gray>{default}</gray>"),
|
|
1056
|
+
)
|
|
1057
|
+
return result.strip() or default
|
|
1058
|
+
|
|
1059
|
+
# Map "start" to "rl-single"
|
|
1060
|
+
actual_template = "rl-single" if template == "start" else template
|
|
1061
|
+
|
|
1062
|
+
# Default project names based on template
|
|
1063
|
+
template_default_names = {
|
|
1064
|
+
"rl-single": "rl-project",
|
|
1065
|
+
"rl-tools": "rl-tools-project",
|
|
1066
|
+
"rl-nextjs": "nextjs-project",
|
|
1067
|
+
"mcp-tavily": "mcp-tavily-project",
|
|
1068
|
+
"sft": "sft-project",
|
|
1069
|
+
"tutorial-reward": "tutorial-reward",
|
|
1070
|
+
"tutorial-tool": "tutorial-tool",
|
|
1071
|
+
"deepseek-aha": "deepseek-aha",
|
|
1072
|
+
"new": "new-project",
|
|
1073
|
+
"blank": "my-project",
|
|
1074
|
+
}
|
|
1075
|
+
default_project_name = template_default_names.get(actual_template, "my-project")
|
|
1076
|
+
|
|
1077
|
+
# Project name prompt
|
|
1078
|
+
project_name = (
|
|
1079
|
+
name if name else styled_prompt("What is your project named?", default_project_name)
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
# Dataset name prompt
|
|
1083
|
+
dataset_name = styled_prompt("What is your dataset named?", "train")
|
|
1084
|
+
|
|
1085
|
+
# Create project directory in current location
|
|
1086
|
+
project_dir = Path(".")
|
|
1087
|
+
|
|
1088
|
+
# Copy template files if template is specified (all except blank)
|
|
1089
|
+
if actual_template != "blank":
|
|
1090
|
+
template_dir = Path(__file__).parent.parent / "templates" / actual_template
|
|
1091
|
+
if template_dir.exists():
|
|
1092
|
+
# Get list of files to copy from template
|
|
1093
|
+
files_to_copy = [f for f in template_dir.iterdir() if f.is_file()]
|
|
1094
|
+
template_file_names = {f.name for f in files_to_copy}
|
|
1095
|
+
|
|
1096
|
+
# Define template-managed files (files that templates can provide)
|
|
1097
|
+
managed_files = {
|
|
1098
|
+
"config.yml",
|
|
1099
|
+
"train.jsonl",
|
|
1100
|
+
"rewards.py",
|
|
1101
|
+
"requirements.txt",
|
|
1102
|
+
"env.py",
|
|
1103
|
+
"README.md",
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
# Find template-managed files that exist but aren't in the new template
|
|
1107
|
+
extra_files = [
|
|
1108
|
+
fname
|
|
1109
|
+
for fname in managed_files
|
|
1110
|
+
if (project_dir / fname).exists() and fname not in template_file_names
|
|
1111
|
+
]
|
|
1112
|
+
|
|
1113
|
+
# Check if any files will be overwritten
|
|
1114
|
+
existing_files = [f.name for f in files_to_copy if (project_dir / f.name).exists()]
|
|
1115
|
+
|
|
1116
|
+
# Show concise warning and confirm
|
|
1117
|
+
if extra_files or existing_files:
|
|
1118
|
+
all_affected = extra_files + existing_files
|
|
1119
|
+
click.echo(
|
|
1120
|
+
click.style("Files to modify:", bold=True)
|
|
1121
|
+
+ click.style(f" {', '.join(all_affected)}", dim=True)
|
|
1122
|
+
)
|
|
1123
|
+
confirm_prompt = (
|
|
1124
|
+
click.style("Continue?", bold=True)
|
|
1125
|
+
+ " ("
|
|
1126
|
+
+ click.style("yes", dim=True)
|
|
1127
|
+
+ "/no)"
|
|
1128
|
+
)
|
|
1129
|
+
if not click.confirm(
|
|
1130
|
+
confirm_prompt, default=True, show_default=False, prompt_suffix=" "
|
|
1131
|
+
):
|
|
1132
|
+
raise click.Abort()
|
|
1133
|
+
|
|
1134
|
+
# Remove extra template files (silently)
|
|
1135
|
+
for fname in extra_files:
|
|
1136
|
+
(project_dir / fname).unlink()
|
|
1137
|
+
|
|
1138
|
+
# Copy all template files to current directory (silently)
|
|
1139
|
+
for file in files_to_copy:
|
|
1140
|
+
dest_file = project_dir / file.name
|
|
1141
|
+
shutil.copy2(file, dest_file)
|
|
1142
|
+
else:
|
|
1143
|
+
click.echo(
|
|
1144
|
+
click.style("Template not found:", bold=True)
|
|
1145
|
+
+ click.style(f" {template}, using blank template", dim=True)
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
# Generate new IDs
|
|
1149
|
+
project_id = str(uuid.uuid4())
|
|
1150
|
+
dataset_id = str(uuid.uuid4())
|
|
1151
|
+
org_id = get_active_organization()
|
|
1152
|
+
|
|
1153
|
+
# Update config.yml with actual IDs
|
|
1154
|
+
config_path = project_dir / "config.yml"
|
|
1155
|
+
if config_path.exists():
|
|
1156
|
+
with open(config_path) as f:
|
|
1157
|
+
config_data = yaml.safe_load(f)
|
|
1158
|
+
|
|
1159
|
+
# Update IDs and name
|
|
1160
|
+
config_data["project_id"] = project_id
|
|
1161
|
+
config_data["project_name"] = project_name
|
|
1162
|
+
config_data["dataset_id"] = dataset_id
|
|
1163
|
+
config_data["dataset_name"] = dataset_name
|
|
1164
|
+
config_data["organization_id"] = org_id
|
|
1165
|
+
|
|
1166
|
+
# Reorder keys to ensure proper field ordering in output
|
|
1167
|
+
key_order = [
|
|
1168
|
+
"project_id",
|
|
1169
|
+
"project_name",
|
|
1170
|
+
"dataset_id",
|
|
1171
|
+
"dataset_name",
|
|
1172
|
+
"dataset_type",
|
|
1173
|
+
"organization_id",
|
|
1174
|
+
"data",
|
|
1175
|
+
"model",
|
|
1176
|
+
"algorithm",
|
|
1177
|
+
"rollout",
|
|
1178
|
+
"trainer",
|
|
1179
|
+
]
|
|
1180
|
+
ordered_config = {k: config_data[k] for k in key_order if k in config_data}
|
|
1181
|
+
# Add any remaining keys not in the order list
|
|
1182
|
+
for k in config_data:
|
|
1183
|
+
if k not in ordered_config:
|
|
1184
|
+
ordered_config[k] = config_data[k]
|
|
1185
|
+
|
|
1186
|
+
with open(config_path, "w") as f:
|
|
1187
|
+
yaml.dump(ordered_config, f, default_flow_style=False, sort_keys=False)
|
|
1188
|
+
else:
|
|
1189
|
+
# Create new config for blank template
|
|
1190
|
+
config = models.ProjectConfig(
|
|
1191
|
+
project_id=project_id,
|
|
1192
|
+
project_name=project_name,
|
|
1193
|
+
dataset_id=dataset_id,
|
|
1194
|
+
dataset_name=dataset_name,
|
|
1195
|
+
dataset_type=models.DatasetType.RL,
|
|
1196
|
+
organization_id=org_id,
|
|
1197
|
+
data=models.DataConfig(batch_size=2, group_size=16),
|
|
1198
|
+
model=models.ModelConfig(path="Qwen/Qwen3-8B"),
|
|
1199
|
+
trainer=models.TrainerConfig(num_epochs=30),
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
with open(config_path, "w") as f:
|
|
1203
|
+
yaml.dump(
|
|
1204
|
+
config.model_dump(mode="json", exclude_none=True),
|
|
1205
|
+
f,
|
|
1206
|
+
default_flow_style=False,
|
|
1207
|
+
sort_keys=False,
|
|
1208
|
+
)
|
|
1209
|
+
pass # Config created silently
|
|
1210
|
+
|
|
1211
|
+
click.echo(click.style(f"\n✓ Created local project: {project_name}", fg=TEAL_RGB))
|
|
1212
|
+
click.echo()
|
|
1213
|
+
click.echo(click.style("Next steps:", bold=True))
|
|
1214
|
+
click.echo(f" 1. Edit {click.style('train.jsonl', underline=True)} with your training data")
|
|
1215
|
+
click.echo(
|
|
1216
|
+
f" 2. Edit {click.style('rewards.py', underline=True)} and {click.style('env.py', underline=True)} with your reward and tool functions"
|
|
1217
|
+
)
|
|
1218
|
+
click.echo(f" 3. Run {click.style('rnow run', fg=TEAL_RGB, bold=True)} to start training")
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
def parse_override(override: str) -> tuple[list[str], any]:
|
|
1222
|
+
"""
|
|
1223
|
+
Parse a single override string like 'algorithm.adv_estimator=grpo'.
|
|
1224
|
+
|
|
1225
|
+
Returns:
|
|
1226
|
+
Tuple of (key_path, value) where key_path is a list of nested keys.
|
|
1227
|
+
"""
|
|
1228
|
+
if "=" not in override:
|
|
1229
|
+
raise click.ClickException(
|
|
1230
|
+
f"Invalid override '{override}'. Use format: key=value or nested.key=value"
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
key, value = override.split("=", 1)
|
|
1234
|
+
key_path = key.strip().split(".")
|
|
1235
|
+
|
|
1236
|
+
# Try to parse value as JSON (for numbers, bools, lists)
|
|
1237
|
+
value = value.strip()
|
|
1238
|
+
if value.lower() == "true":
|
|
1239
|
+
return key_path, True
|
|
1240
|
+
elif value.lower() == "false":
|
|
1241
|
+
return key_path, False
|
|
1242
|
+
elif value.lower() == "null" or value.lower() == "none":
|
|
1243
|
+
return key_path, None
|
|
1244
|
+
|
|
1245
|
+
# Try numeric
|
|
1246
|
+
try:
|
|
1247
|
+
if "." in value:
|
|
1248
|
+
return key_path, float(value)
|
|
1249
|
+
else:
|
|
1250
|
+
return key_path, int(value)
|
|
1251
|
+
except ValueError:
|
|
1252
|
+
pass
|
|
1253
|
+
|
|
1254
|
+
# Keep as string
|
|
1255
|
+
return key_path, value
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
def apply_overrides(config_data: dict, overrides: tuple[str, ...]) -> dict:
|
|
1259
|
+
"""
|
|
1260
|
+
Apply CLI overrides to config data.
|
|
1261
|
+
|
|
1262
|
+
Args:
|
|
1263
|
+
config_data: The loaded config dictionary
|
|
1264
|
+
overrides: Tuple of override strings like ('algorithm.adv_estimator=grpo', 'model.path=Qwen/Qwen3-4B')
|
|
1265
|
+
|
|
1266
|
+
Returns:
|
|
1267
|
+
Modified config data
|
|
1268
|
+
"""
|
|
1269
|
+
for override in overrides:
|
|
1270
|
+
key_path, value = parse_override(override)
|
|
1271
|
+
|
|
1272
|
+
# Navigate to the nested location
|
|
1273
|
+
current = config_data
|
|
1274
|
+
for key in key_path[:-1]:
|
|
1275
|
+
if key not in current:
|
|
1276
|
+
current[key] = {}
|
|
1277
|
+
elif not isinstance(current[key], dict):
|
|
1278
|
+
raise click.ClickException(
|
|
1279
|
+
f"Cannot override '{'.'.join(key_path)}': '{key}' is not a nested object"
|
|
1280
|
+
)
|
|
1281
|
+
current = current[key]
|
|
1282
|
+
|
|
1283
|
+
# Set the value
|
|
1284
|
+
final_key = key_path[-1]
|
|
1285
|
+
old_value = current.get(final_key, "<not set>")
|
|
1286
|
+
current[final_key] = value
|
|
1287
|
+
|
|
1288
|
+
# Show what was changed
|
|
1289
|
+
full_key = ".".join(key_path)
|
|
1290
|
+
click.echo(f" Override: {click.style(full_key, fg=TEAL_RGB)} = {value} (was: {old_value})")
|
|
1291
|
+
|
|
1292
|
+
return config_data
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
def _submit_single_run(
|
|
1296
|
+
ctx,
|
|
1297
|
+
dir: Path,
|
|
1298
|
+
config_data: dict,
|
|
1299
|
+
base_url: str,
|
|
1300
|
+
name: str | None,
|
|
1301
|
+
debug: bool,
|
|
1302
|
+
model_override: str | None,
|
|
1303
|
+
epochs: int | None,
|
|
1304
|
+
batch_size: int | None,
|
|
1305
|
+
lr: float | None,
|
|
1306
|
+
overrides: tuple[str, ...],
|
|
1307
|
+
) -> dict | None:
|
|
1308
|
+
"""
|
|
1309
|
+
Submit a single training run. Returns dict with run_id and run_url on success.
|
|
1310
|
+
Raises click.ClickException on failure.
|
|
1311
|
+
"""
|
|
1312
|
+
# Build combined overrides from shorthand options + explicit overrides
|
|
1313
|
+
all_overrides = list(overrides)
|
|
1314
|
+
|
|
1315
|
+
# Add shorthand options as overrides (skip model override - already set in config_data)
|
|
1316
|
+
if epochs:
|
|
1317
|
+
all_overrides.insert(0, f"trainer.num_epochs={epochs}")
|
|
1318
|
+
if batch_size:
|
|
1319
|
+
all_overrides.insert(0, f"data.batch_size={batch_size}")
|
|
1320
|
+
if lr:
|
|
1321
|
+
all_overrides.insert(0, f"trainer.learning_rate={lr}")
|
|
1322
|
+
|
|
1323
|
+
# Apply CLI overrides before validation
|
|
1324
|
+
if all_overrides:
|
|
1325
|
+
config_data = apply_overrides(config_data, tuple(all_overrides))
|
|
1326
|
+
|
|
1327
|
+
# Now validate the config with overrides applied
|
|
1328
|
+
try:
|
|
1329
|
+
config = models.ProjectConfig(**config_data)
|
|
1330
|
+
except ValidationError as e:
|
|
1331
|
+
raise click.ClickException(format_validation_error(e))
|
|
1332
|
+
|
|
1333
|
+
if not config.organization_id:
|
|
1334
|
+
config.organization_id = get_active_organization()
|
|
1335
|
+
|
|
1336
|
+
# Validate required files
|
|
1337
|
+
required_files = {"train.jsonl": dir / "train.jsonl"}
|
|
1338
|
+
if config.dataset_type == models.DatasetType.RL:
|
|
1339
|
+
required_files["rewards.py"] = dir / "rewards.py"
|
|
1340
|
+
|
|
1341
|
+
for file_name, path in required_files.items():
|
|
1342
|
+
if not path.exists():
|
|
1343
|
+
raise click.ClickException(f"Missing required file: {file_name}")
|
|
1344
|
+
elif path.stat().st_size == 0:
|
|
1345
|
+
raise click.ClickException(f"Empty file: {file_name}")
|
|
1346
|
+
|
|
1347
|
+
# Validate train.jsonl format
|
|
1348
|
+
train_jsonl_path = dir / "train.jsonl"
|
|
1349
|
+
if train_jsonl_path.exists() and train_jsonl_path.stat().st_size > 0:
|
|
1350
|
+
jsonl_errors = validate_train_jsonl(train_jsonl_path, config.dataset_type)
|
|
1351
|
+
if jsonl_errors:
|
|
1352
|
+
raise click.ClickException(f"Invalid train.jsonl: {jsonl_errors[0]}")
|
|
1353
|
+
|
|
1354
|
+
# Validate rewards.py if present
|
|
1355
|
+
if config.dataset_type == models.DatasetType.RL:
|
|
1356
|
+
rewards_path = dir / "rewards.py"
|
|
1357
|
+
if rewards_path.exists():
|
|
1358
|
+
try:
|
|
1359
|
+
from rnow.core.reward import validate_rewards_file
|
|
1360
|
+
|
|
1361
|
+
errors = validate_rewards_file(rewards_path)
|
|
1362
|
+
if errors:
|
|
1363
|
+
raise click.ClickException(f"Invalid rewards.py: {errors[0]}")
|
|
1364
|
+
except ImportError:
|
|
1365
|
+
pass
|
|
1366
|
+
|
|
1367
|
+
ref_errors = validate_reward_references(train_jsonl_path, rewards_path)
|
|
1368
|
+
if ref_errors:
|
|
1369
|
+
raise click.ClickException(f"Reward mismatch: {ref_errors[0]}")
|
|
1370
|
+
|
|
1371
|
+
# Validate context window (prompt + tools + max_tokens)
|
|
1372
|
+
if config.rollout:
|
|
1373
|
+
model_path = config.model.path if config.model else ""
|
|
1374
|
+
model_name = model_path.split("/")[-1] if "/" in model_path else model_path
|
|
1375
|
+
|
|
1376
|
+
click.echo(f" [{model_name}] Validating context window...")
|
|
1377
|
+
|
|
1378
|
+
# Collect all tools
|
|
1379
|
+
all_tools = []
|
|
1380
|
+
|
|
1381
|
+
# Get tools from env.py
|
|
1382
|
+
env_path = dir / "env.py"
|
|
1383
|
+
env_tools = get_tools_from_env_py(env_path)
|
|
1384
|
+
all_tools.extend(env_tools)
|
|
1385
|
+
|
|
1386
|
+
# Fetch MCP tools
|
|
1387
|
+
mcp_urls = config.rollout.mcp_url
|
|
1388
|
+
if mcp_urls:
|
|
1389
|
+
mcp_tools, mcp_error = fetch_mcp_tool_schemas(mcp_urls, timeout=15.0)
|
|
1390
|
+
if mcp_error:
|
|
1391
|
+
raise click.ClickException(
|
|
1392
|
+
f"Failed to fetch MCP tools for {model_name}: {mcp_error}"
|
|
1393
|
+
)
|
|
1394
|
+
all_tools.extend(mcp_tools)
|
|
1395
|
+
click.echo(f" [{model_name}] MCP tools: {len(mcp_tools)} tools")
|
|
1396
|
+
|
|
1397
|
+
# Count tokens with proper format (includes Harmony rendering for gpt-oss)
|
|
1398
|
+
total_prompt_tokens = get_max_prompt_tokens(train_jsonl_path, all_tools, model_path)
|
|
1399
|
+
|
|
1400
|
+
click.echo(
|
|
1401
|
+
f" [{model_name}] Total: {total_prompt_tokens:,} + {config.rollout.max_tokens:,} = {total_prompt_tokens + config.rollout.max_tokens:,} / {models.MAX_CONTEXT_WINDOW:,}"
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
context_error, recommended = validate_max_tokens_for_context(
|
|
1405
|
+
config.rollout.max_tokens, total_prompt_tokens
|
|
1406
|
+
)
|
|
1407
|
+
if context_error:
|
|
1408
|
+
raise click.ClickException(
|
|
1409
|
+
f"Context window exceeded for {model_name}: "
|
|
1410
|
+
f"~{total_prompt_tokens:,} prompt+tools + {config.rollout.max_tokens:,} max_tokens "
|
|
1411
|
+
f"> {models.MAX_CONTEXT_WINDOW:,}. Set max_tokens to {recommended:,} or less."
|
|
1412
|
+
)
|
|
1413
|
+
|
|
1414
|
+
# Check if train.jsonl needs blob upload
|
|
1415
|
+
train_path = dir / "train.jsonl"
|
|
1416
|
+
train_size = train_path.stat().st_size
|
|
1417
|
+
dataset_url = None
|
|
1418
|
+
|
|
1419
|
+
if train_size > MAX_INLINE_BYTES:
|
|
1420
|
+
try:
|
|
1421
|
+
_, blob_info = maybe_upload_to_blob(base_url, train_path, config.dataset_id)
|
|
1422
|
+
if blob_info:
|
|
1423
|
+
dataset_url = blob_info.get("url")
|
|
1424
|
+
except Exception as e:
|
|
1425
|
+
raise click.ClickException(f"Failed to upload large dataset: {e}")
|
|
1426
|
+
|
|
1427
|
+
# Upload files
|
|
1428
|
+
files = []
|
|
1429
|
+
|
|
1430
|
+
# Add config file - create a temporary one with the modified config
|
|
1431
|
+
import tempfile
|
|
1432
|
+
|
|
1433
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as tmp:
|
|
1434
|
+
yaml.dump(config_data, tmp, default_flow_style=False, sort_keys=False)
|
|
1435
|
+
tmp_config_path = Path(tmp.name)
|
|
1436
|
+
|
|
1437
|
+
files.append(
|
|
1438
|
+
("config_yml", ("config.yml", open(tmp_config_path, "rb"), "application/octet-stream"))
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
# Add required files (skip train.jsonl if uploaded to blob)
|
|
1442
|
+
for file_name, path in required_files.items():
|
|
1443
|
+
if file_name == "train.jsonl" and dataset_url:
|
|
1444
|
+
continue
|
|
1445
|
+
files.append(
|
|
1446
|
+
(file_name.replace(".", "_"), (file_name, open(path, "rb"), "application/octet-stream"))
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
# Add optional files
|
|
1450
|
+
optional_files = {"env.py": dir / "env.py", "requirements.txt": dir / "requirements.txt"}
|
|
1451
|
+
for file_name, path in optional_files.items():
|
|
1452
|
+
if path.exists():
|
|
1453
|
+
files.append(
|
|
1454
|
+
(
|
|
1455
|
+
file_name.replace(".", "_"),
|
|
1456
|
+
(file_name, open(path, "rb"), "application/octet-stream"),
|
|
1457
|
+
)
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
headers = auth.get_auth_headers()
|
|
1461
|
+
headers.pop("Content-Type", None)
|
|
1462
|
+
|
|
1463
|
+
submit_data = {
|
|
1464
|
+
"project_id": config.project_id,
|
|
1465
|
+
"dataset_id": config.dataset_id,
|
|
1466
|
+
"organization_id": config.organization_id,
|
|
1467
|
+
}
|
|
1468
|
+
if name:
|
|
1469
|
+
submit_data["run_name"] = name
|
|
1470
|
+
if dataset_url:
|
|
1471
|
+
submit_data["dataset_url"] = dataset_url
|
|
1472
|
+
if debug:
|
|
1473
|
+
submit_data["debug"] = "true"
|
|
1474
|
+
|
|
1475
|
+
run_url = None
|
|
1476
|
+
run_id = None
|
|
1477
|
+
error_msg = None
|
|
1478
|
+
|
|
1479
|
+
try:
|
|
1480
|
+
response = session.post(
|
|
1481
|
+
f"{base_url}/training/submit",
|
|
1482
|
+
data=submit_data,
|
|
1483
|
+
files=files,
|
|
1484
|
+
headers=headers,
|
|
1485
|
+
stream=True,
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
if response.status_code != 200:
|
|
1489
|
+
error_msg = f"Training submission failed: {response.text}"
|
|
1490
|
+
else:
|
|
1491
|
+
response.encoding = "utf-8"
|
|
1492
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
1493
|
+
if line and line.startswith("data: "):
|
|
1494
|
+
msg = line[6:]
|
|
1495
|
+
if "View:" in msg:
|
|
1496
|
+
run_url = msg.split("View:")[-1].strip()
|
|
1497
|
+
if run_url:
|
|
1498
|
+
run_id = run_url.rstrip("/").split("/")[-1]
|
|
1499
|
+
elif "http" in msg and "View" not in msg:
|
|
1500
|
+
run_url = msg.split()[-1].strip()
|
|
1501
|
+
if run_url:
|
|
1502
|
+
run_id = run_url.rstrip("/").split("/")[-1]
|
|
1503
|
+
elif msg.startswith("❌") or "Error" in msg or "failed" in msg.lower():
|
|
1504
|
+
error_msg = msg
|
|
1505
|
+
|
|
1506
|
+
except Exception as e:
|
|
1507
|
+
error_msg = f"Request failed: {e}"
|
|
1508
|
+
finally:
|
|
1509
|
+
for _, (_, fh, _) in files:
|
|
1510
|
+
fh.close()
|
|
1511
|
+
# Clean up temp config file
|
|
1512
|
+
tmp_config_path.unlink(missing_ok=True)
|
|
1513
|
+
|
|
1514
|
+
if error_msg:
|
|
1515
|
+
raise click.ClickException(error_msg)
|
|
1516
|
+
|
|
1517
|
+
return {"run_id": run_id, "run_url": run_url}
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
@click.command()
|
|
1521
|
+
@click.option(
|
|
1522
|
+
"--dir",
|
|
1523
|
+
"-d",
|
|
1524
|
+
default=".",
|
|
1525
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path),
|
|
1526
|
+
help="Directory containing project files (default: current directory)",
|
|
1527
|
+
)
|
|
1528
|
+
@click.option(
|
|
1529
|
+
"--name", "-n", default=None, help="Custom name for the training run (default: auto-generated)"
|
|
1530
|
+
)
|
|
1531
|
+
@click.option(
|
|
1532
|
+
"--debug",
|
|
1533
|
+
is_flag=True,
|
|
1534
|
+
default=False,
|
|
1535
|
+
help="Debug mode: upload files but don't start training job",
|
|
1536
|
+
)
|
|
1537
|
+
@click.option(
|
|
1538
|
+
"--model",
|
|
1539
|
+
"-m",
|
|
1540
|
+
default=None,
|
|
1541
|
+
help="Override model path (e.g., Qwen/Qwen3-4B)",
|
|
1542
|
+
)
|
|
1543
|
+
@click.option(
|
|
1544
|
+
"--epochs",
|
|
1545
|
+
"-e",
|
|
1546
|
+
default=None,
|
|
1547
|
+
type=int,
|
|
1548
|
+
help="Override number of training epochs",
|
|
1549
|
+
)
|
|
1550
|
+
@click.option(
|
|
1551
|
+
"--batch-size",
|
|
1552
|
+
"-b",
|
|
1553
|
+
default=None,
|
|
1554
|
+
type=int,
|
|
1555
|
+
help="Override batch size (1-32)",
|
|
1556
|
+
)
|
|
1557
|
+
@click.option(
|
|
1558
|
+
"--lr",
|
|
1559
|
+
"--learning-rate",
|
|
1560
|
+
default=None,
|
|
1561
|
+
type=float,
|
|
1562
|
+
help="Override learning rate",
|
|
1563
|
+
)
|
|
1564
|
+
@click.argument("overrides", nargs=-1)
|
|
1565
|
+
@click.pass_context
|
|
1566
|
+
def run(
|
|
1567
|
+
ctx,
|
|
1568
|
+
dir: Path,
|
|
1569
|
+
name: str,
|
|
1570
|
+
debug: bool,
|
|
1571
|
+
model: str,
|
|
1572
|
+
epochs: int,
|
|
1573
|
+
batch_size: int,
|
|
1574
|
+
lr: float,
|
|
1575
|
+
overrides: tuple[str, ...],
|
|
1576
|
+
):
|
|
1577
|
+
"""Submit project for training on ReinforceNow platform.
|
|
1578
|
+
|
|
1579
|
+
You can override any config.yml setting by passing key=value arguments:
|
|
1580
|
+
|
|
1581
|
+
\b
|
|
1582
|
+
Examples:
|
|
1583
|
+
rnow run model.path=Qwen/Qwen3-4B
|
|
1584
|
+
rnow run algorithm.adv_estimator=grpo trainer.learning_rate=0.0002
|
|
1585
|
+
rnow run data.batch_size=8 data.group_size=16 trainer.num_epochs=5
|
|
1586
|
+
rnow run rollout.max_turns=3 rollout.max_tokens=4096
|
|
1587
|
+
|
|
1588
|
+
\b
|
|
1589
|
+
Common overrides:
|
|
1590
|
+
model.path Model to train (e.g., Qwen/Qwen3-8B, Qwen/Qwen3-4B)
|
|
1591
|
+
model.qlora_rank LoRA rank (default: 32)
|
|
1592
|
+
data.batch_size Batch size (1-32)
|
|
1593
|
+
data.group_size Rollouts per prompt for RL (1-64)
|
|
1594
|
+
trainer.num_epochs Number of training epochs
|
|
1595
|
+
trainer.learning_rate Learning rate (default: 0.0001)
|
|
1596
|
+
algorithm.adv_estimator Advantage estimator: grpo, gae, reinforce
|
|
1597
|
+
algorithm.loss_fn Loss function: ppo, importance_sampling
|
|
1598
|
+
rollout.max_turns Max conversation turns for RL
|
|
1599
|
+
rollout.max_tokens Max tokens per generation
|
|
1600
|
+
rollout.thinking_mode Reasoning mode: disabled, easy, medium, hard
|
|
1601
|
+
|
|
1602
|
+
Multi-model training:
|
|
1603
|
+
If model.path is a list in config.yml, a separate run will be submitted
|
|
1604
|
+
for each model in the list.
|
|
1605
|
+
|
|
1606
|
+
Example config.yml:
|
|
1607
|
+
model:
|
|
1608
|
+
path:
|
|
1609
|
+
- Qwen/Qwen3-8B
|
|
1610
|
+
- Qwen/Qwen3-4B
|
|
1611
|
+
- meta-llama/Llama-3.1-8B-Instruct
|
|
1612
|
+
"""
|
|
1613
|
+
require_auth()
|
|
1614
|
+
base_url = ctx.obj.get("api_url", "https://www.reinforcenow.ai/api")
|
|
1615
|
+
|
|
1616
|
+
# Load and validate config
|
|
1617
|
+
config_yml = dir / "config.yml"
|
|
1618
|
+
config_json = dir / "config.json"
|
|
1619
|
+
|
|
1620
|
+
# First load raw config data
|
|
1621
|
+
config_data = None
|
|
1622
|
+
if config_yml.exists():
|
|
1623
|
+
try:
|
|
1624
|
+
with open(config_yml) as f:
|
|
1625
|
+
config_data = yaml.safe_load(f)
|
|
1626
|
+
except FileNotFoundError:
|
|
1627
|
+
raise click.ClickException(f"Config file not found in {dir}")
|
|
1628
|
+
except yaml.YAMLError as e:
|
|
1629
|
+
raise click.ClickException(f"Invalid YAML in config file: {e}")
|
|
1630
|
+
elif config_json.exists():
|
|
1631
|
+
try:
|
|
1632
|
+
with open(config_json) as f:
|
|
1633
|
+
config_data = json.load(f)
|
|
1634
|
+
except json.JSONDecodeError as e:
|
|
1635
|
+
raise click.ClickException(f"Invalid JSON in config file: {e}")
|
|
1636
|
+
else:
|
|
1637
|
+
raise click.ClickException(f"No config.yml or config.json found in {dir}")
|
|
1638
|
+
|
|
1639
|
+
# Check if model.path is a list (multi-model training)
|
|
1640
|
+
model_paths = config_data.get("model", {}).get("path")
|
|
1641
|
+
if isinstance(model_paths, list) and len(model_paths) > 1:
|
|
1642
|
+
# Validate qlora_rank for each model before starting any runs
|
|
1643
|
+
qlora_rank = config_data.get("model", {}).get("qlora_rank", 32)
|
|
1644
|
+
for model_path in model_paths:
|
|
1645
|
+
max_rank = models.get_max_lora_rank(model_path)
|
|
1646
|
+
if qlora_rank > max_rank:
|
|
1647
|
+
model_name = model_path.split("/")[-1] if "/" in model_path else model_path
|
|
1648
|
+
raise click.ClickException(
|
|
1649
|
+
f"qlora_rank {qlora_rank} exceeds maximum {max_rank} for model {model_name}. "
|
|
1650
|
+
f"Set qlora_rank to {max_rank} or lower to train all models."
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
results = []
|
|
1654
|
+
for model_path in model_paths:
|
|
1655
|
+
model_name = model_path.split("/")[-1] if "/" in model_path else model_path
|
|
1656
|
+
|
|
1657
|
+
# Create a copy of config_data with single model path
|
|
1658
|
+
single_config = json.loads(json.dumps(config_data)) # Deep copy
|
|
1659
|
+
single_config["model"]["path"] = model_path
|
|
1660
|
+
|
|
1661
|
+
# Submit this model's run
|
|
1662
|
+
try:
|
|
1663
|
+
run_result = _submit_single_run(
|
|
1664
|
+
ctx=ctx,
|
|
1665
|
+
dir=dir,
|
|
1666
|
+
config_data=single_config,
|
|
1667
|
+
base_url=base_url,
|
|
1668
|
+
name=name,
|
|
1669
|
+
debug=debug,
|
|
1670
|
+
model_override=model,
|
|
1671
|
+
epochs=epochs,
|
|
1672
|
+
batch_size=batch_size,
|
|
1673
|
+
lr=lr,
|
|
1674
|
+
overrides=overrides,
|
|
1675
|
+
)
|
|
1676
|
+
results.append((model_path, run_result, None))
|
|
1677
|
+
except click.ClickException as e:
|
|
1678
|
+
results.append((model_path, None, str(e)))
|
|
1679
|
+
|
|
1680
|
+
# Show run IDs
|
|
1681
|
+
click.echo(click.style("Run IDs:", bold=True))
|
|
1682
|
+
for model_path, result, _error in results:
|
|
1683
|
+
model_name = model_path.split("/")[-1] if "/" in model_path else model_path
|
|
1684
|
+
if result and result.get("run_id"):
|
|
1685
|
+
click.echo(f" {model_name}: {result['run_id']}")
|
|
1686
|
+
else:
|
|
1687
|
+
click.echo(f" {model_name}: {click.style('failed', fg='red')}")
|
|
1688
|
+
|
|
1689
|
+
# Show run URLs
|
|
1690
|
+
successful = [r for r in results if r[1] is not None]
|
|
1691
|
+
if successful:
|
|
1692
|
+
click.echo()
|
|
1693
|
+
click.echo(click.style("Run URLs:", bold=True))
|
|
1694
|
+
for model_path, result, _ in successful:
|
|
1695
|
+
model_name = model_path.split("/")[-1] if "/" in model_path else model_path
|
|
1696
|
+
if result and result.get("run_url"):
|
|
1697
|
+
click.echo(f" {model_name}: {click.style(result['run_url'], fg=TEAL_RGB)}")
|
|
1698
|
+
|
|
1699
|
+
return
|
|
1700
|
+
|
|
1701
|
+
# Single model training - continue with normal flow
|
|
1702
|
+
# Build combined overrides from shorthand options + explicit overrides
|
|
1703
|
+
all_overrides = list(overrides)
|
|
1704
|
+
|
|
1705
|
+
# Add shorthand options as overrides
|
|
1706
|
+
if model:
|
|
1707
|
+
all_overrides.insert(0, f"model.path={model}")
|
|
1708
|
+
if epochs:
|
|
1709
|
+
all_overrides.insert(0, f"trainer.num_epochs={epochs}")
|
|
1710
|
+
if batch_size:
|
|
1711
|
+
all_overrides.insert(0, f"data.batch_size={batch_size}")
|
|
1712
|
+
if lr:
|
|
1713
|
+
all_overrides.insert(0, f"trainer.learning_rate={lr}")
|
|
1714
|
+
|
|
1715
|
+
# Apply CLI overrides before validation
|
|
1716
|
+
if all_overrides:
|
|
1717
|
+
click.echo(click.style("Applying config overrides:", bold=True))
|
|
1718
|
+
config_data = apply_overrides(config_data, tuple(all_overrides))
|
|
1719
|
+
click.echo()
|
|
1720
|
+
|
|
1721
|
+
# Now validate the config with overrides applied
|
|
1722
|
+
try:
|
|
1723
|
+
config = models.ProjectConfig(**config_data)
|
|
1724
|
+
except ValidationError as e:
|
|
1725
|
+
click.echo(format_validation_error(e))
|
|
1726
|
+
if overrides:
|
|
1727
|
+
click.echo(
|
|
1728
|
+
click.style("\nHint: One of your overrides may have an invalid value.", fg="yellow")
|
|
1729
|
+
)
|
|
1730
|
+
raise click.ClickException("Please fix config before submitting")
|
|
1731
|
+
|
|
1732
|
+
if not config.organization_id:
|
|
1733
|
+
config.organization_id = get_active_organization()
|
|
1734
|
+
|
|
1735
|
+
# Validate required files (all in the same directory now)
|
|
1736
|
+
required_files = {
|
|
1737
|
+
"train.jsonl": dir / "train.jsonl",
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
# Only require rewards.py for RL datasets
|
|
1741
|
+
if config.dataset_type == models.DatasetType.RL:
|
|
1742
|
+
required_files["rewards.py"] = dir / "rewards.py"
|
|
1743
|
+
|
|
1744
|
+
missing_files = []
|
|
1745
|
+
empty_files = []
|
|
1746
|
+
for file_name, path in required_files.items():
|
|
1747
|
+
if not path.exists():
|
|
1748
|
+
missing_files.append(f" • {file_name} at {path}")
|
|
1749
|
+
elif path.stat().st_size == 0:
|
|
1750
|
+
if file_name == "train.jsonl":
|
|
1751
|
+
empty_files.append(
|
|
1752
|
+
f" • {file_name} is empty - please add training examples (one JSON object per line)"
|
|
1753
|
+
)
|
|
1754
|
+
elif file_name == "rewards.py":
|
|
1755
|
+
empty_files.append(
|
|
1756
|
+
f" • {file_name} is empty - please implement your reward function"
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1759
|
+
if missing_files:
|
|
1760
|
+
click.echo(click.style("✗ Required files missing:", fg="red", bold=True))
|
|
1761
|
+
for file_msg in missing_files:
|
|
1762
|
+
click.echo(file_msg)
|
|
1763
|
+
raise click.ClickException("Missing required files for training submission")
|
|
1764
|
+
|
|
1765
|
+
if empty_files:
|
|
1766
|
+
click.echo(click.style("✗ Empty files detected:", fg="red", bold=True))
|
|
1767
|
+
for file_msg in empty_files:
|
|
1768
|
+
click.echo(file_msg)
|
|
1769
|
+
raise click.ClickException("Please add content to empty files before submitting")
|
|
1770
|
+
|
|
1771
|
+
# Validate train.jsonl format (sample first 50 lines)
|
|
1772
|
+
train_jsonl_path = dir / "train.jsonl"
|
|
1773
|
+
if train_jsonl_path.exists() and train_jsonl_path.stat().st_size > 0:
|
|
1774
|
+
jsonl_errors = validate_train_jsonl(train_jsonl_path, config.dataset_type)
|
|
1775
|
+
if jsonl_errors:
|
|
1776
|
+
click.echo(click.style("✗ Invalid train.jsonl format:", fg="red", bold=True))
|
|
1777
|
+
for err in jsonl_errors:
|
|
1778
|
+
click.echo(f" • {err}")
|
|
1779
|
+
raise click.ClickException("Please fix train.jsonl format before submitting")
|
|
1780
|
+
|
|
1781
|
+
# Validate max_tokens vs prompt size for RL (including tool definitions)
|
|
1782
|
+
if config.dataset_type == models.DatasetType.RL and config.rollout:
|
|
1783
|
+
# Get model path for accurate tokenization
|
|
1784
|
+
model_path = config.model.path if config.model else ""
|
|
1785
|
+
|
|
1786
|
+
# Try to load tokenizer (show message if loading)
|
|
1787
|
+
if model_path:
|
|
1788
|
+
click.echo(click.style("Loading tokenizer...", dim=True), nl=False)
|
|
1789
|
+
tokenizer_info = get_tokenizer_for_model(model_path)
|
|
1790
|
+
if tokenizer_info:
|
|
1791
|
+
tokenizer_type = tokenizer_info[0]
|
|
1792
|
+
label = "Harmony" if tokenizer_type == "harmony" else "HuggingFace"
|
|
1793
|
+
click.echo(
|
|
1794
|
+
"\r"
|
|
1795
|
+
+ click.style("Tokenizer: ", fg=TEAL_RGB)
|
|
1796
|
+
+ f"{label} ({model_path})"
|
|
1797
|
+
+ " " * 10
|
|
1798
|
+
)
|
|
1799
|
+
else:
|
|
1800
|
+
click.echo(
|
|
1801
|
+
"\r"
|
|
1802
|
+
+ click.style("Tokenizer: ", fg="yellow")
|
|
1803
|
+
+ "not available, using estimates"
|
|
1804
|
+
+ " " * 10
|
|
1805
|
+
)
|
|
1806
|
+
|
|
1807
|
+
# Collect all tools
|
|
1808
|
+
all_tools = []
|
|
1809
|
+
|
|
1810
|
+
# Get tools from env.py
|
|
1811
|
+
env_path = dir / "env.py"
|
|
1812
|
+
env_tools = get_tools_from_env_py(env_path)
|
|
1813
|
+
all_tools.extend(env_tools)
|
|
1814
|
+
|
|
1815
|
+
# Fetch MCP tool schemas (with progress indicator)
|
|
1816
|
+
mcp_urls = config.rollout.mcp_url if config.rollout else None
|
|
1817
|
+
if mcp_urls:
|
|
1818
|
+
# Check if fastmcp is installed
|
|
1819
|
+
try:
|
|
1820
|
+
import fastmcp # noqa: F401
|
|
1821
|
+
except ImportError:
|
|
1822
|
+
click.echo()
|
|
1823
|
+
click.echo(click.style("✗ MCP support requires fastmcp", fg="red", bold=True))
|
|
1824
|
+
click.echo()
|
|
1825
|
+
click.echo(" Your config.yml uses mcp_url, but fastmcp is not installed")
|
|
1826
|
+
click.echo(" in the same Python environment as rnow.")
|
|
1827
|
+
click.echo()
|
|
1828
|
+
click.echo(f" rnow is running from: {click.style(sys.executable, dim=True)}")
|
|
1829
|
+
click.echo()
|
|
1830
|
+
click.echo(" Install it with:")
|
|
1831
|
+
click.echo(
|
|
1832
|
+
click.style(f" {sys.executable} -m pip install fastmcp", fg=TEAL_RGB)
|
|
1833
|
+
)
|
|
1834
|
+
click.echo()
|
|
1835
|
+
raise click.ClickException("Missing dependency: fastmcp")
|
|
1836
|
+
|
|
1837
|
+
click.echo(click.style("Fetching MCP tools...", dim=True), nl=False)
|
|
1838
|
+
mcp_tools, mcp_error = fetch_mcp_tool_schemas(mcp_urls, timeout=15.0)
|
|
1839
|
+
if mcp_error:
|
|
1840
|
+
click.echo(
|
|
1841
|
+
"\r" + click.style("MCP: ", fg="red") + f"failed ({mcp_error})" + " " * 20
|
|
1842
|
+
)
|
|
1843
|
+
raise click.ClickException(f"Failed to fetch MCP tools: {mcp_error}")
|
|
1844
|
+
|
|
1845
|
+
all_tools.extend(mcp_tools)
|
|
1846
|
+
click.echo(
|
|
1847
|
+
"\r" + click.style("MCP: ", fg=TEAL_RGB) + f"{len(mcp_tools)} tools" + " " * 20
|
|
1848
|
+
)
|
|
1849
|
+
|
|
1850
|
+
# Count tokens with proper format (includes Harmony rendering for gpt-oss)
|
|
1851
|
+
total_prompt_tokens = get_max_prompt_tokens(train_jsonl_path, all_tools, model_path)
|
|
1852
|
+
|
|
1853
|
+
# Show context window usage
|
|
1854
|
+
if total_prompt_tokens > 0:
|
|
1855
|
+
is_gpt_oss = "gpt-oss" in model_path.lower()
|
|
1856
|
+
format_note = " (Harmony format)" if is_gpt_oss else ""
|
|
1857
|
+
click.echo(
|
|
1858
|
+
click.style("Context: ", fg=TEAL_RGB)
|
|
1859
|
+
+ f"~{total_prompt_tokens:,} prompt+tools{format_note}"
|
|
1860
|
+
+ f" + {config.rollout.max_tokens:,} max_tokens"
|
|
1861
|
+
+ f" = ~{total_prompt_tokens + config.rollout.max_tokens:,}"
|
|
1862
|
+
+ f" / {models.MAX_CONTEXT_WINDOW:,}"
|
|
1863
|
+
)
|
|
1864
|
+
context_error, recommended = validate_max_tokens_for_context(
|
|
1865
|
+
config.rollout.max_tokens, total_prompt_tokens
|
|
1866
|
+
)
|
|
1867
|
+
if context_error:
|
|
1868
|
+
click.echo()
|
|
1869
|
+
click.echo(click.style("✗ Context window exceeded", fg="red", bold=True))
|
|
1870
|
+
click.echo()
|
|
1871
|
+
click.echo(
|
|
1872
|
+
f" Total prompt context (with tools): ~{total_prompt_tokens:,} tokens."
|
|
1873
|
+
)
|
|
1874
|
+
if is_gpt_oss:
|
|
1875
|
+
click.echo(
|
|
1876
|
+
" Note: gpt-oss uses Harmony format which includes system overhead."
|
|
1877
|
+
)
|
|
1878
|
+
click.echo(
|
|
1879
|
+
f" With max_tokens={config.rollout.max_tokens:,}, the total exceeds"
|
|
1880
|
+
)
|
|
1881
|
+
click.echo(f" the {models.MAX_CONTEXT_WINDOW:,} token context window.")
|
|
1882
|
+
click.echo()
|
|
1883
|
+
click.echo(
|
|
1884
|
+
click.style(" Fix:", bold=True)
|
|
1885
|
+
+ f" Set rollout.max_tokens to {recommended:,} or less"
|
|
1886
|
+
)
|
|
1887
|
+
click.echo()
|
|
1888
|
+
click.echo(click.style(" In config.yml:", dim=True))
|
|
1889
|
+
click.echo(click.style(" rollout:", dim=True))
|
|
1890
|
+
click.echo(f" max_tokens: {click.style(str(recommended), fg=TEAL_RGB)}")
|
|
1891
|
+
click.echo()
|
|
1892
|
+
raise click.ClickException("max_tokens + prompt length exceeds context window")
|
|
1893
|
+
|
|
1894
|
+
# Validate requirements.txt if present (check format and Python 3.11 compatibility)
|
|
1895
|
+
requirements_path = dir / "requirements.txt"
|
|
1896
|
+
if requirements_path.exists() and requirements_path.stat().st_size > 0:
|
|
1897
|
+
req_errors = validate_requirements_txt(requirements_path, target_python="3.11")
|
|
1898
|
+
if req_errors:
|
|
1899
|
+
click.echo(click.style("✗ Invalid requirements.txt:", fg="red", bold=True))
|
|
1900
|
+
for err in req_errors:
|
|
1901
|
+
click.echo(f" • {err}")
|
|
1902
|
+
raise click.ClickException("Please fix requirements.txt before submitting")
|
|
1903
|
+
|
|
1904
|
+
# Validate rewards.py if present (check signature on @reward functions)
|
|
1905
|
+
if config.dataset_type == models.DatasetType.RL:
|
|
1906
|
+
rewards_path = dir / "rewards.py"
|
|
1907
|
+
if rewards_path.exists():
|
|
1908
|
+
try:
|
|
1909
|
+
from rnow.core.reward import validate_rewards_file
|
|
1910
|
+
|
|
1911
|
+
errors = validate_rewards_file(rewards_path)
|
|
1912
|
+
if errors:
|
|
1913
|
+
click.echo(click.style("✗ Invalid rewards.py:", fg="red", bold=True))
|
|
1914
|
+
for err in errors:
|
|
1915
|
+
click.echo(f" • {err}")
|
|
1916
|
+
raise click.ClickException("Please fix rewards.py before submitting")
|
|
1917
|
+
except ImportError:
|
|
1918
|
+
pass # Skip validation if module not available
|
|
1919
|
+
|
|
1920
|
+
# Validate that rewards referenced in train.jsonl exist in rewards.py
|
|
1921
|
+
ref_errors = validate_reward_references(train_jsonl_path, rewards_path)
|
|
1922
|
+
if ref_errors:
|
|
1923
|
+
click.echo(click.style("✗ Reward mismatch:", fg="red", bold=True))
|
|
1924
|
+
for err in ref_errors:
|
|
1925
|
+
click.echo(f" • {err}")
|
|
1926
|
+
raise click.ClickException(
|
|
1927
|
+
"Please ensure reward names in train.jsonl match functions in rewards.py"
|
|
1928
|
+
)
|
|
1929
|
+
|
|
1930
|
+
# Validate env.py if present (check for docstrings on @tool functions)
|
|
1931
|
+
env_path = dir / "env.py"
|
|
1932
|
+
has_env_py = env_path.exists() and env_path.stat().st_size > 0
|
|
1933
|
+
if has_env_py:
|
|
1934
|
+
try:
|
|
1935
|
+
from rnow.core.tool import validate_tools_file
|
|
1936
|
+
|
|
1937
|
+
errors = validate_tools_file(env_path)
|
|
1938
|
+
if errors:
|
|
1939
|
+
click.echo(click.style("✗ Invalid env.py:", fg="red", bold=True))
|
|
1940
|
+
for err in errors:
|
|
1941
|
+
click.echo(f" • {err}")
|
|
1942
|
+
raise click.ClickException("Please fix env.py before submitting")
|
|
1943
|
+
except ImportError:
|
|
1944
|
+
pass # Skip validation if module not available
|
|
1945
|
+
|
|
1946
|
+
# Check for MCP URL(s) in config
|
|
1947
|
+
has_mcp_url = config.rollout is not None and config.rollout.mcp_url is not None
|
|
1948
|
+
mcp_url_count = 0
|
|
1949
|
+
if has_mcp_url:
|
|
1950
|
+
mcp_url = config.rollout.mcp_url
|
|
1951
|
+
mcp_url_count = len(mcp_url) if isinstance(mcp_url, list) else 1
|
|
1952
|
+
|
|
1953
|
+
# Show tool sources message
|
|
1954
|
+
if has_env_py and has_mcp_url:
|
|
1955
|
+
server_text = f"{mcp_url_count} server(s)" if mcp_url_count > 1 else "1 server"
|
|
1956
|
+
click.echo(
|
|
1957
|
+
click.style("Tools: ", fg=TEAL_RGB) + f"Using MCP ({server_text}) and env.py tools"
|
|
1958
|
+
)
|
|
1959
|
+
elif has_mcp_url:
|
|
1960
|
+
server_text = f"{mcp_url_count} server(s)" if mcp_url_count > 1 else "1 server"
|
|
1961
|
+
click.echo(click.style("Tools: ", fg=TEAL_RGB) + f"Using MCP ({server_text})")
|
|
1962
|
+
elif has_env_py:
|
|
1963
|
+
click.echo(click.style("Tools: ", fg=TEAL_RGB) + "Using env.py tools")
|
|
1964
|
+
|
|
1965
|
+
# Start cube spinner early
|
|
1966
|
+
spinner = CubeSpinner()
|
|
1967
|
+
|
|
1968
|
+
# Check if train.jsonl needs blob upload (>4MB)
|
|
1969
|
+
train_path = dir / "train.jsonl"
|
|
1970
|
+
train_size = train_path.stat().st_size
|
|
1971
|
+
dataset_url = None
|
|
1972
|
+
|
|
1973
|
+
if train_size > MAX_INLINE_BYTES:
|
|
1974
|
+
spinner.start()
|
|
1975
|
+
try:
|
|
1976
|
+
_, blob_info = maybe_upload_to_blob(base_url, train_path, config.dataset_id)
|
|
1977
|
+
if blob_info:
|
|
1978
|
+
dataset_url = blob_info.get("url")
|
|
1979
|
+
except Exception as e:
|
|
1980
|
+
spinner.stop()
|
|
1981
|
+
raise click.ClickException(f"Failed to upload large dataset: {e}")
|
|
1982
|
+
|
|
1983
|
+
# Upload files
|
|
1984
|
+
files = []
|
|
1985
|
+
|
|
1986
|
+
# Add config file
|
|
1987
|
+
if config_yml.exists():
|
|
1988
|
+
files.append(
|
|
1989
|
+
("config_yml", ("config.yml", open(config_yml, "rb"), "application/octet-stream"))
|
|
1990
|
+
)
|
|
1991
|
+
elif config_json.exists():
|
|
1992
|
+
files.append(
|
|
1993
|
+
("config_json", ("config.json", open(config_json, "rb"), "application/octet-stream"))
|
|
1994
|
+
)
|
|
1995
|
+
|
|
1996
|
+
# Add required files (skip train.jsonl if uploaded to blob)
|
|
1997
|
+
for file_name, path in required_files.items():
|
|
1998
|
+
if file_name == "train.jsonl" and dataset_url:
|
|
1999
|
+
# Skip - already uploaded to blob
|
|
2000
|
+
continue
|
|
2001
|
+
files.append(
|
|
2002
|
+
(file_name.replace(".", "_"), (file_name, open(path, "rb"), "application/octet-stream"))
|
|
2003
|
+
)
|
|
2004
|
+
|
|
2005
|
+
# Add optional files (all in the same directory now)
|
|
2006
|
+
optional_files = {
|
|
2007
|
+
"env.py": dir / "env.py",
|
|
2008
|
+
"requirements.txt": dir / "requirements.txt",
|
|
2009
|
+
}
|
|
2010
|
+
|
|
2011
|
+
for file_name, path in optional_files.items():
|
|
2012
|
+
if path.exists():
|
|
2013
|
+
files.append(
|
|
2014
|
+
(
|
|
2015
|
+
file_name.replace(".", "_"),
|
|
2016
|
+
(file_name, open(path, "rb"), "application/octet-stream"),
|
|
2017
|
+
)
|
|
2018
|
+
)
|
|
2019
|
+
|
|
2020
|
+
# For multipart, we need to omit Content-Type so requests sets the boundary
|
|
2021
|
+
headers = auth.get_auth_headers()
|
|
2022
|
+
headers.pop("Content-Type", None)
|
|
2023
|
+
|
|
2024
|
+
# Include custom run name if provided
|
|
2025
|
+
submit_data = {
|
|
2026
|
+
"project_id": config.project_id,
|
|
2027
|
+
"dataset_id": config.dataset_id,
|
|
2028
|
+
"organization_id": config.organization_id,
|
|
2029
|
+
}
|
|
2030
|
+
if name:
|
|
2031
|
+
submit_data["run_name"] = name
|
|
2032
|
+
|
|
2033
|
+
# Add dataset URL if uploaded to blob
|
|
2034
|
+
if dataset_url:
|
|
2035
|
+
submit_data["dataset_url"] = dataset_url
|
|
2036
|
+
|
|
2037
|
+
# Add debug flag if set
|
|
2038
|
+
if debug:
|
|
2039
|
+
submit_data["debug"] = "true"
|
|
2040
|
+
|
|
2041
|
+
# Start cube spinner if not already running (for small files)
|
|
2042
|
+
if not spinner.running:
|
|
2043
|
+
spinner.start()
|
|
2044
|
+
|
|
2045
|
+
run_url = None
|
|
2046
|
+
run_id = None
|
|
2047
|
+
error_msg = None
|
|
2048
|
+
resolved_base_model = None
|
|
2049
|
+
resolved_finetuned_model = None
|
|
2050
|
+
|
|
2051
|
+
try:
|
|
2052
|
+
response = session.post(
|
|
2053
|
+
f"{base_url}/training/submit",
|
|
2054
|
+
data=submit_data,
|
|
2055
|
+
files=files,
|
|
2056
|
+
headers=headers,
|
|
2057
|
+
stream=True,
|
|
2058
|
+
)
|
|
2059
|
+
|
|
2060
|
+
if response.status_code != 200:
|
|
2061
|
+
error_msg = f"Training submission failed: {response.text}"
|
|
2062
|
+
else:
|
|
2063
|
+
response.encoding = "utf-8"
|
|
2064
|
+
|
|
2065
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
2066
|
+
if line and line.startswith("data: "):
|
|
2067
|
+
msg = line[6:]
|
|
2068
|
+
|
|
2069
|
+
if "View:" in msg:
|
|
2070
|
+
run_url = msg.split("View:")[-1].strip()
|
|
2071
|
+
# Extract run_id from URL (last path segment)
|
|
2072
|
+
if run_url:
|
|
2073
|
+
run_id = run_url.rstrip("/").split("/")[-1]
|
|
2074
|
+
elif "http" in msg and "View" not in msg:
|
|
2075
|
+
run_url = msg.split()[-1].strip()
|
|
2076
|
+
if run_url:
|
|
2077
|
+
run_id = run_url.rstrip("/").split("/")[-1]
|
|
2078
|
+
elif msg.startswith("❌") or "Error" in msg or "failed" in msg.lower():
|
|
2079
|
+
error_msg = msg
|
|
2080
|
+
# Capture resolved model info from server
|
|
2081
|
+
elif "Resuming from finetuned model:" in msg:
|
|
2082
|
+
resolved_finetuned_model = msg.split("Resuming from finetuned model:")[
|
|
2083
|
+
-1
|
|
2084
|
+
].strip()
|
|
2085
|
+
elif "Base model:" in msg:
|
|
2086
|
+
resolved_base_model = msg.split("Base model:")[-1].strip()
|
|
2087
|
+
|
|
2088
|
+
except Exception as e:
|
|
2089
|
+
error_msg = f"Request failed: {e}"
|
|
2090
|
+
finally:
|
|
2091
|
+
for _, (_, fh, _) in files:
|
|
2092
|
+
fh.close()
|
|
2093
|
+
|
|
2094
|
+
# Show result
|
|
2095
|
+
if error_msg:
|
|
2096
|
+
spinner.stop() # Clear cube on error
|
|
2097
|
+
click.echo(click.style(f"✗ {error_msg}", fg="red", bold=True))
|
|
2098
|
+
raise click.ClickException("Training submission failed")
|
|
2099
|
+
|
|
2100
|
+
# Stop spinner but keep cube visible on success
|
|
2101
|
+
spinner.stop(keep_visible=True)
|
|
2102
|
+
|
|
2103
|
+
# Get display values
|
|
2104
|
+
model_path = config.model.path if config.model else "Qwen/Qwen3-8B"
|
|
2105
|
+
thinking_mode = get_thinking_mode_display(config)
|
|
2106
|
+
|
|
2107
|
+
# Build model display string
|
|
2108
|
+
thinking_styled = click.style(thinking_mode, fg=TEAL_RGB)
|
|
2109
|
+
if resolved_finetuned_model and resolved_base_model:
|
|
2110
|
+
# Resuming from a finetuned model - show both
|
|
2111
|
+
model_display = f"{resolved_finetuned_model} ({thinking_styled})"
|
|
2112
|
+
base_model_display = f" Base: {resolved_base_model}"
|
|
2113
|
+
else:
|
|
2114
|
+
# Fresh training from base model
|
|
2115
|
+
model_display = f"{model_path} ({thinking_styled})"
|
|
2116
|
+
base_model_display = None
|
|
2117
|
+
|
|
2118
|
+
# Output completion messages below the cube
|
|
2119
|
+
click.echo(f"Run started successfully {click.style('✅', fg=TEAL_RGB)}")
|
|
2120
|
+
click.echo(f" Project: {config.project_name}")
|
|
2121
|
+
click.echo(f" Model: {model_display}")
|
|
2122
|
+
if base_model_display:
|
|
2123
|
+
click.echo(base_model_display)
|
|
2124
|
+
if run_id:
|
|
2125
|
+
click.echo(f" Run ID: {run_id}")
|
|
2126
|
+
if run_url:
|
|
2127
|
+
click.echo("\nView your experiment here:")
|
|
2128
|
+
click.echo(click.style(run_url, fg=TEAL_RGB))
|
|
2129
|
+
|
|
2130
|
+
|
|
2131
|
+
@click.command()
|
|
2132
|
+
@click.argument("model_id", required=True)
|
|
2133
|
+
@click.option(
|
|
2134
|
+
"--output",
|
|
2135
|
+
"-o",
|
|
2136
|
+
type=click.Path(path_type=Path),
|
|
2137
|
+
default=None,
|
|
2138
|
+
help="Output directory for extracted checkpoint (default: ./<model_name>/)",
|
|
2139
|
+
)
|
|
2140
|
+
@click.option(
|
|
2141
|
+
"--keep-archive", is_flag=True, default=False, help="Keep the tar archive after extraction"
|
|
2142
|
+
)
|
|
2143
|
+
@click.pass_context
|
|
2144
|
+
def download(ctx, model_id: str, output: Path, keep_archive: bool):
|
|
2145
|
+
"""Download a trained model by ID.
|
|
2146
|
+
|
|
2147
|
+
Downloads and extracts the model checkpoint (LoRA adapter weights).
|
|
2148
|
+
The checkpoint is downloaded as a tar archive and automatically extracted.
|
|
2149
|
+
|
|
2150
|
+
Examples:
|
|
2151
|
+
|
|
2152
|
+
# Download and extract to ./My_Model/
|
|
2153
|
+
rnow download abc123
|
|
2154
|
+
|
|
2155
|
+
# Download and extract to ./models/
|
|
2156
|
+
rnow download abc123 --output ./models/
|
|
2157
|
+
|
|
2158
|
+
# Keep the tar archive after extraction
|
|
2159
|
+
rnow download abc123 --keep-archive
|
|
2160
|
+
"""
|
|
2161
|
+
import shutil
|
|
2162
|
+
import tarfile
|
|
2163
|
+
import tempfile
|
|
2164
|
+
import urllib.request
|
|
2165
|
+
|
|
2166
|
+
base_url = ctx.obj.get("api_url", "https://www.reinforcenow.ai/api")
|
|
2167
|
+
|
|
2168
|
+
# Get download URL from API
|
|
2169
|
+
click.echo(f"Fetching download URL for model: {model_id}...")
|
|
2170
|
+
|
|
2171
|
+
try:
|
|
2172
|
+
response = api_request("get", f"/models/{model_id}/download", base_url)
|
|
2173
|
+
|
|
2174
|
+
if response.status_code == 404:
|
|
2175
|
+
raise click.ClickException("Model not found or no file available for download")
|
|
2176
|
+
elif response.status_code == 403:
|
|
2177
|
+
raise click.ClickException(
|
|
2178
|
+
"Access denied. You don't have permission to download this model."
|
|
2179
|
+
)
|
|
2180
|
+
elif response.status_code == 400:
|
|
2181
|
+
data = response.json()
|
|
2182
|
+
raise click.ClickException(data.get("message", "Cannot download this model"))
|
|
2183
|
+
|
|
2184
|
+
response.raise_for_status()
|
|
2185
|
+
data = response.json()
|
|
2186
|
+
|
|
2187
|
+
except requests.RequestException as e:
|
|
2188
|
+
raise click.ClickException(f"Failed to get download URL: {e}")
|
|
2189
|
+
|
|
2190
|
+
download_url = data.get("downloadUrl")
|
|
2191
|
+
model_name = data.get("modelName", model_id)
|
|
2192
|
+
model_size = int(data.get("modelSize", 0))
|
|
2193
|
+
|
|
2194
|
+
if not download_url:
|
|
2195
|
+
raise click.ClickException("No download URL returned from server")
|
|
2196
|
+
|
|
2197
|
+
# Determine output directory
|
|
2198
|
+
if output is None:
|
|
2199
|
+
# Clean model name for directory
|
|
2200
|
+
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in model_name)
|
|
2201
|
+
output = Path(safe_name)
|
|
2202
|
+
|
|
2203
|
+
# Create output directory
|
|
2204
|
+
output.mkdir(parents=True, exist_ok=True)
|
|
2205
|
+
|
|
2206
|
+
# Download the file
|
|
2207
|
+
click.echo(f"Downloading {model_name}...")
|
|
2208
|
+
|
|
2209
|
+
if model_size > 0:
|
|
2210
|
+
size_mb = model_size / (1024 * 1024)
|
|
2211
|
+
click.echo(f" Size: {size_mb:.1f} MB")
|
|
2212
|
+
|
|
2213
|
+
# Use a temp file for the archive
|
|
2214
|
+
with tempfile.NamedTemporaryFile(suffix=".tar", delete=False) as tmp_file:
|
|
2215
|
+
archive_path = Path(tmp_file.name)
|
|
2216
|
+
|
|
2217
|
+
try:
|
|
2218
|
+
# Download with progress bar
|
|
2219
|
+
with urllib.request.urlopen(download_url, timeout=30) as response:
|
|
2220
|
+
total_size = int(response.headers.get("Content-Length", model_size))
|
|
2221
|
+
|
|
2222
|
+
with (
|
|
2223
|
+
click.progressbar(
|
|
2224
|
+
length=total_size if total_size > 0 else None,
|
|
2225
|
+
label="Downloading",
|
|
2226
|
+
show_pos=total_size > 0,
|
|
2227
|
+
show_percent=True,
|
|
2228
|
+
fill_char=click.style("#", fg=(20, 184, 166)),
|
|
2229
|
+
) as bar,
|
|
2230
|
+
open(archive_path, "wb") as f,
|
|
2231
|
+
):
|
|
2232
|
+
while True:
|
|
2233
|
+
chunk = response.read(8192)
|
|
2234
|
+
if not chunk:
|
|
2235
|
+
break
|
|
2236
|
+
f.write(chunk)
|
|
2237
|
+
bar.update(len(chunk))
|
|
2238
|
+
|
|
2239
|
+
# Extract the archive
|
|
2240
|
+
click.echo("Extracting checkpoint...")
|
|
2241
|
+
try:
|
|
2242
|
+
with tarfile.open(archive_path, "r") as tar:
|
|
2243
|
+
members = tar.getmembers()
|
|
2244
|
+
with click.progressbar(
|
|
2245
|
+
members,
|
|
2246
|
+
label="Extracting",
|
|
2247
|
+
show_percent=True,
|
|
2248
|
+
show_pos=True,
|
|
2249
|
+
fill_char=click.style("#", fg=(20, 184, 166)),
|
|
2250
|
+
) as bar:
|
|
2251
|
+
for member in bar:
|
|
2252
|
+
tar.extract(member, path=output)
|
|
2253
|
+
except tarfile.TarError as e:
|
|
2254
|
+
raise click.ClickException(f"Failed to extract archive: {e}")
|
|
2255
|
+
|
|
2256
|
+
# Optionally keep the archive
|
|
2257
|
+
if keep_archive:
|
|
2258
|
+
final_archive = output / f"{output.name}.tar"
|
|
2259
|
+
shutil.move(str(archive_path), str(final_archive))
|
|
2260
|
+
click.echo(f" Archive saved to: {final_archive}")
|
|
2261
|
+
else:
|
|
2262
|
+
archive_path.unlink(missing_ok=True)
|
|
2263
|
+
|
|
2264
|
+
except urllib.request.URLError as e:
|
|
2265
|
+
archive_path.unlink(missing_ok=True)
|
|
2266
|
+
raise click.ClickException(f"Download failed: {e}")
|
|
2267
|
+
except Exception as e:
|
|
2268
|
+
archive_path.unlink(missing_ok=True)
|
|
2269
|
+
raise click.ClickException(f"Download failed: {e}")
|
|
2270
|
+
|
|
2271
|
+
click.echo(
|
|
2272
|
+
click.style(
|
|
2273
|
+
f"\n✓ Model downloaded and extracted to: {output}/", fg=(20, 184, 166), bold=True
|
|
2274
|
+
)
|
|
2275
|
+
)
|
|
2276
|
+
|
|
2277
|
+
# List extracted files
|
|
2278
|
+
files = list(output.iterdir())
|
|
2279
|
+
if files:
|
|
2280
|
+
click.echo("\nExtracted files:")
|
|
2281
|
+
for f in files[:10]: # Show first 10 files
|
|
2282
|
+
click.echo(f" • {f.name}")
|
|
2283
|
+
if len(files) > 10:
|
|
2284
|
+
click.echo(f" ... and {len(files) - 10} more files")
|
|
2285
|
+
|
|
2286
|
+
|
|
2287
|
+
@click.command()
|
|
2288
|
+
@click.argument("run_id", required=True)
|
|
2289
|
+
@click.confirmation_option(prompt="Are you sure you want to stop this training run?")
|
|
2290
|
+
@click.pass_context
|
|
2291
|
+
def stop(ctx, run_id: str):
|
|
2292
|
+
"""Stop an active training run.
|
|
2293
|
+
|
|
2294
|
+
Requires the RUN_ID obtained from 'rnow run' command.
|
|
2295
|
+
"""
|
|
2296
|
+
base_url = ctx.obj.get("api_url", "https://www.reinforcenow.ai/api")
|
|
2297
|
+
|
|
2298
|
+
try:
|
|
2299
|
+
click.echo(f"Stopping training run: {run_id}...")
|
|
2300
|
+
response = api_request("post", "/training/stop", base_url, json={"run_id": run_id})
|
|
2301
|
+
response.raise_for_status()
|
|
2302
|
+
data = response.json()
|
|
2303
|
+
except requests.RequestException as e:
|
|
2304
|
+
raise click.ClickException(f"Failed to stop training: {e}")
|
|
2305
|
+
|
|
2306
|
+
click.echo(click.style(f"✓ Training run stopped: {run_id}", fg="green"))
|
|
2307
|
+
|
|
2308
|
+
if data.get("duration_minutes"):
|
|
2309
|
+
click.echo(f" Duration: {data['duration_minutes']:.1f} minutes")
|
|
2310
|
+
if data.get("charged_amount"):
|
|
2311
|
+
click.echo(f" Charged: ${data['charged_amount']:.2f}")
|