rnow 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnow/__init__.py +5 -0
- rnow/__main__.py +7 -0
- rnow/cli/__init__.py +6 -0
- rnow/cli/auth.py +67 -0
- rnow/cli/blob.py +98 -0
- rnow/cli/commands.py +2311 -0
- rnow/cli/common.py +28 -0
- rnow/cli/cube.py +255 -0
- rnow/cli/main.py +49 -0
- rnow/cli/test.py +728 -0
- rnow/cli/token_count.py +295 -0
- rnow/core/__init__.py +33 -0
- rnow/core/reward.py +333 -0
- rnow/core/tool.py +494 -0
- rnow/models.py +295 -0
- rnow/templates/deepseek-aha/config.yml +26 -0
- rnow/templates/deepseek-aha/rewards.py +36 -0
- rnow/templates/deepseek-aha/train.jsonl +1000 -0
- rnow/templates/mcp-tavily/config.yml +29 -0
- rnow/templates/mcp-tavily/requirements.txt +1 -0
- rnow/templates/mcp-tavily/rewards.py +25 -0
- rnow/templates/mcp-tavily/train.jsonl +500 -0
- rnow/templates/new/config.yml +26 -0
- rnow/templates/new/requirements.txt +1 -0
- rnow/templates/new/rewards.py +0 -0
- rnow/templates/new/train.jsonl +0 -0
- rnow/templates/rl-nextjs/config.yml +27 -0
- rnow/templates/rl-nextjs/requirements.txt +2 -0
- rnow/templates/rl-nextjs/rewards.py +446 -0
- rnow/templates/rl-nextjs/train.jsonl +1000 -0
- rnow/templates/rl-single/config.yml +27 -0
- rnow/templates/rl-single/requirements.txt +1 -0
- rnow/templates/rl-single/rewards.py +14 -0
- rnow/templates/rl-single/train.jsonl +1000 -0
- rnow/templates/rl-tools/config.yml +27 -0
- rnow/templates/rl-tools/env.py +38 -0
- rnow/templates/rl-tools/requirements.txt +3 -0
- rnow/templates/rl-tools/rewards.py +25 -0
- rnow/templates/rl-tools/train.jsonl +500 -0
- rnow/templates/sft/config.yml +20 -0
- rnow/templates/sft/train.jsonl +100 -0
- rnow/templates/tutorial-reward/config.yml +27 -0
- rnow/templates/tutorial-reward/requirements.txt +1 -0
- rnow/templates/tutorial-reward/rewards.py +15 -0
- rnow/templates/tutorial-reward/train.jsonl +1000 -0
- rnow/templates/tutorial-tool/config.yml +27 -0
- rnow/templates/tutorial-tool/env.py +7 -0
- rnow/templates/tutorial-tool/requirements.txt +3 -0
- rnow/templates/tutorial-tool/rewards.py +7 -0
- rnow/templates/tutorial-tool/train.jsonl +1266 -0
- rnow-0.2.4.dist-info/METADATA +135 -0
- rnow-0.2.4.dist-info/RECORD +56 -0
- rnow-0.2.4.dist-info/WHEEL +5 -0
- rnow-0.2.4.dist-info/entry_points.txt +2 -0
- rnow-0.2.4.dist-info/licenses/LICENSE +21 -0
- rnow-0.2.4.dist-info/top_level.txt +1 -0
rnow/cli/token_count.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Token counting utilities for accurate context window validation.
|
|
3
|
+
|
|
4
|
+
Supports:
|
|
5
|
+
- gpt-oss: Uses openai-harmony render_conversation_for_completion (canonical API)
|
|
6
|
+
- Qwen/other HF models: Uses HuggingFace tokenizers with chat template approximation
|
|
7
|
+
- Fallback: Conservative char-based estimate
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import heapq
|
|
13
|
+
import json
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
# Tokenizer cache to avoid reloading
|
|
18
|
+
_tokenizer_cache: dict[str, Any] = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_tokenizer_for_model(model_path: str) -> tuple[str, Any] | None:
|
|
22
|
+
"""
|
|
23
|
+
Get the appropriate tokenizer for a model.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tuple of (tokenizer_type, tokenizer) or None if unavailable.
|
|
27
|
+
tokenizer_type is "harmony" for gpt-oss, "hf" for HuggingFace models.
|
|
28
|
+
"""
|
|
29
|
+
if model_path in _tokenizer_cache:
|
|
30
|
+
return _tokenizer_cache[model_path]
|
|
31
|
+
|
|
32
|
+
# gpt-oss models use openai-harmony
|
|
33
|
+
if "gpt-oss" in model_path.lower():
|
|
34
|
+
try:
|
|
35
|
+
from openai_harmony import HarmonyEncodingName, load_harmony_encoding
|
|
36
|
+
|
|
37
|
+
tokenizer = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
|
38
|
+
_tokenizer_cache[model_path] = ("harmony", tokenizer)
|
|
39
|
+
return _tokenizer_cache[model_path]
|
|
40
|
+
except ImportError:
|
|
41
|
+
pass
|
|
42
|
+
else:
|
|
43
|
+
# Try HuggingFace tokenizers (requires tokenizer.json in repo)
|
|
44
|
+
try:
|
|
45
|
+
from tokenizers import Tokenizer
|
|
46
|
+
|
|
47
|
+
tokenizer = Tokenizer.from_pretrained(model_path)
|
|
48
|
+
_tokenizer_cache[model_path] = ("hf", tokenizer)
|
|
49
|
+
return _tokenizer_cache[model_path]
|
|
50
|
+
except Exception:
|
|
51
|
+
# tokenizer.json not available or repo requires auth
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
_tokenizer_cache[model_path] = None
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def count_tokens_for_gpt_oss(
|
|
59
|
+
messages: list[dict], tools: list[dict], instructions: str = ""
|
|
60
|
+
) -> int:
|
|
61
|
+
"""
|
|
62
|
+
Count tokens for gpt-oss using the canonical Harmony API.
|
|
63
|
+
|
|
64
|
+
This uses render_conversation_for_completion which is the correct way
|
|
65
|
+
to count tokens exactly as gpt-oss runtime sees them.
|
|
66
|
+
|
|
67
|
+
System messages from the dataset are merged into DeveloperContent.instructions,
|
|
68
|
+
since Harmony's SystemContent is for model metadata, not user system prompts.
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
from openai_harmony import (
|
|
72
|
+
Conversation,
|
|
73
|
+
DeveloperContent,
|
|
74
|
+
HarmonyEncodingName,
|
|
75
|
+
Message,
|
|
76
|
+
Role,
|
|
77
|
+
SystemContent,
|
|
78
|
+
ToolDescription,
|
|
79
|
+
load_harmony_encoding,
|
|
80
|
+
)
|
|
81
|
+
except ImportError:
|
|
82
|
+
# Conservative fallback: count JSON chars as tokens
|
|
83
|
+
return sum(len(json.dumps(m, ensure_ascii=False)) for m in messages) + sum(
|
|
84
|
+
len(json.dumps(t, ensure_ascii=False)) for t in tools
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
|
88
|
+
|
|
89
|
+
# Extract system messages from dataset - these become developer instructions
|
|
90
|
+
# (Harmony's SystemContent is model metadata, not user system prompts)
|
|
91
|
+
sys_texts = [
|
|
92
|
+
m.get("content", "")
|
|
93
|
+
for m in messages
|
|
94
|
+
if m.get("role") == "system" and isinstance(m.get("content"), str)
|
|
95
|
+
]
|
|
96
|
+
sys_text = "\n\n".join(t.strip() for t in sys_texts if t.strip())
|
|
97
|
+
|
|
98
|
+
# Combine with any additional instructions
|
|
99
|
+
combined_instructions = "\n\n".join(s for s in [instructions.strip(), sys_text] if s).strip()
|
|
100
|
+
|
|
101
|
+
# Build tool descriptions using canonical API
|
|
102
|
+
tool_descs = []
|
|
103
|
+
for t in tools:
|
|
104
|
+
if not isinstance(t, dict):
|
|
105
|
+
continue
|
|
106
|
+
try:
|
|
107
|
+
td = ToolDescription.new(
|
|
108
|
+
t.get("name", "unknown"),
|
|
109
|
+
t.get("description", "") or "",
|
|
110
|
+
parameters=t.get("schema") or {},
|
|
111
|
+
)
|
|
112
|
+
tool_descs.append(td)
|
|
113
|
+
except Exception:
|
|
114
|
+
# Skip malformed tools
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
# Build DeveloperContent with instructions and tools
|
|
118
|
+
dev = DeveloperContent.new()
|
|
119
|
+
if combined_instructions:
|
|
120
|
+
dev = dev.with_instructions(combined_instructions)
|
|
121
|
+
if tool_descs:
|
|
122
|
+
dev = dev.with_function_tools(tool_descs)
|
|
123
|
+
|
|
124
|
+
# Build conversation using canonical from_role_and_content
|
|
125
|
+
convo_msgs = [
|
|
126
|
+
Message.from_role_and_content(Role.SYSTEM, SystemContent.new()),
|
|
127
|
+
Message.from_role_and_content(Role.DEVELOPER, dev),
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
# Add non-system messages
|
|
131
|
+
for m in messages:
|
|
132
|
+
role = (m.get("role") or "").lower()
|
|
133
|
+
if role == "system":
|
|
134
|
+
continue # Already handled in DeveloperContent
|
|
135
|
+
|
|
136
|
+
content = m.get("content", "") or ""
|
|
137
|
+
|
|
138
|
+
if role == "user":
|
|
139
|
+
convo_msgs.append(Message.from_role_and_content(Role.USER, content))
|
|
140
|
+
elif role == "assistant":
|
|
141
|
+
convo_msgs.append(Message.from_role_and_content(Role.ASSISTANT, content))
|
|
142
|
+
|
|
143
|
+
convo = Conversation.from_messages(convo_msgs)
|
|
144
|
+
tokens = encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
|
145
|
+
return len(tokens)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def count_tokens_for_hf_model(messages: list[dict], tools: list[dict], model_path: str) -> int:
|
|
149
|
+
"""
|
|
150
|
+
Count tokens for HuggingFace models using tokenizers library.
|
|
151
|
+
|
|
152
|
+
Uses ChatML-style template as approximation since we don't have
|
|
153
|
+
the exact chat template. This may slightly over-count which is safer
|
|
154
|
+
than under-counting.
|
|
155
|
+
"""
|
|
156
|
+
tokenizer_info = get_tokenizer_for_model(model_path)
|
|
157
|
+
|
|
158
|
+
# Build prompt with ChatML-style template (common for Qwen, etc.)
|
|
159
|
+
# This is an approximation but errs on the side of over-counting
|
|
160
|
+
prompt_parts = []
|
|
161
|
+
|
|
162
|
+
# Add tools as system content
|
|
163
|
+
if tools:
|
|
164
|
+
tool_lines = ["You have access to the following tools:\n"]
|
|
165
|
+
for tool in tools:
|
|
166
|
+
name = tool.get("name", "unknown")
|
|
167
|
+
desc = tool.get("description", "")
|
|
168
|
+
schema = tool.get("schema", {})
|
|
169
|
+
tool_lines.append(f"### {name}")
|
|
170
|
+
if desc:
|
|
171
|
+
tool_lines.append(f"{desc}")
|
|
172
|
+
if schema:
|
|
173
|
+
tool_lines.append(f"Parameters: {json.dumps(schema)}")
|
|
174
|
+
tool_lines.append("")
|
|
175
|
+
prompt_parts.append(f"<|im_start|>system\n{''.join(tool_lines)}<|im_end|>")
|
|
176
|
+
|
|
177
|
+
# Add messages
|
|
178
|
+
for msg in messages:
|
|
179
|
+
role = msg.get("role", "user")
|
|
180
|
+
content = msg.get("content", "")
|
|
181
|
+
prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
|
|
182
|
+
|
|
183
|
+
full_prompt = "\n".join(prompt_parts)
|
|
184
|
+
|
|
185
|
+
if tokenizer_info is None:
|
|
186
|
+
# Conservative estimate: 1 char = 1 token (over-estimates)
|
|
187
|
+
return len(full_prompt)
|
|
188
|
+
|
|
189
|
+
tokenizer_type, tokenizer = tokenizer_info
|
|
190
|
+
try:
|
|
191
|
+
if tokenizer_type == "hf":
|
|
192
|
+
encoded = tokenizer.encode(full_prompt)
|
|
193
|
+
return len(encoded.ids)
|
|
194
|
+
else:
|
|
195
|
+
# Shouldn't happen, but handle gracefully
|
|
196
|
+
return len(full_prompt)
|
|
197
|
+
except Exception:
|
|
198
|
+
return len(full_prompt)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def count_tokens_for_prompt(messages: list[dict], tools: list[dict], model_path: str) -> int:
|
|
202
|
+
"""
|
|
203
|
+
Count tokens for a prompt with the correct format for the model.
|
|
204
|
+
|
|
205
|
+
For gpt-oss: Uses Harmony render_conversation_for_completion (exact)
|
|
206
|
+
For HF models: Uses tokenizer with chat template approximation (conservative)
|
|
207
|
+
Fallback: Char-based estimate (very conservative)
|
|
208
|
+
"""
|
|
209
|
+
is_gpt_oss = "gpt-oss" in model_path.lower()
|
|
210
|
+
|
|
211
|
+
if is_gpt_oss:
|
|
212
|
+
return count_tokens_for_gpt_oss(messages, tools)
|
|
213
|
+
|
|
214
|
+
return count_tokens_for_hf_model(messages, tools, model_path)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def get_max_prompt_tokens(
|
|
218
|
+
path: Path,
|
|
219
|
+
tools: list[dict],
|
|
220
|
+
model_path: str = "",
|
|
221
|
+
sample_size: int = 0,
|
|
222
|
+
top_k: int = 20,
|
|
223
|
+
) -> int:
|
|
224
|
+
"""
|
|
225
|
+
Scan train.jsonl and return the maximum prompt token count.
|
|
226
|
+
|
|
227
|
+
Uses top-K strategy: keeps top K candidates by char count, then
|
|
228
|
+
token-counts all of them to find the true maximum. This handles
|
|
229
|
+
cases where char count doesn't correlate with token count
|
|
230
|
+
(e.g., CJK text, code, JSON).
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
path: Path to train.jsonl
|
|
234
|
+
tools: List of tool definitions to include in token count
|
|
235
|
+
model_path: Model path for tokenizer selection
|
|
236
|
+
sample_size: Max lines to scan (0 = all)
|
|
237
|
+
top_k: Number of candidates to keep for token counting
|
|
238
|
+
"""
|
|
239
|
+
# Min-heap of (char_count, index, messages) - keeps smallest at top
|
|
240
|
+
# Index is needed to make tuples comparable when char_counts are equal
|
|
241
|
+
heap: list[tuple[int, int, list[dict]]] = []
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
with open(path, encoding="utf-8") as f:
|
|
245
|
+
for i, line in enumerate(f):
|
|
246
|
+
if sample_size and i >= sample_size:
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
line = line.strip()
|
|
250
|
+
if not line:
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
rec = json.loads(line)
|
|
255
|
+
except json.JSONDecodeError:
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
msgs = rec.get("messages")
|
|
259
|
+
if not isinstance(msgs, list):
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
# Calculate char count for this example
|
|
263
|
+
char_count = sum(
|
|
264
|
+
len(m.get("content", ""))
|
|
265
|
+
for m in msgs
|
|
266
|
+
if isinstance(m, dict) and isinstance(m.get("content", ""), str)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Keep top_k largest by char count
|
|
270
|
+
if len(heap) < top_k:
|
|
271
|
+
heapq.heappush(heap, (char_count, i, msgs))
|
|
272
|
+
elif char_count > heap[0][0]:
|
|
273
|
+
heapq.heappushpop(heap, (char_count, i, msgs))
|
|
274
|
+
|
|
275
|
+
except Exception:
|
|
276
|
+
return 0
|
|
277
|
+
|
|
278
|
+
if not heap:
|
|
279
|
+
return 0
|
|
280
|
+
|
|
281
|
+
# Token-count all candidates and return the maximum
|
|
282
|
+
max_tokens = 0
|
|
283
|
+
for _chars, _idx, msgs in heap:
|
|
284
|
+
tokens = count_tokens_for_prompt(msgs, tools, model_path)
|
|
285
|
+
max_tokens = max(max_tokens, tokens)
|
|
286
|
+
|
|
287
|
+
return max_tokens
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def estimate_tokens_from_chars(char_count: int) -> int:
|
|
291
|
+
"""
|
|
292
|
+
Conservative token estimate from character count.
|
|
293
|
+
Uses 2 chars per token which typically over-estimates.
|
|
294
|
+
"""
|
|
295
|
+
return char_count // 2
|
rnow/core/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ReinforceNow Core - Entry points for reward and tool decorators.
|
|
3
|
+
|
|
4
|
+
Users only need:
|
|
5
|
+
- @reward decorator for defining reward functions
|
|
6
|
+
- @tool decorator for defining tool functions
|
|
7
|
+
- RewardArgs for type hints in reward functions
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from rnow.models import RewardArgs
|
|
11
|
+
|
|
12
|
+
from .reward import (
|
|
13
|
+
REWARD_REGISTRY,
|
|
14
|
+
clear_reward_registry,
|
|
15
|
+
compute_total_reward,
|
|
16
|
+
is_precondition,
|
|
17
|
+
reward,
|
|
18
|
+
)
|
|
19
|
+
from .tool import TOOL_REGISTRY, clear_tool_registry, tool
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
# User-facing API
|
|
23
|
+
"reward",
|
|
24
|
+
"tool",
|
|
25
|
+
"RewardArgs",
|
|
26
|
+
# Registries (used by CLI and trainer)
|
|
27
|
+
"REWARD_REGISTRY",
|
|
28
|
+
"TOOL_REGISTRY",
|
|
29
|
+
"clear_reward_registry",
|
|
30
|
+
"clear_tool_registry",
|
|
31
|
+
"is_precondition",
|
|
32
|
+
"compute_total_reward",
|
|
33
|
+
]
|
rnow/core/reward.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reward entry point for ReinforceNow with validation.
|
|
3
|
+
|
|
4
|
+
Validates at decorator-time:
|
|
5
|
+
- Function has correct signature: (args: RewardArgs, messages: list) -> float
|
|
6
|
+
- Function has docstring or description
|
|
7
|
+
|
|
8
|
+
Both sync and async functions are supported. Execution strategy is
|
|
9
|
+
determined automatically at runtime.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import inspect
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from typing import get_type_hints
|
|
15
|
+
|
|
16
|
+
from rnow.models import RewardArgs
|
|
17
|
+
|
|
18
|
+
# Global registry for reward functions
|
|
19
|
+
REWARD_REGISTRY: dict[str, Callable] = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def clear_reward_registry() -> None:
|
|
23
|
+
"""Clear the reward registry (useful for testing multiple projects)."""
|
|
24
|
+
REWARD_REGISTRY.clear()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def is_precondition(name: str) -> bool:
|
|
28
|
+
"""Check if a reward function is marked as a precondition."""
|
|
29
|
+
fn = REWARD_REGISTRY.get(name)
|
|
30
|
+
if fn is None:
|
|
31
|
+
return False
|
|
32
|
+
return getattr(fn, "_is_precondition", False)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def compute_total_reward(reward_results: dict[str, float]) -> float:
|
|
36
|
+
"""
|
|
37
|
+
Compute total reward with precondition logic.
|
|
38
|
+
|
|
39
|
+
Uses the _is_precondition attribute on registered reward functions.
|
|
40
|
+
If any precondition reward is 0, total reward is 0.
|
|
41
|
+
Otherwise, total reward is sum of all rewards.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
reward_results: Dict mapping reward name to its value
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Total reward value
|
|
48
|
+
"""
|
|
49
|
+
precondition_sum = 0.0
|
|
50
|
+
other_sum = 0.0
|
|
51
|
+
|
|
52
|
+
for name, value in reward_results.items():
|
|
53
|
+
if is_precondition(name):
|
|
54
|
+
# If any precondition fails (returns 0), total reward is 0
|
|
55
|
+
if value == 0.0:
|
|
56
|
+
return 0.0
|
|
57
|
+
precondition_sum += value
|
|
58
|
+
else:
|
|
59
|
+
other_sum += value
|
|
60
|
+
|
|
61
|
+
return precondition_sum + other_sum
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _validate_reward_signature(func: Callable) -> None:
|
|
65
|
+
"""
|
|
66
|
+
Validate that a reward function has the correct signature.
|
|
67
|
+
|
|
68
|
+
Expected signature:
|
|
69
|
+
def reward_fn(args: RewardArgs, messages: list) -> float
|
|
70
|
+
|
|
71
|
+
Both sync and async functions are supported.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
TypeError: If signature doesn't match expected format
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
sig = inspect.signature(func)
|
|
78
|
+
params = list(sig.parameters.values())
|
|
79
|
+
|
|
80
|
+
# Check parameter count (should be exactly 2: args and messages)
|
|
81
|
+
if len(params) != 2:
|
|
82
|
+
raise TypeError(
|
|
83
|
+
f"Reward '{func.__name__}' must have exactly 2 parameters: "
|
|
84
|
+
"(args: RewardArgs, messages: list). Got {len(params)} parameters."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Get type hints
|
|
88
|
+
hints = get_type_hints(func)
|
|
89
|
+
|
|
90
|
+
# Check first parameter (args: RewardArgs)
|
|
91
|
+
first_param = params[0]
|
|
92
|
+
if first_param.name not in hints:
|
|
93
|
+
raise TypeError(
|
|
94
|
+
f"Reward '{func.__name__}': parameter '{first_param.name}' must have "
|
|
95
|
+
"type hint 'RewardArgs'."
|
|
96
|
+
)
|
|
97
|
+
first_type = hints[first_param.name]
|
|
98
|
+
if first_type is not RewardArgs:
|
|
99
|
+
raise TypeError(
|
|
100
|
+
f"Reward '{func.__name__}': first parameter must be typed as 'RewardArgs', "
|
|
101
|
+
f"got '{first_type}'."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Check second parameter (messages: list)
|
|
105
|
+
second_param = params[1]
|
|
106
|
+
if second_param.name not in hints:
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"Reward '{func.__name__}': parameter '{second_param.name}' must have "
|
|
109
|
+
"type hint 'list'."
|
|
110
|
+
)
|
|
111
|
+
second_type = hints[second_param.name]
|
|
112
|
+
# Allow list or List (from typing)
|
|
113
|
+
from typing import get_origin
|
|
114
|
+
|
|
115
|
+
second_origin = get_origin(second_type) or second_type
|
|
116
|
+
if second_origin not in (list, list):
|
|
117
|
+
raise TypeError(
|
|
118
|
+
f"Reward '{func.__name__}': second parameter must be typed as 'list', "
|
|
119
|
+
f"got '{second_type}'."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Check return type
|
|
123
|
+
if "return" not in hints:
|
|
124
|
+
raise TypeError(f"Reward '{func.__name__}' must declare return type '-> float'.")
|
|
125
|
+
return_type = hints["return"]
|
|
126
|
+
if return_type is not float:
|
|
127
|
+
raise TypeError(f"Reward '{func.__name__}' must return 'float', got '{return_type}'.")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def reward(fn: Callable = None, *, precondition: bool = False) -> Callable:
|
|
131
|
+
"""
|
|
132
|
+
Decorator to register reward functions with validation.
|
|
133
|
+
|
|
134
|
+
Validates at decorator-time:
|
|
135
|
+
- Signature is (args: RewardArgs, messages: list) -> float
|
|
136
|
+
|
|
137
|
+
Both sync and async functions are supported. Execution strategy
|
|
138
|
+
is determined automatically at runtime.
|
|
139
|
+
|
|
140
|
+
Usage:
|
|
141
|
+
@reward
|
|
142
|
+
def accuracy(args: RewardArgs, messages: list) -> float:
|
|
143
|
+
ground_truth = args.metadata.get("ground_truth")
|
|
144
|
+
response = messages[-1].get("content", "")
|
|
145
|
+
return 1.0 if ground_truth in response else 0.0
|
|
146
|
+
|
|
147
|
+
@reward(precondition=True) # Mark as precondition reward
|
|
148
|
+
def format_check(args: RewardArgs, messages: list) -> float:
|
|
149
|
+
# If this returns 0, total reward is 0
|
|
150
|
+
# If this returns 1, total reward is 1 + sum(other rewards)
|
|
151
|
+
return 1.0 if valid_format else 0.0
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
precondition: If True, this reward acts as a gate:
|
|
155
|
+
- If precondition reward is 0, total reward is 0
|
|
156
|
+
- If precondition reward is 1, total reward is 1 + sum(other rewards)
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def decorator(func):
|
|
160
|
+
# Validate signature (async, correct params, return type)
|
|
161
|
+
try:
|
|
162
|
+
_validate_reward_signature(func)
|
|
163
|
+
except TypeError as e:
|
|
164
|
+
raise TypeError(f"Reward registration failed: {e}") from e
|
|
165
|
+
|
|
166
|
+
# Warn if overwriting existing reward
|
|
167
|
+
if func.__name__ in REWARD_REGISTRY:
|
|
168
|
+
import warnings
|
|
169
|
+
|
|
170
|
+
warnings.warn(
|
|
171
|
+
f"Reward '{func.__name__}' is being overwritten in the registry.",
|
|
172
|
+
UserWarning,
|
|
173
|
+
stacklevel=2,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Store metadata
|
|
177
|
+
func._is_reward = True
|
|
178
|
+
func._reward_name = func.__name__
|
|
179
|
+
func._is_precondition = precondition
|
|
180
|
+
|
|
181
|
+
# Register the function
|
|
182
|
+
REWARD_REGISTRY[func.__name__] = func
|
|
183
|
+
return func
|
|
184
|
+
|
|
185
|
+
# Support both @reward and @reward(precondition=False)
|
|
186
|
+
return decorator(fn) if fn else decorator
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def validate_rewards_file(filepath) -> list:
|
|
190
|
+
"""
|
|
191
|
+
Validate a rewards.py file without executing it.
|
|
192
|
+
|
|
193
|
+
Parses the AST to find @reward decorated functions and checks:
|
|
194
|
+
- Has correct number of parameters
|
|
195
|
+
- Return values are between 0.0 and 1.0
|
|
196
|
+
|
|
197
|
+
Both sync and async functions are supported.
|
|
198
|
+
|
|
199
|
+
Returns a list of error messages (empty if valid).
|
|
200
|
+
"""
|
|
201
|
+
import ast
|
|
202
|
+
from pathlib import Path
|
|
203
|
+
|
|
204
|
+
errors = []
|
|
205
|
+
filepath = Path(filepath)
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
source = filepath.read_text()
|
|
209
|
+
tree = ast.parse(source, filename=str(filepath))
|
|
210
|
+
except SyntaxError as e:
|
|
211
|
+
return [f"Syntax error in {filepath.name}: {e}"]
|
|
212
|
+
|
|
213
|
+
def get_numeric_value(node) -> float | None:
|
|
214
|
+
"""Extract numeric value from AST node if it's a constant."""
|
|
215
|
+
# Python 3.8+: ast.Constant
|
|
216
|
+
if isinstance(node, ast.Constant) and isinstance(node.value, int | float):
|
|
217
|
+
return float(node.value)
|
|
218
|
+
# Python 3.7: ast.Num
|
|
219
|
+
if hasattr(ast, "Num") and isinstance(node, ast.Num):
|
|
220
|
+
return float(node.n)
|
|
221
|
+
# Handle negative numbers: ast.UnaryOp with ast.USub
|
|
222
|
+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
|
|
223
|
+
inner = get_numeric_value(node.operand)
|
|
224
|
+
if inner is not None:
|
|
225
|
+
return -inner
|
|
226
|
+
return None
|
|
227
|
+
|
|
228
|
+
def check_return_values(func_node, func_name: str):
|
|
229
|
+
"""Check all return statements in a function for valid reward values."""
|
|
230
|
+
for child in ast.walk(func_node):
|
|
231
|
+
if isinstance(child, ast.Return) and child.value is not None:
|
|
232
|
+
# Check direct numeric returns
|
|
233
|
+
value = get_numeric_value(child.value)
|
|
234
|
+
if value is not None and (value < 0.0 or value > 1.0):
|
|
235
|
+
errors.append(
|
|
236
|
+
f"Reward '{func_name}' returns {value}, must be between 0.0 and 1.0 "
|
|
237
|
+
f"(line {child.lineno})"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Check ternary expressions: x if cond else y
|
|
241
|
+
if isinstance(child.value, ast.IfExp):
|
|
242
|
+
for branch in [child.value.body, child.value.orelse]:
|
|
243
|
+
value = get_numeric_value(branch)
|
|
244
|
+
if value is not None and (value < 0.0 or value > 1.0):
|
|
245
|
+
errors.append(
|
|
246
|
+
f"Reward '{func_name}' returns {value}, must be between 0.0 and 1.0 "
|
|
247
|
+
f"(line {child.lineno})"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
for node in ast.walk(tree):
|
|
251
|
+
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
|
|
252
|
+
# Check if function has @reward decorator
|
|
253
|
+
is_reward = False
|
|
254
|
+
for decorator in node.decorator_list:
|
|
255
|
+
if (
|
|
256
|
+
isinstance(decorator, ast.Name)
|
|
257
|
+
and decorator.id == "reward"
|
|
258
|
+
or (
|
|
259
|
+
isinstance(decorator, ast.Call)
|
|
260
|
+
and isinstance(decorator.func, ast.Name)
|
|
261
|
+
and decorator.func.id == "reward"
|
|
262
|
+
)
|
|
263
|
+
):
|
|
264
|
+
is_reward = True
|
|
265
|
+
|
|
266
|
+
if is_reward:
|
|
267
|
+
# Both async and sync functions are allowed
|
|
268
|
+
|
|
269
|
+
# Check parameter count
|
|
270
|
+
params = node.args.args
|
|
271
|
+
if len(params) != 2:
|
|
272
|
+
errors.append(
|
|
273
|
+
f"Reward '{node.name}' must have exactly 2 parameters: "
|
|
274
|
+
f"(args: RewardArgs, messages: list). Got {len(params)} parameters."
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Check return type annotation
|
|
278
|
+
if node.returns is None:
|
|
279
|
+
errors.append(
|
|
280
|
+
f"Reward '{node.name}' must have a return type annotation '-> float'."
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Check parameter type annotations
|
|
284
|
+
for i, arg in enumerate(params):
|
|
285
|
+
if arg.annotation is None:
|
|
286
|
+
param_hint = "RewardArgs" if i == 0 else "list"
|
|
287
|
+
errors.append(
|
|
288
|
+
f"Reward '{node.name}': parameter '{arg.arg}' must have type hint '{param_hint}'."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Check return values are between 0.0 and 1.0
|
|
292
|
+
check_return_values(node, node.name)
|
|
293
|
+
|
|
294
|
+
return errors
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def get_reward_names_from_file(filepath) -> set[str]:
|
|
298
|
+
"""
|
|
299
|
+
Extract reward function names from a rewards.py file without executing it.
|
|
300
|
+
|
|
301
|
+
Parses the AST to find @reward decorated functions and returns their names.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Set of reward function names defined in the file.
|
|
305
|
+
"""
|
|
306
|
+
import ast
|
|
307
|
+
from pathlib import Path
|
|
308
|
+
|
|
309
|
+
names = set()
|
|
310
|
+
filepath = Path(filepath)
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
source = filepath.read_text()
|
|
314
|
+
tree = ast.parse(source, filename=str(filepath))
|
|
315
|
+
except SyntaxError:
|
|
316
|
+
return names
|
|
317
|
+
|
|
318
|
+
for node in ast.walk(tree):
|
|
319
|
+
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
|
|
320
|
+
# Check if function has @reward decorator
|
|
321
|
+
for decorator in node.decorator_list:
|
|
322
|
+
if (
|
|
323
|
+
isinstance(decorator, ast.Name)
|
|
324
|
+
and decorator.id == "reward"
|
|
325
|
+
or (
|
|
326
|
+
isinstance(decorator, ast.Call)
|
|
327
|
+
and isinstance(decorator.func, ast.Name)
|
|
328
|
+
and decorator.func.id == "reward"
|
|
329
|
+
)
|
|
330
|
+
):
|
|
331
|
+
names.add(node.name)
|
|
332
|
+
|
|
333
|
+
return names
|