rnow 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. rnow/__init__.py +5 -0
  2. rnow/__main__.py +7 -0
  3. rnow/cli/__init__.py +6 -0
  4. rnow/cli/auth.py +67 -0
  5. rnow/cli/blob.py +98 -0
  6. rnow/cli/commands.py +2311 -0
  7. rnow/cli/common.py +28 -0
  8. rnow/cli/cube.py +255 -0
  9. rnow/cli/main.py +49 -0
  10. rnow/cli/test.py +728 -0
  11. rnow/cli/token_count.py +295 -0
  12. rnow/core/__init__.py +33 -0
  13. rnow/core/reward.py +333 -0
  14. rnow/core/tool.py +494 -0
  15. rnow/models.py +295 -0
  16. rnow/templates/deepseek-aha/config.yml +26 -0
  17. rnow/templates/deepseek-aha/rewards.py +36 -0
  18. rnow/templates/deepseek-aha/train.jsonl +1000 -0
  19. rnow/templates/mcp-tavily/config.yml +29 -0
  20. rnow/templates/mcp-tavily/requirements.txt +1 -0
  21. rnow/templates/mcp-tavily/rewards.py +25 -0
  22. rnow/templates/mcp-tavily/train.jsonl +500 -0
  23. rnow/templates/new/config.yml +26 -0
  24. rnow/templates/new/requirements.txt +1 -0
  25. rnow/templates/new/rewards.py +0 -0
  26. rnow/templates/new/train.jsonl +0 -0
  27. rnow/templates/rl-nextjs/config.yml +27 -0
  28. rnow/templates/rl-nextjs/requirements.txt +2 -0
  29. rnow/templates/rl-nextjs/rewards.py +446 -0
  30. rnow/templates/rl-nextjs/train.jsonl +1000 -0
  31. rnow/templates/rl-single/config.yml +27 -0
  32. rnow/templates/rl-single/requirements.txt +1 -0
  33. rnow/templates/rl-single/rewards.py +14 -0
  34. rnow/templates/rl-single/train.jsonl +1000 -0
  35. rnow/templates/rl-tools/config.yml +27 -0
  36. rnow/templates/rl-tools/env.py +38 -0
  37. rnow/templates/rl-tools/requirements.txt +3 -0
  38. rnow/templates/rl-tools/rewards.py +25 -0
  39. rnow/templates/rl-tools/train.jsonl +500 -0
  40. rnow/templates/sft/config.yml +20 -0
  41. rnow/templates/sft/train.jsonl +100 -0
  42. rnow/templates/tutorial-reward/config.yml +27 -0
  43. rnow/templates/tutorial-reward/requirements.txt +1 -0
  44. rnow/templates/tutorial-reward/rewards.py +15 -0
  45. rnow/templates/tutorial-reward/train.jsonl +1000 -0
  46. rnow/templates/tutorial-tool/config.yml +27 -0
  47. rnow/templates/tutorial-tool/env.py +7 -0
  48. rnow/templates/tutorial-tool/requirements.txt +3 -0
  49. rnow/templates/tutorial-tool/rewards.py +7 -0
  50. rnow/templates/tutorial-tool/train.jsonl +1266 -0
  51. rnow-0.2.4.dist-info/METADATA +135 -0
  52. rnow-0.2.4.dist-info/RECORD +56 -0
  53. rnow-0.2.4.dist-info/WHEEL +5 -0
  54. rnow-0.2.4.dist-info/entry_points.txt +2 -0
  55. rnow-0.2.4.dist-info/licenses/LICENSE +21 -0
  56. rnow-0.2.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,295 @@
1
+ """
2
+ Token counting utilities for accurate context window validation.
3
+
4
+ Supports:
5
+ - gpt-oss: Uses openai-harmony render_conversation_for_completion (canonical API)
6
+ - Qwen/other HF models: Uses HuggingFace tokenizers with chat template approximation
7
+ - Fallback: Conservative char-based estimate
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import heapq
13
+ import json
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ # Tokenizer cache to avoid reloading
18
+ _tokenizer_cache: dict[str, Any] = {}
19
+
20
+
21
+ def get_tokenizer_for_model(model_path: str) -> tuple[str, Any] | None:
22
+ """
23
+ Get the appropriate tokenizer for a model.
24
+
25
+ Returns:
26
+ Tuple of (tokenizer_type, tokenizer) or None if unavailable.
27
+ tokenizer_type is "harmony" for gpt-oss, "hf" for HuggingFace models.
28
+ """
29
+ if model_path in _tokenizer_cache:
30
+ return _tokenizer_cache[model_path]
31
+
32
+ # gpt-oss models use openai-harmony
33
+ if "gpt-oss" in model_path.lower():
34
+ try:
35
+ from openai_harmony import HarmonyEncodingName, load_harmony_encoding
36
+
37
+ tokenizer = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
38
+ _tokenizer_cache[model_path] = ("harmony", tokenizer)
39
+ return _tokenizer_cache[model_path]
40
+ except ImportError:
41
+ pass
42
+ else:
43
+ # Try HuggingFace tokenizers (requires tokenizer.json in repo)
44
+ try:
45
+ from tokenizers import Tokenizer
46
+
47
+ tokenizer = Tokenizer.from_pretrained(model_path)
48
+ _tokenizer_cache[model_path] = ("hf", tokenizer)
49
+ return _tokenizer_cache[model_path]
50
+ except Exception:
51
+ # tokenizer.json not available or repo requires auth
52
+ pass
53
+
54
+ _tokenizer_cache[model_path] = None
55
+ return None
56
+
57
+
58
+ def count_tokens_for_gpt_oss(
59
+ messages: list[dict], tools: list[dict], instructions: str = ""
60
+ ) -> int:
61
+ """
62
+ Count tokens for gpt-oss using the canonical Harmony API.
63
+
64
+ This uses render_conversation_for_completion which is the correct way
65
+ to count tokens exactly as gpt-oss runtime sees them.
66
+
67
+ System messages from the dataset are merged into DeveloperContent.instructions,
68
+ since Harmony's SystemContent is for model metadata, not user system prompts.
69
+ """
70
+ try:
71
+ from openai_harmony import (
72
+ Conversation,
73
+ DeveloperContent,
74
+ HarmonyEncodingName,
75
+ Message,
76
+ Role,
77
+ SystemContent,
78
+ ToolDescription,
79
+ load_harmony_encoding,
80
+ )
81
+ except ImportError:
82
+ # Conservative fallback: count JSON chars as tokens
83
+ return sum(len(json.dumps(m, ensure_ascii=False)) for m in messages) + sum(
84
+ len(json.dumps(t, ensure_ascii=False)) for t in tools
85
+ )
86
+
87
+ encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
88
+
89
+ # Extract system messages from dataset - these become developer instructions
90
+ # (Harmony's SystemContent is model metadata, not user system prompts)
91
+ sys_texts = [
92
+ m.get("content", "")
93
+ for m in messages
94
+ if m.get("role") == "system" and isinstance(m.get("content"), str)
95
+ ]
96
+ sys_text = "\n\n".join(t.strip() for t in sys_texts if t.strip())
97
+
98
+ # Combine with any additional instructions
99
+ combined_instructions = "\n\n".join(s for s in [instructions.strip(), sys_text] if s).strip()
100
+
101
+ # Build tool descriptions using canonical API
102
+ tool_descs = []
103
+ for t in tools:
104
+ if not isinstance(t, dict):
105
+ continue
106
+ try:
107
+ td = ToolDescription.new(
108
+ t.get("name", "unknown"),
109
+ t.get("description", "") or "",
110
+ parameters=t.get("schema") or {},
111
+ )
112
+ tool_descs.append(td)
113
+ except Exception:
114
+ # Skip malformed tools
115
+ pass
116
+
117
+ # Build DeveloperContent with instructions and tools
118
+ dev = DeveloperContent.new()
119
+ if combined_instructions:
120
+ dev = dev.with_instructions(combined_instructions)
121
+ if tool_descs:
122
+ dev = dev.with_function_tools(tool_descs)
123
+
124
+ # Build conversation using canonical from_role_and_content
125
+ convo_msgs = [
126
+ Message.from_role_and_content(Role.SYSTEM, SystemContent.new()),
127
+ Message.from_role_and_content(Role.DEVELOPER, dev),
128
+ ]
129
+
130
+ # Add non-system messages
131
+ for m in messages:
132
+ role = (m.get("role") or "").lower()
133
+ if role == "system":
134
+ continue # Already handled in DeveloperContent
135
+
136
+ content = m.get("content", "") or ""
137
+
138
+ if role == "user":
139
+ convo_msgs.append(Message.from_role_and_content(Role.USER, content))
140
+ elif role == "assistant":
141
+ convo_msgs.append(Message.from_role_and_content(Role.ASSISTANT, content))
142
+
143
+ convo = Conversation.from_messages(convo_msgs)
144
+ tokens = encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
145
+ return len(tokens)
146
+
147
+
148
+ def count_tokens_for_hf_model(messages: list[dict], tools: list[dict], model_path: str) -> int:
149
+ """
150
+ Count tokens for HuggingFace models using tokenizers library.
151
+
152
+ Uses ChatML-style template as approximation since we don't have
153
+ the exact chat template. This may slightly over-count which is safer
154
+ than under-counting.
155
+ """
156
+ tokenizer_info = get_tokenizer_for_model(model_path)
157
+
158
+ # Build prompt with ChatML-style template (common for Qwen, etc.)
159
+ # This is an approximation but errs on the side of over-counting
160
+ prompt_parts = []
161
+
162
+ # Add tools as system content
163
+ if tools:
164
+ tool_lines = ["You have access to the following tools:\n"]
165
+ for tool in tools:
166
+ name = tool.get("name", "unknown")
167
+ desc = tool.get("description", "")
168
+ schema = tool.get("schema", {})
169
+ tool_lines.append(f"### {name}")
170
+ if desc:
171
+ tool_lines.append(f"{desc}")
172
+ if schema:
173
+ tool_lines.append(f"Parameters: {json.dumps(schema)}")
174
+ tool_lines.append("")
175
+ prompt_parts.append(f"<|im_start|>system\n{''.join(tool_lines)}<|im_end|>")
176
+
177
+ # Add messages
178
+ for msg in messages:
179
+ role = msg.get("role", "user")
180
+ content = msg.get("content", "")
181
+ prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
182
+
183
+ full_prompt = "\n".join(prompt_parts)
184
+
185
+ if tokenizer_info is None:
186
+ # Conservative estimate: 1 char = 1 token (over-estimates)
187
+ return len(full_prompt)
188
+
189
+ tokenizer_type, tokenizer = tokenizer_info
190
+ try:
191
+ if tokenizer_type == "hf":
192
+ encoded = tokenizer.encode(full_prompt)
193
+ return len(encoded.ids)
194
+ else:
195
+ # Shouldn't happen, but handle gracefully
196
+ return len(full_prompt)
197
+ except Exception:
198
+ return len(full_prompt)
199
+
200
+
201
+ def count_tokens_for_prompt(messages: list[dict], tools: list[dict], model_path: str) -> int:
202
+ """
203
+ Count tokens for a prompt with the correct format for the model.
204
+
205
+ For gpt-oss: Uses Harmony render_conversation_for_completion (exact)
206
+ For HF models: Uses tokenizer with chat template approximation (conservative)
207
+ Fallback: Char-based estimate (very conservative)
208
+ """
209
+ is_gpt_oss = "gpt-oss" in model_path.lower()
210
+
211
+ if is_gpt_oss:
212
+ return count_tokens_for_gpt_oss(messages, tools)
213
+
214
+ return count_tokens_for_hf_model(messages, tools, model_path)
215
+
216
+
217
+ def get_max_prompt_tokens(
218
+ path: Path,
219
+ tools: list[dict],
220
+ model_path: str = "",
221
+ sample_size: int = 0,
222
+ top_k: int = 20,
223
+ ) -> int:
224
+ """
225
+ Scan train.jsonl and return the maximum prompt token count.
226
+
227
+ Uses top-K strategy: keeps top K candidates by char count, then
228
+ token-counts all of them to find the true maximum. This handles
229
+ cases where char count doesn't correlate with token count
230
+ (e.g., CJK text, code, JSON).
231
+
232
+ Args:
233
+ path: Path to train.jsonl
234
+ tools: List of tool definitions to include in token count
235
+ model_path: Model path for tokenizer selection
236
+ sample_size: Max lines to scan (0 = all)
237
+ top_k: Number of candidates to keep for token counting
238
+ """
239
+ # Min-heap of (char_count, index, messages) - keeps smallest at top
240
+ # Index is needed to make tuples comparable when char_counts are equal
241
+ heap: list[tuple[int, int, list[dict]]] = []
242
+
243
+ try:
244
+ with open(path, encoding="utf-8") as f:
245
+ for i, line in enumerate(f):
246
+ if sample_size and i >= sample_size:
247
+ break
248
+
249
+ line = line.strip()
250
+ if not line:
251
+ continue
252
+
253
+ try:
254
+ rec = json.loads(line)
255
+ except json.JSONDecodeError:
256
+ continue
257
+
258
+ msgs = rec.get("messages")
259
+ if not isinstance(msgs, list):
260
+ continue
261
+
262
+ # Calculate char count for this example
263
+ char_count = sum(
264
+ len(m.get("content", ""))
265
+ for m in msgs
266
+ if isinstance(m, dict) and isinstance(m.get("content", ""), str)
267
+ )
268
+
269
+ # Keep top_k largest by char count
270
+ if len(heap) < top_k:
271
+ heapq.heappush(heap, (char_count, i, msgs))
272
+ elif char_count > heap[0][0]:
273
+ heapq.heappushpop(heap, (char_count, i, msgs))
274
+
275
+ except Exception:
276
+ return 0
277
+
278
+ if not heap:
279
+ return 0
280
+
281
+ # Token-count all candidates and return the maximum
282
+ max_tokens = 0
283
+ for _chars, _idx, msgs in heap:
284
+ tokens = count_tokens_for_prompt(msgs, tools, model_path)
285
+ max_tokens = max(max_tokens, tokens)
286
+
287
+ return max_tokens
288
+
289
+
290
+ def estimate_tokens_from_chars(char_count: int) -> int:
291
+ """
292
+ Conservative token estimate from character count.
293
+ Uses 2 chars per token which typically over-estimates.
294
+ """
295
+ return char_count // 2
rnow/core/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """
2
+ ReinforceNow Core - Entry points for reward and tool decorators.
3
+
4
+ Users only need:
5
+ - @reward decorator for defining reward functions
6
+ - @tool decorator for defining tool functions
7
+ - RewardArgs for type hints in reward functions
8
+ """
9
+
10
+ from rnow.models import RewardArgs
11
+
12
+ from .reward import (
13
+ REWARD_REGISTRY,
14
+ clear_reward_registry,
15
+ compute_total_reward,
16
+ is_precondition,
17
+ reward,
18
+ )
19
+ from .tool import TOOL_REGISTRY, clear_tool_registry, tool
20
+
21
+ __all__ = [
22
+ # User-facing API
23
+ "reward",
24
+ "tool",
25
+ "RewardArgs",
26
+ # Registries (used by CLI and trainer)
27
+ "REWARD_REGISTRY",
28
+ "TOOL_REGISTRY",
29
+ "clear_reward_registry",
30
+ "clear_tool_registry",
31
+ "is_precondition",
32
+ "compute_total_reward",
33
+ ]
rnow/core/reward.py ADDED
@@ -0,0 +1,333 @@
1
+ """
2
+ Reward entry point for ReinforceNow with validation.
3
+
4
+ Validates at decorator-time:
5
+ - Function has correct signature: (args: RewardArgs, messages: list) -> float
6
+ - Function has docstring or description
7
+
8
+ Both sync and async functions are supported. Execution strategy is
9
+ determined automatically at runtime.
10
+ """
11
+
12
+ import inspect
13
+ from collections.abc import Callable
14
+ from typing import get_type_hints
15
+
16
+ from rnow.models import RewardArgs
17
+
18
+ # Global registry for reward functions
19
+ REWARD_REGISTRY: dict[str, Callable] = {}
20
+
21
+
22
+ def clear_reward_registry() -> None:
23
+ """Clear the reward registry (useful for testing multiple projects)."""
24
+ REWARD_REGISTRY.clear()
25
+
26
+
27
+ def is_precondition(name: str) -> bool:
28
+ """Check if a reward function is marked as a precondition."""
29
+ fn = REWARD_REGISTRY.get(name)
30
+ if fn is None:
31
+ return False
32
+ return getattr(fn, "_is_precondition", False)
33
+
34
+
35
+ def compute_total_reward(reward_results: dict[str, float]) -> float:
36
+ """
37
+ Compute total reward with precondition logic.
38
+
39
+ Uses the _is_precondition attribute on registered reward functions.
40
+ If any precondition reward is 0, total reward is 0.
41
+ Otherwise, total reward is sum of all rewards.
42
+
43
+ Args:
44
+ reward_results: Dict mapping reward name to its value
45
+
46
+ Returns:
47
+ Total reward value
48
+ """
49
+ precondition_sum = 0.0
50
+ other_sum = 0.0
51
+
52
+ for name, value in reward_results.items():
53
+ if is_precondition(name):
54
+ # If any precondition fails (returns 0), total reward is 0
55
+ if value == 0.0:
56
+ return 0.0
57
+ precondition_sum += value
58
+ else:
59
+ other_sum += value
60
+
61
+ return precondition_sum + other_sum
62
+
63
+
64
+ def _validate_reward_signature(func: Callable) -> None:
65
+ """
66
+ Validate that a reward function has the correct signature.
67
+
68
+ Expected signature:
69
+ def reward_fn(args: RewardArgs, messages: list) -> float
70
+
71
+ Both sync and async functions are supported.
72
+
73
+ Raises:
74
+ TypeError: If signature doesn't match expected format
75
+ """
76
+
77
+ sig = inspect.signature(func)
78
+ params = list(sig.parameters.values())
79
+
80
+ # Check parameter count (should be exactly 2: args and messages)
81
+ if len(params) != 2:
82
+ raise TypeError(
83
+ f"Reward '{func.__name__}' must have exactly 2 parameters: "
84
+ "(args: RewardArgs, messages: list). Got {len(params)} parameters."
85
+ )
86
+
87
+ # Get type hints
88
+ hints = get_type_hints(func)
89
+
90
+ # Check first parameter (args: RewardArgs)
91
+ first_param = params[0]
92
+ if first_param.name not in hints:
93
+ raise TypeError(
94
+ f"Reward '{func.__name__}': parameter '{first_param.name}' must have "
95
+ "type hint 'RewardArgs'."
96
+ )
97
+ first_type = hints[first_param.name]
98
+ if first_type is not RewardArgs:
99
+ raise TypeError(
100
+ f"Reward '{func.__name__}': first parameter must be typed as 'RewardArgs', "
101
+ f"got '{first_type}'."
102
+ )
103
+
104
+ # Check second parameter (messages: list)
105
+ second_param = params[1]
106
+ if second_param.name not in hints:
107
+ raise TypeError(
108
+ f"Reward '{func.__name__}': parameter '{second_param.name}' must have "
109
+ "type hint 'list'."
110
+ )
111
+ second_type = hints[second_param.name]
112
+ # Allow list or List (from typing)
113
+ from typing import get_origin
114
+
115
+ second_origin = get_origin(second_type) or second_type
116
+ if second_origin not in (list, list):
117
+ raise TypeError(
118
+ f"Reward '{func.__name__}': second parameter must be typed as 'list', "
119
+ f"got '{second_type}'."
120
+ )
121
+
122
+ # Check return type
123
+ if "return" not in hints:
124
+ raise TypeError(f"Reward '{func.__name__}' must declare return type '-> float'.")
125
+ return_type = hints["return"]
126
+ if return_type is not float:
127
+ raise TypeError(f"Reward '{func.__name__}' must return 'float', got '{return_type}'.")
128
+
129
+
130
+ def reward(fn: Callable = None, *, precondition: bool = False) -> Callable:
131
+ """
132
+ Decorator to register reward functions with validation.
133
+
134
+ Validates at decorator-time:
135
+ - Signature is (args: RewardArgs, messages: list) -> float
136
+
137
+ Both sync and async functions are supported. Execution strategy
138
+ is determined automatically at runtime.
139
+
140
+ Usage:
141
+ @reward
142
+ def accuracy(args: RewardArgs, messages: list) -> float:
143
+ ground_truth = args.metadata.get("ground_truth")
144
+ response = messages[-1].get("content", "")
145
+ return 1.0 if ground_truth in response else 0.0
146
+
147
+ @reward(precondition=True) # Mark as precondition reward
148
+ def format_check(args: RewardArgs, messages: list) -> float:
149
+ # If this returns 0, total reward is 0
150
+ # If this returns 1, total reward is 1 + sum(other rewards)
151
+ return 1.0 if valid_format else 0.0
152
+
153
+ Args:
154
+ precondition: If True, this reward acts as a gate:
155
+ - If precondition reward is 0, total reward is 0
156
+ - If precondition reward is 1, total reward is 1 + sum(other rewards)
157
+ """
158
+
159
+ def decorator(func):
160
+ # Validate signature (async, correct params, return type)
161
+ try:
162
+ _validate_reward_signature(func)
163
+ except TypeError as e:
164
+ raise TypeError(f"Reward registration failed: {e}") from e
165
+
166
+ # Warn if overwriting existing reward
167
+ if func.__name__ in REWARD_REGISTRY:
168
+ import warnings
169
+
170
+ warnings.warn(
171
+ f"Reward '{func.__name__}' is being overwritten in the registry.",
172
+ UserWarning,
173
+ stacklevel=2,
174
+ )
175
+
176
+ # Store metadata
177
+ func._is_reward = True
178
+ func._reward_name = func.__name__
179
+ func._is_precondition = precondition
180
+
181
+ # Register the function
182
+ REWARD_REGISTRY[func.__name__] = func
183
+ return func
184
+
185
+ # Support both @reward and @reward(precondition=False)
186
+ return decorator(fn) if fn else decorator
187
+
188
+
189
+ def validate_rewards_file(filepath) -> list:
190
+ """
191
+ Validate a rewards.py file without executing it.
192
+
193
+ Parses the AST to find @reward decorated functions and checks:
194
+ - Has correct number of parameters
195
+ - Return values are between 0.0 and 1.0
196
+
197
+ Both sync and async functions are supported.
198
+
199
+ Returns a list of error messages (empty if valid).
200
+ """
201
+ import ast
202
+ from pathlib import Path
203
+
204
+ errors = []
205
+ filepath = Path(filepath)
206
+
207
+ try:
208
+ source = filepath.read_text()
209
+ tree = ast.parse(source, filename=str(filepath))
210
+ except SyntaxError as e:
211
+ return [f"Syntax error in {filepath.name}: {e}"]
212
+
213
+ def get_numeric_value(node) -> float | None:
214
+ """Extract numeric value from AST node if it's a constant."""
215
+ # Python 3.8+: ast.Constant
216
+ if isinstance(node, ast.Constant) and isinstance(node.value, int | float):
217
+ return float(node.value)
218
+ # Python 3.7: ast.Num
219
+ if hasattr(ast, "Num") and isinstance(node, ast.Num):
220
+ return float(node.n)
221
+ # Handle negative numbers: ast.UnaryOp with ast.USub
222
+ if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
223
+ inner = get_numeric_value(node.operand)
224
+ if inner is not None:
225
+ return -inner
226
+ return None
227
+
228
+ def check_return_values(func_node, func_name: str):
229
+ """Check all return statements in a function for valid reward values."""
230
+ for child in ast.walk(func_node):
231
+ if isinstance(child, ast.Return) and child.value is not None:
232
+ # Check direct numeric returns
233
+ value = get_numeric_value(child.value)
234
+ if value is not None and (value < 0.0 or value > 1.0):
235
+ errors.append(
236
+ f"Reward '{func_name}' returns {value}, must be between 0.0 and 1.0 "
237
+ f"(line {child.lineno})"
238
+ )
239
+
240
+ # Check ternary expressions: x if cond else y
241
+ if isinstance(child.value, ast.IfExp):
242
+ for branch in [child.value.body, child.value.orelse]:
243
+ value = get_numeric_value(branch)
244
+ if value is not None and (value < 0.0 or value > 1.0):
245
+ errors.append(
246
+ f"Reward '{func_name}' returns {value}, must be between 0.0 and 1.0 "
247
+ f"(line {child.lineno})"
248
+ )
249
+
250
+ for node in ast.walk(tree):
251
+ if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
252
+ # Check if function has @reward decorator
253
+ is_reward = False
254
+ for decorator in node.decorator_list:
255
+ if (
256
+ isinstance(decorator, ast.Name)
257
+ and decorator.id == "reward"
258
+ or (
259
+ isinstance(decorator, ast.Call)
260
+ and isinstance(decorator.func, ast.Name)
261
+ and decorator.func.id == "reward"
262
+ )
263
+ ):
264
+ is_reward = True
265
+
266
+ if is_reward:
267
+ # Both async and sync functions are allowed
268
+
269
+ # Check parameter count
270
+ params = node.args.args
271
+ if len(params) != 2:
272
+ errors.append(
273
+ f"Reward '{node.name}' must have exactly 2 parameters: "
274
+ f"(args: RewardArgs, messages: list). Got {len(params)} parameters."
275
+ )
276
+
277
+ # Check return type annotation
278
+ if node.returns is None:
279
+ errors.append(
280
+ f"Reward '{node.name}' must have a return type annotation '-> float'."
281
+ )
282
+
283
+ # Check parameter type annotations
284
+ for i, arg in enumerate(params):
285
+ if arg.annotation is None:
286
+ param_hint = "RewardArgs" if i == 0 else "list"
287
+ errors.append(
288
+ f"Reward '{node.name}': parameter '{arg.arg}' must have type hint '{param_hint}'."
289
+ )
290
+
291
+ # Check return values are between 0.0 and 1.0
292
+ check_return_values(node, node.name)
293
+
294
+ return errors
295
+
296
+
297
+ def get_reward_names_from_file(filepath) -> set[str]:
298
+ """
299
+ Extract reward function names from a rewards.py file without executing it.
300
+
301
+ Parses the AST to find @reward decorated functions and returns their names.
302
+
303
+ Returns:
304
+ Set of reward function names defined in the file.
305
+ """
306
+ import ast
307
+ from pathlib import Path
308
+
309
+ names = set()
310
+ filepath = Path(filepath)
311
+
312
+ try:
313
+ source = filepath.read_text()
314
+ tree = ast.parse(source, filename=str(filepath))
315
+ except SyntaxError:
316
+ return names
317
+
318
+ for node in ast.walk(tree):
319
+ if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
320
+ # Check if function has @reward decorator
321
+ for decorator in node.decorator_list:
322
+ if (
323
+ isinstance(decorator, ast.Name)
324
+ and decorator.id == "reward"
325
+ or (
326
+ isinstance(decorator, ast.Call)
327
+ and isinstance(decorator.func, ast.Name)
328
+ and decorator.func.id == "reward"
329
+ )
330
+ ):
331
+ names.add(node.name)
332
+
333
+ return names