rnow 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnow/__init__.py +5 -0
- rnow/__main__.py +7 -0
- rnow/cli/__init__.py +6 -0
- rnow/cli/auth.py +67 -0
- rnow/cli/blob.py +98 -0
- rnow/cli/commands.py +2311 -0
- rnow/cli/common.py +28 -0
- rnow/cli/cube.py +255 -0
- rnow/cli/main.py +49 -0
- rnow/cli/test.py +728 -0
- rnow/cli/token_count.py +295 -0
- rnow/core/__init__.py +33 -0
- rnow/core/reward.py +333 -0
- rnow/core/tool.py +494 -0
- rnow/models.py +295 -0
- rnow/templates/deepseek-aha/config.yml +26 -0
- rnow/templates/deepseek-aha/rewards.py +36 -0
- rnow/templates/deepseek-aha/train.jsonl +1000 -0
- rnow/templates/mcp-tavily/config.yml +29 -0
- rnow/templates/mcp-tavily/requirements.txt +1 -0
- rnow/templates/mcp-tavily/rewards.py +25 -0
- rnow/templates/mcp-tavily/train.jsonl +500 -0
- rnow/templates/new/config.yml +26 -0
- rnow/templates/new/requirements.txt +1 -0
- rnow/templates/new/rewards.py +0 -0
- rnow/templates/new/train.jsonl +0 -0
- rnow/templates/rl-nextjs/config.yml +27 -0
- rnow/templates/rl-nextjs/requirements.txt +2 -0
- rnow/templates/rl-nextjs/rewards.py +446 -0
- rnow/templates/rl-nextjs/train.jsonl +1000 -0
- rnow/templates/rl-single/config.yml +27 -0
- rnow/templates/rl-single/requirements.txt +1 -0
- rnow/templates/rl-single/rewards.py +14 -0
- rnow/templates/rl-single/train.jsonl +1000 -0
- rnow/templates/rl-tools/config.yml +27 -0
- rnow/templates/rl-tools/env.py +38 -0
- rnow/templates/rl-tools/requirements.txt +3 -0
- rnow/templates/rl-tools/rewards.py +25 -0
- rnow/templates/rl-tools/train.jsonl +500 -0
- rnow/templates/sft/config.yml +20 -0
- rnow/templates/sft/train.jsonl +100 -0
- rnow/templates/tutorial-reward/config.yml +27 -0
- rnow/templates/tutorial-reward/requirements.txt +1 -0
- rnow/templates/tutorial-reward/rewards.py +15 -0
- rnow/templates/tutorial-reward/train.jsonl +1000 -0
- rnow/templates/tutorial-tool/config.yml +27 -0
- rnow/templates/tutorial-tool/env.py +7 -0
- rnow/templates/tutorial-tool/requirements.txt +3 -0
- rnow/templates/tutorial-tool/rewards.py +7 -0
- rnow/templates/tutorial-tool/train.jsonl +1266 -0
- rnow-0.2.4.dist-info/METADATA +135 -0
- rnow-0.2.4.dist-info/RECORD +56 -0
- rnow-0.2.4.dist-info/WHEEL +5 -0
- rnow-0.2.4.dist-info/entry_points.txt +2 -0
- rnow-0.2.4.dist-info/licenses/LICENSE +21 -0
- rnow-0.2.4.dist-info/top_level.txt +1 -0
rnow/cli/test.py
ADDED
|
@@ -0,0 +1,728 @@
|
|
|
1
|
+
# rnow/cli/test.py
|
|
2
|
+
"""
|
|
3
|
+
Test command for running RL rollouts locally.
|
|
4
|
+
|
|
5
|
+
Requires authentication for billing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import itertools
|
|
10
|
+
import json
|
|
11
|
+
import random
|
|
12
|
+
import re
|
|
13
|
+
import signal
|
|
14
|
+
import sys
|
|
15
|
+
import threading
|
|
16
|
+
import time
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from string import Template
|
|
20
|
+
|
|
21
|
+
import click
|
|
22
|
+
import httpx
|
|
23
|
+
import yaml
|
|
24
|
+
|
|
25
|
+
# Global flag for graceful shutdown
|
|
26
|
+
_shutdown_requested = False
|
|
27
|
+
|
|
28
|
+
from rnow.cli.auth import get_auth_headers
|
|
29
|
+
from rnow.cli.commands import get_thinking_mode_display
|
|
30
|
+
|
|
31
|
+
# ReinforceNow teal: #14B8A6 as RGB tuple for click.style()
|
|
32
|
+
TEAL_RGB = (20, 184, 166)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Spinner:
|
|
36
|
+
"""Simple spinner for CLI feedback with dynamic status updates."""
|
|
37
|
+
|
|
38
|
+
FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
39
|
+
|
|
40
|
+
def __init__(self, message: str = ""):
|
|
41
|
+
self.message = message
|
|
42
|
+
self._stop_event = threading.Event()
|
|
43
|
+
self._thread: threading.Thread | None = None
|
|
44
|
+
self._lock = threading.Lock()
|
|
45
|
+
|
|
46
|
+
def update(self, message: str):
|
|
47
|
+
"""Update the spinner message."""
|
|
48
|
+
with self._lock:
|
|
49
|
+
self.message = message
|
|
50
|
+
|
|
51
|
+
def _spin(self):
|
|
52
|
+
for frame in itertools.cycle(self.FRAMES):
|
|
53
|
+
if self._stop_event.is_set() or _shutdown_requested:
|
|
54
|
+
break
|
|
55
|
+
with self._lock:
|
|
56
|
+
msg = self.message
|
|
57
|
+
# Clear line and write new status
|
|
58
|
+
sys.stdout.write(f"\r\033[K{frame} {msg}")
|
|
59
|
+
sys.stdout.flush()
|
|
60
|
+
time.sleep(0.08)
|
|
61
|
+
# Clear the spinner line when done
|
|
62
|
+
sys.stdout.write("\r\033[K")
|
|
63
|
+
sys.stdout.flush()
|
|
64
|
+
|
|
65
|
+
def start(self):
|
|
66
|
+
self._stop_event.clear()
|
|
67
|
+
self._thread = threading.Thread(target=self._spin, daemon=True)
|
|
68
|
+
self._thread.start()
|
|
69
|
+
|
|
70
|
+
def stop(self):
|
|
71
|
+
self._stop_event.set()
|
|
72
|
+
if self._thread:
|
|
73
|
+
self._thread.join(timeout=0.5) # Don't wait forever
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
from rnow.cli.common import require_auth
|
|
77
|
+
from rnow.core.reward import REWARD_REGISTRY, clear_reward_registry, compute_total_reward
|
|
78
|
+
from rnow.core.tool import TOOL_REGISTRY, clear_tool_registry
|
|
79
|
+
from rnow.models import ProjectConfig, RewardArgs
|
|
80
|
+
|
|
81
|
+
DEFAULT_API_URL = "https://www.reinforcenow.ai"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ModelCompleter:
|
|
85
|
+
"""
|
|
86
|
+
Completer that handles tokenization and calls Next.js API.
|
|
87
|
+
Requires authentication for billing.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, api_base: str, model: str, max_tokens: int = 2048, temperature: float = 1.0):
|
|
91
|
+
self.api_base = api_base.rstrip("/")
|
|
92
|
+
self.model = model
|
|
93
|
+
self.max_tokens = max_tokens
|
|
94
|
+
self.temperature = temperature
|
|
95
|
+
self.auth_headers = get_auth_headers()
|
|
96
|
+
self.client = httpx.AsyncClient(timeout=120.0)
|
|
97
|
+
self.session_id: str | None = None # Cached session ID for reuse
|
|
98
|
+
self.total_latency_ms = 0
|
|
99
|
+
self.request_count = 0
|
|
100
|
+
|
|
101
|
+
# Initialize tokenizer and renderer
|
|
102
|
+
from tinker_cookbook import renderers
|
|
103
|
+
from tinker_cookbook.model_info import get_recommended_renderer_name
|
|
104
|
+
from tinker_cookbook.tokenizer_utils import get_tokenizer
|
|
105
|
+
|
|
106
|
+
self.tokenizer = get_tokenizer(model)
|
|
107
|
+
renderer_name = get_recommended_renderer_name(model)
|
|
108
|
+
self.renderer = renderers.get_renderer(renderer_name, self.tokenizer)
|
|
109
|
+
|
|
110
|
+
async def __call__(self, messages: list[dict], stop: list[str] | None = None) -> dict:
|
|
111
|
+
"""
|
|
112
|
+
Tokenize messages, call Next.js API, decode response.
|
|
113
|
+
"""
|
|
114
|
+
# Build model input using renderer
|
|
115
|
+
model_input = self.renderer.build_generation_prompt(messages)
|
|
116
|
+
tokens = model_input.to_ints()
|
|
117
|
+
|
|
118
|
+
# Get stop sequences from renderer if not provided
|
|
119
|
+
if stop is None:
|
|
120
|
+
stop = self.renderer.get_stop_sequences()
|
|
121
|
+
|
|
122
|
+
# Build request payload
|
|
123
|
+
payload = {
|
|
124
|
+
"model": self.model,
|
|
125
|
+
"tokens": tokens,
|
|
126
|
+
"stop": stop,
|
|
127
|
+
"max_tokens": self.max_tokens,
|
|
128
|
+
"temperature": self.temperature,
|
|
129
|
+
}
|
|
130
|
+
# Include session_id if we have one cached
|
|
131
|
+
if self.session_id:
|
|
132
|
+
payload["session_id"] = self.session_id
|
|
133
|
+
|
|
134
|
+
# Call Next.js API with tokens
|
|
135
|
+
resp = await self.client.post(
|
|
136
|
+
f"{self.api_base}/api/rnow/sample",
|
|
137
|
+
json=payload,
|
|
138
|
+
headers=self.auth_headers,
|
|
139
|
+
)
|
|
140
|
+
resp.raise_for_status()
|
|
141
|
+
data = resp.json()
|
|
142
|
+
|
|
143
|
+
if "error" in data:
|
|
144
|
+
raise Exception(f"API error: {data.get('detail', data.get('error'))}")
|
|
145
|
+
|
|
146
|
+
# Cache the session_id for future requests
|
|
147
|
+
if "session_id" in data and data["session_id"]:
|
|
148
|
+
self.session_id = data["session_id"]
|
|
149
|
+
|
|
150
|
+
# Track latency
|
|
151
|
+
if "latency_ms" in data:
|
|
152
|
+
self.total_latency_ms += data["latency_ms"]
|
|
153
|
+
self.request_count += 1
|
|
154
|
+
|
|
155
|
+
# Decode tokens back to text
|
|
156
|
+
output_tokens = data.get("tokens", [])
|
|
157
|
+
parsed_message, _success = self.renderer.parse_response(output_tokens)
|
|
158
|
+
|
|
159
|
+
return {
|
|
160
|
+
"content": parsed_message.get("content", ""),
|
|
161
|
+
"latency_ms": data.get("latency_ms", 0),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
async def close(self):
|
|
165
|
+
await self.client.aclose()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
async def flush_pending_charges(api_url: str) -> dict | None:
|
|
169
|
+
"""
|
|
170
|
+
Flush any pending ROLLOUT charges at the end of the test.
|
|
171
|
+
Returns the flush result or None if it failed.
|
|
172
|
+
"""
|
|
173
|
+
try:
|
|
174
|
+
headers = get_auth_headers()
|
|
175
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
176
|
+
resp = await client.post(
|
|
177
|
+
f"{api_url.rstrip('/')}/api/billing/flush-rollout",
|
|
178
|
+
headers=headers,
|
|
179
|
+
)
|
|
180
|
+
resp.raise_for_status()
|
|
181
|
+
return resp.json()
|
|
182
|
+
except Exception as e:
|
|
183
|
+
click.echo(click.style(f"Warning: Failed to flush pending charges: {e}", fg="yellow"))
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _exec_file(path: Path, module_name: str) -> None:
|
|
188
|
+
"""Execute a Python file to populate registries."""
|
|
189
|
+
import importlib.util
|
|
190
|
+
|
|
191
|
+
spec = importlib.util.spec_from_file_location(module_name, path)
|
|
192
|
+
if spec is None or spec.loader is None:
|
|
193
|
+
raise ImportError(f"Could not load module from {path}")
|
|
194
|
+
module = importlib.util.module_from_spec(spec)
|
|
195
|
+
spec.loader.exec_module(module)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _build_tools_block(tool_registry: dict[str, Callable]) -> str:
|
|
199
|
+
"""Build the tools description block from registered tool functions."""
|
|
200
|
+
if not tool_registry:
|
|
201
|
+
return ""
|
|
202
|
+
|
|
203
|
+
tools_json = []
|
|
204
|
+
for name, fn in tool_registry.items():
|
|
205
|
+
schema = getattr(fn, "_schema", {"type": "object", "properties": {}})
|
|
206
|
+
description = getattr(fn, "_description", "No description available.")
|
|
207
|
+
tools_json.append(
|
|
208
|
+
{
|
|
209
|
+
"name": name,
|
|
210
|
+
"description": description,
|
|
211
|
+
"parameters": schema,
|
|
212
|
+
}
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
tools_block = f"""# Tools
|
|
216
|
+
|
|
217
|
+
You may call one or more functions to assist with the user query.
|
|
218
|
+
|
|
219
|
+
You are provided with function signatures within <tools></tools> XML tags:
|
|
220
|
+
<tools>
|
|
221
|
+
{json.dumps(tools_json, indent=2)}
|
|
222
|
+
</tools>
|
|
223
|
+
|
|
224
|
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
225
|
+
<tool_call>
|
|
226
|
+
{{"name": "<function-name>", "arguments": {{"<arg-name>": "<value>"}}}}
|
|
227
|
+
</tool_call>
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
return tools_block
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
TOOL_CALL_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _format_message(msg: dict, max_len: int = 300) -> str:
|
|
237
|
+
"""Format a message for display."""
|
|
238
|
+
role = msg.get("role", "unknown")
|
|
239
|
+
content = msg.get("content", "")
|
|
240
|
+
# Truncate long content
|
|
241
|
+
if len(content) > max_len:
|
|
242
|
+
content = content[:max_len] + "..."
|
|
243
|
+
# Color based on role
|
|
244
|
+
colors = {"system": "yellow", "user": "blue", "assistant": "green", "tool": "magenta"}
|
|
245
|
+
color = colors.get(role, "white")
|
|
246
|
+
return click.style(f"[{role}]", fg=color) + f" {content}"
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
async def _run_single_rollout(
|
|
250
|
+
completer: ModelCompleter,
|
|
251
|
+
sample: dict,
|
|
252
|
+
reward_registry: dict[str, Callable],
|
|
253
|
+
tool_registry: dict[str, Callable],
|
|
254
|
+
max_turns: int,
|
|
255
|
+
termination_policy: str,
|
|
256
|
+
verbose: bool = False,
|
|
257
|
+
) -> dict:
|
|
258
|
+
"""Run a single rollout for an RL sample."""
|
|
259
|
+
import inspect
|
|
260
|
+
|
|
261
|
+
messages_templates = sample["messages"]
|
|
262
|
+
reward_names = sample["rewards"]
|
|
263
|
+
variables = sample.get("variables", {})
|
|
264
|
+
metadata = sample.get("metadata", {})
|
|
265
|
+
|
|
266
|
+
reward_fns = []
|
|
267
|
+
for name in reward_names:
|
|
268
|
+
if name not in reward_registry:
|
|
269
|
+
raise ValueError(f"Reward function '{name}' not found in registry")
|
|
270
|
+
reward_fns.append(reward_registry[name])
|
|
271
|
+
|
|
272
|
+
ctx = {**metadata, **variables}
|
|
273
|
+
messages = [
|
|
274
|
+
{"role": msg["role"], "content": Template(msg["content"]).safe_substitute(ctx)}
|
|
275
|
+
for msg in messages_templates
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
if tool_registry:
|
|
279
|
+
tools_block = _build_tools_block(tool_registry)
|
|
280
|
+
system_found = False
|
|
281
|
+
for msg in messages:
|
|
282
|
+
if msg["role"] == "system":
|
|
283
|
+
msg["content"] = tools_block + "\n\n" + msg["content"]
|
|
284
|
+
system_found = True
|
|
285
|
+
break
|
|
286
|
+
if not system_found:
|
|
287
|
+
messages.insert(0, {"role": "system", "content": tools_block})
|
|
288
|
+
|
|
289
|
+
conversation = messages.copy()
|
|
290
|
+
turn_count = 0
|
|
291
|
+
total_tool_calls = 0
|
|
292
|
+
|
|
293
|
+
# Show initial messages in verbose mode
|
|
294
|
+
if verbose:
|
|
295
|
+
click.echo(" --- Initial Messages ---")
|
|
296
|
+
for msg in messages:
|
|
297
|
+
click.echo(f" {_format_message(msg)}")
|
|
298
|
+
click.echo(" -------------------------")
|
|
299
|
+
|
|
300
|
+
while turn_count < max_turns:
|
|
301
|
+
turn_count += 1
|
|
302
|
+
|
|
303
|
+
result = await completer(conversation, stop=None)
|
|
304
|
+
response_content = result.get("content", "")
|
|
305
|
+
|
|
306
|
+
conversation.append({"role": "assistant", "content": response_content})
|
|
307
|
+
|
|
308
|
+
if verbose:
|
|
309
|
+
click.echo(
|
|
310
|
+
f" [Turn {turn_count}] {_format_message({'role': 'assistant', 'content': response_content}, max_len=500)}"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
tool_matches = TOOL_CALL_RE.findall(response_content)
|
|
314
|
+
tool_call_count = len(tool_matches)
|
|
315
|
+
total_tool_calls += tool_call_count
|
|
316
|
+
|
|
317
|
+
for raw_call in tool_matches:
|
|
318
|
+
if not tool_registry:
|
|
319
|
+
break
|
|
320
|
+
try:
|
|
321
|
+
tool_data = json.loads(raw_call)
|
|
322
|
+
tool_name = tool_data.get("name")
|
|
323
|
+
args = tool_data.get("arguments", {})
|
|
324
|
+
|
|
325
|
+
if tool_name not in tool_registry:
|
|
326
|
+
tool_response = f"<tool_error>Tool '{tool_name}' not found</tool_error>"
|
|
327
|
+
conversation.append({"role": "tool", "content": tool_response})
|
|
328
|
+
if verbose:
|
|
329
|
+
click.echo(
|
|
330
|
+
f" {_format_message({'role': 'tool', 'content': tool_response})}"
|
|
331
|
+
)
|
|
332
|
+
continue
|
|
333
|
+
|
|
334
|
+
tool_fn = tool_registry[tool_name]
|
|
335
|
+
tool_result = (
|
|
336
|
+
await tool_fn(**args)
|
|
337
|
+
if inspect.iscoroutinefunction(tool_fn)
|
|
338
|
+
else tool_fn(**args)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
tool_response = f"<tool_result>{json.dumps(tool_result)}</tool_result>"
|
|
342
|
+
conversation.append({"role": "tool", "content": tool_response})
|
|
343
|
+
|
|
344
|
+
if verbose:
|
|
345
|
+
click.echo(
|
|
346
|
+
f" Tool {click.style(tool_name, fg=TEAL_RGB)}: {str(tool_result)[:200]}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
except json.JSONDecodeError as e:
|
|
350
|
+
tool_response = f"<tool_error>Invalid JSON: {str(e)}</tool_error>"
|
|
351
|
+
conversation.append({"role": "tool", "content": tool_response})
|
|
352
|
+
if verbose:
|
|
353
|
+
click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
|
|
354
|
+
except Exception as e:
|
|
355
|
+
tool_response = f"<tool_error>{str(e)}</tool_error>"
|
|
356
|
+
conversation.append({"role": "tool", "content": tool_response})
|
|
357
|
+
if verbose:
|
|
358
|
+
click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
|
|
359
|
+
|
|
360
|
+
if termination_policy == "last_tool" and tool_call_count == 0:
|
|
361
|
+
break
|
|
362
|
+
|
|
363
|
+
# Show final conversation summary in verbose mode
|
|
364
|
+
if verbose:
|
|
365
|
+
click.echo(f" --- Rollout Complete: {turn_count} turns, {total_tool_calls} tool calls ---")
|
|
366
|
+
|
|
367
|
+
reward_args = RewardArgs(metadata=metadata, variables=variables)
|
|
368
|
+
rewards = {}
|
|
369
|
+
for fn, name in zip(reward_fns, reward_names, strict=False):
|
|
370
|
+
value = await fn(reward_args, conversation)
|
|
371
|
+
rewards[name] = value
|
|
372
|
+
|
|
373
|
+
total_reward = compute_total_reward(rewards) if rewards else 0.0
|
|
374
|
+
|
|
375
|
+
return {
|
|
376
|
+
"total_reward": total_reward,
|
|
377
|
+
"rewards": rewards,
|
|
378
|
+
"turns": turn_count,
|
|
379
|
+
"tools_used": total_tool_calls,
|
|
380
|
+
"conversation": conversation,
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _check_test_dependencies():
|
|
385
|
+
"""Check if optional test dependencies are installed."""
|
|
386
|
+
try:
|
|
387
|
+
import tinker_cookbook # noqa: F401
|
|
388
|
+
except ImportError:
|
|
389
|
+
click.echo()
|
|
390
|
+
click.echo(
|
|
391
|
+
click.style("Error: ", fg="red", bold=True)
|
|
392
|
+
+ "The 'rnow test' command requires additional dependencies."
|
|
393
|
+
)
|
|
394
|
+
pip_cmd = "pip install 'rnow[test]'"
|
|
395
|
+
click.echo(f"Install them with: {click.style(pip_cmd, fg=TEAL_RGB)}")
|
|
396
|
+
click.echo()
|
|
397
|
+
raise SystemExit(1)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@click.command(name="test")
|
|
401
|
+
@click.option(
|
|
402
|
+
"--dir",
|
|
403
|
+
"-d",
|
|
404
|
+
"project_dir",
|
|
405
|
+
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
|
|
406
|
+
default=".",
|
|
407
|
+
help="Project directory containing config.yml, rewards.py, env.py, train.jsonl",
|
|
408
|
+
)
|
|
409
|
+
@click.option(
|
|
410
|
+
"--num-rollouts",
|
|
411
|
+
"-n",
|
|
412
|
+
default=3,
|
|
413
|
+
show_default=True,
|
|
414
|
+
help="Number of rollouts to run",
|
|
415
|
+
)
|
|
416
|
+
@click.option(
|
|
417
|
+
"--multi-turn/--single-turn",
|
|
418
|
+
default=True,
|
|
419
|
+
show_default=True,
|
|
420
|
+
help="Allow multi-turn rollouts or force single-turn",
|
|
421
|
+
)
|
|
422
|
+
@click.option(
|
|
423
|
+
"--with-tools/--no-tools",
|
|
424
|
+
default=True,
|
|
425
|
+
show_default=True,
|
|
426
|
+
help="Enable or disable tool use during rollout",
|
|
427
|
+
)
|
|
428
|
+
@click.option(
|
|
429
|
+
"--model",
|
|
430
|
+
default=None,
|
|
431
|
+
help="Override model name for sampling (otherwise uses config.model.path)",
|
|
432
|
+
)
|
|
433
|
+
@click.option(
|
|
434
|
+
"--api-url",
|
|
435
|
+
envvar="RNOW_API_URL",
|
|
436
|
+
default=None,
|
|
437
|
+
help="Base URL of the Next.js backend (default: https://www.reinforcenow.ai)",
|
|
438
|
+
)
|
|
439
|
+
@click.option(
|
|
440
|
+
"--verbose",
|
|
441
|
+
"-v",
|
|
442
|
+
is_flag=True,
|
|
443
|
+
help="Show detailed output for each rollout turn",
|
|
444
|
+
)
|
|
445
|
+
@click.option(
|
|
446
|
+
"--truncate",
|
|
447
|
+
"-t",
|
|
448
|
+
default=None,
|
|
449
|
+
type=int,
|
|
450
|
+
help="Truncate message content to N characters (default: no truncation)",
|
|
451
|
+
)
|
|
452
|
+
@click.pass_context
|
|
453
|
+
def test(ctx, project_dir, num_rollouts, multi_turn, with_tools, model, api_url, verbose, truncate):
|
|
454
|
+
"""Test RL rollouts locally before submitting.
|
|
455
|
+
|
|
456
|
+
This command runs local RL rollouts by calling the Next.js API
|
|
457
|
+
for model sampling.
|
|
458
|
+
|
|
459
|
+
Only works with RL projects (dataset_type: rl).
|
|
460
|
+
"""
|
|
461
|
+
global _shutdown_requested
|
|
462
|
+
_shutdown_requested = False
|
|
463
|
+
|
|
464
|
+
def handle_sigint(signum, frame):
|
|
465
|
+
global _shutdown_requested
|
|
466
|
+
if _shutdown_requested:
|
|
467
|
+
# Second Ctrl+C, force exit
|
|
468
|
+
sys.exit(1)
|
|
469
|
+
_shutdown_requested = True
|
|
470
|
+
click.echo("\n" + click.style("Interrupted. Shutting down gracefully...", fg="yellow"))
|
|
471
|
+
|
|
472
|
+
# Set up signal handler
|
|
473
|
+
original_handler = signal.signal(signal.SIGINT, handle_sigint)
|
|
474
|
+
|
|
475
|
+
require_auth()
|
|
476
|
+
_check_test_dependencies()
|
|
477
|
+
try:
|
|
478
|
+
asyncio.run(
|
|
479
|
+
_test_async(
|
|
480
|
+
project_dir=project_dir,
|
|
481
|
+
num_rollouts=num_rollouts,
|
|
482
|
+
multi_turn=multi_turn,
|
|
483
|
+
with_tools=with_tools,
|
|
484
|
+
model_override=model,
|
|
485
|
+
api_url=api_url
|
|
486
|
+
or ctx.obj.get("api_url", "").replace("/api", "")
|
|
487
|
+
or DEFAULT_API_URL,
|
|
488
|
+
verbose=verbose,
|
|
489
|
+
truncate=truncate,
|
|
490
|
+
)
|
|
491
|
+
)
|
|
492
|
+
except KeyboardInterrupt:
|
|
493
|
+
click.echo(click.style("Aborted.", fg="yellow"))
|
|
494
|
+
finally:
|
|
495
|
+
# Restore original signal handler
|
|
496
|
+
signal.signal(signal.SIGINT, original_handler)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
async def _test_async(
|
|
500
|
+
project_dir: Path,
|
|
501
|
+
num_rollouts: int,
|
|
502
|
+
multi_turn: bool,
|
|
503
|
+
with_tools: bool,
|
|
504
|
+
model_override: str | None,
|
|
505
|
+
api_url: str,
|
|
506
|
+
verbose: bool,
|
|
507
|
+
truncate: int | None,
|
|
508
|
+
):
|
|
509
|
+
project_dir = Path(project_dir)
|
|
510
|
+
|
|
511
|
+
config_path = project_dir / "config.yml"
|
|
512
|
+
if not config_path.exists():
|
|
513
|
+
config_path = project_dir / "config.json"
|
|
514
|
+
|
|
515
|
+
if not config_path.exists():
|
|
516
|
+
raise click.ClickException("No config.yml or config.json found in project directory")
|
|
517
|
+
|
|
518
|
+
if config_path.suffix == ".yml":
|
|
519
|
+
config_data = yaml.safe_load(config_path.read_text())
|
|
520
|
+
else:
|
|
521
|
+
config_data = json.loads(config_path.read_text())
|
|
522
|
+
|
|
523
|
+
config = ProjectConfig(**config_data)
|
|
524
|
+
|
|
525
|
+
if config.dataset_type.value != "rl":
|
|
526
|
+
raise click.ClickException(
|
|
527
|
+
f"rnow test only supports RL projects (dataset_type: rl). "
|
|
528
|
+
f"Found: {config.dataset_type.value}"
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
rewards_path = project_dir / "rewards.py"
|
|
532
|
+
env_path = project_dir / "env.py"
|
|
533
|
+
train_path = project_dir / "train.jsonl"
|
|
534
|
+
|
|
535
|
+
if not rewards_path.exists():
|
|
536
|
+
raise click.ClickException("rewards.py not found in project directory")
|
|
537
|
+
if not train_path.exists():
|
|
538
|
+
raise click.ClickException("train.jsonl not found in project directory")
|
|
539
|
+
|
|
540
|
+
# Validate max_tokens vs prompt size
|
|
541
|
+
from rnow.cli.commands import get_max_prompt_tokens, validate_max_tokens_for_context
|
|
542
|
+
from rnow.models import MAX_CONTEXT_WINDOW
|
|
543
|
+
|
|
544
|
+
if config.rollout:
|
|
545
|
+
max_prompt_tokens = get_max_prompt_tokens(train_path)
|
|
546
|
+
if max_prompt_tokens > 0:
|
|
547
|
+
context_error, recommended = validate_max_tokens_for_context(
|
|
548
|
+
config.rollout.max_tokens, max_prompt_tokens
|
|
549
|
+
)
|
|
550
|
+
if context_error:
|
|
551
|
+
click.echo()
|
|
552
|
+
click.echo(click.style("✗ Context window exceeded", fg="red", bold=True))
|
|
553
|
+
click.echo()
|
|
554
|
+
click.echo(
|
|
555
|
+
f" Your longest prompt in train.jsonl is ~{max_prompt_tokens:,} tokens."
|
|
556
|
+
)
|
|
557
|
+
click.echo(f" With max_tokens={config.rollout.max_tokens:,}, the total exceeds")
|
|
558
|
+
click.echo(f" the {MAX_CONTEXT_WINDOW:,} token context window.")
|
|
559
|
+
click.echo()
|
|
560
|
+
click.echo(
|
|
561
|
+
click.style(" Fix:", bold=True)
|
|
562
|
+
+ f" Set rollout.max_tokens to {recommended:,} or less"
|
|
563
|
+
)
|
|
564
|
+
click.echo()
|
|
565
|
+
raise click.ClickException("max_tokens + prompt length exceeds context window")
|
|
566
|
+
|
|
567
|
+
clear_reward_registry()
|
|
568
|
+
clear_tool_registry()
|
|
569
|
+
|
|
570
|
+
_exec_file(rewards_path, "rewards")
|
|
571
|
+
|
|
572
|
+
if with_tools and env_path.exists():
|
|
573
|
+
_exec_file(env_path, "env")
|
|
574
|
+
|
|
575
|
+
samples = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()]
|
|
576
|
+
|
|
577
|
+
if not samples:
|
|
578
|
+
raise click.ClickException("train.jsonl is empty")
|
|
579
|
+
|
|
580
|
+
model_name = model_override or config.model.path
|
|
581
|
+
max_tokens = config.rollout.max_tokens if config.rollout else 2048
|
|
582
|
+
max_turns_config = config.rollout.max_turns if config.rollout else 1
|
|
583
|
+
termination_policy = config.rollout.termination_policy if config.rollout else "last_tool"
|
|
584
|
+
|
|
585
|
+
max_turns = 1 if not multi_turn else max_turns_config
|
|
586
|
+
|
|
587
|
+
# Check for gpt-oss with tools - not supported in rnow test
|
|
588
|
+
is_gpt_oss = "gpt-oss" in model_name.lower() or "gptoss" in model_name.lower()
|
|
589
|
+
has_tools = with_tools and (env_path.exists() or (config.rollout and config.rollout.mcp_url))
|
|
590
|
+
|
|
591
|
+
if is_gpt_oss and has_tools:
|
|
592
|
+
click.echo(
|
|
593
|
+
click.style("Warning: ", fg="yellow")
|
|
594
|
+
+ "Tool calling with gpt-oss models is not supported in 'rnow test'. Running without tools."
|
|
595
|
+
)
|
|
596
|
+
with_tools = False
|
|
597
|
+
|
|
598
|
+
rewards = []
|
|
599
|
+
tool_registry_to_use = TOOL_REGISTRY if with_tools else {}
|
|
600
|
+
|
|
601
|
+
# Display model info with reasoning mode (same format as rnow run)
|
|
602
|
+
thinking_display = get_thinking_mode_display(config)
|
|
603
|
+
click.echo(f"Model: {model_name} ({click.style(thinking_display, fg=TEAL_RGB)})")
|
|
604
|
+
click.echo()
|
|
605
|
+
|
|
606
|
+
try:
|
|
607
|
+
# Create one completer per concurrent rollout to avoid session conflicts
|
|
608
|
+
completers = [
|
|
609
|
+
ModelCompleter(
|
|
610
|
+
api_base=api_url,
|
|
611
|
+
model=model_name,
|
|
612
|
+
max_tokens=max_tokens,
|
|
613
|
+
)
|
|
614
|
+
for _ in range(num_rollouts)
|
|
615
|
+
]
|
|
616
|
+
|
|
617
|
+
# Select samples for each rollout upfront
|
|
618
|
+
selected_samples = [random.choice(samples) for _ in range(num_rollouts)]
|
|
619
|
+
|
|
620
|
+
# Start spinner for concurrent rollouts
|
|
621
|
+
spinner = Spinner(f"Running {num_rollouts} rollouts concurrently...")
|
|
622
|
+
spinner.start()
|
|
623
|
+
|
|
624
|
+
async def run_rollout_with_index(idx: int) -> tuple[int, dict | Exception]:
|
|
625
|
+
"""Run a single rollout and return (index, result or exception)."""
|
|
626
|
+
if _shutdown_requested:
|
|
627
|
+
return (idx, asyncio.CancelledError("Shutdown requested"))
|
|
628
|
+
try:
|
|
629
|
+
result = await _run_single_rollout(
|
|
630
|
+
completer=completers[idx],
|
|
631
|
+
sample=selected_samples[idx],
|
|
632
|
+
reward_registry=REWARD_REGISTRY,
|
|
633
|
+
tool_registry=tool_registry_to_use,
|
|
634
|
+
max_turns=max_turns,
|
|
635
|
+
termination_policy=termination_policy,
|
|
636
|
+
verbose=False,
|
|
637
|
+
)
|
|
638
|
+
return (idx, result)
|
|
639
|
+
except asyncio.CancelledError:
|
|
640
|
+
return (idx, asyncio.CancelledError("Cancelled"))
|
|
641
|
+
except Exception as e:
|
|
642
|
+
return (idx, e)
|
|
643
|
+
|
|
644
|
+
# Run all rollouts concurrently
|
|
645
|
+
start_time = time.time()
|
|
646
|
+
tasks = [asyncio.create_task(run_rollout_with_index(i)) for i in range(num_rollouts)]
|
|
647
|
+
|
|
648
|
+
try:
|
|
649
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
650
|
+
except asyncio.CancelledError:
|
|
651
|
+
# Cancel all tasks if we get interrupted
|
|
652
|
+
for task in tasks:
|
|
653
|
+
task.cancel()
|
|
654
|
+
results = []
|
|
655
|
+
|
|
656
|
+
total_time = time.time() - start_time
|
|
657
|
+
spinner.stop()
|
|
658
|
+
|
|
659
|
+
# Check if shutdown was requested
|
|
660
|
+
if _shutdown_requested:
|
|
661
|
+
# Close completers and exit early
|
|
662
|
+
for c in completers:
|
|
663
|
+
await c.close()
|
|
664
|
+
return
|
|
665
|
+
|
|
666
|
+
# Display results in order
|
|
667
|
+
for idx, result in sorted(results, key=lambda x: x[0]):
|
|
668
|
+
click.echo(f"Rollout {idx+1}/{num_rollouts}")
|
|
669
|
+
|
|
670
|
+
if isinstance(result, Exception):
|
|
671
|
+
if isinstance(result, httpx.HTTPStatusError):
|
|
672
|
+
click.echo(
|
|
673
|
+
click.style(f" ✗ HTTP Error: {result.response.status_code}", fg="red")
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
click.echo(click.style(f" ✗ {result}", fg="red"))
|
|
677
|
+
click.echo()
|
|
678
|
+
continue
|
|
679
|
+
|
|
680
|
+
total_reward = result["total_reward"]
|
|
681
|
+
rewards.append(total_reward)
|
|
682
|
+
|
|
683
|
+
# Get conversation
|
|
684
|
+
conversation = result["conversation"]
|
|
685
|
+
|
|
686
|
+
# Show all messages with red tags
|
|
687
|
+
for msg in conversation:
|
|
688
|
+
role = msg.get("role", "unknown")
|
|
689
|
+
content = msg.get("content", "")
|
|
690
|
+
# Truncate if flag is set
|
|
691
|
+
if truncate and len(content) > truncate:
|
|
692
|
+
content = content[:truncate] + "..."
|
|
693
|
+
tag = click.style(f"[{role}]", fg="red")
|
|
694
|
+
click.echo(f" {tag} {content}")
|
|
695
|
+
reward_str = ", ".join(f"{k}={v:.3f}" for k, v in result["rewards"].items())
|
|
696
|
+
click.echo(
|
|
697
|
+
f" {click.style('reward', fg=TEAL_RGB)}={total_reward:.3f} "
|
|
698
|
+
f"| turns={result['turns']} "
|
|
699
|
+
f"| tools_used={result['tools_used']} "
|
|
700
|
+
f"| [{reward_str}]"
|
|
701
|
+
)
|
|
702
|
+
click.echo()
|
|
703
|
+
|
|
704
|
+
# Close all completers
|
|
705
|
+
for c in completers:
|
|
706
|
+
await c.close()
|
|
707
|
+
|
|
708
|
+
except Exception:
|
|
709
|
+
raise
|
|
710
|
+
|
|
711
|
+
# Flush any pending billing charges
|
|
712
|
+
flush_result = await flush_pending_charges(api_url)
|
|
713
|
+
|
|
714
|
+
if rewards:
|
|
715
|
+
mean_reward = sum(rewards) / len(rewards)
|
|
716
|
+
click.echo()
|
|
717
|
+
click.echo(f"Mean reward: {click.style(f'{mean_reward:.3f}', fg=TEAL_RGB)}")
|
|
718
|
+
click.echo(f"Latency: {click.style(f'{total_time:.1f}s', fg=TEAL_RGB)}")
|
|
719
|
+
else:
|
|
720
|
+
click.echo(click.style("\nNo successful rollouts completed.", fg="yellow"))
|
|
721
|
+
|
|
722
|
+
# Show billing summary if charges were flushed
|
|
723
|
+
if flush_result and flush_result.get("flushed"):
|
|
724
|
+
amount_cents = flush_result.get("amountCents", 0)
|
|
725
|
+
total_tokens = flush_result.get("totalTokens", 0)
|
|
726
|
+
click.echo(
|
|
727
|
+
f"Billing: {click.style(f'${amount_cents/100:.2f}', fg=TEAL_RGB)} ({total_tokens:,} tokens)"
|
|
728
|
+
)
|