rnow 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnow/__init__.py +5 -0
- rnow/__main__.py +7 -0
- rnow/cli/__init__.py +6 -0
- rnow/cli/auth.py +67 -0
- rnow/cli/blob.py +98 -0
- rnow/cli/commands.py +2311 -0
- rnow/cli/common.py +28 -0
- rnow/cli/cube.py +255 -0
- rnow/cli/main.py +49 -0
- rnow/cli/test.py +728 -0
- rnow/cli/token_count.py +295 -0
- rnow/core/__init__.py +33 -0
- rnow/core/reward.py +333 -0
- rnow/core/tool.py +494 -0
- rnow/models.py +295 -0
- rnow/templates/deepseek-aha/config.yml +26 -0
- rnow/templates/deepseek-aha/rewards.py +36 -0
- rnow/templates/deepseek-aha/train.jsonl +1000 -0
- rnow/templates/mcp-tavily/config.yml +29 -0
- rnow/templates/mcp-tavily/requirements.txt +1 -0
- rnow/templates/mcp-tavily/rewards.py +25 -0
- rnow/templates/mcp-tavily/train.jsonl +500 -0
- rnow/templates/new/config.yml +26 -0
- rnow/templates/new/requirements.txt +1 -0
- rnow/templates/new/rewards.py +0 -0
- rnow/templates/new/train.jsonl +0 -0
- rnow/templates/rl-nextjs/config.yml +27 -0
- rnow/templates/rl-nextjs/requirements.txt +2 -0
- rnow/templates/rl-nextjs/rewards.py +446 -0
- rnow/templates/rl-nextjs/train.jsonl +1000 -0
- rnow/templates/rl-single/config.yml +27 -0
- rnow/templates/rl-single/requirements.txt +1 -0
- rnow/templates/rl-single/rewards.py +14 -0
- rnow/templates/rl-single/train.jsonl +1000 -0
- rnow/templates/rl-tools/config.yml +27 -0
- rnow/templates/rl-tools/env.py +38 -0
- rnow/templates/rl-tools/requirements.txt +3 -0
- rnow/templates/rl-tools/rewards.py +25 -0
- rnow/templates/rl-tools/train.jsonl +500 -0
- rnow/templates/sft/config.yml +20 -0
- rnow/templates/sft/train.jsonl +100 -0
- rnow/templates/tutorial-reward/config.yml +27 -0
- rnow/templates/tutorial-reward/requirements.txt +1 -0
- rnow/templates/tutorial-reward/rewards.py +15 -0
- rnow/templates/tutorial-reward/train.jsonl +1000 -0
- rnow/templates/tutorial-tool/config.yml +27 -0
- rnow/templates/tutorial-tool/env.py +7 -0
- rnow/templates/tutorial-tool/requirements.txt +3 -0
- rnow/templates/tutorial-tool/rewards.py +7 -0
- rnow/templates/tutorial-tool/train.jsonl +1266 -0
- rnow-0.2.4.dist-info/METADATA +135 -0
- rnow-0.2.4.dist-info/RECORD +56 -0
- rnow-0.2.4.dist-info/WHEEL +5 -0
- rnow-0.2.4.dist-info/entry_points.txt +2 -0
- rnow-0.2.4.dist-info/licenses/LICENSE +21 -0
- rnow-0.2.4.dist-info/top_level.txt +1 -0
rnow/models.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ReinforceNow Models - User-facing types for the rnow CLI.
|
|
3
|
+
|
|
4
|
+
This module contains ONLY the types that users need:
|
|
5
|
+
- RewardArgs for reward function signatures
|
|
6
|
+
- ProjectConfig and related configs for config.yml
|
|
7
|
+
- Basic enums
|
|
8
|
+
|
|
9
|
+
Trainer-internal types (Env, StepResult, Observation) live in docker/trainer/
|
|
10
|
+
where tinker is available.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Literal
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RunStatus(str, Enum):
|
|
20
|
+
PENDING = "pending"
|
|
21
|
+
RUNNING = "running"
|
|
22
|
+
COMPLETED = "completed"
|
|
23
|
+
FAILED = "failed"
|
|
24
|
+
STOPPED = "stopped"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ModelType(str, Enum):
|
|
28
|
+
QWEN3_8B = "qwen3-8b"
|
|
29
|
+
GLM4_9B = "glm4-9b"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OrgRole(str, Enum):
|
|
33
|
+
OWNER = "owner"
|
|
34
|
+
ADMIN = "admin"
|
|
35
|
+
MEMBER = "member"
|
|
36
|
+
VIEWER = "viewer"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DatasetType(str, Enum):
|
|
40
|
+
SFT = "sft" # Supervised Fine-Tuning
|
|
41
|
+
RL = "rl" # Reinforcement Learning
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LossFunction(str, Enum):
|
|
45
|
+
PPO = "ppo" # Proximal Policy Optimization
|
|
46
|
+
IS = "importance_sampling" # Importance Sampling
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class AdvantageEstimator(str, Enum):
|
|
50
|
+
GRPO = "grpo" # Generalized Reward Policy Optimization
|
|
51
|
+
GAE = "gae" # Generalized Advantage Estimation
|
|
52
|
+
REINFORCE = "reinforce" # REINFORCE algorithm
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TerminationPolicy(str, Enum):
|
|
56
|
+
MAX_TURNS = "max_turns" # Episode ends when max_turns is exhausted
|
|
57
|
+
LAST_TOOL = "last_tool" # Episode ends when assistant responds without a tool call
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RewardArgs(BaseModel):
|
|
61
|
+
"""Arguments passed to reward functions containing context about the sample."""
|
|
62
|
+
|
|
63
|
+
metadata: dict = Field(default_factory=dict)
|
|
64
|
+
variables: dict = Field(default_factory=dict)
|
|
65
|
+
|
|
66
|
+
class Config:
|
|
67
|
+
arbitrary_types_allowed = True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DeviceCode(BaseModel):
|
|
71
|
+
device_code: str
|
|
72
|
+
user_code: str
|
|
73
|
+
verification_uri: str
|
|
74
|
+
expires_in: int = 1800
|
|
75
|
+
interval: int = 5
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Token(BaseModel):
|
|
79
|
+
access_token: str
|
|
80
|
+
organization_id: str | None = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TokenError(BaseModel):
|
|
84
|
+
error: str
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Organization(BaseModel):
|
|
88
|
+
id: str
|
|
89
|
+
name: str
|
|
90
|
+
role: OrgRole
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Organizations(BaseModel):
|
|
94
|
+
organizations: list[Organization]
|
|
95
|
+
active_organization_id: str | None = None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Supported model IDs
|
|
99
|
+
SUPPORTED_MODELS = Literal[
|
|
100
|
+
# Qwen models
|
|
101
|
+
"Qwen/Qwen3-235B-A22B-Instruct-2507",
|
|
102
|
+
"Qwen/Qwen3-30B-A3B-Instruct-2507",
|
|
103
|
+
"Qwen/Qwen3-30B-A3B",
|
|
104
|
+
"Qwen/Qwen3-30B-A3B-Base",
|
|
105
|
+
"Qwen/Qwen3-32B",
|
|
106
|
+
"Qwen/Qwen3-8B",
|
|
107
|
+
"Qwen/Qwen3-8B-Base",
|
|
108
|
+
"Qwen/Qwen3-4B-Instruct-2507",
|
|
109
|
+
# OpenAI models
|
|
110
|
+
"openai/gpt-oss-120b",
|
|
111
|
+
"openai/gpt-oss-20b",
|
|
112
|
+
# DeepSeek models
|
|
113
|
+
"deepseek-ai/DeepSeek-V3.1",
|
|
114
|
+
"deepseek-ai/DeepSeek-V3.1-Base",
|
|
115
|
+
# Meta Llama models
|
|
116
|
+
"meta-llama/Llama-3.1-70B",
|
|
117
|
+
"meta-llama/Llama-3.3-70B-Instruct",
|
|
118
|
+
"meta-llama/Llama-3.1-8B",
|
|
119
|
+
"meta-llama/Llama-3.1-8B-Instruct",
|
|
120
|
+
"meta-llama/Llama-3.2-3B",
|
|
121
|
+
"meta-llama/Llama-3.2-1B",
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
# Maximum context window for all supported models
|
|
125
|
+
MAX_CONTEXT_WINDOW = 32768
|
|
126
|
+
|
|
127
|
+
# Conservative max_tokens limit (leaves room for prompts)
|
|
128
|
+
MAX_GENERATION_TOKENS = 30000
|
|
129
|
+
|
|
130
|
+
# Maximum LoRA rank per model
|
|
131
|
+
# Models not listed here default to 128
|
|
132
|
+
MODEL_MAX_LORA_RANK: dict[str, int] = {
|
|
133
|
+
# Max 32
|
|
134
|
+
"openai/gpt-oss-120b": 32,
|
|
135
|
+
"openai/gpt-oss-20b": 32,
|
|
136
|
+
# Max 64
|
|
137
|
+
"Qwen/Qwen3-235B-A22B-Instruct-2507": 64,
|
|
138
|
+
"Qwen/Qwen3-30B-A3B-Instruct-2507": 64,
|
|
139
|
+
"Qwen/Qwen3-30B-A3B": 64,
|
|
140
|
+
"Qwen/Qwen3-30B-A3B-Base": 64,
|
|
141
|
+
"deepseek-ai/DeepSeek-V3.1": 64,
|
|
142
|
+
"deepseek-ai/DeepSeek-V3.1-Base": 64,
|
|
143
|
+
# Max 128 (default for all others)
|
|
144
|
+
# Qwen/Qwen3-32B, Qwen/Qwen3-8B*, Qwen/Qwen3-4B-Instruct-2507, all meta-llama/*
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
# Default max LoRA rank for models not in the dict
|
|
148
|
+
DEFAULT_MAX_LORA_RANK = 128
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_max_lora_rank(model_path: str) -> int:
|
|
152
|
+
"""Get the maximum LoRA rank for a given model."""
|
|
153
|
+
return MODEL_MAX_LORA_RANK.get(model_path, DEFAULT_MAX_LORA_RANK)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class DataConfig(BaseModel):
|
|
157
|
+
"""Data configuration for training."""
|
|
158
|
+
|
|
159
|
+
model_config = ConfigDict(extra="forbid")
|
|
160
|
+
|
|
161
|
+
train_file: str = "train.jsonl"
|
|
162
|
+
batch_size: int = Field(..., gt=0, le=32) # Max 32
|
|
163
|
+
group_size: int = Field(default=4, gt=0, le=64) # Max 64, RL only
|
|
164
|
+
val_split: float | None = Field(default=None, ge=0, le=1) # Validation split ratio (0.0-1.0)
|
|
165
|
+
|
|
166
|
+
@model_validator(mode="after")
|
|
167
|
+
def _check_batch_group_product(self):
|
|
168
|
+
"""Validate batch_size * group_size <= 2048 (sandbox concurrency limit)."""
|
|
169
|
+
prod = self.batch_size * self.group_size
|
|
170
|
+
if prod > 2048:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"batch_size * group_size must be <= 2048 (got {self.batch_size} * {self.group_size} = {prod})"
|
|
173
|
+
)
|
|
174
|
+
return self
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ModelConfig(BaseModel):
|
|
178
|
+
"""Model configuration.
|
|
179
|
+
|
|
180
|
+
The `path` field accepts either:
|
|
181
|
+
- A supported base model name (e.g., "Qwen/Qwen3-4B-Instruct-2507")
|
|
182
|
+
- A ReinforceNow model ID (e.g., "acfa2862-23a9-4e65-ab68-b9b2698b0e75") to resume from a finetuned model
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
|
186
|
+
|
|
187
|
+
path: str = Field(
|
|
188
|
+
...,
|
|
189
|
+
description="Base model name (e.g., 'Qwen/Qwen3-8B') or a ReinforceNow model ID to resume from",
|
|
190
|
+
)
|
|
191
|
+
qlora_rank: int = Field(default=32, ge=1)
|
|
192
|
+
qlora_alpha: int | None = Field(default=None, ge=1) # Defaults to qlora_rank * 2
|
|
193
|
+
name: str | None = None # Custom name for the output model (default: auto-generated)
|
|
194
|
+
description: str | None = None # Custom description for the output model
|
|
195
|
+
|
|
196
|
+
# Internal fields resolved by the server (not set by users)
|
|
197
|
+
resolved_checkpoint_path: str | None = Field(
|
|
198
|
+
default=None,
|
|
199
|
+
alias="_resolvedCheckpointPath",
|
|
200
|
+
description="Internal: Tinker checkpoint path resolved from model ID",
|
|
201
|
+
)
|
|
202
|
+
resolved_base_model: str | None = Field(
|
|
203
|
+
default=None,
|
|
204
|
+
alias="_baseModelName",
|
|
205
|
+
description="Internal: Base model name resolved from model ID",
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class AlgorithmConfig(BaseModel):
|
|
210
|
+
"""Algorithm configuration for RL training."""
|
|
211
|
+
|
|
212
|
+
model_config = ConfigDict(extra="forbid")
|
|
213
|
+
|
|
214
|
+
loss_fn: Literal["ppo", "importance_sampling"] = "ppo"
|
|
215
|
+
adv_estimator: Literal["grpo", "gae", "reinforce"] = "grpo"
|
|
216
|
+
kl_penalty_coef: float = Field(default=0.01, ge=0)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class RolloutConfig(BaseModel):
|
|
220
|
+
"""Rollout configuration for RL training."""
|
|
221
|
+
|
|
222
|
+
model_config = ConfigDict(extra="forbid")
|
|
223
|
+
|
|
224
|
+
max_turns: int = Field(default=1, gt=0)
|
|
225
|
+
max_tokens: int = Field(default=2048, gt=0)
|
|
226
|
+
termination_policy: Literal["max_turns", "last_tool"] = "last_tool"
|
|
227
|
+
thinking_mode: Literal["disabled", "easy", "medium", "hard"] | None = (
|
|
228
|
+
None # None = model default
|
|
229
|
+
)
|
|
230
|
+
mcp_url: str | list[str] | None = Field(
|
|
231
|
+
default=None,
|
|
232
|
+
description="MCP server URL(s) for tools. Can be a single URL or a list of URLs. Can be used alongside env.py to combine both tool sources.",
|
|
233
|
+
)
|
|
234
|
+
max_tool_response_chars: int | None = Field(
|
|
235
|
+
default=4000,
|
|
236
|
+
gt=0,
|
|
237
|
+
description="Maximum characters for tool responses. Longer responses are truncated. Set to null/None to disable truncation.",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class TrainerConfig(BaseModel):
|
|
242
|
+
"""Trainer configuration."""
|
|
243
|
+
|
|
244
|
+
model_config = ConfigDict(extra="forbid")
|
|
245
|
+
|
|
246
|
+
num_epochs: int = Field(..., gt=0)
|
|
247
|
+
learning_rate: float = Field(default=0.0001, gt=0)
|
|
248
|
+
save_step: int = Field(default=0, ge=0) # Save checkpoint every N steps (0 = end of epoch only)
|
|
249
|
+
eval_step: int = Field(default=0, ge=0) # Evaluate every N steps (0 = end of epoch only)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class ProjectConfig(BaseModel):
|
|
253
|
+
"""Full project configuration."""
|
|
254
|
+
|
|
255
|
+
model_config = ConfigDict(extra="forbid")
|
|
256
|
+
|
|
257
|
+
project_id: str = Field(default="")
|
|
258
|
+
project_name: str = Field(default="")
|
|
259
|
+
dataset_id: str = Field(default="")
|
|
260
|
+
dataset_name: str | None = None
|
|
261
|
+
dataset_type: DatasetType = Field(...)
|
|
262
|
+
organization_id: str | None = None
|
|
263
|
+
|
|
264
|
+
# Nested config sections
|
|
265
|
+
data: DataConfig = Field(...)
|
|
266
|
+
model: ModelConfig = Field(...)
|
|
267
|
+
trainer: TrainerConfig = Field(...)
|
|
268
|
+
algorithm: AlgorithmConfig | None = None # RL only
|
|
269
|
+
rollout: RolloutConfig | None = None # RL only
|
|
270
|
+
|
|
271
|
+
@model_validator(mode="after")
|
|
272
|
+
def validate_config(self):
|
|
273
|
+
"""Set defaults and validate based on dataset_type."""
|
|
274
|
+
if self.dataset_type == DatasetType.RL:
|
|
275
|
+
# Set RL defaults if not specified
|
|
276
|
+
if self.algorithm is None:
|
|
277
|
+
self.algorithm = AlgorithmConfig()
|
|
278
|
+
if self.rollout is None:
|
|
279
|
+
self.rollout = RolloutConfig()
|
|
280
|
+
else: # SFT
|
|
281
|
+
# Clear RL-specific configs for SFT
|
|
282
|
+
self.algorithm = None
|
|
283
|
+
self.rollout = None
|
|
284
|
+
|
|
285
|
+
# Validate qlora_rank against model-specific limits
|
|
286
|
+
# Only validate for known models (model IDs for finetuned models are validated server-side)
|
|
287
|
+
model_path = self.model.path
|
|
288
|
+
if "/" in model_path: # Standard model path (not a model ID)
|
|
289
|
+
max_rank = get_max_lora_rank(model_path)
|
|
290
|
+
if self.model.qlora_rank > max_rank:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
f"qlora_rank {self.model.qlora_rank} exceeds maximum {max_rank} for model {model_path}"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
return self
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
project_id: ""
|
|
2
|
+
project_name: "DeepSeek R1 Aha Moment - Countdown Game"
|
|
3
|
+
dataset_id: ""
|
|
4
|
+
dataset_name: "countdown-puzzles"
|
|
5
|
+
dataset_type: rl
|
|
6
|
+
organization_id: ""
|
|
7
|
+
data:
|
|
8
|
+
train_file: train.jsonl
|
|
9
|
+
batch_size: 32
|
|
10
|
+
group_size: 16
|
|
11
|
+
model:
|
|
12
|
+
path: openai/gpt-oss-20b
|
|
13
|
+
qlora_rank: 32
|
|
14
|
+
name: "Countdown Reasoning Model"
|
|
15
|
+
description: "Reproduces DeepSeek R1 aha moment using GRPO on the Countdown game"
|
|
16
|
+
algorithm:
|
|
17
|
+
loss_fn: ppo
|
|
18
|
+
adv_estimator: grpo
|
|
19
|
+
kl_penalty_coef: 0.001
|
|
20
|
+
rollout:
|
|
21
|
+
max_turns: 1
|
|
22
|
+
max_tokens: 16384
|
|
23
|
+
trainer:
|
|
24
|
+
num_epochs: 4
|
|
25
|
+
learning_rate: 0.00005
|
|
26
|
+
save_step: 333
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from rnow.core import RewardArgs, reward
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@reward(precondition=True)
|
|
7
|
+
def format(args: RewardArgs, messages: list) -> float:
|
|
8
|
+
"""Check for \\boxed{} format."""
|
|
9
|
+
response = messages[-1]["content"]
|
|
10
|
+
return 1.0 if re.search(r"\\boxed\{", response) else 0.0
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@reward
|
|
14
|
+
def accuracy(args: RewardArgs, messages: list) -> float:
|
|
15
|
+
"""Check if equation equals target and uses all numbers exactly once."""
|
|
16
|
+
response = messages[-1]["content"]
|
|
17
|
+
target = args.metadata["target"]
|
|
18
|
+
numbers = args.metadata["numbers"]
|
|
19
|
+
|
|
20
|
+
# Extract equation from \boxed{} (take the last one)
|
|
21
|
+
matches = re.findall(r"\\boxed\{([^}]*)\}", response)
|
|
22
|
+
if not matches:
|
|
23
|
+
return 0.0
|
|
24
|
+
|
|
25
|
+
equation = matches[-1].strip()
|
|
26
|
+
|
|
27
|
+
# Check all numbers used exactly once
|
|
28
|
+
used = [int(n) for n in re.findall(r"\d+", equation)]
|
|
29
|
+
if sorted(used) != sorted(numbers):
|
|
30
|
+
return 0.0
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
result = eval(equation)
|
|
34
|
+
return 1.0 if abs(result - target) < 0.0001 else 0.0
|
|
35
|
+
except:
|
|
36
|
+
return 0.0
|