rnow 0.2.4__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rnow/cli/test.py CHANGED
@@ -1,22 +1,26 @@
1
1
  # rnow/cli/test.py
2
2
  """
3
- Test command for running RL rollouts locally.
3
+ Test command for running RL rollouts via API.
4
4
 
5
- Requires authentication for billing.
5
+ Uses the /api/rnow/rollout endpoint which runs rollouts on Cloud Run.
6
+
7
+ Modes:
8
+ - Default: Uses tinker models (requires auth)
9
+ - --smoke-test: Uses OpenAI gpt-5-nano (requires OPENAI_API_KEY)
6
10
  """
7
11
 
12
+ from __future__ import annotations
13
+
8
14
  import asyncio
9
15
  import itertools
10
16
  import json
17
+ import os
11
18
  import random
12
- import re
13
19
  import signal
14
20
  import sys
15
21
  import threading
16
22
  import time
17
- from collections.abc import Callable
18
23
  from pathlib import Path
19
- from string import Template
20
24
 
21
25
  import click
22
26
  import httpx
@@ -74,66 +78,86 @@ class Spinner:
74
78
 
75
79
 
76
80
  from rnow.cli.common import require_auth
77
- from rnow.core.reward import REWARD_REGISTRY, clear_reward_registry, compute_total_reward
78
- from rnow.core.tool import TOOL_REGISTRY, clear_tool_registry
79
- from rnow.models import ProjectConfig, RewardArgs
81
+ from rnow.models import ProjectConfig
80
82
 
81
83
  DEFAULT_API_URL = "https://www.reinforcenow.ai"
82
84
 
83
85
 
84
- class ModelCompleter:
86
+ class RolloutClient:
85
87
  """
86
- Completer that handles tokenization and calls Next.js API.
87
- Requires authentication for billing.
88
+ Client for running rollouts via the /api/rnow/rollout endpoint.
89
+
90
+ Uses async polling: POST starts job, GET polls for results.
88
91
  """
89
92
 
90
- def __init__(self, api_base: str, model: str, max_tokens: int = 2048, temperature: float = 1.0):
93
+ def __init__(
94
+ self,
95
+ api_base: str,
96
+ model: str,
97
+ max_tokens: int = 2048,
98
+ temperature: float = 1.0,
99
+ max_turns: int = 1,
100
+ termination_policy: str = "last_tool",
101
+ debug: bool = False,
102
+ smoke_test: bool = False,
103
+ openai_api_key: str | None = None,
104
+ mcp_url: str | list[str] | None = None,
105
+ ):
91
106
  self.api_base = api_base.rstrip("/")
92
107
  self.model = model
93
108
  self.max_tokens = max_tokens
94
109
  self.temperature = temperature
110
+ self.max_turns = max_turns
111
+ self.termination_policy = termination_policy
112
+ self.debug = debug
113
+ self.smoke_test = smoke_test
114
+ self.openai_api_key = openai_api_key
115
+ self.mcp_url = mcp_url
95
116
  self.auth_headers = get_auth_headers()
96
- self.client = httpx.AsyncClient(timeout=120.0)
97
- self.session_id: str | None = None # Cached session ID for reuse
98
- self.total_latency_ms = 0
99
- self.request_count = 0
100
-
101
- # Initialize tokenizer and renderer
102
- from tinker_cookbook import renderers
103
- from tinker_cookbook.model_info import get_recommended_renderer_name
104
- from tinker_cookbook.tokenizer_utils import get_tokenizer
105
-
106
- self.tokenizer = get_tokenizer(model)
107
- renderer_name = get_recommended_renderer_name(model)
108
- self.renderer = renderers.get_renderer(renderer_name, self.tokenizer)
109
-
110
- async def __call__(self, messages: list[dict], stop: list[str] | None = None) -> dict:
117
+ self.client = httpx.AsyncClient(timeout=60.0)
118
+ self.total_charged_dollars = 0.0
119
+
120
+ async def start_rollout(
121
+ self,
122
+ samples: list[dict],
123
+ tools_py_code: str | None = None,
124
+ rewards_py_code: str | None = None,
125
+ dockerfiles: dict[str, str] | None = None,
126
+ secrets: dict[str, str] | None = None,
127
+ ) -> str:
111
128
  """
112
- Tokenize messages, call Next.js API, decode response.
129
+ Start rollouts and return rollout ID immediately.
130
+ Use poll_rollout() to check for results.
113
131
  """
114
- # Build model input using renderer
115
- model_input = self.renderer.build_generation_prompt(messages)
116
- tokens = model_input.to_ints()
117
-
118
- # Get stop sequences from renderer if not provided
119
- if stop is None:
120
- stop = self.renderer.get_stop_sequences()
121
-
122
- # Build request payload
123
132
  payload = {
133
+ "samples": samples,
124
134
  "model": self.model,
125
- "tokens": tokens,
126
- "stop": stop,
127
135
  "max_tokens": self.max_tokens,
128
136
  "temperature": self.temperature,
137
+ "max_turns": self.max_turns,
138
+ "termination_policy": self.termination_policy,
139
+ "tools_py_code": tools_py_code,
140
+ "rewards_py_code": rewards_py_code,
141
+ "debug": self.debug,
129
142
  }
130
- # Include session_id if we have one cached
131
- if self.session_id:
132
- payload["session_id"] = self.session_id
133
143
 
134
- # Call Next.js API with tokens
144
+ if self.mcp_url:
145
+ payload["mcp_url"] = self.mcp_url
146
+
147
+ # Send Dockerfiles for local/ images
148
+ if dockerfiles:
149
+ payload["dockerfiles"] = dockerfiles
150
+
151
+ # Send project secrets (from .env file)
152
+ if secrets:
153
+ payload["secrets"] = secrets
154
+
155
+ if self.smoke_test:
156
+ payload["smoke_test"] = True
157
+ payload["openai_api_key"] = self.openai_api_key
158
+
135
159
  resp = await self.client.post(
136
- f"{self.api_base}/api/rnow/sample",
160
+ f"{self.api_base}/api/rnow/rollout",
137
161
  json=payload,
138
162
  headers=self.auth_headers,
139
163
  )
@@ -143,94 +167,79 @@ class ModelCompleter:
143
167
  if "error" in data:
144
168
  raise Exception(f"API error: {data.get('detail', data.get('error'))}")
145
169
 
146
- # Cache the session_id for future requests
147
- if "session_id" in data and data["session_id"]:
148
- self.session_id = data["session_id"]
149
-
150
- # Track latency
151
- if "latency_ms" in data:
152
- self.total_latency_ms += data["latency_ms"]
153
- self.request_count += 1
170
+ return data["rollout_id"]
154
171
 
155
- # Decode tokens back to text
156
- output_tokens = data.get("tokens", [])
157
- parsed_message, _success = self.renderer.parse_response(output_tokens)
172
+ async def poll_rollout(self, rollout_id: str) -> dict:
173
+ """Poll for rollout status. Returns dict with 'status' field."""
174
+ resp = await self.client.get(
175
+ f"{self.api_base}/api/rnow/rollout",
176
+ params={"id": rollout_id},
177
+ headers=self.auth_headers,
178
+ )
179
+ resp.raise_for_status()
180
+ return resp.json()
181
+
182
+ async def run_batch_rollouts(
183
+ self,
184
+ samples: list[dict],
185
+ tools_py_code: str | None = None,
186
+ rewards_py_code: str | None = None,
187
+ dockerfiles: dict[str, str] | None = None,
188
+ secrets: dict[str, str] | None = None,
189
+ spinner: Spinner | None = None,
190
+ timeout_minutes: int = 30,
191
+ ) -> tuple[str, list[dict]]:
192
+ """
193
+ Run rollouts with exponential backoff polling.
194
+ Returns (rollout_id, results).
195
+ """
196
+ # Start the rollout
197
+ rollout_id = await self.start_rollout(
198
+ samples, tools_py_code, rewards_py_code, dockerfiles, secrets
199
+ )
158
200
 
159
- return {
160
- "content": parsed_message.get("content", ""),
161
- "latency_ms": data.get("latency_ms", 0),
162
- }
201
+ if spinner:
202
+ spinner.update(f"Running rollouts... (ID: {rollout_id[:8]})")
163
203
 
164
- async def close(self):
165
- await self.client.aclose()
204
+ # Poll with exponential backoff
205
+ poll_interval = 2.0 # Start at 2 seconds
206
+ max_interval = 10.0 # Cap at 10 seconds
207
+ timeout = timeout_minutes * 60
208
+ start_time = time.time()
166
209
 
210
+ while time.time() - start_time < timeout:
211
+ if _shutdown_requested:
212
+ raise asyncio.CancelledError()
167
213
 
168
- async def flush_pending_charges(api_url: str) -> dict | None:
169
- """
170
- Flush any pending ROLLOUT charges at the end of the test.
171
- Returns the flush result or None if it failed.
172
- """
173
- try:
174
- headers = get_auth_headers()
175
- async with httpx.AsyncClient(timeout=30.0) as client:
176
- resp = await client.post(
177
- f"{api_url.rstrip('/')}/api/billing/flush-rollout",
178
- headers=headers,
179
- )
180
- resp.raise_for_status()
181
- return resp.json()
182
- except Exception as e:
183
- click.echo(click.style(f"Warning: Failed to flush pending charges: {e}", fg="yellow"))
184
- return None
185
-
186
-
187
- def _exec_file(path: Path, module_name: str) -> None:
188
- """Execute a Python file to populate registries."""
189
- import importlib.util
190
-
191
- spec = importlib.util.spec_from_file_location(module_name, path)
192
- if spec is None or spec.loader is None:
193
- raise ImportError(f"Could not load module from {path}")
194
- module = importlib.util.module_from_spec(spec)
195
- spec.loader.exec_module(module)
196
-
197
-
198
- def _build_tools_block(tool_registry: dict[str, Callable]) -> str:
199
- """Build the tools description block from registered tool functions."""
200
- if not tool_registry:
201
- return ""
202
-
203
- tools_json = []
204
- for name, fn in tool_registry.items():
205
- schema = getattr(fn, "_schema", {"type": "object", "properties": {}})
206
- description = getattr(fn, "_description", "No description available.")
207
- tools_json.append(
208
- {
209
- "name": name,
210
- "description": description,
211
- "parameters": schema,
212
- }
213
- )
214
+ # Add jitter (±20%)
215
+ jitter = poll_interval * 0.2 * (random.random() * 2 - 1)
216
+ await asyncio.sleep(poll_interval + jitter)
214
217
 
215
- tools_block = f"""# Tools
218
+ result = await self.poll_rollout(rollout_id)
219
+ status = result.get("status")
216
220
 
217
- You may call one or more functions to assist with the user query.
221
+ if status == "completed":
222
+ # Track billing
223
+ if "billing" in result:
224
+ billing = result["billing"]
225
+ tokens = billing.get("prompt_tokens", 0) + billing.get("completion_tokens", 0)
226
+ self.total_charged_dollars += tokens * 0.000001
227
+ return rollout_id, result.get("results", [])
218
228
 
219
- You are provided with function signatures within <tools></tools> XML tags:
220
- <tools>
221
- {json.dumps(tools_json, indent=2)}
222
- </tools>
229
+ if status == "failed":
230
+ raise Exception(f"Rollout failed: {result.get('error', 'Unknown error')}")
223
231
 
224
- For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
225
- <tool_call>
226
- {{"name": "<function-name>", "arguments": {{"<arg-name>": "<value>"}}}}
227
- </tool_call>
228
- """
232
+ # Exponential backoff
233
+ poll_interval = min(poll_interval * 1.5, max_interval)
229
234
 
230
- return tools_block
235
+ if spinner:
236
+ elapsed = int(time.time() - start_time)
237
+ spinner.update(f"Running rollouts... ({elapsed}s, ID: {rollout_id[:8]})")
231
238
 
239
+ raise TimeoutError(f"Rollout timed out after {timeout_minutes} minutes")
232
240
 
233
- TOOL_CALL_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
241
+ async def close(self):
242
+ await self.client.aclose()
234
243
 
235
244
 
236
245
  def _format_message(msg: dict, max_len: int = 300) -> str:
@@ -247,154 +256,27 @@ def _format_message(msg: dict, max_len: int = 300) -> str:
247
256
 
248
257
 
249
258
  async def _run_single_rollout(
250
- completer: ModelCompleter,
259
+ client: RolloutClient,
251
260
  sample: dict,
252
- reward_registry: dict[str, Callable],
253
- tool_registry: dict[str, Callable],
254
- max_turns: int,
255
- termination_policy: str,
261
+ tools_py_code: str | None,
262
+ rewards_py_code: str | None,
256
263
  verbose: bool = False,
257
264
  ) -> dict:
258
- """Run a single rollout for an RL sample."""
259
- import inspect
260
-
261
- messages_templates = sample["messages"]
262
- reward_names = sample["rewards"]
263
- variables = sample.get("variables", {})
264
- metadata = sample.get("metadata", {})
265
-
266
- reward_fns = []
267
- for name in reward_names:
268
- if name not in reward_registry:
269
- raise ValueError(f"Reward function '{name}' not found in registry")
270
- reward_fns.append(reward_registry[name])
271
-
272
- ctx = {**metadata, **variables}
273
- messages = [
274
- {"role": msg["role"], "content": Template(msg["content"]).safe_substitute(ctx)}
275
- for msg in messages_templates
276
- ]
277
-
278
- if tool_registry:
279
- tools_block = _build_tools_block(tool_registry)
280
- system_found = False
281
- for msg in messages:
282
- if msg["role"] == "system":
283
- msg["content"] = tools_block + "\n\n" + msg["content"]
284
- system_found = True
285
- break
286
- if not system_found:
287
- messages.insert(0, {"role": "system", "content": tools_block})
288
-
289
- conversation = messages.copy()
290
- turn_count = 0
291
- total_tool_calls = 0
292
-
293
- # Show initial messages in verbose mode
265
+ """Run a single rollout via the API."""
266
+ result = await client.run_rollout(
267
+ sample=sample,
268
+ tools_py_code=tools_py_code,
269
+ rewards_py_code=rewards_py_code,
270
+ )
271
+
272
+ # Show conversation in verbose mode
294
273
  if verbose:
295
- click.echo(" --- Initial Messages ---")
296
- for msg in messages:
274
+ click.echo(" --- Conversation ---")
275
+ for msg in result.get("conversation", []):
297
276
  click.echo(f" {_format_message(msg)}")
298
- click.echo(" -------------------------")
299
-
300
- while turn_count < max_turns:
301
- turn_count += 1
277
+ click.echo(" ---------------------")
302
278
 
303
- result = await completer(conversation, stop=None)
304
- response_content = result.get("content", "")
305
-
306
- conversation.append({"role": "assistant", "content": response_content})
307
-
308
- if verbose:
309
- click.echo(
310
- f" [Turn {turn_count}] {_format_message({'role': 'assistant', 'content': response_content}, max_len=500)}"
311
- )
312
-
313
- tool_matches = TOOL_CALL_RE.findall(response_content)
314
- tool_call_count = len(tool_matches)
315
- total_tool_calls += tool_call_count
316
-
317
- for raw_call in tool_matches:
318
- if not tool_registry:
319
- break
320
- try:
321
- tool_data = json.loads(raw_call)
322
- tool_name = tool_data.get("name")
323
- args = tool_data.get("arguments", {})
324
-
325
- if tool_name not in tool_registry:
326
- tool_response = f"<tool_error>Tool '{tool_name}' not found</tool_error>"
327
- conversation.append({"role": "tool", "content": tool_response})
328
- if verbose:
329
- click.echo(
330
- f" {_format_message({'role': 'tool', 'content': tool_response})}"
331
- )
332
- continue
333
-
334
- tool_fn = tool_registry[tool_name]
335
- tool_result = (
336
- await tool_fn(**args)
337
- if inspect.iscoroutinefunction(tool_fn)
338
- else tool_fn(**args)
339
- )
340
-
341
- tool_response = f"<tool_result>{json.dumps(tool_result)}</tool_result>"
342
- conversation.append({"role": "tool", "content": tool_response})
343
-
344
- if verbose:
345
- click.echo(
346
- f" Tool {click.style(tool_name, fg=TEAL_RGB)}: {str(tool_result)[:200]}"
347
- )
348
-
349
- except json.JSONDecodeError as e:
350
- tool_response = f"<tool_error>Invalid JSON: {str(e)}</tool_error>"
351
- conversation.append({"role": "tool", "content": tool_response})
352
- if verbose:
353
- click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
354
- except Exception as e:
355
- tool_response = f"<tool_error>{str(e)}</tool_error>"
356
- conversation.append({"role": "tool", "content": tool_response})
357
- if verbose:
358
- click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
359
-
360
- if termination_policy == "last_tool" and tool_call_count == 0:
361
- break
362
-
363
- # Show final conversation summary in verbose mode
364
- if verbose:
365
- click.echo(f" --- Rollout Complete: {turn_count} turns, {total_tool_calls} tool calls ---")
366
-
367
- reward_args = RewardArgs(metadata=metadata, variables=variables)
368
- rewards = {}
369
- for fn, name in zip(reward_fns, reward_names, strict=False):
370
- value = await fn(reward_args, conversation)
371
- rewards[name] = value
372
-
373
- total_reward = compute_total_reward(rewards) if rewards else 0.0
374
-
375
- return {
376
- "total_reward": total_reward,
377
- "rewards": rewards,
378
- "turns": turn_count,
379
- "tools_used": total_tool_calls,
380
- "conversation": conversation,
381
- }
382
-
383
-
384
- def _check_test_dependencies():
385
- """Check if optional test dependencies are installed."""
386
- try:
387
- import tinker_cookbook # noqa: F401
388
- except ImportError:
389
- click.echo()
390
- click.echo(
391
- click.style("Error: ", fg="red", bold=True)
392
- + "The 'rnow test' command requires additional dependencies."
393
- )
394
- pip_cmd = "pip install 'rnow[test]'"
395
- click.echo(f"Install them with: {click.style(pip_cmd, fg=TEAL_RGB)}")
396
- click.echo()
397
- raise SystemExit(1)
279
+ return result
398
280
 
399
281
 
400
282
  @click.command(name="test")
@@ -404,12 +286,12 @@ def _check_test_dependencies():
404
286
  "project_dir",
405
287
  type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
406
288
  default=".",
407
- help="Project directory containing config.yml, rewards.py, env.py, train.jsonl",
289
+ help="Project directory containing config.yml, rewards.py, tools.py, train.jsonl",
408
290
  )
409
291
  @click.option(
410
292
  "--num-rollouts",
411
293
  "-n",
412
- default=3,
294
+ default=1,
413
295
  show_default=True,
414
296
  help="Number of rollouts to run",
415
297
  )
@@ -449,51 +331,290 @@ def _check_test_dependencies():
449
331
  type=int,
450
332
  help="Truncate message content to N characters (default: no truncation)",
451
333
  )
334
+ @click.option(
335
+ "--debug",
336
+ is_flag=True,
337
+ help="Use debug trainer image from Docker Hub (for testing trainer changes)",
338
+ )
339
+ @click.option(
340
+ "--output-dir",
341
+ "-o",
342
+ "output_dir",
343
+ type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
344
+ default=None,
345
+ help="Save rollout results as JSON files in this directory",
346
+ )
347
+ @click.option(
348
+ "--smoke-test",
349
+ is_flag=True,
350
+ help="Use OpenAI gpt-5-nano instead of tinker (requires OPENAI_API_KEY env var)",
351
+ )
352
+ @click.option(
353
+ "--id",
354
+ "rollout_id",
355
+ default=None,
356
+ help="Fetch results for an existing rollout ID (skip running new rollout)",
357
+ )
358
+ @click.option(
359
+ "--store",
360
+ is_flag=True,
361
+ help="Store rollout ID in ./rollouts/<id>.txt for later retrieval",
362
+ )
363
+ @click.option(
364
+ "--timeout",
365
+ default=60,
366
+ show_default=True,
367
+ help="Timeout in minutes for polling results",
368
+ )
369
+ @click.option(
370
+ "--entry",
371
+ "-e",
372
+ "entries",
373
+ default=None,
374
+ help="Entry indices from train.jsonl (0-indexed). Examples: -e 5, -e 0,2,5, -e 0 -e 2 -e 5",
375
+ multiple=True,
376
+ )
452
377
  @click.pass_context
453
- def test(ctx, project_dir, num_rollouts, multi_turn, with_tools, model, api_url, verbose, truncate):
454
- """Test RL rollouts locally before submitting.
378
+ def test(
379
+ ctx,
380
+ project_dir,
381
+ num_rollouts,
382
+ multi_turn,
383
+ with_tools,
384
+ model,
385
+ api_url,
386
+ verbose,
387
+ truncate,
388
+ debug,
389
+ output_dir,
390
+ smoke_test,
391
+ rollout_id,
392
+ store,
393
+ timeout,
394
+ entries,
395
+ ):
396
+ """Test RL rollouts before submitting.
397
+
398
+ Runs rollouts via the /api/rnow/rollout endpoint on Cloud Run.
399
+
400
+ Use --smoke-test to use OpenAI gpt-5-nano instead of tinker models
401
+ (requires OPENAI_API_KEY environment variable).
455
402
 
456
- This command runs local RL rollouts by calling the Next.js API
457
- for model sampling.
403
+ Use --id to fetch results for an existing rollout.
458
404
 
459
405
  Only works with RL projects (dataset_type: rl).
460
406
  """
461
407
  global _shutdown_requested
462
408
  _shutdown_requested = False
463
409
 
464
- def handle_sigint(signum, frame):
465
- global _shutdown_requested
466
- if _shutdown_requested:
467
- # Second Ctrl+C, force exit
468
- sys.exit(1)
469
- _shutdown_requested = True
470
- click.echo("\n" + click.style("Interrupted. Shutting down gracefully...", fg="yellow"))
471
-
472
- # Set up signal handler
473
- original_handler = signal.signal(signal.SIGINT, handle_sigint)
410
+ resolved_api_url = api_url or ctx.obj.get("api_url", "").replace("/api", "") or DEFAULT_API_URL
474
411
 
475
- require_auth()
476
- _check_test_dependencies()
477
- try:
412
+ # Handle --id flag: just fetch existing rollout results
413
+ if rollout_id:
478
414
  asyncio.run(
479
- _test_async(
415
+ _fetch_rollout_results(
416
+ rollout_id=rollout_id,
417
+ api_url=resolved_api_url,
418
+ store=store,
419
+ truncate=truncate,
420
+ output_dir=output_dir,
421
+ )
422
+ )
423
+ return
424
+
425
+ # Check for OpenAI API key in smoke test mode
426
+ openai_api_key = None
427
+ if smoke_test:
428
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
429
+ if not openai_api_key:
430
+ raise click.ClickException(
431
+ "OPENAI_API_KEY environment variable is required for smoke test mode.\n"
432
+ "Set it with: export OPENAI_API_KEY=sk-..."
433
+ )
434
+ else:
435
+ require_auth()
436
+
437
+ async def run_with_cancellation():
438
+ """Run test with proper cancellation support."""
439
+ loop = asyncio.get_running_loop()
440
+ task = asyncio.current_task()
441
+
442
+ def handle_sigint():
443
+ global _shutdown_requested
444
+ if _shutdown_requested:
445
+ sys.exit(1)
446
+ _shutdown_requested = True
447
+ click.echo("\n" + click.style("Interrupted. Cancelling...", fg="yellow"))
448
+ task.cancel()
449
+
450
+ loop.add_signal_handler(signal.SIGINT, handle_sigint)
451
+
452
+ try:
453
+ await _test_async(
480
454
  project_dir=project_dir,
481
455
  num_rollouts=num_rollouts,
482
456
  multi_turn=multi_turn,
483
457
  with_tools=with_tools,
484
458
  model_override=model,
485
- api_url=api_url
486
- or ctx.obj.get("api_url", "").replace("/api", "")
487
- or DEFAULT_API_URL,
459
+ api_url=resolved_api_url,
488
460
  verbose=verbose,
489
461
  truncate=truncate,
462
+ debug=debug,
463
+ output_dir=output_dir,
464
+ smoke_test=smoke_test,
465
+ openai_api_key=openai_api_key,
466
+ store=store,
467
+ timeout_minutes=timeout,
468
+ entries=entries,
490
469
  )
491
- )
470
+ except asyncio.CancelledError:
471
+ click.echo(click.style("Aborted.", fg="yellow"))
472
+ finally:
473
+ loop.remove_signal_handler(signal.SIGINT)
474
+
475
+ try:
476
+ asyncio.run(run_with_cancellation())
492
477
  except KeyboardInterrupt:
493
478
  click.echo(click.style("Aborted.", fg="yellow"))
479
+
480
+
481
+ async def _fetch_rollout_results(
482
+ rollout_id: str,
483
+ api_url: str,
484
+ store: bool = False,
485
+ truncate: int | None = None,
486
+ output_dir: Path | None = None,
487
+ ):
488
+ """Fetch results for an existing rollout ID."""
489
+ click.echo(f"Fetching results for rollout: {click.style(rollout_id, fg=TEAL_RGB)}")
490
+
491
+ client = httpx.AsyncClient(timeout=30.0)
492
+ auth_headers = get_auth_headers()
493
+
494
+ try:
495
+ resp = await client.get(
496
+ f"{api_url}/api/rnow/rollout",
497
+ params={"id": rollout_id},
498
+ headers=auth_headers,
499
+ )
500
+ resp.raise_for_status()
501
+ data = resp.json()
494
502
  finally:
495
- # Restore original signal handler
496
- signal.signal(signal.SIGINT, original_handler)
503
+ await client.aclose()
504
+
505
+ status = data.get("status")
506
+ if status == "pending":
507
+ click.echo(click.style("Rollout still running...", fg="yellow"))
508
+ click.echo(f"Poll again with: rnow test --id {rollout_id}")
509
+ return
510
+
511
+ if status == "failed":
512
+ click.echo(click.style(f"Rollout failed: {data.get('error', 'Unknown')}", fg="red"))
513
+ return
514
+
515
+ # Store rollout ID if requested
516
+ if store:
517
+ _store_rollout_id(rollout_id, data)
518
+
519
+ # Display results
520
+ results = data.get("results", [])
521
+ _display_results(results, truncate, output_dir, rollout_id)
522
+
523
+ # Show billing
524
+ billing = data.get("billing", {})
525
+ tokens = billing.get("prompt_tokens", 0) + billing.get("completion_tokens", 0)
526
+ if tokens > 0:
527
+ click.echo(f"Tokens: {tokens}")
528
+
529
+
530
+ def _store_rollout_id(rollout_id: str, data: dict):
531
+ """Store rollout ID and results in ./rollouts/<id>.txt"""
532
+ rollouts_dir = Path("rollouts")
533
+ rollouts_dir.mkdir(exist_ok=True)
534
+
535
+ filepath = rollouts_dir / f"{rollout_id}.txt"
536
+ with open(filepath, "w") as f:
537
+ f.write(f"Rollout ID: {rollout_id}\n")
538
+ f.write(f"Status: {data.get('status', 'unknown')}\n")
539
+ f.write(f"S3 Path: rollouts/{rollout_id}/result.json\n")
540
+ f.write("\n")
541
+
542
+ # Write summary
543
+ results = data.get("results", [])
544
+ successful = [r for r in results if r.get("success")]
545
+ if successful:
546
+ rewards = [r.get("total_reward", 0) for r in successful]
547
+ f.write(f"Successful: {len(successful)}/{len(results)}\n")
548
+ f.write(f"Mean Reward: {sum(rewards) / len(rewards):.3f}\n")
549
+
550
+ # Write billing
551
+ billing = data.get("billing", {})
552
+ tokens = billing.get("prompt_tokens", 0) + billing.get("completion_tokens", 0)
553
+ if tokens > 0:
554
+ f.write(f"Tokens: {tokens}\n")
555
+
556
+ f.write("\n--- Full Results ---\n")
557
+ f.write(json.dumps(data, indent=2))
558
+
559
+ click.echo(f"Stored: {click.style(str(filepath), fg=TEAL_RGB)}")
560
+
561
+
562
+ def _display_results(
563
+ results: list[dict],
564
+ truncate: int | None,
565
+ output_dir: Path | None,
566
+ rollout_id: str | None = None,
567
+ ):
568
+ """Display rollout results."""
569
+ rewards = []
570
+
571
+ for idx, result in enumerate(results):
572
+ click.echo(f"Rollout {idx + 1}/{len(results)}")
573
+
574
+ if not result.get("success"):
575
+ click.echo(click.style(f" ✗ {result.get('error', 'Unknown error')}", fg="red"))
576
+ click.echo()
577
+ continue
578
+
579
+ total_reward = result.get("total_reward", 0.0)
580
+ rewards.append(total_reward)
581
+
582
+ # Show conversation
583
+ for msg in result.get("conversation", []):
584
+ role = msg.get("role", "unknown")
585
+ content = msg.get("content", "")
586
+ if truncate and len(content) > truncate:
587
+ content = content[:truncate] + "..."
588
+ tag = click.style(f"[{role}]", fg="red")
589
+ click.echo(f" {tag} {content}")
590
+
591
+ reward_breakdown = result.get("rewards", {})
592
+ reward_str = ", ".join(f"{k}={v:.3f}" for k, v in reward_breakdown.items())
593
+ turns = result.get("turns", 0)
594
+ click.echo(
595
+ f" {click.style('reward', fg=TEAL_RGB)}={total_reward:.3f} "
596
+ f"| turns={turns} "
597
+ f"| [{reward_str}]"
598
+ )
599
+ click.echo()
600
+
601
+ # Save to files if requested
602
+ if output_dir and results:
603
+ output_dir.mkdir(parents=True, exist_ok=True)
604
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
605
+ for idx, result in enumerate(results):
606
+ if result.get("success"):
607
+ filename = output_dir / f"rollout_{timestamp}_{idx + 1}.json"
608
+ filename.write_text(json.dumps(result, indent=2))
609
+ click.echo(f"Results saved to {click.style(str(output_dir), fg=TEAL_RGB)}")
610
+
611
+ # Summary
612
+ if rewards:
613
+ mean_reward = sum(rewards) / len(rewards)
614
+ click.echo()
615
+ click.echo(f"Mean reward: {click.style(f'{mean_reward:.3f}', fg=TEAL_RGB)}")
616
+ if rollout_id:
617
+ click.echo(f"Rollout ID: {click.style(rollout_id, fg=TEAL_RGB)}")
497
618
 
498
619
 
499
620
  async def _test_async(
@@ -505,6 +626,13 @@ async def _test_async(
505
626
  api_url: str,
506
627
  verbose: bool,
507
628
  truncate: int | None,
629
+ debug: bool = False,
630
+ output_dir: Path | None = None,
631
+ smoke_test: bool = False,
632
+ openai_api_key: str | None = None,
633
+ store: bool = False,
634
+ timeout_minutes: int = 60,
635
+ entries: tuple[int, ...] = (),
508
636
  ):
509
637
  project_dir = Path(project_dir)
510
638
 
@@ -529,7 +657,7 @@ async def _test_async(
529
657
  )
530
658
 
531
659
  rewards_path = project_dir / "rewards.py"
532
- env_path = project_dir / "env.py"
660
+ tools_path = project_dir / "tools.py"
533
661
  train_path = project_dir / "train.jsonl"
534
662
 
535
663
  if not rewards_path.exists():
@@ -537,192 +665,159 @@ async def _test_async(
537
665
  if not train_path.exists():
538
666
  raise click.ClickException("train.jsonl not found in project directory")
539
667
 
540
- # Validate max_tokens vs prompt size
541
- from rnow.cli.commands import get_max_prompt_tokens, validate_max_tokens_for_context
542
- from rnow.models import MAX_CONTEXT_WINDOW
543
-
544
- if config.rollout:
545
- max_prompt_tokens = get_max_prompt_tokens(train_path)
546
- if max_prompt_tokens > 0:
547
- context_error, recommended = validate_max_tokens_for_context(
548
- config.rollout.max_tokens, max_prompt_tokens
549
- )
550
- if context_error:
551
- click.echo()
552
- click.echo(click.style("✗ Context window exceeded", fg="red", bold=True))
553
- click.echo()
554
- click.echo(
555
- f" Your longest prompt in train.jsonl is ~{max_prompt_tokens:,} tokens."
556
- )
557
- click.echo(f" With max_tokens={config.rollout.max_tokens:,}, the total exceeds")
558
- click.echo(f" the {MAX_CONTEXT_WINDOW:,} token context window.")
559
- click.echo()
560
- click.echo(
561
- click.style(" Fix:", bold=True)
562
- + f" Set rollout.max_tokens to {recommended:,} or less"
563
- )
564
- click.echo()
565
- raise click.ClickException("max_tokens + prompt length exceeds context window")
566
-
567
- clear_reward_registry()
568
- clear_tool_registry()
569
-
570
- _exec_file(rewards_path, "rewards")
571
-
572
- if with_tools and env_path.exists():
573
- _exec_file(env_path, "env")
668
+ # Read user code files to send to the API
669
+ rewards_py_code = rewards_path.read_text()
670
+ tools_py_code = tools_path.read_text() if with_tools and tools_path.exists() else None
574
671
 
672
+ # Load samples
575
673
  samples = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()]
576
674
 
675
+ # Read Dockerfile.* files for local/ docker images
676
+ dockerfiles: dict[str, str] = {}
677
+ for dockerfile_path in project_dir.glob("Dockerfile.*"):
678
+ dockerfiles[dockerfile_path.name] = dockerfile_path.read_text()
679
+ click.echo(f" Found {dockerfile_path.name}")
680
+
681
+ # Read .env file for project secrets
682
+ project_secrets: dict[str, str] = {}
683
+ env_path = project_dir / ".env"
684
+ if env_path.exists():
685
+ for line in env_path.read_text().splitlines():
686
+ line = line.strip()
687
+ if line and not line.startswith("#") and "=" in line:
688
+ key, _, value = line.partition("=")
689
+ # Remove quotes if present
690
+ value = value.strip().strip("'\"")
691
+ project_secrets[key.strip()] = value
692
+ if project_secrets:
693
+ click.echo(f" Loaded secrets: {list(project_secrets.keys())}")
694
+
577
695
  if not samples:
578
696
  raise click.ClickException("train.jsonl is empty")
579
697
 
580
- model_name = model_override or config.model.path
698
+ # For smoke test, always use gpt-5-nano
699
+ model_name = "gpt-5-nano" if smoke_test else model_override or config.model.path
700
+
581
701
  max_tokens = config.rollout.max_tokens if config.rollout else 2048
582
702
  max_turns_config = config.rollout.max_turns if config.rollout else 1
583
703
  termination_policy = config.rollout.termination_policy if config.rollout else "last_tool"
704
+ mcp_url = config.rollout.mcp_url if config.rollout else None
584
705
 
585
706
  max_turns = 1 if not multi_turn else max_turns_config
586
707
 
587
- # Check for gpt-oss with tools - not supported in rnow test
588
- is_gpt_oss = "gpt-oss" in model_name.lower() or "gptoss" in model_name.lower()
589
- has_tools = with_tools and (env_path.exists() or (config.rollout and config.rollout.mcp_url))
590
-
591
- if is_gpt_oss and has_tools:
592
- click.echo(
593
- click.style("Warning: ", fg="yellow")
594
- + "Tool calling with gpt-oss models is not supported in 'rnow test'. Running without tools."
595
- )
596
- with_tools = False
597
-
598
- rewards = []
599
- tool_registry_to_use = TOOL_REGISTRY if with_tools else {}
708
+ # Display mode and model info
709
+ if smoke_test:
710
+ click.echo(f"Mode: {click.style('SMOKE TEST', fg=TEAL_RGB)} (OpenAI gpt-5-nano)")
711
+ else:
712
+ thinking_display = get_thinking_mode_display(config)
713
+ click.echo(f"Model: {model_name} ({click.style(thinking_display, fg=TEAL_RGB)})")
600
714
 
601
- # Display model info with reasoning mode (same format as rnow run)
602
- thinking_display = get_thinking_mode_display(config)
603
- click.echo(f"Model: {model_name} ({click.style(thinking_display, fg=TEAL_RGB)})")
604
715
  click.echo()
605
716
 
606
717
  try:
607
- # Create one completer per concurrent rollout to avoid session conflicts
608
- completers = [
609
- ModelCompleter(
610
- api_base=api_url,
611
- model=model_name,
612
- max_tokens=max_tokens,
613
- )
614
- for _ in range(num_rollouts)
615
- ]
616
-
617
- # Select samples for each rollout upfront
618
- selected_samples = [random.choice(samples) for _ in range(num_rollouts)]
718
+ # Create one RolloutClient for all rollouts
719
+ client = RolloutClient(
720
+ api_base=api_url,
721
+ model=model_name,
722
+ max_tokens=max_tokens,
723
+ temperature=1.0,
724
+ max_turns=max_turns,
725
+ termination_policy=termination_policy,
726
+ debug=debug,
727
+ smoke_test=smoke_test,
728
+ openai_api_key=openai_api_key,
729
+ mcp_url=mcp_url,
730
+ )
619
731
 
620
- # Start spinner for concurrent rollouts
621
- spinner = Spinner(f"Running {num_rollouts} rollouts concurrently...")
732
+ # Select samples for batch rollout
733
+ if entries:
734
+ # Parse entries - support both "-e 0 -e 2" and "-e 0,2,5"
735
+ entry_indices = []
736
+ for entry in entries:
737
+ # Handle comma-separated values
738
+ for part in str(entry).split(","):
739
+ part = part.strip()
740
+ if part:
741
+ try:
742
+ idx = int(part)
743
+ except ValueError:
744
+ raise click.ClickException(f"Invalid entry index: {part}")
745
+ if idx < 0 or idx >= len(samples):
746
+ raise click.ClickException(
747
+ f"Entry index {idx} out of range. train.jsonl has {len(samples)} entries (0-{len(samples) - 1})"
748
+ )
749
+ entry_indices.append(idx)
750
+
751
+ if not entry_indices:
752
+ raise click.ClickException("No valid entry indices provided")
753
+
754
+ selected_samples = [samples[idx] for idx in entry_indices]
755
+ click.echo(f"Testing entries: {entry_indices}")
756
+ else:
757
+ # Random selection
758
+ selected_samples = [random.choice(samples) for _ in range(num_rollouts)]
759
+
760
+ # Start spinner for batch rollout
761
+ spinner = Spinner(f"Starting {len(selected_samples)} rollouts...")
622
762
  spinner.start()
623
763
 
624
- async def run_rollout_with_index(idx: int) -> tuple[int, dict | Exception]:
625
- """Run a single rollout and return (index, result or exception)."""
626
- if _shutdown_requested:
627
- return (idx, asyncio.CancelledError("Shutdown requested"))
628
- try:
629
- result = await _run_single_rollout(
630
- completer=completers[idx],
631
- sample=selected_samples[idx],
632
- reward_registry=REWARD_REGISTRY,
633
- tool_registry=tool_registry_to_use,
634
- max_turns=max_turns,
635
- termination_policy=termination_policy,
636
- verbose=False,
637
- )
638
- return (idx, result)
639
- except asyncio.CancelledError:
640
- return (idx, asyncio.CancelledError("Cancelled"))
641
- except Exception as e:
642
- return (idx, e)
643
-
644
- # Run all rollouts concurrently
645
764
  start_time = time.time()
646
- tasks = [asyncio.create_task(run_rollout_with_index(i)) for i in range(num_rollouts)]
765
+ rollout_id = None
647
766
 
648
767
  try:
649
- results = await asyncio.gather(*tasks, return_exceptions=True)
768
+ # Start rollout and poll for results with exponential backoff
769
+ rollout_id, batch_results = await client.run_batch_rollouts(
770
+ samples=selected_samples,
771
+ tools_py_code=tools_py_code,
772
+ rewards_py_code=rewards_py_code,
773
+ dockerfiles=dockerfiles if dockerfiles else None,
774
+ secrets=project_secrets if project_secrets else None,
775
+ spinner=spinner,
776
+ timeout_minutes=timeout_minutes,
777
+ )
650
778
  except asyncio.CancelledError:
651
- # Cancel all tasks if we get interrupted
652
- for task in tasks:
653
- task.cancel()
654
- results = []
779
+ batch_results = []
780
+ except Exception as e:
781
+ spinner.stop()
782
+ raise e
655
783
 
656
784
  total_time = time.time() - start_time
657
785
  spinner.stop()
658
786
 
787
+ # Show rollout ID
788
+ if rollout_id:
789
+ click.echo(f"Rollout ID: {click.style(rollout_id, fg=TEAL_RGB)}")
790
+ click.echo()
791
+
659
792
  # Check if shutdown was requested
660
793
  if _shutdown_requested:
661
- # Close completers and exit early
662
- for c in completers:
663
- await c.close()
794
+ await client.close()
664
795
  return
665
796
 
666
- # Display results in order
667
- for idx, result in sorted(results, key=lambda x: x[0]):
668
- click.echo(f"Rollout {idx+1}/{num_rollouts}")
669
-
670
- if isinstance(result, Exception):
671
- if isinstance(result, httpx.HTTPStatusError):
672
- click.echo(
673
- click.style(f" ✗ HTTP Error: {result.response.status_code}", fg="red")
674
- )
675
- else:
676
- click.echo(click.style(f" ✗ {result}", fg="red"))
677
- click.echo()
678
- continue
679
-
680
- total_reward = result["total_reward"]
681
- rewards.append(total_reward)
682
-
683
- # Get conversation
684
- conversation = result["conversation"]
685
-
686
- # Show all messages with red tags
687
- for msg in conversation:
688
- role = msg.get("role", "unknown")
689
- content = msg.get("content", "")
690
- # Truncate if flag is set
691
- if truncate and len(content) > truncate:
692
- content = content[:truncate] + "..."
693
- tag = click.style(f"[{role}]", fg="red")
694
- click.echo(f" {tag} {content}")
695
- reward_str = ", ".join(f"{k}={v:.3f}" for k, v in result["rewards"].items())
696
- click.echo(
697
- f" {click.style('reward', fg=TEAL_RGB)}={total_reward:.3f} "
698
- f"| turns={result['turns']} "
699
- f"| tools_used={result['tools_used']} "
700
- f"| [{reward_str}]"
797
+ # Store results if requested
798
+ if store and rollout_id:
799
+ _store_rollout_id(
800
+ rollout_id,
801
+ {
802
+ "status": "completed",
803
+ "results": batch_results,
804
+ "billing": {"prompt_tokens": 0, "completion_tokens": 0},
805
+ },
701
806
  )
702
- click.echo()
703
807
 
704
- # Close all completers
705
- for c in completers:
706
- await c.close()
808
+ # Display results using shared function
809
+ _display_results(batch_results, truncate, output_dir, rollout_id)
707
810
 
708
- except Exception:
709
- raise
811
+ # Get total billing
812
+ total_charged = client.total_charged_dollars
710
813
 
711
- # Flush any pending billing charges
712
- flush_result = await flush_pending_charges(api_url)
814
+ # Close client
815
+ await client.close()
713
816
 
714
- if rewards:
715
- mean_reward = sum(rewards) / len(rewards)
716
- click.echo()
717
- click.echo(f"Mean reward: {click.style(f'{mean_reward:.3f}', fg=TEAL_RGB)}")
718
- click.echo(f"Latency: {click.style(f'{total_time:.1f}s', fg=TEAL_RGB)}")
719
- else:
720
- click.echo(click.style("\nNo successful rollouts completed.", fg="yellow"))
817
+ except Exception:
818
+ raise
721
819
 
722
- # Show billing summary if charges were flushed
723
- if flush_result and flush_result.get("flushed"):
724
- amount_cents = flush_result.get("amountCents", 0)
725
- total_tokens = flush_result.get("totalTokens", 0)
726
- click.echo(
727
- f"Billing: {click.style(f'${amount_cents/100:.2f}', fg=TEAL_RGB)} ({total_tokens:,} tokens)"
728
- )
820
+ # Show timing and cost
821
+ click.echo(f"Latency: {click.style(f'{total_time:.1f}s', fg=TEAL_RGB)}")
822
+ if total_charged > 0:
823
+ click.echo(f"Cost: {click.style(f'${total_charged:.4f}', fg=TEAL_RGB)}")