rnow 0.2.4__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnow/cli/commands.py +226 -84
- rnow/cli/test.py +536 -441
- rnow/core/__init__.py +4 -1
- rnow/core/reward.py +34 -3
- rnow/core/tool.py +29 -7
- rnow/models.py +88 -6
- rnow/templates/deepseek-aha/config.yml +1 -1
- rnow/templates/mcp-tavily/config.yml +1 -1
- rnow/templates/rl-single/config.yml +7 -7
- rnow/templates/rl-single/train.jsonl +0 -908
- rnow/templates/rl-tools/config.yml +1 -1
- rnow/templates/tutorial-reward/config.yml +7 -7
- rnow/templates/tutorial-reward/train.jsonl +0 -908
- rnow/templates/tutorial-tool/config.yml +1 -1
- {rnow-0.2.4.dist-info → rnow-0.3.9.dist-info}/METADATA +23 -9
- {rnow-0.2.4.dist-info → rnow-0.3.9.dist-info}/RECORD +22 -22
- /rnow/templates/rl-tools/{env.py → tools.py} +0 -0
- /rnow/templates/tutorial-tool/{env.py → tools.py} +0 -0
- {rnow-0.2.4.dist-info → rnow-0.3.9.dist-info}/WHEEL +0 -0
- {rnow-0.2.4.dist-info → rnow-0.3.9.dist-info}/entry_points.txt +0 -0
- {rnow-0.2.4.dist-info → rnow-0.3.9.dist-info}/licenses/LICENSE +0 -0
- {rnow-0.2.4.dist-info → rnow-0.3.9.dist-info}/top_level.txt +0 -0
rnow/cli/test.py
CHANGED
|
@@ -1,22 +1,26 @@
|
|
|
1
1
|
# rnow/cli/test.py
|
|
2
2
|
"""
|
|
3
|
-
Test command for running RL rollouts
|
|
3
|
+
Test command for running RL rollouts via API.
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
Uses the /api/rnow/rollout endpoint which runs rollouts on Cloud Run.
|
|
6
|
+
|
|
7
|
+
Modes:
|
|
8
|
+
- Default: Uses tinker models (requires auth)
|
|
9
|
+
- --smoke-test: Uses OpenAI gpt-5-nano (requires OPENAI_API_KEY)
|
|
6
10
|
"""
|
|
7
11
|
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
8
14
|
import asyncio
|
|
9
15
|
import itertools
|
|
10
16
|
import json
|
|
17
|
+
import os
|
|
11
18
|
import random
|
|
12
|
-
import re
|
|
13
19
|
import signal
|
|
14
20
|
import sys
|
|
15
21
|
import threading
|
|
16
22
|
import time
|
|
17
|
-
from collections.abc import Callable
|
|
18
23
|
from pathlib import Path
|
|
19
|
-
from string import Template
|
|
20
24
|
|
|
21
25
|
import click
|
|
22
26
|
import httpx
|
|
@@ -74,66 +78,86 @@ class Spinner:
|
|
|
74
78
|
|
|
75
79
|
|
|
76
80
|
from rnow.cli.common import require_auth
|
|
77
|
-
from rnow.
|
|
78
|
-
from rnow.core.tool import TOOL_REGISTRY, clear_tool_registry
|
|
79
|
-
from rnow.models import ProjectConfig, RewardArgs
|
|
81
|
+
from rnow.models import ProjectConfig
|
|
80
82
|
|
|
81
83
|
DEFAULT_API_URL = "https://www.reinforcenow.ai"
|
|
82
84
|
|
|
83
85
|
|
|
84
|
-
class
|
|
86
|
+
class RolloutClient:
|
|
85
87
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
Client for running rollouts via the /api/rnow/rollout endpoint.
|
|
89
|
+
|
|
90
|
+
Uses async polling: POST starts job, GET polls for results.
|
|
88
91
|
"""
|
|
89
92
|
|
|
90
|
-
def __init__(
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
api_base: str,
|
|
96
|
+
model: str,
|
|
97
|
+
max_tokens: int = 2048,
|
|
98
|
+
temperature: float = 1.0,
|
|
99
|
+
max_turns: int = 1,
|
|
100
|
+
termination_policy: str = "last_tool",
|
|
101
|
+
debug: bool = False,
|
|
102
|
+
smoke_test: bool = False,
|
|
103
|
+
openai_api_key: str | None = None,
|
|
104
|
+
mcp_url: str | list[str] | None = None,
|
|
105
|
+
):
|
|
91
106
|
self.api_base = api_base.rstrip("/")
|
|
92
107
|
self.model = model
|
|
93
108
|
self.max_tokens = max_tokens
|
|
94
109
|
self.temperature = temperature
|
|
110
|
+
self.max_turns = max_turns
|
|
111
|
+
self.termination_policy = termination_policy
|
|
112
|
+
self.debug = debug
|
|
113
|
+
self.smoke_test = smoke_test
|
|
114
|
+
self.openai_api_key = openai_api_key
|
|
115
|
+
self.mcp_url = mcp_url
|
|
95
116
|
self.auth_headers = get_auth_headers()
|
|
96
|
-
self.client = httpx.AsyncClient(timeout=
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
renderer_name = get_recommended_renderer_name(model)
|
|
108
|
-
self.renderer = renderers.get_renderer(renderer_name, self.tokenizer)
|
|
109
|
-
|
|
110
|
-
async def __call__(self, messages: list[dict], stop: list[str] | None = None) -> dict:
|
|
117
|
+
self.client = httpx.AsyncClient(timeout=60.0)
|
|
118
|
+
self.total_charged_dollars = 0.0
|
|
119
|
+
|
|
120
|
+
async def start_rollout(
|
|
121
|
+
self,
|
|
122
|
+
samples: list[dict],
|
|
123
|
+
tools_py_code: str | None = None,
|
|
124
|
+
rewards_py_code: str | None = None,
|
|
125
|
+
dockerfiles: dict[str, str] | None = None,
|
|
126
|
+
secrets: dict[str, str] | None = None,
|
|
127
|
+
) -> str:
|
|
111
128
|
"""
|
|
112
|
-
|
|
129
|
+
Start rollouts and return rollout ID immediately.
|
|
130
|
+
Use poll_rollout() to check for results.
|
|
113
131
|
"""
|
|
114
|
-
# Build model input using renderer
|
|
115
|
-
model_input = self.renderer.build_generation_prompt(messages)
|
|
116
|
-
tokens = model_input.to_ints()
|
|
117
|
-
|
|
118
|
-
# Get stop sequences from renderer if not provided
|
|
119
|
-
if stop is None:
|
|
120
|
-
stop = self.renderer.get_stop_sequences()
|
|
121
|
-
|
|
122
|
-
# Build request payload
|
|
123
132
|
payload = {
|
|
133
|
+
"samples": samples,
|
|
124
134
|
"model": self.model,
|
|
125
|
-
"tokens": tokens,
|
|
126
|
-
"stop": stop,
|
|
127
135
|
"max_tokens": self.max_tokens,
|
|
128
136
|
"temperature": self.temperature,
|
|
137
|
+
"max_turns": self.max_turns,
|
|
138
|
+
"termination_policy": self.termination_policy,
|
|
139
|
+
"tools_py_code": tools_py_code,
|
|
140
|
+
"rewards_py_code": rewards_py_code,
|
|
141
|
+
"debug": self.debug,
|
|
129
142
|
}
|
|
130
|
-
# Include session_id if we have one cached
|
|
131
|
-
if self.session_id:
|
|
132
|
-
payload["session_id"] = self.session_id
|
|
133
143
|
|
|
134
|
-
|
|
144
|
+
if self.mcp_url:
|
|
145
|
+
payload["mcp_url"] = self.mcp_url
|
|
146
|
+
|
|
147
|
+
# Send Dockerfiles for local/ images
|
|
148
|
+
if dockerfiles:
|
|
149
|
+
payload["dockerfiles"] = dockerfiles
|
|
150
|
+
|
|
151
|
+
# Send project secrets (from .env file)
|
|
152
|
+
if secrets:
|
|
153
|
+
payload["secrets"] = secrets
|
|
154
|
+
|
|
155
|
+
if self.smoke_test:
|
|
156
|
+
payload["smoke_test"] = True
|
|
157
|
+
payload["openai_api_key"] = self.openai_api_key
|
|
158
|
+
|
|
135
159
|
resp = await self.client.post(
|
|
136
|
-
f"{self.api_base}/api/rnow/
|
|
160
|
+
f"{self.api_base}/api/rnow/rollout",
|
|
137
161
|
json=payload,
|
|
138
162
|
headers=self.auth_headers,
|
|
139
163
|
)
|
|
@@ -143,94 +167,79 @@ class ModelCompleter:
|
|
|
143
167
|
if "error" in data:
|
|
144
168
|
raise Exception(f"API error: {data.get('detail', data.get('error'))}")
|
|
145
169
|
|
|
146
|
-
|
|
147
|
-
if "session_id" in data and data["session_id"]:
|
|
148
|
-
self.session_id = data["session_id"]
|
|
149
|
-
|
|
150
|
-
# Track latency
|
|
151
|
-
if "latency_ms" in data:
|
|
152
|
-
self.total_latency_ms += data["latency_ms"]
|
|
153
|
-
self.request_count += 1
|
|
170
|
+
return data["rollout_id"]
|
|
154
171
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
172
|
+
async def poll_rollout(self, rollout_id: str) -> dict:
|
|
173
|
+
"""Poll for rollout status. Returns dict with 'status' field."""
|
|
174
|
+
resp = await self.client.get(
|
|
175
|
+
f"{self.api_base}/api/rnow/rollout",
|
|
176
|
+
params={"id": rollout_id},
|
|
177
|
+
headers=self.auth_headers,
|
|
178
|
+
)
|
|
179
|
+
resp.raise_for_status()
|
|
180
|
+
return resp.json()
|
|
181
|
+
|
|
182
|
+
async def run_batch_rollouts(
|
|
183
|
+
self,
|
|
184
|
+
samples: list[dict],
|
|
185
|
+
tools_py_code: str | None = None,
|
|
186
|
+
rewards_py_code: str | None = None,
|
|
187
|
+
dockerfiles: dict[str, str] | None = None,
|
|
188
|
+
secrets: dict[str, str] | None = None,
|
|
189
|
+
spinner: Spinner | None = None,
|
|
190
|
+
timeout_minutes: int = 30,
|
|
191
|
+
) -> tuple[str, list[dict]]:
|
|
192
|
+
"""
|
|
193
|
+
Run rollouts with exponential backoff polling.
|
|
194
|
+
Returns (rollout_id, results).
|
|
195
|
+
"""
|
|
196
|
+
# Start the rollout
|
|
197
|
+
rollout_id = await self.start_rollout(
|
|
198
|
+
samples, tools_py_code, rewards_py_code, dockerfiles, secrets
|
|
199
|
+
)
|
|
158
200
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
"latency_ms": data.get("latency_ms", 0),
|
|
162
|
-
}
|
|
201
|
+
if spinner:
|
|
202
|
+
spinner.update(f"Running rollouts... (ID: {rollout_id[:8]})")
|
|
163
203
|
|
|
164
|
-
|
|
165
|
-
|
|
204
|
+
# Poll with exponential backoff
|
|
205
|
+
poll_interval = 2.0 # Start at 2 seconds
|
|
206
|
+
max_interval = 10.0 # Cap at 10 seconds
|
|
207
|
+
timeout = timeout_minutes * 60
|
|
208
|
+
start_time = time.time()
|
|
166
209
|
|
|
210
|
+
while time.time() - start_time < timeout:
|
|
211
|
+
if _shutdown_requested:
|
|
212
|
+
raise asyncio.CancelledError()
|
|
167
213
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
Returns the flush result or None if it failed.
|
|
172
|
-
"""
|
|
173
|
-
try:
|
|
174
|
-
headers = get_auth_headers()
|
|
175
|
-
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
176
|
-
resp = await client.post(
|
|
177
|
-
f"{api_url.rstrip('/')}/api/billing/flush-rollout",
|
|
178
|
-
headers=headers,
|
|
179
|
-
)
|
|
180
|
-
resp.raise_for_status()
|
|
181
|
-
return resp.json()
|
|
182
|
-
except Exception as e:
|
|
183
|
-
click.echo(click.style(f"Warning: Failed to flush pending charges: {e}", fg="yellow"))
|
|
184
|
-
return None
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
def _exec_file(path: Path, module_name: str) -> None:
|
|
188
|
-
"""Execute a Python file to populate registries."""
|
|
189
|
-
import importlib.util
|
|
190
|
-
|
|
191
|
-
spec = importlib.util.spec_from_file_location(module_name, path)
|
|
192
|
-
if spec is None or spec.loader is None:
|
|
193
|
-
raise ImportError(f"Could not load module from {path}")
|
|
194
|
-
module = importlib.util.module_from_spec(spec)
|
|
195
|
-
spec.loader.exec_module(module)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
def _build_tools_block(tool_registry: dict[str, Callable]) -> str:
|
|
199
|
-
"""Build the tools description block from registered tool functions."""
|
|
200
|
-
if not tool_registry:
|
|
201
|
-
return ""
|
|
202
|
-
|
|
203
|
-
tools_json = []
|
|
204
|
-
for name, fn in tool_registry.items():
|
|
205
|
-
schema = getattr(fn, "_schema", {"type": "object", "properties": {}})
|
|
206
|
-
description = getattr(fn, "_description", "No description available.")
|
|
207
|
-
tools_json.append(
|
|
208
|
-
{
|
|
209
|
-
"name": name,
|
|
210
|
-
"description": description,
|
|
211
|
-
"parameters": schema,
|
|
212
|
-
}
|
|
213
|
-
)
|
|
214
|
+
# Add jitter (±20%)
|
|
215
|
+
jitter = poll_interval * 0.2 * (random.random() * 2 - 1)
|
|
216
|
+
await asyncio.sleep(poll_interval + jitter)
|
|
214
217
|
|
|
215
|
-
|
|
218
|
+
result = await self.poll_rollout(rollout_id)
|
|
219
|
+
status = result.get("status")
|
|
216
220
|
|
|
217
|
-
|
|
221
|
+
if status == "completed":
|
|
222
|
+
# Track billing
|
|
223
|
+
if "billing" in result:
|
|
224
|
+
billing = result["billing"]
|
|
225
|
+
tokens = billing.get("prompt_tokens", 0) + billing.get("completion_tokens", 0)
|
|
226
|
+
self.total_charged_dollars += tokens * 0.000001
|
|
227
|
+
return rollout_id, result.get("results", [])
|
|
218
228
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
{json.dumps(tools_json, indent=2)}
|
|
222
|
-
</tools>
|
|
229
|
+
if status == "failed":
|
|
230
|
+
raise Exception(f"Rollout failed: {result.get('error', 'Unknown error')}")
|
|
223
231
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
{{"name": "<function-name>", "arguments": {{"<arg-name>": "<value>"}}}}
|
|
227
|
-
</tool_call>
|
|
228
|
-
"""
|
|
232
|
+
# Exponential backoff
|
|
233
|
+
poll_interval = min(poll_interval * 1.5, max_interval)
|
|
229
234
|
|
|
230
|
-
|
|
235
|
+
if spinner:
|
|
236
|
+
elapsed = int(time.time() - start_time)
|
|
237
|
+
spinner.update(f"Running rollouts... ({elapsed}s, ID: {rollout_id[:8]})")
|
|
231
238
|
|
|
239
|
+
raise TimeoutError(f"Rollout timed out after {timeout_minutes} minutes")
|
|
232
240
|
|
|
233
|
-
|
|
241
|
+
async def close(self):
|
|
242
|
+
await self.client.aclose()
|
|
234
243
|
|
|
235
244
|
|
|
236
245
|
def _format_message(msg: dict, max_len: int = 300) -> str:
|
|
@@ -247,154 +256,27 @@ def _format_message(msg: dict, max_len: int = 300) -> str:
|
|
|
247
256
|
|
|
248
257
|
|
|
249
258
|
async def _run_single_rollout(
|
|
250
|
-
|
|
259
|
+
client: RolloutClient,
|
|
251
260
|
sample: dict,
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
max_turns: int,
|
|
255
|
-
termination_policy: str,
|
|
261
|
+
tools_py_code: str | None,
|
|
262
|
+
rewards_py_code: str | None,
|
|
256
263
|
verbose: bool = False,
|
|
257
264
|
) -> dict:
|
|
258
|
-
"""Run a single rollout
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
reward_fns = []
|
|
267
|
-
for name in reward_names:
|
|
268
|
-
if name not in reward_registry:
|
|
269
|
-
raise ValueError(f"Reward function '{name}' not found in registry")
|
|
270
|
-
reward_fns.append(reward_registry[name])
|
|
271
|
-
|
|
272
|
-
ctx = {**metadata, **variables}
|
|
273
|
-
messages = [
|
|
274
|
-
{"role": msg["role"], "content": Template(msg["content"]).safe_substitute(ctx)}
|
|
275
|
-
for msg in messages_templates
|
|
276
|
-
]
|
|
277
|
-
|
|
278
|
-
if tool_registry:
|
|
279
|
-
tools_block = _build_tools_block(tool_registry)
|
|
280
|
-
system_found = False
|
|
281
|
-
for msg in messages:
|
|
282
|
-
if msg["role"] == "system":
|
|
283
|
-
msg["content"] = tools_block + "\n\n" + msg["content"]
|
|
284
|
-
system_found = True
|
|
285
|
-
break
|
|
286
|
-
if not system_found:
|
|
287
|
-
messages.insert(0, {"role": "system", "content": tools_block})
|
|
288
|
-
|
|
289
|
-
conversation = messages.copy()
|
|
290
|
-
turn_count = 0
|
|
291
|
-
total_tool_calls = 0
|
|
292
|
-
|
|
293
|
-
# Show initial messages in verbose mode
|
|
265
|
+
"""Run a single rollout via the API."""
|
|
266
|
+
result = await client.run_rollout(
|
|
267
|
+
sample=sample,
|
|
268
|
+
tools_py_code=tools_py_code,
|
|
269
|
+
rewards_py_code=rewards_py_code,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Show conversation in verbose mode
|
|
294
273
|
if verbose:
|
|
295
|
-
click.echo(" ---
|
|
296
|
-
for msg in
|
|
274
|
+
click.echo(" --- Conversation ---")
|
|
275
|
+
for msg in result.get("conversation", []):
|
|
297
276
|
click.echo(f" {_format_message(msg)}")
|
|
298
|
-
click.echo("
|
|
299
|
-
|
|
300
|
-
while turn_count < max_turns:
|
|
301
|
-
turn_count += 1
|
|
277
|
+
click.echo(" ---------------------")
|
|
302
278
|
|
|
303
|
-
|
|
304
|
-
response_content = result.get("content", "")
|
|
305
|
-
|
|
306
|
-
conversation.append({"role": "assistant", "content": response_content})
|
|
307
|
-
|
|
308
|
-
if verbose:
|
|
309
|
-
click.echo(
|
|
310
|
-
f" [Turn {turn_count}] {_format_message({'role': 'assistant', 'content': response_content}, max_len=500)}"
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
tool_matches = TOOL_CALL_RE.findall(response_content)
|
|
314
|
-
tool_call_count = len(tool_matches)
|
|
315
|
-
total_tool_calls += tool_call_count
|
|
316
|
-
|
|
317
|
-
for raw_call in tool_matches:
|
|
318
|
-
if not tool_registry:
|
|
319
|
-
break
|
|
320
|
-
try:
|
|
321
|
-
tool_data = json.loads(raw_call)
|
|
322
|
-
tool_name = tool_data.get("name")
|
|
323
|
-
args = tool_data.get("arguments", {})
|
|
324
|
-
|
|
325
|
-
if tool_name not in tool_registry:
|
|
326
|
-
tool_response = f"<tool_error>Tool '{tool_name}' not found</tool_error>"
|
|
327
|
-
conversation.append({"role": "tool", "content": tool_response})
|
|
328
|
-
if verbose:
|
|
329
|
-
click.echo(
|
|
330
|
-
f" {_format_message({'role': 'tool', 'content': tool_response})}"
|
|
331
|
-
)
|
|
332
|
-
continue
|
|
333
|
-
|
|
334
|
-
tool_fn = tool_registry[tool_name]
|
|
335
|
-
tool_result = (
|
|
336
|
-
await tool_fn(**args)
|
|
337
|
-
if inspect.iscoroutinefunction(tool_fn)
|
|
338
|
-
else tool_fn(**args)
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
tool_response = f"<tool_result>{json.dumps(tool_result)}</tool_result>"
|
|
342
|
-
conversation.append({"role": "tool", "content": tool_response})
|
|
343
|
-
|
|
344
|
-
if verbose:
|
|
345
|
-
click.echo(
|
|
346
|
-
f" Tool {click.style(tool_name, fg=TEAL_RGB)}: {str(tool_result)[:200]}"
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
except json.JSONDecodeError as e:
|
|
350
|
-
tool_response = f"<tool_error>Invalid JSON: {str(e)}</tool_error>"
|
|
351
|
-
conversation.append({"role": "tool", "content": tool_response})
|
|
352
|
-
if verbose:
|
|
353
|
-
click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
|
|
354
|
-
except Exception as e:
|
|
355
|
-
tool_response = f"<tool_error>{str(e)}</tool_error>"
|
|
356
|
-
conversation.append({"role": "tool", "content": tool_response})
|
|
357
|
-
if verbose:
|
|
358
|
-
click.echo(f" {_format_message({'role': 'tool', 'content': tool_response})}")
|
|
359
|
-
|
|
360
|
-
if termination_policy == "last_tool" and tool_call_count == 0:
|
|
361
|
-
break
|
|
362
|
-
|
|
363
|
-
# Show final conversation summary in verbose mode
|
|
364
|
-
if verbose:
|
|
365
|
-
click.echo(f" --- Rollout Complete: {turn_count} turns, {total_tool_calls} tool calls ---")
|
|
366
|
-
|
|
367
|
-
reward_args = RewardArgs(metadata=metadata, variables=variables)
|
|
368
|
-
rewards = {}
|
|
369
|
-
for fn, name in zip(reward_fns, reward_names, strict=False):
|
|
370
|
-
value = await fn(reward_args, conversation)
|
|
371
|
-
rewards[name] = value
|
|
372
|
-
|
|
373
|
-
total_reward = compute_total_reward(rewards) if rewards else 0.0
|
|
374
|
-
|
|
375
|
-
return {
|
|
376
|
-
"total_reward": total_reward,
|
|
377
|
-
"rewards": rewards,
|
|
378
|
-
"turns": turn_count,
|
|
379
|
-
"tools_used": total_tool_calls,
|
|
380
|
-
"conversation": conversation,
|
|
381
|
-
}
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
def _check_test_dependencies():
|
|
385
|
-
"""Check if optional test dependencies are installed."""
|
|
386
|
-
try:
|
|
387
|
-
import tinker_cookbook # noqa: F401
|
|
388
|
-
except ImportError:
|
|
389
|
-
click.echo()
|
|
390
|
-
click.echo(
|
|
391
|
-
click.style("Error: ", fg="red", bold=True)
|
|
392
|
-
+ "The 'rnow test' command requires additional dependencies."
|
|
393
|
-
)
|
|
394
|
-
pip_cmd = "pip install 'rnow[test]'"
|
|
395
|
-
click.echo(f"Install them with: {click.style(pip_cmd, fg=TEAL_RGB)}")
|
|
396
|
-
click.echo()
|
|
397
|
-
raise SystemExit(1)
|
|
279
|
+
return result
|
|
398
280
|
|
|
399
281
|
|
|
400
282
|
@click.command(name="test")
|
|
@@ -404,12 +286,12 @@ def _check_test_dependencies():
|
|
|
404
286
|
"project_dir",
|
|
405
287
|
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
|
|
406
288
|
default=".",
|
|
407
|
-
help="Project directory containing config.yml, rewards.py,
|
|
289
|
+
help="Project directory containing config.yml, rewards.py, tools.py, train.jsonl",
|
|
408
290
|
)
|
|
409
291
|
@click.option(
|
|
410
292
|
"--num-rollouts",
|
|
411
293
|
"-n",
|
|
412
|
-
default=
|
|
294
|
+
default=1,
|
|
413
295
|
show_default=True,
|
|
414
296
|
help="Number of rollouts to run",
|
|
415
297
|
)
|
|
@@ -449,51 +331,290 @@ def _check_test_dependencies():
|
|
|
449
331
|
type=int,
|
|
450
332
|
help="Truncate message content to N characters (default: no truncation)",
|
|
451
333
|
)
|
|
334
|
+
@click.option(
|
|
335
|
+
"--debug",
|
|
336
|
+
is_flag=True,
|
|
337
|
+
help="Use debug trainer image from Docker Hub (for testing trainer changes)",
|
|
338
|
+
)
|
|
339
|
+
@click.option(
|
|
340
|
+
"--output-dir",
|
|
341
|
+
"-o",
|
|
342
|
+
"output_dir",
|
|
343
|
+
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
|
|
344
|
+
default=None,
|
|
345
|
+
help="Save rollout results as JSON files in this directory",
|
|
346
|
+
)
|
|
347
|
+
@click.option(
|
|
348
|
+
"--smoke-test",
|
|
349
|
+
is_flag=True,
|
|
350
|
+
help="Use OpenAI gpt-5-nano instead of tinker (requires OPENAI_API_KEY env var)",
|
|
351
|
+
)
|
|
352
|
+
@click.option(
|
|
353
|
+
"--id",
|
|
354
|
+
"rollout_id",
|
|
355
|
+
default=None,
|
|
356
|
+
help="Fetch results for an existing rollout ID (skip running new rollout)",
|
|
357
|
+
)
|
|
358
|
+
@click.option(
|
|
359
|
+
"--store",
|
|
360
|
+
is_flag=True,
|
|
361
|
+
help="Store rollout ID in ./rollouts/<id>.txt for later retrieval",
|
|
362
|
+
)
|
|
363
|
+
@click.option(
|
|
364
|
+
"--timeout",
|
|
365
|
+
default=60,
|
|
366
|
+
show_default=True,
|
|
367
|
+
help="Timeout in minutes for polling results",
|
|
368
|
+
)
|
|
369
|
+
@click.option(
|
|
370
|
+
"--entry",
|
|
371
|
+
"-e",
|
|
372
|
+
"entries",
|
|
373
|
+
default=None,
|
|
374
|
+
help="Entry indices from train.jsonl (0-indexed). Examples: -e 5, -e 0,2,5, -e 0 -e 2 -e 5",
|
|
375
|
+
multiple=True,
|
|
376
|
+
)
|
|
452
377
|
@click.pass_context
|
|
453
|
-
def test(
|
|
454
|
-
|
|
378
|
+
def test(
|
|
379
|
+
ctx,
|
|
380
|
+
project_dir,
|
|
381
|
+
num_rollouts,
|
|
382
|
+
multi_turn,
|
|
383
|
+
with_tools,
|
|
384
|
+
model,
|
|
385
|
+
api_url,
|
|
386
|
+
verbose,
|
|
387
|
+
truncate,
|
|
388
|
+
debug,
|
|
389
|
+
output_dir,
|
|
390
|
+
smoke_test,
|
|
391
|
+
rollout_id,
|
|
392
|
+
store,
|
|
393
|
+
timeout,
|
|
394
|
+
entries,
|
|
395
|
+
):
|
|
396
|
+
"""Test RL rollouts before submitting.
|
|
397
|
+
|
|
398
|
+
Runs rollouts via the /api/rnow/rollout endpoint on Cloud Run.
|
|
399
|
+
|
|
400
|
+
Use --smoke-test to use OpenAI gpt-5-nano instead of tinker models
|
|
401
|
+
(requires OPENAI_API_KEY environment variable).
|
|
455
402
|
|
|
456
|
-
|
|
457
|
-
for model sampling.
|
|
403
|
+
Use --id to fetch results for an existing rollout.
|
|
458
404
|
|
|
459
405
|
Only works with RL projects (dataset_type: rl).
|
|
460
406
|
"""
|
|
461
407
|
global _shutdown_requested
|
|
462
408
|
_shutdown_requested = False
|
|
463
409
|
|
|
464
|
-
|
|
465
|
-
global _shutdown_requested
|
|
466
|
-
if _shutdown_requested:
|
|
467
|
-
# Second Ctrl+C, force exit
|
|
468
|
-
sys.exit(1)
|
|
469
|
-
_shutdown_requested = True
|
|
470
|
-
click.echo("\n" + click.style("Interrupted. Shutting down gracefully...", fg="yellow"))
|
|
471
|
-
|
|
472
|
-
# Set up signal handler
|
|
473
|
-
original_handler = signal.signal(signal.SIGINT, handle_sigint)
|
|
410
|
+
resolved_api_url = api_url or ctx.obj.get("api_url", "").replace("/api", "") or DEFAULT_API_URL
|
|
474
411
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
try:
|
|
412
|
+
# Handle --id flag: just fetch existing rollout results
|
|
413
|
+
if rollout_id:
|
|
478
414
|
asyncio.run(
|
|
479
|
-
|
|
415
|
+
_fetch_rollout_results(
|
|
416
|
+
rollout_id=rollout_id,
|
|
417
|
+
api_url=resolved_api_url,
|
|
418
|
+
store=store,
|
|
419
|
+
truncate=truncate,
|
|
420
|
+
output_dir=output_dir,
|
|
421
|
+
)
|
|
422
|
+
)
|
|
423
|
+
return
|
|
424
|
+
|
|
425
|
+
# Check for OpenAI API key in smoke test mode
|
|
426
|
+
openai_api_key = None
|
|
427
|
+
if smoke_test:
|
|
428
|
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
|
429
|
+
if not openai_api_key:
|
|
430
|
+
raise click.ClickException(
|
|
431
|
+
"OPENAI_API_KEY environment variable is required for smoke test mode.\n"
|
|
432
|
+
"Set it with: export OPENAI_API_KEY=sk-..."
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
require_auth()
|
|
436
|
+
|
|
437
|
+
async def run_with_cancellation():
|
|
438
|
+
"""Run test with proper cancellation support."""
|
|
439
|
+
loop = asyncio.get_running_loop()
|
|
440
|
+
task = asyncio.current_task()
|
|
441
|
+
|
|
442
|
+
def handle_sigint():
|
|
443
|
+
global _shutdown_requested
|
|
444
|
+
if _shutdown_requested:
|
|
445
|
+
sys.exit(1)
|
|
446
|
+
_shutdown_requested = True
|
|
447
|
+
click.echo("\n" + click.style("Interrupted. Cancelling...", fg="yellow"))
|
|
448
|
+
task.cancel()
|
|
449
|
+
|
|
450
|
+
loop.add_signal_handler(signal.SIGINT, handle_sigint)
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
await _test_async(
|
|
480
454
|
project_dir=project_dir,
|
|
481
455
|
num_rollouts=num_rollouts,
|
|
482
456
|
multi_turn=multi_turn,
|
|
483
457
|
with_tools=with_tools,
|
|
484
458
|
model_override=model,
|
|
485
|
-
api_url=
|
|
486
|
-
or ctx.obj.get("api_url", "").replace("/api", "")
|
|
487
|
-
or DEFAULT_API_URL,
|
|
459
|
+
api_url=resolved_api_url,
|
|
488
460
|
verbose=verbose,
|
|
489
461
|
truncate=truncate,
|
|
462
|
+
debug=debug,
|
|
463
|
+
output_dir=output_dir,
|
|
464
|
+
smoke_test=smoke_test,
|
|
465
|
+
openai_api_key=openai_api_key,
|
|
466
|
+
store=store,
|
|
467
|
+
timeout_minutes=timeout,
|
|
468
|
+
entries=entries,
|
|
490
469
|
)
|
|
491
|
-
|
|
470
|
+
except asyncio.CancelledError:
|
|
471
|
+
click.echo(click.style("Aborted.", fg="yellow"))
|
|
472
|
+
finally:
|
|
473
|
+
loop.remove_signal_handler(signal.SIGINT)
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
asyncio.run(run_with_cancellation())
|
|
492
477
|
except KeyboardInterrupt:
|
|
493
478
|
click.echo(click.style("Aborted.", fg="yellow"))
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
async def _fetch_rollout_results(
|
|
482
|
+
rollout_id: str,
|
|
483
|
+
api_url: str,
|
|
484
|
+
store: bool = False,
|
|
485
|
+
truncate: int | None = None,
|
|
486
|
+
output_dir: Path | None = None,
|
|
487
|
+
):
|
|
488
|
+
"""Fetch results for an existing rollout ID."""
|
|
489
|
+
click.echo(f"Fetching results for rollout: {click.style(rollout_id, fg=TEAL_RGB)}")
|
|
490
|
+
|
|
491
|
+
client = httpx.AsyncClient(timeout=30.0)
|
|
492
|
+
auth_headers = get_auth_headers()
|
|
493
|
+
|
|
494
|
+
try:
|
|
495
|
+
resp = await client.get(
|
|
496
|
+
f"{api_url}/api/rnow/rollout",
|
|
497
|
+
params={"id": rollout_id},
|
|
498
|
+
headers=auth_headers,
|
|
499
|
+
)
|
|
500
|
+
resp.raise_for_status()
|
|
501
|
+
data = resp.json()
|
|
494
502
|
finally:
|
|
495
|
-
|
|
496
|
-
|
|
503
|
+
await client.aclose()
|
|
504
|
+
|
|
505
|
+
status = data.get("status")
|
|
506
|
+
if status == "pending":
|
|
507
|
+
click.echo(click.style("Rollout still running...", fg="yellow"))
|
|
508
|
+
click.echo(f"Poll again with: rnow test --id {rollout_id}")
|
|
509
|
+
return
|
|
510
|
+
|
|
511
|
+
if status == "failed":
|
|
512
|
+
click.echo(click.style(f"Rollout failed: {data.get('error', 'Unknown')}", fg="red"))
|
|
513
|
+
return
|
|
514
|
+
|
|
515
|
+
# Store rollout ID if requested
|
|
516
|
+
if store:
|
|
517
|
+
_store_rollout_id(rollout_id, data)
|
|
518
|
+
|
|
519
|
+
# Display results
|
|
520
|
+
results = data.get("results", [])
|
|
521
|
+
_display_results(results, truncate, output_dir, rollout_id)
|
|
522
|
+
|
|
523
|
+
# Show billing
|
|
524
|
+
billing = data.get("billing", {})
|
|
525
|
+
tokens = billing.get("prompt_tokens", 0) + billing.get("completion_tokens", 0)
|
|
526
|
+
if tokens > 0:
|
|
527
|
+
click.echo(f"Tokens: {tokens}")
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def _store_rollout_id(rollout_id: str, data: dict):
|
|
531
|
+
"""Store rollout ID and results in ./rollouts/<id>.txt"""
|
|
532
|
+
rollouts_dir = Path("rollouts")
|
|
533
|
+
rollouts_dir.mkdir(exist_ok=True)
|
|
534
|
+
|
|
535
|
+
filepath = rollouts_dir / f"{rollout_id}.txt"
|
|
536
|
+
with open(filepath, "w") as f:
|
|
537
|
+
f.write(f"Rollout ID: {rollout_id}\n")
|
|
538
|
+
f.write(f"Status: {data.get('status', 'unknown')}\n")
|
|
539
|
+
f.write(f"S3 Path: rollouts/{rollout_id}/result.json\n")
|
|
540
|
+
f.write("\n")
|
|
541
|
+
|
|
542
|
+
# Write summary
|
|
543
|
+
results = data.get("results", [])
|
|
544
|
+
successful = [r for r in results if r.get("success")]
|
|
545
|
+
if successful:
|
|
546
|
+
rewards = [r.get("total_reward", 0) for r in successful]
|
|
547
|
+
f.write(f"Successful: {len(successful)}/{len(results)}\n")
|
|
548
|
+
f.write(f"Mean Reward: {sum(rewards) / len(rewards):.3f}\n")
|
|
549
|
+
|
|
550
|
+
# Write billing
|
|
551
|
+
billing = data.get("billing", {})
|
|
552
|
+
tokens = billing.get("prompt_tokens", 0) + billing.get("completion_tokens", 0)
|
|
553
|
+
if tokens > 0:
|
|
554
|
+
f.write(f"Tokens: {tokens}\n")
|
|
555
|
+
|
|
556
|
+
f.write("\n--- Full Results ---\n")
|
|
557
|
+
f.write(json.dumps(data, indent=2))
|
|
558
|
+
|
|
559
|
+
click.echo(f"Stored: {click.style(str(filepath), fg=TEAL_RGB)}")
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _display_results(
|
|
563
|
+
results: list[dict],
|
|
564
|
+
truncate: int | None,
|
|
565
|
+
output_dir: Path | None,
|
|
566
|
+
rollout_id: str | None = None,
|
|
567
|
+
):
|
|
568
|
+
"""Display rollout results."""
|
|
569
|
+
rewards = []
|
|
570
|
+
|
|
571
|
+
for idx, result in enumerate(results):
|
|
572
|
+
click.echo(f"Rollout {idx + 1}/{len(results)}")
|
|
573
|
+
|
|
574
|
+
if not result.get("success"):
|
|
575
|
+
click.echo(click.style(f" ✗ {result.get('error', 'Unknown error')}", fg="red"))
|
|
576
|
+
click.echo()
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
total_reward = result.get("total_reward", 0.0)
|
|
580
|
+
rewards.append(total_reward)
|
|
581
|
+
|
|
582
|
+
# Show conversation
|
|
583
|
+
for msg in result.get("conversation", []):
|
|
584
|
+
role = msg.get("role", "unknown")
|
|
585
|
+
content = msg.get("content", "")
|
|
586
|
+
if truncate and len(content) > truncate:
|
|
587
|
+
content = content[:truncate] + "..."
|
|
588
|
+
tag = click.style(f"[{role}]", fg="red")
|
|
589
|
+
click.echo(f" {tag} {content}")
|
|
590
|
+
|
|
591
|
+
reward_breakdown = result.get("rewards", {})
|
|
592
|
+
reward_str = ", ".join(f"{k}={v:.3f}" for k, v in reward_breakdown.items())
|
|
593
|
+
turns = result.get("turns", 0)
|
|
594
|
+
click.echo(
|
|
595
|
+
f" {click.style('reward', fg=TEAL_RGB)}={total_reward:.3f} "
|
|
596
|
+
f"| turns={turns} "
|
|
597
|
+
f"| [{reward_str}]"
|
|
598
|
+
)
|
|
599
|
+
click.echo()
|
|
600
|
+
|
|
601
|
+
# Save to files if requested
|
|
602
|
+
if output_dir and results:
|
|
603
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
604
|
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
605
|
+
for idx, result in enumerate(results):
|
|
606
|
+
if result.get("success"):
|
|
607
|
+
filename = output_dir / f"rollout_{timestamp}_{idx + 1}.json"
|
|
608
|
+
filename.write_text(json.dumps(result, indent=2))
|
|
609
|
+
click.echo(f"Results saved to {click.style(str(output_dir), fg=TEAL_RGB)}")
|
|
610
|
+
|
|
611
|
+
# Summary
|
|
612
|
+
if rewards:
|
|
613
|
+
mean_reward = sum(rewards) / len(rewards)
|
|
614
|
+
click.echo()
|
|
615
|
+
click.echo(f"Mean reward: {click.style(f'{mean_reward:.3f}', fg=TEAL_RGB)}")
|
|
616
|
+
if rollout_id:
|
|
617
|
+
click.echo(f"Rollout ID: {click.style(rollout_id, fg=TEAL_RGB)}")
|
|
497
618
|
|
|
498
619
|
|
|
499
620
|
async def _test_async(
|
|
@@ -505,6 +626,13 @@ async def _test_async(
|
|
|
505
626
|
api_url: str,
|
|
506
627
|
verbose: bool,
|
|
507
628
|
truncate: int | None,
|
|
629
|
+
debug: bool = False,
|
|
630
|
+
output_dir: Path | None = None,
|
|
631
|
+
smoke_test: bool = False,
|
|
632
|
+
openai_api_key: str | None = None,
|
|
633
|
+
store: bool = False,
|
|
634
|
+
timeout_minutes: int = 60,
|
|
635
|
+
entries: tuple[int, ...] = (),
|
|
508
636
|
):
|
|
509
637
|
project_dir = Path(project_dir)
|
|
510
638
|
|
|
@@ -529,7 +657,7 @@ async def _test_async(
|
|
|
529
657
|
)
|
|
530
658
|
|
|
531
659
|
rewards_path = project_dir / "rewards.py"
|
|
532
|
-
|
|
660
|
+
tools_path = project_dir / "tools.py"
|
|
533
661
|
train_path = project_dir / "train.jsonl"
|
|
534
662
|
|
|
535
663
|
if not rewards_path.exists():
|
|
@@ -537,192 +665,159 @@ async def _test_async(
|
|
|
537
665
|
if not train_path.exists():
|
|
538
666
|
raise click.ClickException("train.jsonl not found in project directory")
|
|
539
667
|
|
|
540
|
-
#
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
if config.rollout:
|
|
545
|
-
max_prompt_tokens = get_max_prompt_tokens(train_path)
|
|
546
|
-
if max_prompt_tokens > 0:
|
|
547
|
-
context_error, recommended = validate_max_tokens_for_context(
|
|
548
|
-
config.rollout.max_tokens, max_prompt_tokens
|
|
549
|
-
)
|
|
550
|
-
if context_error:
|
|
551
|
-
click.echo()
|
|
552
|
-
click.echo(click.style("✗ Context window exceeded", fg="red", bold=True))
|
|
553
|
-
click.echo()
|
|
554
|
-
click.echo(
|
|
555
|
-
f" Your longest prompt in train.jsonl is ~{max_prompt_tokens:,} tokens."
|
|
556
|
-
)
|
|
557
|
-
click.echo(f" With max_tokens={config.rollout.max_tokens:,}, the total exceeds")
|
|
558
|
-
click.echo(f" the {MAX_CONTEXT_WINDOW:,} token context window.")
|
|
559
|
-
click.echo()
|
|
560
|
-
click.echo(
|
|
561
|
-
click.style(" Fix:", bold=True)
|
|
562
|
-
+ f" Set rollout.max_tokens to {recommended:,} or less"
|
|
563
|
-
)
|
|
564
|
-
click.echo()
|
|
565
|
-
raise click.ClickException("max_tokens + prompt length exceeds context window")
|
|
566
|
-
|
|
567
|
-
clear_reward_registry()
|
|
568
|
-
clear_tool_registry()
|
|
569
|
-
|
|
570
|
-
_exec_file(rewards_path, "rewards")
|
|
571
|
-
|
|
572
|
-
if with_tools and env_path.exists():
|
|
573
|
-
_exec_file(env_path, "env")
|
|
668
|
+
# Read user code files to send to the API
|
|
669
|
+
rewards_py_code = rewards_path.read_text()
|
|
670
|
+
tools_py_code = tools_path.read_text() if with_tools and tools_path.exists() else None
|
|
574
671
|
|
|
672
|
+
# Load samples
|
|
575
673
|
samples = [json.loads(line) for line in train_path.read_text().splitlines() if line.strip()]
|
|
576
674
|
|
|
675
|
+
# Read Dockerfile.* files for local/ docker images
|
|
676
|
+
dockerfiles: dict[str, str] = {}
|
|
677
|
+
for dockerfile_path in project_dir.glob("Dockerfile.*"):
|
|
678
|
+
dockerfiles[dockerfile_path.name] = dockerfile_path.read_text()
|
|
679
|
+
click.echo(f" Found {dockerfile_path.name}")
|
|
680
|
+
|
|
681
|
+
# Read .env file for project secrets
|
|
682
|
+
project_secrets: dict[str, str] = {}
|
|
683
|
+
env_path = project_dir / ".env"
|
|
684
|
+
if env_path.exists():
|
|
685
|
+
for line in env_path.read_text().splitlines():
|
|
686
|
+
line = line.strip()
|
|
687
|
+
if line and not line.startswith("#") and "=" in line:
|
|
688
|
+
key, _, value = line.partition("=")
|
|
689
|
+
# Remove quotes if present
|
|
690
|
+
value = value.strip().strip("'\"")
|
|
691
|
+
project_secrets[key.strip()] = value
|
|
692
|
+
if project_secrets:
|
|
693
|
+
click.echo(f" Loaded secrets: {list(project_secrets.keys())}")
|
|
694
|
+
|
|
577
695
|
if not samples:
|
|
578
696
|
raise click.ClickException("train.jsonl is empty")
|
|
579
697
|
|
|
580
|
-
|
|
698
|
+
# For smoke test, always use gpt-5-nano
|
|
699
|
+
model_name = "gpt-5-nano" if smoke_test else model_override or config.model.path
|
|
700
|
+
|
|
581
701
|
max_tokens = config.rollout.max_tokens if config.rollout else 2048
|
|
582
702
|
max_turns_config = config.rollout.max_turns if config.rollout else 1
|
|
583
703
|
termination_policy = config.rollout.termination_policy if config.rollout else "last_tool"
|
|
704
|
+
mcp_url = config.rollout.mcp_url if config.rollout else None
|
|
584
705
|
|
|
585
706
|
max_turns = 1 if not multi_turn else max_turns_config
|
|
586
707
|
|
|
587
|
-
#
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
click.echo(
|
|
593
|
-
click.style("Warning: ", fg="yellow")
|
|
594
|
-
+ "Tool calling with gpt-oss models is not supported in 'rnow test'. Running without tools."
|
|
595
|
-
)
|
|
596
|
-
with_tools = False
|
|
597
|
-
|
|
598
|
-
rewards = []
|
|
599
|
-
tool_registry_to_use = TOOL_REGISTRY if with_tools else {}
|
|
708
|
+
# Display mode and model info
|
|
709
|
+
if smoke_test:
|
|
710
|
+
click.echo(f"Mode: {click.style('SMOKE TEST', fg=TEAL_RGB)} (OpenAI gpt-5-nano)")
|
|
711
|
+
else:
|
|
712
|
+
thinking_display = get_thinking_mode_display(config)
|
|
713
|
+
click.echo(f"Model: {model_name} ({click.style(thinking_display, fg=TEAL_RGB)})")
|
|
600
714
|
|
|
601
|
-
# Display model info with reasoning mode (same format as rnow run)
|
|
602
|
-
thinking_display = get_thinking_mode_display(config)
|
|
603
|
-
click.echo(f"Model: {model_name} ({click.style(thinking_display, fg=TEAL_RGB)})")
|
|
604
715
|
click.echo()
|
|
605
716
|
|
|
606
717
|
try:
|
|
607
|
-
# Create one
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
718
|
+
# Create one RolloutClient for all rollouts
|
|
719
|
+
client = RolloutClient(
|
|
720
|
+
api_base=api_url,
|
|
721
|
+
model=model_name,
|
|
722
|
+
max_tokens=max_tokens,
|
|
723
|
+
temperature=1.0,
|
|
724
|
+
max_turns=max_turns,
|
|
725
|
+
termination_policy=termination_policy,
|
|
726
|
+
debug=debug,
|
|
727
|
+
smoke_test=smoke_test,
|
|
728
|
+
openai_api_key=openai_api_key,
|
|
729
|
+
mcp_url=mcp_url,
|
|
730
|
+
)
|
|
619
731
|
|
|
620
|
-
#
|
|
621
|
-
|
|
732
|
+
# Select samples for batch rollout
|
|
733
|
+
if entries:
|
|
734
|
+
# Parse entries - support both "-e 0 -e 2" and "-e 0,2,5"
|
|
735
|
+
entry_indices = []
|
|
736
|
+
for entry in entries:
|
|
737
|
+
# Handle comma-separated values
|
|
738
|
+
for part in str(entry).split(","):
|
|
739
|
+
part = part.strip()
|
|
740
|
+
if part:
|
|
741
|
+
try:
|
|
742
|
+
idx = int(part)
|
|
743
|
+
except ValueError:
|
|
744
|
+
raise click.ClickException(f"Invalid entry index: {part}")
|
|
745
|
+
if idx < 0 or idx >= len(samples):
|
|
746
|
+
raise click.ClickException(
|
|
747
|
+
f"Entry index {idx} out of range. train.jsonl has {len(samples)} entries (0-{len(samples) - 1})"
|
|
748
|
+
)
|
|
749
|
+
entry_indices.append(idx)
|
|
750
|
+
|
|
751
|
+
if not entry_indices:
|
|
752
|
+
raise click.ClickException("No valid entry indices provided")
|
|
753
|
+
|
|
754
|
+
selected_samples = [samples[idx] for idx in entry_indices]
|
|
755
|
+
click.echo(f"Testing entries: {entry_indices}")
|
|
756
|
+
else:
|
|
757
|
+
# Random selection
|
|
758
|
+
selected_samples = [random.choice(samples) for _ in range(num_rollouts)]
|
|
759
|
+
|
|
760
|
+
# Start spinner for batch rollout
|
|
761
|
+
spinner = Spinner(f"Starting {len(selected_samples)} rollouts...")
|
|
622
762
|
spinner.start()
|
|
623
763
|
|
|
624
|
-
async def run_rollout_with_index(idx: int) -> tuple[int, dict | Exception]:
|
|
625
|
-
"""Run a single rollout and return (index, result or exception)."""
|
|
626
|
-
if _shutdown_requested:
|
|
627
|
-
return (idx, asyncio.CancelledError("Shutdown requested"))
|
|
628
|
-
try:
|
|
629
|
-
result = await _run_single_rollout(
|
|
630
|
-
completer=completers[idx],
|
|
631
|
-
sample=selected_samples[idx],
|
|
632
|
-
reward_registry=REWARD_REGISTRY,
|
|
633
|
-
tool_registry=tool_registry_to_use,
|
|
634
|
-
max_turns=max_turns,
|
|
635
|
-
termination_policy=termination_policy,
|
|
636
|
-
verbose=False,
|
|
637
|
-
)
|
|
638
|
-
return (idx, result)
|
|
639
|
-
except asyncio.CancelledError:
|
|
640
|
-
return (idx, asyncio.CancelledError("Cancelled"))
|
|
641
|
-
except Exception as e:
|
|
642
|
-
return (idx, e)
|
|
643
|
-
|
|
644
|
-
# Run all rollouts concurrently
|
|
645
764
|
start_time = time.time()
|
|
646
|
-
|
|
765
|
+
rollout_id = None
|
|
647
766
|
|
|
648
767
|
try:
|
|
649
|
-
results
|
|
768
|
+
# Start rollout and poll for results with exponential backoff
|
|
769
|
+
rollout_id, batch_results = await client.run_batch_rollouts(
|
|
770
|
+
samples=selected_samples,
|
|
771
|
+
tools_py_code=tools_py_code,
|
|
772
|
+
rewards_py_code=rewards_py_code,
|
|
773
|
+
dockerfiles=dockerfiles if dockerfiles else None,
|
|
774
|
+
secrets=project_secrets if project_secrets else None,
|
|
775
|
+
spinner=spinner,
|
|
776
|
+
timeout_minutes=timeout_minutes,
|
|
777
|
+
)
|
|
650
778
|
except asyncio.CancelledError:
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
779
|
+
batch_results = []
|
|
780
|
+
except Exception as e:
|
|
781
|
+
spinner.stop()
|
|
782
|
+
raise e
|
|
655
783
|
|
|
656
784
|
total_time = time.time() - start_time
|
|
657
785
|
spinner.stop()
|
|
658
786
|
|
|
787
|
+
# Show rollout ID
|
|
788
|
+
if rollout_id:
|
|
789
|
+
click.echo(f"Rollout ID: {click.style(rollout_id, fg=TEAL_RGB)}")
|
|
790
|
+
click.echo()
|
|
791
|
+
|
|
659
792
|
# Check if shutdown was requested
|
|
660
793
|
if _shutdown_requested:
|
|
661
|
-
|
|
662
|
-
for c in completers:
|
|
663
|
-
await c.close()
|
|
794
|
+
await client.close()
|
|
664
795
|
return
|
|
665
796
|
|
|
666
|
-
#
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
else:
|
|
676
|
-
click.echo(click.style(f" ✗ {result}", fg="red"))
|
|
677
|
-
click.echo()
|
|
678
|
-
continue
|
|
679
|
-
|
|
680
|
-
total_reward = result["total_reward"]
|
|
681
|
-
rewards.append(total_reward)
|
|
682
|
-
|
|
683
|
-
# Get conversation
|
|
684
|
-
conversation = result["conversation"]
|
|
685
|
-
|
|
686
|
-
# Show all messages with red tags
|
|
687
|
-
for msg in conversation:
|
|
688
|
-
role = msg.get("role", "unknown")
|
|
689
|
-
content = msg.get("content", "")
|
|
690
|
-
# Truncate if flag is set
|
|
691
|
-
if truncate and len(content) > truncate:
|
|
692
|
-
content = content[:truncate] + "..."
|
|
693
|
-
tag = click.style(f"[{role}]", fg="red")
|
|
694
|
-
click.echo(f" {tag} {content}")
|
|
695
|
-
reward_str = ", ".join(f"{k}={v:.3f}" for k, v in result["rewards"].items())
|
|
696
|
-
click.echo(
|
|
697
|
-
f" {click.style('reward', fg=TEAL_RGB)}={total_reward:.3f} "
|
|
698
|
-
f"| turns={result['turns']} "
|
|
699
|
-
f"| tools_used={result['tools_used']} "
|
|
700
|
-
f"| [{reward_str}]"
|
|
797
|
+
# Store results if requested
|
|
798
|
+
if store and rollout_id:
|
|
799
|
+
_store_rollout_id(
|
|
800
|
+
rollout_id,
|
|
801
|
+
{
|
|
802
|
+
"status": "completed",
|
|
803
|
+
"results": batch_results,
|
|
804
|
+
"billing": {"prompt_tokens": 0, "completion_tokens": 0},
|
|
805
|
+
},
|
|
701
806
|
)
|
|
702
|
-
click.echo()
|
|
703
807
|
|
|
704
|
-
#
|
|
705
|
-
|
|
706
|
-
await c.close()
|
|
808
|
+
# Display results using shared function
|
|
809
|
+
_display_results(batch_results, truncate, output_dir, rollout_id)
|
|
707
810
|
|
|
708
|
-
|
|
709
|
-
|
|
811
|
+
# Get total billing
|
|
812
|
+
total_charged = client.total_charged_dollars
|
|
710
813
|
|
|
711
|
-
|
|
712
|
-
|
|
814
|
+
# Close client
|
|
815
|
+
await client.close()
|
|
713
816
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
click.echo()
|
|
717
|
-
click.echo(f"Mean reward: {click.style(f'{mean_reward:.3f}', fg=TEAL_RGB)}")
|
|
718
|
-
click.echo(f"Latency: {click.style(f'{total_time:.1f}s', fg=TEAL_RGB)}")
|
|
719
|
-
else:
|
|
720
|
-
click.echo(click.style("\nNo successful rollouts completed.", fg="yellow"))
|
|
817
|
+
except Exception:
|
|
818
|
+
raise
|
|
721
819
|
|
|
722
|
-
# Show
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
click.echo(
|
|
727
|
-
f"Billing: {click.style(f'${amount_cents/100:.2f}', fg=TEAL_RGB)} ({total_tokens:,} tokens)"
|
|
728
|
-
)
|
|
820
|
+
# Show timing and cost
|
|
821
|
+
click.echo(f"Latency: {click.style(f'{total_time:.1f}s', fg=TEAL_RGB)}")
|
|
822
|
+
if total_charged > 0:
|
|
823
|
+
click.echo(f"Cost: {click.style(f'${total_charged:.4f}', fg=TEAL_RGB)}")
|