rnow 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. rnow/__init__.py +5 -0
  2. rnow/__main__.py +7 -0
  3. rnow/cli/__init__.py +6 -0
  4. rnow/cli/auth.py +67 -0
  5. rnow/cli/blob.py +98 -0
  6. rnow/cli/commands.py +2311 -0
  7. rnow/cli/common.py +28 -0
  8. rnow/cli/cube.py +255 -0
  9. rnow/cli/main.py +49 -0
  10. rnow/cli/test.py +728 -0
  11. rnow/cli/token_count.py +295 -0
  12. rnow/core/__init__.py +33 -0
  13. rnow/core/reward.py +333 -0
  14. rnow/core/tool.py +494 -0
  15. rnow/models.py +295 -0
  16. rnow/templates/deepseek-aha/config.yml +26 -0
  17. rnow/templates/deepseek-aha/rewards.py +36 -0
  18. rnow/templates/deepseek-aha/train.jsonl +1000 -0
  19. rnow/templates/mcp-tavily/config.yml +29 -0
  20. rnow/templates/mcp-tavily/requirements.txt +1 -0
  21. rnow/templates/mcp-tavily/rewards.py +25 -0
  22. rnow/templates/mcp-tavily/train.jsonl +500 -0
  23. rnow/templates/new/config.yml +26 -0
  24. rnow/templates/new/requirements.txt +1 -0
  25. rnow/templates/new/rewards.py +0 -0
  26. rnow/templates/new/train.jsonl +0 -0
  27. rnow/templates/rl-nextjs/config.yml +27 -0
  28. rnow/templates/rl-nextjs/requirements.txt +2 -0
  29. rnow/templates/rl-nextjs/rewards.py +446 -0
  30. rnow/templates/rl-nextjs/train.jsonl +1000 -0
  31. rnow/templates/rl-single/config.yml +27 -0
  32. rnow/templates/rl-single/requirements.txt +1 -0
  33. rnow/templates/rl-single/rewards.py +14 -0
  34. rnow/templates/rl-single/train.jsonl +1000 -0
  35. rnow/templates/rl-tools/config.yml +27 -0
  36. rnow/templates/rl-tools/env.py +38 -0
  37. rnow/templates/rl-tools/requirements.txt +3 -0
  38. rnow/templates/rl-tools/rewards.py +25 -0
  39. rnow/templates/rl-tools/train.jsonl +500 -0
  40. rnow/templates/sft/config.yml +20 -0
  41. rnow/templates/sft/train.jsonl +100 -0
  42. rnow/templates/tutorial-reward/config.yml +27 -0
  43. rnow/templates/tutorial-reward/requirements.txt +1 -0
  44. rnow/templates/tutorial-reward/rewards.py +15 -0
  45. rnow/templates/tutorial-reward/train.jsonl +1000 -0
  46. rnow/templates/tutorial-tool/config.yml +27 -0
  47. rnow/templates/tutorial-tool/env.py +7 -0
  48. rnow/templates/tutorial-tool/requirements.txt +3 -0
  49. rnow/templates/tutorial-tool/rewards.py +7 -0
  50. rnow/templates/tutorial-tool/train.jsonl +1266 -0
  51. rnow-0.2.4.dist-info/METADATA +135 -0
  52. rnow-0.2.4.dist-info/RECORD +56 -0
  53. rnow-0.2.4.dist-info/WHEEL +5 -0
  54. rnow-0.2.4.dist-info/entry_points.txt +2 -0
  55. rnow-0.2.4.dist-info/licenses/LICENSE +21 -0
  56. rnow-0.2.4.dist-info/top_level.txt +1 -0
rnow/cli/test.py ADDED
@@ -0,0 +1,728 @@
1
+ # rnow/cli/test.py
2
+ """
3
+ Test command for running RL rollouts locally.
4
+
5
+ Requires authentication for billing.
6
+ """
7
+
8
+ import asyncio
9
+ import itertools
10
+ import json
11
+ import random
12
+ import re
13
+ import signal
14
+ import sys
15
+ import threading
16
+ import time
17
+ from collections.abc import Callable
18
+ from pathlib import Path
19
+ from string import Template
20
+
21
+ import click
22
+ import httpx
23
+ import yaml
24
+
25
+ # Global flag for graceful shutdown
26
+ _shutdown_requested = False
27
+
28
+ from rnow.cli.auth import get_auth_headers
29
+ from rnow.cli.commands import get_thinking_mode_display
30
+
31
+ # ReinforceNow teal: #14B8A6 as RGB tuple for click.style()
32
+ TEAL_RGB = (20, 184, 166)
33
+
34
+
35
+ class Spinner:
36
+ """Simple spinner for CLI feedback with dynamic status updates."""
37
+
38
+ FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
39
+
40
+ def __init__(self, message: str = ""):
41
+ self.message = message
42
+ self._stop_event = threading.Event()
43
+ self._thread: threading.Thread | None = None
44
+ self._lock = threading.Lock()
45
+
46
+ def update(self, message: str):
47
+ """Update the spinner message."""
48
+ with self._lock:
49
+ self.message = message
50
+
51
+ def _spin(self):
52
+ for frame in itertools.cycle(self.FRAMES):
53
+ if self._stop_event.is_set() or _shutdown_requested:
54
+ break
55
+ with self._lock:
56
+ msg = self.message
57
+ # Clear line and write new status
58
+ sys.stdout.write(f"\r\033[K{frame} {msg}")
59
+ sys.stdout.flush()
60
+ time.sleep(0.08)
61
+ # Clear the spinner line when done
62
+ sys.stdout.write("\r\033[K")
63
+ sys.stdout.flush()
64
+
65
+ def start(self):
66
+ self._stop_event.clear()
67
+ self._thread = threading.Thread(target=self._spin, daemon=True)
68
+ self._thread.start()
69
+
70
+ def stop(self):
71
+ self._stop_event.set()
72
+ if self._thread:
73
+ self._thread.join(timeout=0.5) # Don't wait forever
74
+
75
+
76
+ from rnow.cli.common import require_auth
77
+ from rnow.core.reward import REWARD_REGISTRY, clear_reward_registry, compute_total_reward
78
+ from rnow.core.tool import TOOL_REGISTRY, clear_tool_registry
79
+ from rnow.models import ProjectConfig, RewardArgs
80
+
81
+ DEFAULT_API_URL = "https://www.reinforcenow.ai"
82
+
83
+
84
+ class ModelCompleter:
85
+ """
86
+ Completer that handles tokenization and calls Next.js API.
87
+ Requires authentication for billing.
88
+ """
89
+
90
+ def __init__(self, api_base: str, model: str, max_tokens: int = 2048, temperature: float = 1.0):
91
+ self.api_base = api_base.rstrip("/")
92
+ self.model = model
93
+ self.max_tokens = max_tokens
94
+ self.temperature = temperature
95
+ self.auth_headers = get_auth_headers()
96
+ self.client = httpx.AsyncClient(timeout=120.0)
97
+ self.session_id: str | None = None # Cached session ID for reuse
98
+ self.total_latency_ms = 0
99
+ self.request_count = 0
100
+
101
+ # Initialize tokenizer and renderer
102
+ from tinker_cookbook import renderers
103
+ from tinker_cookbook.model_info import get_recommended_renderer_name
104
+ from tinker_cookbook.tokenizer_utils import get_tokenizer
105
+
106
+ self.tokenizer = get_tokenizer(model)
107
+ renderer_name = get_recommended_renderer_name(model)
108
+ self.renderer = renderers.get_renderer(renderer_name, self.tokenizer)
109
+
110
+ async def __call__(self, messages: list[dict], stop: list[str] | None = None) -> dict:
111
+ """
112
+ Tokenize messages, call Next.js API, decode response.
113
+ """
114
+ # Build model input using renderer
115
+ model_input = self.renderer.build_generation_prompt(messages)
116
+ tokens = model_input.to_ints()
117
+
118
+ # Get stop sequences from renderer if not provided
119
+ if stop is None:
120
+ stop = self.renderer.get_stop_sequences()
121
+
122
+ # Build request payload
123
+ payload = {
124
+ "model": self.model,
125
+ "tokens": tokens,
126
+ "stop": stop,
127
+ "max_tokens": self.max_tokens,
128
+ "temperature": self.temperature,
129
+ }
130
+ # Include session_id if we have one cached
131
+ if self.session_id:
132
+ payload["session_id"] = self.session_id
133
+
134
+ # Call Next.js API with tokens
135
+ resp = await self.client.post(
136
+ f"{self.api_base}/api/rnow/sample",
137
+ json=payload,
138
+ headers=self.auth_headers,
139
+ )
140
+ resp.raise_for_status()
141
+ data = resp.json()
142
+
143
+ if "error" in data:
144
+ raise Exception(f"API error: {data.get('detail', data.get('error'))}")
145
+
146
+ # Cache the session_id for future requests
147
+ if "session_id" in data and data["session_id"]:
148
+ self.session_id = data["session_id"]
149
+
150
+ # Track latency
151
+ if "latency_ms" in data:
152
+ self.total_latency_ms += data["latency_ms"]
153
+ self.request_count += 1
154
+
155
+ # Decode tokens back to text
156
+ output_tokens = data.get("tokens", [])
157
+ parsed_message, _success = self.renderer.parse_response(output_tokens)
158
+
159
+ return {
160
+ "content": parsed_message.get("content", ""),
161
+ "latency_ms": data.get("latency_ms", 0),
162
+ }
163
+
164
+ async def close(self):
165
+ await self.client.aclose()
166
+
167
+
168
+ async def flush_pending_charges(api_url: str) -> dict | None:
169
+ """
170
+ Flush any pending ROLLOUT charges at the end of the test.
171
+ Returns the flush result or None if it failed.
172
+ """
173
+ try:
174
+ headers = get_auth_headers()
175
+ async with httpx.AsyncClient(timeout=30.0) as client:
176
+ resp = await client.post(
177
+ f"{api_url.rstrip('/')}/api/billing/flush-rollout",
178
+ headers=headers,
179
+ )
180
+ resp.raise_for_status()
181
+ return resp.json()
182
+ except Exception as e:
183
+ click.echo(click.style(f"Warning: Failed to flush pending charges: {e}", fg="yellow"))
184
+ return None
185
+
186
+
187
+ def _exec_file(path: Path, module_name: str) -> None:
188
+ """Execute a Python file to populate registries."""
189
+ import importlib.util
190
+
191
+ spec = importlib.util.spec_from_file_location(module_name, path)
192
+ if spec is None or spec.loader is None:
193
+ raise ImportError(f"Could not load module from {path}")
194
+ module = importlib.util.module_from_spec(spec)
195
+ spec.loader.exec_module(module)
196
+
197
+
198
+ def _build_tools_block(tool_registry: dict[str, Callable]) -> str:
199
+ """Build the tools description block from registered tool functions."""
200
+ if not tool_registry:
201
+ return ""
202
+
203
+ tools_json = []
204
+ for name, fn in tool_registry.items():
205
+ schema = getattr(fn, "_schema", {"type": "object", "properties": {}})
206
+ description = getattr(fn, "_description", "No description available.")
207
+ tools_json.append(
208
+ {
209
+ "name": name,
210
+ "description": description,
211
+ "parameters": schema,
212
+ }
213
+ )
214
+
215
+ tools_block = f"""# Tools
216
+
217
+ You may call one or more functions to assist with the user query.
218
+
219
+ You are provided with function signatures within <tools></tools> XML tags:
220
+ <tools>
221
+ {json.dumps(tools_json, indent=2)}
222
+ </tools>
223
+
224
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
225
+ <tool_call>
226
+ {{"name": "<function-name>", "arguments": {{"<arg-name>": "<value>"}}}}
227
+ </tool_call>
228
+ """
229
+
230
+ return tools_block
231
+
232
+
233
+ TOOL_CALL_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
234
+
235
+
236
+ def _format_message(msg: dict, max_len: int = 300) -> str:
237
+ """Format a message for display."""
238
+ role = msg.get("role", "unknown")
239
+ content = msg.get("content", "")
240
+ # Truncate long content
241
+ if len(content) > max_len:
242
+ content = content[:max_len] + "..."
243
+ # Color based on role
244
+ colors = {"system": "yellow", "user": "blue", "assistant": "green", "tool": "magenta"}
245
+ color = colors.get(role, "white")
246
+ return click.style(f"[{role}]", fg=color) + f" {content}"
247
+
248
+
249
+ async def _run_single_rollout(
250
+ completer: ModelCompleter,
251
+ sample: dict,
252
+ reward_registry: dict[str, Callable],
253
+ tool_registry: dict[str, Callable],
254
+ max_turns: int,
255
+ termination_policy: str,
256
+ verbose: bool = False,
257
+ ) -> dict:
258
+ """Run a single rollout for an RL sample."""
259
+ import inspect
260
+
261
+ messages_templates = sample["messages"]
262
+ reward_names = sample["rewards"]
263
+ variables = sample.get("variables", {})
264
+ metadata = sample.get("metadata", {})
265
+
266
+ reward_fns = []
267
+ for name in reward_names:
268
+ if name not in reward_registry:
269
+ raise ValueError(f"Reward function '{name}' not found in registry")
270
+ reward_fns.append(reward_registry[name])
271
+
272
+ ctx = {**metadata, **variables}
273
+ messages = [
274
+ {"role": msg["role"], "content": Template(msg["content"]).safe_substitute(ctx)}
275
+ for msg in messages_templates
276
+ ]
277
+
278
+ if tool_registry:
279
+ tools_block = _build_tools_block(tool_registry)
280
+ system_found = False
281
+ for msg in messages:
282
+ if msg["role"] == "system":
283
+ msg["content"] = tools_block + "\n\n" + msg["content"]
284
+ system_found = True
285
+ break
286
+ if not system_found:
287
+ messages.insert(0, {"role": "system", "content": tools_block})
288
+
289
+ conversation = messages.copy()
290
+ turn_count = 0
291
+ total_tool_calls = 0
292
+
293
+ # Show initial messages in verbose mode
294
+ if verbose:
295
+ click.echo(" --- Initial Messages ---")
296
+ for msg in messages:
297
+ click.echo(f" {_format_message(msg)}")
298
+ click.echo(" -------------------------")
299
+
300
+ while turn_count < max_turns:
301
+ turn_count += 1
302
+
303
+ result = await completer(conversation, stop=None)
304
+ response_content = result.get("content", "")
305
+
306
+ conversation.append({"role": "assistant", "content": response_content})
307
+
308
+ if verbose:
309
+ click.echo(
310
+ f" [Turn {turn_count}] {_format_message({'role': 'assistant', 'content': response_content}, max_len=500)}"
311
+ )
312
+
313
+ tool_matches = TOOL_CALL_RE.findall(response_content)
314
+ tool_call_count = len(tool_matches)
315
+ total_tool_calls += tool_call_count
316
+
317
+ for raw_call in tool_matches:
318
+ if not tool_registry:
319
+ break
320
+ try:
321
+ tool_data = json.loads(raw_call)
322
+ tool_name = tool_data.get("name")
323
+ args = tool_data.get("arguments", {})
324
+
325
+ if tool_name not in tool_registry:
326
+ tool_response = f"<tool_error>Tool '{tool_name}' not found</tool_error>"
327
+ conversation.append({"role": "tool", "content": tool_response})
328
+ if verbose:
329
+ click.echo(
330
+ f" {_format_message({'role': 'tool', 'content': tool_response})}"
331
+ )
332
+ continue
333
+
334
+ tool_fn = tool_registry[tool_name]
335
+ tool_result = (
336
+ await tool_fn(**args)
337
+ if inspect.iscoroutinefunction(tool_fn)
338
+ else tool_fn(**args)
339
+ )
340
+
341
+ tool_response = f"<tool_result>{json.dumps(tool_result)}</tool_result>"
342
+ conversation.append({"role": "tool", "content": tool_response})
343
+
344
+ if verbose:
345
+ click.echo(
346
+ f" Tool {click.style(tool_name, fg=TEAL_RGB)}: {str(tool_result)[:200]}"
347
+ )
348
+
349
+ except json.JSONDecodeError as e:
350
+ tool_response = f"<tool_error>Invalid JSON: {str(e)}</tool_error>"
351
+ conversation.append({"role": "tool", "content": tool_response})
352
+ if verbose:
353
+ click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
354
+ except Exception as e:
355
+ tool_response = f"<tool_error>{str(e)}</tool_error>"
356
+ conversation.append({"role": "tool", "content": tool_response})
357
+ if verbose:
358
+ click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
359
+
360
+ if termination_policy == "last_tool" and tool_call_count == 0:
361
+ break
362
+
363
+ # Show final conversation summary in verbose mode
364
+ if verbose:
365
+ click.echo(f" --- Rollout Complete: {turn_count} turns, {total_tool_calls} tool calls ---")
366
+
367
+ reward_args = RewardArgs(metadata=metadata, variables=variables)
368
+ rewards = {}
369
+ for fn, name in zip(reward_fns, reward_names, strict=False):
370
+ value = await fn(reward_args, conversation)
371
+ rewards[name] = value
372
+
373
+ total_reward = compute_total_reward(rewards) if rewards else 0.0
374
+
375
+ return {
376
+ "total_reward": total_reward,
377
+ "rewards": rewards,
378
+ "turns": turn_count,
379
+ "tools_used": total_tool_calls,
380
+ "conversation": conversation,
381
+ }
382
+
383
+
384
+ def _check_test_dependencies():
385
+ """Check if optional test dependencies are installed."""
386
+ try:
387
+ import tinker_cookbook # noqa: F401
388
+ except ImportError:
389
+ click.echo()
390
+ click.echo(
391
+ click.style("Error: ", fg="red", bold=True)
392
+ + "The 'rnow test' command requires additional dependencies."
393
+ )
394
+ pip_cmd = "pip install 'rnow[test]'"
395
+ click.echo(f"Install them with: {click.style(pip_cmd, fg=TEAL_RGB)}")
396
+ click.echo()
397
+ raise SystemExit(1)
398
+
399
+
400
+ @click.command(name="test")
401
+ @click.option(
402
+ "--dir",
403
+ "-d",
404
+ "project_dir",
405
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
406
+ default=".",
407
+ help="Project directory containing config.yml, rewards.py, env.py, train.jsonl",
408
+ )
409
+ @click.option(
410
+ "--num-rollouts",
411
+ "-n",
412
+ default=3,
413
+ show_default=True,
414
+ help="Number of rollouts to run",
415
+ )
416
+ @click.option(
417
+ "--multi-turn/--single-turn",
418
+ default=True,
419
+ show_default=True,
420
+ help="Allow multi-turn rollouts or force single-turn",
421
+ )
422
+ @click.option(
423
+ "--with-tools/--no-tools",
424
+ default=True,
425
+ show_default=True,
426
+ help="Enable or disable tool use during rollout",
427
+ )
428
+ @click.option(
429
+ "--model",
430
+ default=None,
431
+ help="Override model name for sampling (otherwise uses config.model.path)",
432
+ )
433
+ @click.option(
434
+ "--api-url",
435
+ envvar="RNOW_API_URL",
436
+ default=None,
437
+ help="Base URL of the Next.js backend (default: https://www.reinforcenow.ai)",
438
+ )
439
+ @click.option(
440
+ "--verbose",
441
+ "-v",
442
+ is_flag=True,
443
+ help="Show detailed output for each rollout turn",
444
+ )
445
+ @click.option(
446
+ "--truncate",
447
+ "-t",
448
+ default=None,
449
+ type=int,
450
+ help="Truncate message content to N characters (default: no truncation)",
451
+ )
452
+ @click.pass_context
453
+ def test(ctx, project_dir, num_rollouts, multi_turn, with_tools, model, api_url, verbose, truncate):
454
+ """Test RL rollouts locally before submitting.
455
+
456
+ This command runs local RL rollouts by calling the Next.js API
457
+ for model sampling.
458
+
459
+ Only works with RL projects (dataset_type: rl).
460
+ """
461
+ global _shutdown_requested
462
+ _shutdown_requested = False
463
+
464
+ def handle_sigint(signum, frame):
465
+ global _shutdown_requested
466
+ if _shutdown_requested:
467
+ # Second Ctrl+C, force exit
468
+ sys.exit(1)
469
+ _shutdown_requested = True
470
+ click.echo("\n" + click.style("Interrupted. Shutting down gracefully...", fg="yellow"))
471
+
472
+ # Set up signal handler
473
+ original_handler = signal.signal(signal.SIGINT, handle_sigint)
474
+
475
+ require_auth()
476
+ _check_test_dependencies()
477
+ try:
478
+ asyncio.run(
479
+ _test_async(
480
+ project_dir=project_dir,
481
+ num_rollouts=num_rollouts,
482
+ multi_turn=multi_turn,
483
+ with_tools=with_tools,
484
+ model_override=model,
485
+ api_url=api_url
486
+ or ctx.obj.get("api_url", "").replace("/api", "")
487
+ or DEFAULT_API_URL,
488
+ verbose=verbose,
489
+ truncate=truncate,
490
+ )
491
+ )
492
+ except KeyboardInterrupt:
493
+ click.echo(click.style("Aborted.", fg="yellow"))
494
+ finally:
495
+ # Restore original signal handler
496
+ signal.signal(signal.SIGINT, original_handler)
497
+
498
+
499
+ async def _test_async(
500
+ project_dir: Path,
501
+ num_rollouts: int,
502
+ multi_turn: bool,
503
+ with_tools: bool,
504
+ model_override: str | None,
505
+ api_url: str,
506
+ verbose: bool,
507
+ truncate: int | None,
508
+ ):
509
+ project_dir = Path(project_dir)
510
+
511
+ config_path = project_dir / "config.yml"
512
+ if not config_path.exists():
513
+ config_path = project_dir / "config.json"
514
+
515
+ if not config_path.exists():
516
+ raise click.ClickException("No config.yml or config.json found in project directory")
517
+
518
+ if config_path.suffix == ".yml":
519
+ config_data = yaml.safe_load(config_path.read_text())
520
+ else:
521
+ config_data = json.loads(config_path.read_text())
522
+
523
+ config = ProjectConfig(**config_data)
524
+
525
+ if config.dataset_type.value != "rl":
526
+ raise click.ClickException(
527
+ f"rnow test only supports RL projects (dataset_type: rl). "
528
+ f"Found: {config.dataset_type.value}"
529
+ )
530
+
531
+ rewards_path = project_dir / "rewards.py"
532
+ env_path = project_dir / "env.py"
533
+ train_path = project_dir / "train.jsonl"
534
+
535
+ if not rewards_path.exists():
536
+ raise click.ClickException("rewards.py not found in project directory")
537
+ if not train_path.exists():
538
+ raise click.ClickException("train.jsonl not found in project directory")
539
+
540
+ # Validate max_tokens vs prompt size
541
+ from rnow.cli.commands import get_max_prompt_tokens, validate_max_tokens_for_context
542
+ from rnow.models import MAX_CONTEXT_WINDOW
543
+
544
+ if config.rollout:
545
+ max_prompt_tokens = get_max_prompt_tokens(train_path)
546
+ if max_prompt_tokens > 0:
547
+ context_error, recommended = validate_max_tokens_for_context(
548
+ config.rollout.max_tokens, max_prompt_tokens
549
+ )
550
+ if context_error:
551
+ click.echo()
552
+ click.echo(click.style("✗ Context window exceeded", fg="red", bold=True))
553
+ click.echo()
554
+ click.echo(
555
+ f" Your longest prompt in train.jsonl is ~{max_prompt_tokens:,} tokens."
556
+ )
557
+ click.echo(f" With max_tokens={config.rollout.max_tokens:,}, the total exceeds")
558
+ click.echo(f" the {MAX_CONTEXT_WINDOW:,} token context window.")
559
+ click.echo()
560
+ click.echo(
561
+ click.style(" Fix:", bold=True)
562
+ + f" Set rollout.max_tokens to {recommended:,} or less"
563
+ )
564
+ click.echo()
565
+ raise click.ClickException("max_tokens + prompt length exceeds context window")
566
+
567
+ clear_reward_registry()
568
+ clear_tool_registry()
569
+
570
+ _exec_file(rewards_path, "rewards")
571
+
572
+ if with_tools and env_path.exists():
573
+ _exec_file(env_path, "env")
574
+
575
+ samples = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()]
576
+
577
+ if not samples:
578
+ raise click.ClickException("train.jsonl is empty")
579
+
580
+ model_name = model_override or config.model.path
581
+ max_tokens = config.rollout.max_tokens if config.rollout else 2048
582
+ max_turns_config = config.rollout.max_turns if config.rollout else 1
583
+ termination_policy = config.rollout.termination_policy if config.rollout else "last_tool"
584
+
585
+ max_turns = 1 if not multi_turn else max_turns_config
586
+
587
+ # Check for gpt-oss with tools - not supported in rnow test
588
+ is_gpt_oss = "gpt-oss" in model_name.lower() or "gptoss" in model_name.lower()
589
+ has_tools = with_tools and (env_path.exists() or (config.rollout and config.rollout.mcp_url))
590
+
591
+ if is_gpt_oss and has_tools:
592
+ click.echo(
593
+ click.style("Warning: ", fg="yellow")
594
+ + "Tool calling with gpt-oss models is not supported in 'rnow test'. Running without tools."
595
+ )
596
+ with_tools = False
597
+
598
+ rewards = []
599
+ tool_registry_to_use = TOOL_REGISTRY if with_tools else {}
600
+
601
+ # Display model info with reasoning mode (same format as rnow run)
602
+ thinking_display = get_thinking_mode_display(config)
603
+ click.echo(f"Model: {model_name} ({click.style(thinking_display, fg=TEAL_RGB)})")
604
+ click.echo()
605
+
606
+ try:
607
+ # Create one completer per concurrent rollout to avoid session conflicts
608
+ completers = [
609
+ ModelCompleter(
610
+ api_base=api_url,
611
+ model=model_name,
612
+ max_tokens=max_tokens,
613
+ )
614
+ for _ in range(num_rollouts)
615
+ ]
616
+
617
+ # Select samples for each rollout upfront
618
+ selected_samples = [random.choice(samples) for _ in range(num_rollouts)]
619
+
620
+ # Start spinner for concurrent rollouts
621
+ spinner = Spinner(f"Running {num_rollouts} rollouts concurrently...")
622
+ spinner.start()
623
+
624
+ async def run_rollout_with_index(idx: int) -> tuple[int, dict | Exception]:
625
+ """Run a single rollout and return (index, result or exception)."""
626
+ if _shutdown_requested:
627
+ return (idx, asyncio.CancelledError("Shutdown requested"))
628
+ try:
629
+ result = await _run_single_rollout(
630
+ completer=completers[idx],
631
+ sample=selected_samples[idx],
632
+ reward_registry=REWARD_REGISTRY,
633
+ tool_registry=tool_registry_to_use,
634
+ max_turns=max_turns,
635
+ termination_policy=termination_policy,
636
+ verbose=False,
637
+ )
638
+ return (idx, result)
639
+ except asyncio.CancelledError:
640
+ return (idx, asyncio.CancelledError("Cancelled"))
641
+ except Exception as e:
642
+ return (idx, e)
643
+
644
+ # Run all rollouts concurrently
645
+ start_time = time.time()
646
+ tasks = [asyncio.create_task(run_rollout_with_index(i)) for i in range(num_rollouts)]
647
+
648
+ try:
649
+ results = await asyncio.gather(*tasks, return_exceptions=True)
650
+ except asyncio.CancelledError:
651
+ # Cancel all tasks if we get interrupted
652
+ for task in tasks:
653
+ task.cancel()
654
+ results = []
655
+
656
+ total_time = time.time() - start_time
657
+ spinner.stop()
658
+
659
+ # Check if shutdown was requested
660
+ if _shutdown_requested:
661
+ # Close completers and exit early
662
+ for c in completers:
663
+ await c.close()
664
+ return
665
+
666
+ # Display results in order
667
+ for idx, result in sorted(results, key=lambda x: x[0]):
668
+ click.echo(f"Rollout {idx+1}/{num_rollouts}")
669
+
670
+ if isinstance(result, Exception):
671
+ if isinstance(result, httpx.HTTPStatusError):
672
+ click.echo(
673
+ click.style(f" ✗ HTTP Error: {result.response.status_code}", fg="red")
674
+ )
675
+ else:
676
+ click.echo(click.style(f" ✗ {result}", fg="red"))
677
+ click.echo()
678
+ continue
679
+
680
+ total_reward = result["total_reward"]
681
+ rewards.append(total_reward)
682
+
683
+ # Get conversation
684
+ conversation = result["conversation"]
685
+
686
+ # Show all messages with red tags
687
+ for msg in conversation:
688
+ role = msg.get("role", "unknown")
689
+ content = msg.get("content", "")
690
+ # Truncate if flag is set
691
+ if truncate and len(content) > truncate:
692
+ content = content[:truncate] + "..."
693
+ tag = click.style(f"[{role}]", fg="red")
694
+ click.echo(f" {tag} {content}")
695
+ reward_str = ", ".join(f"{k}={v:.3f}" for k, v in result["rewards"].items())
696
+ click.echo(
697
+ f" {click.style('reward', fg=TEAL_RGB)}={total_reward:.3f} "
698
+ f"| turns={result['turns']} "
699
+ f"| tools_used={result['tools_used']} "
700
+ f"| [{reward_str}]"
701
+ )
702
+ click.echo()
703
+
704
+ # Close all completers
705
+ for c in completers:
706
+ await c.close()
707
+
708
+ except Exception:
709
+ raise
710
+
711
+ # Flush any pending billing charges
712
+ flush_result = await flush_pending_charges(api_url)
713
+
714
+ if rewards:
715
+ mean_reward = sum(rewards) / len(rewards)
716
+ click.echo()
717
+ click.echo(f"Mean reward: {click.style(f'{mean_reward:.3f}', fg=TEAL_RGB)}")
718
+ click.echo(f"Latency: {click.style(f'{total_time:.1f}s', fg=TEAL_RGB)}")
719
+ else:
720
+ click.echo(click.style("\nNo successful rollouts completed.", fg="yellow"))
721
+
722
+ # Show billing summary if charges were flushed
723
+ if flush_result and flush_result.get("flushed"):
724
+ amount_cents = flush_result.get("amountCents", 0)
725
+ total_tokens = flush_result.get("totalTokens", 0)
726
+ click.echo(
727
+ f"Billing: {click.style(f'${amount_cents/100:.2f}', fg=TEAL_RGB)} ({total_tokens:,} tokens)"
728
+ )