textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# textpolicy/rewards/adapters.py
|
|
2
|
+
"""
|
|
3
|
+
Core configuration models and utility functions for MLX-optimized reward system.
|
|
4
|
+
|
|
5
|
+
This module provides essential patterns for building modular reward systems:
|
|
6
|
+
- Configuration models for reward system setup
|
|
7
|
+
- Sample types for data handling
|
|
8
|
+
- Math utilities for text processing
|
|
9
|
+
- Async patterns for external model integration
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import mlx.core as mx
|
|
13
|
+
import asyncio
|
|
14
|
+
from typing import Dict, List, Any, Optional, Union, Callable
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from enum import Enum
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
# Optional dependencies
|
|
20
|
+
try:
|
|
21
|
+
import aiohttp # type: ignore
|
|
22
|
+
HAS_AIOHTTP = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
HAS_AIOHTTP = False
|
|
25
|
+
aiohttp = None
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from pydantic import BaseModel, Field
|
|
29
|
+
HAS_PYDANTIC = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
HAS_PYDANTIC = False
|
|
32
|
+
# Fallback simple config class
|
|
33
|
+
class BaseModel:
|
|
34
|
+
def __init__(self, **kwargs):
|
|
35
|
+
for k, v in kwargs.items():
|
|
36
|
+
setattr(self, k, v)
|
|
37
|
+
|
|
38
|
+
def Field(default=None, **kwargs):
|
|
39
|
+
return default
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ==========================================
|
|
43
|
+
# CONFIGURATION MODELS
|
|
44
|
+
# ==========================================
|
|
45
|
+
|
|
46
|
+
class MLXRewardConfig(BaseModel):
|
|
47
|
+
"""Configuration for individual reward functions."""
|
|
48
|
+
weight: float = Field(1.0, description="Weight of this reward function")
|
|
49
|
+
params: Dict[str, Any] = Field(default_factory=dict, description="Parameters for the reward function")
|
|
50
|
+
verifiers: Optional[List[str]] = Field(None, description="List of verifier names")
|
|
51
|
+
verifier_penalty: float = Field(0.0, description="Penalty if verifiers fail")
|
|
52
|
+
enable_mlx_compilation: bool = Field(True, description="Enable MLX compilation")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MLXRewardSystemConfig(BaseModel):
|
|
56
|
+
"""Configuration for complete reward system setup."""
|
|
57
|
+
reward_configs: Dict[str, MLXRewardConfig] = Field(
|
|
58
|
+
default_factory=dict,
|
|
59
|
+
description="MLX-optimized reward configurations"
|
|
60
|
+
)
|
|
61
|
+
batch_size: int = Field(32, description="Batch size for MLX processing")
|
|
62
|
+
max_workers: int = Field(4, description="Max workers for async processing")
|
|
63
|
+
enable_external_rewards: bool = Field(False, description="Enable external reward models")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ==========================================
|
|
67
|
+
# SAMPLE TYPES
|
|
68
|
+
# ==========================================
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class MLXSample:
|
|
72
|
+
"""Lightweight sample type for text generation data."""
|
|
73
|
+
prompt: str = ""
|
|
74
|
+
response: str = ""
|
|
75
|
+
label: Optional[str] = None
|
|
76
|
+
reward: Optional[float] = None
|
|
77
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
78
|
+
|
|
79
|
+
class Status(Enum):
|
|
80
|
+
PENDING = "pending"
|
|
81
|
+
COMPLETED = "completed"
|
|
82
|
+
FAILED = "failed"
|
|
83
|
+
|
|
84
|
+
status: Status = Status.PENDING
|
|
85
|
+
|
|
86
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
87
|
+
"""Convert to dictionary for processing."""
|
|
88
|
+
return {
|
|
89
|
+
"prompt": self.prompt,
|
|
90
|
+
"completion": self.response, # Map to our naming convention
|
|
91
|
+
"example": {
|
|
92
|
+
"label": self.label,
|
|
93
|
+
"metadata": self.metadata,
|
|
94
|
+
**self.metadata # Flatten metadata
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# ==========================================
|
|
100
|
+
# MATH UTILITIES
|
|
101
|
+
# ==========================================
|
|
102
|
+
|
|
103
|
+
def extract_boxed_answer(text: str) -> Optional[str]:
|
|
104
|
+
"""Extract LaTeX boxed answer from text."""
|
|
105
|
+
idx = text.rfind("\\boxed{")
|
|
106
|
+
if idx < 0:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
i = idx
|
|
110
|
+
brace_count = 0
|
|
111
|
+
while i < len(text):
|
|
112
|
+
if text[i] == "{":
|
|
113
|
+
brace_count += 1
|
|
114
|
+
elif text[i] == "}":
|
|
115
|
+
brace_count -= 1
|
|
116
|
+
if brace_count == 0:
|
|
117
|
+
return text[idx:i+1]
|
|
118
|
+
i += 1
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def normalize_math_answer(answer: str) -> str:
|
|
123
|
+
"""Normalize mathematical answer for comparison."""
|
|
124
|
+
answer = answer.split("=")[-1].strip()
|
|
125
|
+
|
|
126
|
+
# Remove common expressions
|
|
127
|
+
removals = ["square", "dollars", "units", "\\text{}", "^\\circ"]
|
|
128
|
+
for removal in removals:
|
|
129
|
+
answer = answer.replace(removal, "")
|
|
130
|
+
|
|
131
|
+
# Normalize fractions and roots
|
|
132
|
+
answer = re.sub(r"(frac)([^{])(.)", r"frac{\2}{\3}", answer)
|
|
133
|
+
answer = re.sub(r"(sqrt)([^{])", r"sqrt{\2}", answer)
|
|
134
|
+
answer = answer.replace("$", "").strip()
|
|
135
|
+
|
|
136
|
+
return answer
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def compute_f1_score(prediction: str, ground_truth: str) -> float:
|
|
140
|
+
"""Compute F1 score between prediction and ground truth."""
|
|
141
|
+
pred_tokens = prediction.lower().split()
|
|
142
|
+
truth_tokens = ground_truth.lower().split()
|
|
143
|
+
|
|
144
|
+
if not truth_tokens:
|
|
145
|
+
return 1.0 if not pred_tokens else 0.0
|
|
146
|
+
|
|
147
|
+
common = set(pred_tokens) & set(truth_tokens)
|
|
148
|
+
if not common:
|
|
149
|
+
return 0.0
|
|
150
|
+
|
|
151
|
+
precision = len(common) / len(pred_tokens) if pred_tokens else 0.0
|
|
152
|
+
recall = len(common) / len(truth_tokens)
|
|
153
|
+
|
|
154
|
+
if precision + recall == 0:
|
|
155
|
+
return 0.0
|
|
156
|
+
|
|
157
|
+
return 2 * (precision * recall) / (precision + recall)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# ==========================================
|
|
161
|
+
# EXTERNAL REWARD MODELS
|
|
162
|
+
# ==========================================
|
|
163
|
+
|
|
164
|
+
class MLXExternalRewardModel:
|
|
165
|
+
"""Async client for external reward model APIs."""
|
|
166
|
+
|
|
167
|
+
def __init__(self, url: str, timeout: float = 30.0):
|
|
168
|
+
self.url = url
|
|
169
|
+
self.timeout = timeout
|
|
170
|
+
|
|
171
|
+
async def get_reward(self, sample: MLXSample) -> float:
|
|
172
|
+
"""Get reward from external model."""
|
|
173
|
+
if not HAS_AIOHTTP or aiohttp is None:
|
|
174
|
+
raise ImportError("aiohttp is required for external reward models. Install with: uv add aiohttp")
|
|
175
|
+
|
|
176
|
+
payload = {
|
|
177
|
+
"prompt": sample.prompt,
|
|
178
|
+
"response": sample.response,
|
|
179
|
+
"label": sample.label,
|
|
180
|
+
"metadata": sample.metadata
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
async with aiohttp.ClientSession() as session:
|
|
185
|
+
async with session.post(
|
|
186
|
+
self.url,
|
|
187
|
+
json=payload,
|
|
188
|
+
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
189
|
+
) as resp:
|
|
190
|
+
if resp.status == 200:
|
|
191
|
+
result = await resp.json()
|
|
192
|
+
return float(result.get("reward", 0.0))
|
|
193
|
+
else:
|
|
194
|
+
return 0.0
|
|
195
|
+
except Exception:
|
|
196
|
+
return 0.0
|
|
197
|
+
|
|
198
|
+
async def get_batch_rewards(self, samples: List[MLXSample]) -> List[float]:
|
|
199
|
+
"""Get rewards for batch of samples."""
|
|
200
|
+
tasks = [self.get_reward(sample) for sample in samples]
|
|
201
|
+
results = await asyncio.gather(*tasks, return_exceptions=False)
|
|
202
|
+
return list(results)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
# ==========================================
|
|
206
|
+
# MLX-NATIVE BRIDGE ADAPTERS
|
|
207
|
+
# ==========================================
|
|
208
|
+
|
|
209
|
+
def create_mlx_batch_adapter(
|
|
210
|
+
reward_configs: Dict[str, MLXRewardConfig],
|
|
211
|
+
external_models: Optional[Dict[str, MLXExternalRewardModel]] = None
|
|
212
|
+
) -> Union[Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array],
|
|
213
|
+
Callable[[List[str], List[str], List[Dict[str, Any]]], Any]]:
|
|
214
|
+
"""
|
|
215
|
+
Create MLX-native batch adapter for processing reward configurations.
|
|
216
|
+
|
|
217
|
+
This function takes configuration patterns and creates MLX-optimized processors.
|
|
218
|
+
"""
|
|
219
|
+
from .registry import create_configured_reward_function, RewardConfig
|
|
220
|
+
|
|
221
|
+
# Convert MLX configs to our internal format
|
|
222
|
+
internal_configs = []
|
|
223
|
+
for name, config in reward_configs.items():
|
|
224
|
+
internal_config = RewardConfig(
|
|
225
|
+
name=name,
|
|
226
|
+
weight=config.weight,
|
|
227
|
+
params=config.params,
|
|
228
|
+
verifiers=config.verifiers or [],
|
|
229
|
+
verifier_penalty=config.verifier_penalty
|
|
230
|
+
)
|
|
231
|
+
internal_configs.append(internal_config)
|
|
232
|
+
|
|
233
|
+
# Create base MLX processor
|
|
234
|
+
from .mlx_batch_processor import create_mlx_optimized_batch_processor
|
|
235
|
+
base_processor = create_mlx_optimized_batch_processor(internal_configs)
|
|
236
|
+
|
|
237
|
+
# Add external model support if needed
|
|
238
|
+
if external_models:
|
|
239
|
+
async def async_processor(
|
|
240
|
+
prompts: List[str],
|
|
241
|
+
completions: List[str],
|
|
242
|
+
examples: List[Dict[str, Any]]
|
|
243
|
+
) -> mx.array:
|
|
244
|
+
# Process local rewards
|
|
245
|
+
local_rewards = base_processor(prompts, completions, examples)
|
|
246
|
+
|
|
247
|
+
# Process external rewards
|
|
248
|
+
external_rewards = []
|
|
249
|
+
samples = [
|
|
250
|
+
MLXSample(prompt=p, response=c, metadata=e)
|
|
251
|
+
for p, c, e in zip(prompts, completions, examples)
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
for name, model in external_models.items():
|
|
255
|
+
batch_rewards = await model.get_batch_rewards(samples)
|
|
256
|
+
external_rewards.append(mx.array(batch_rewards))
|
|
257
|
+
|
|
258
|
+
# Combine all rewards
|
|
259
|
+
if external_rewards:
|
|
260
|
+
all_rewards = mx.stack([local_rewards] + external_rewards, axis=1)
|
|
261
|
+
return mx.mean(all_rewards, axis=1) # Simple average
|
|
262
|
+
else:
|
|
263
|
+
return local_rewards
|
|
264
|
+
|
|
265
|
+
return async_processor
|
|
266
|
+
else:
|
|
267
|
+
return base_processor
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def samples_to_mlx_format(samples: List[MLXSample]) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
|
271
|
+
"""Convert MLXSample list to our MLX processing format."""
|
|
272
|
+
prompts = [s.prompt for s in samples]
|
|
273
|
+
completions = [s.response for s in samples]
|
|
274
|
+
examples = [s.to_dict()["example"] for s in samples]
|
|
275
|
+
return prompts, completions, examples
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def mlx_format_to_samples(
|
|
279
|
+
prompts: List[str],
|
|
280
|
+
completions: List[str],
|
|
281
|
+
examples: List[Dict[str, Any]],
|
|
282
|
+
rewards: mx.array
|
|
283
|
+
) -> List[MLXSample]:
|
|
284
|
+
"""Convert MLX format back to MLXSample list."""
|
|
285
|
+
samples = []
|
|
286
|
+
for i, (prompt, completion, example) in enumerate(zip(prompts, completions, examples)):
|
|
287
|
+
sample = MLXSample(
|
|
288
|
+
prompt=prompt,
|
|
289
|
+
response=completion,
|
|
290
|
+
label=example.get("label"),
|
|
291
|
+
reward=float(rewards[i]),
|
|
292
|
+
metadata=example.get("metadata", {}),
|
|
293
|
+
status=MLXSample.Status.COMPLETED
|
|
294
|
+
)
|
|
295
|
+
samples.append(sample)
|
|
296
|
+
return samples
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# ==========================================
|
|
300
|
+
# ESSENTIAL MATH REWARD FUNCTIONS
|
|
301
|
+
# ==========================================
|
|
302
|
+
|
|
303
|
+
from .registry import reward
|
|
304
|
+
|
|
305
|
+
@reward(name="math_accuracy")
|
|
306
|
+
def math_accuracy_reward(
|
|
307
|
+
prompt: str,
|
|
308
|
+
completion: str,
|
|
309
|
+
example: Dict[str, Any],
|
|
310
|
+
extract_boxed: bool = True,
|
|
311
|
+
**kwargs
|
|
312
|
+
) -> float:
|
|
313
|
+
"""Math accuracy reward using boxed answer extraction."""
|
|
314
|
+
ground_truth = example.get("label") or example.get("ground_truth")
|
|
315
|
+
if not ground_truth:
|
|
316
|
+
return 0.0
|
|
317
|
+
|
|
318
|
+
# Extract answer if needed
|
|
319
|
+
prediction = completion
|
|
320
|
+
if extract_boxed:
|
|
321
|
+
boxed = extract_boxed_answer(completion)
|
|
322
|
+
if boxed:
|
|
323
|
+
# Remove \boxed{} wrapper
|
|
324
|
+
prediction = boxed[7:-1] if boxed.startswith("\\boxed{") and boxed.endswith("}") else boxed
|
|
325
|
+
|
|
326
|
+
# Normalize both answers
|
|
327
|
+
pred_normalized = normalize_math_answer(prediction)
|
|
328
|
+
truth_normalized = normalize_math_answer(ground_truth)
|
|
329
|
+
|
|
330
|
+
# Exact match
|
|
331
|
+
if pred_normalized == truth_normalized:
|
|
332
|
+
return 1.0
|
|
333
|
+
|
|
334
|
+
# Fallback to F1 score
|
|
335
|
+
return compute_f1_score(pred_normalized, truth_normalized)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@reward(name="f1_score")
|
|
339
|
+
def f1_score_reward(
|
|
340
|
+
prompt: str,
|
|
341
|
+
completion: str,
|
|
342
|
+
example: Dict[str, Any],
|
|
343
|
+
**kwargs
|
|
344
|
+
) -> float:
|
|
345
|
+
"""F1 score reward for text overlap measurement."""
|
|
346
|
+
ground_truth = example.get("label") or example.get("ground_truth", "")
|
|
347
|
+
return compute_f1_score(completion, ground_truth)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
# ==========================================
|
|
351
|
+
# CONFIGURATION HELPERS
|
|
352
|
+
# ==========================================
|
|
353
|
+
|
|
354
|
+
def create_mlx_system_from_config(config_dict: Dict[str, Any]) -> Callable:
|
|
355
|
+
"""
|
|
356
|
+
Create complete MLX reward system from configuration dictionary.
|
|
357
|
+
|
|
358
|
+
This function creates MLX-optimized reward processors from configuration.
|
|
359
|
+
"""
|
|
360
|
+
# Parse configuration
|
|
361
|
+
reward_configs = {}
|
|
362
|
+
external_models = {}
|
|
363
|
+
|
|
364
|
+
for name, config_data in config_dict.items():
|
|
365
|
+
if config_data.get("external_url"):
|
|
366
|
+
# External model
|
|
367
|
+
external_models[name] = MLXExternalRewardModel(
|
|
368
|
+
url=config_data["external_url"],
|
|
369
|
+
timeout=config_data.get("timeout", 30.0)
|
|
370
|
+
)
|
|
371
|
+
else:
|
|
372
|
+
# Local MLX reward
|
|
373
|
+
reward_configs[name] = MLXRewardConfig(**config_data)
|
|
374
|
+
|
|
375
|
+
# Create adapter
|
|
376
|
+
return create_mlx_batch_adapter(reward_configs, external_models if external_models else None)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def get_available_adapters() -> Dict[str, List[str]]:
|
|
380
|
+
"""List available adapter components."""
|
|
381
|
+
return {
|
|
382
|
+
"reward_functions": ["math_accuracy", "f1_score"],
|
|
383
|
+
"math_utilities": ["extract_boxed_answer", "normalize_math_answer", "compute_f1_score"],
|
|
384
|
+
"external_models": ["MLXExternalRewardModel"],
|
|
385
|
+
"config_models": ["MLXRewardConfig", "MLXRewardSystemConfig"],
|
|
386
|
+
"sample_types": ["MLXSample"]
|
|
387
|
+
}
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# textpolicy/rewards/basic.py
|
|
2
|
+
"""
|
|
3
|
+
Basic pure reward functions for text generation following retrain's patterns.
|
|
4
|
+
|
|
5
|
+
All functions follow the signature: (prompt: str, completion: str, example: Dict[str, Any], **kwargs) -> float
|
|
6
|
+
All functions are pure - no side effects, deterministic output.
|
|
7
|
+
All functions are MLX compilation compatible and registered via decorators.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import List, Optional, Dict, Any
|
|
11
|
+
from .registry import reward
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@reward
|
|
15
|
+
def length_reward(
|
|
16
|
+
prompt: str,
|
|
17
|
+
completion: str,
|
|
18
|
+
example: Dict[str, Any],
|
|
19
|
+
target_length: int = 50,
|
|
20
|
+
tolerance: float = 0.2,
|
|
21
|
+
**kwargs
|
|
22
|
+
) -> float:
|
|
23
|
+
"""
|
|
24
|
+
Pure function rewarding responses close to target length.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
prompt: Input prompt (not used but kept for signature consistency)
|
|
28
|
+
completion: Generated response text
|
|
29
|
+
example: Example data context (not used here)
|
|
30
|
+
target_length: Target length in words
|
|
31
|
+
tolerance: Tolerance for length deviation (0.2 = 20%)
|
|
32
|
+
**kwargs: Additional parameters
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Reward between 0.0 and 1.0
|
|
36
|
+
"""
|
|
37
|
+
if not completion.strip():
|
|
38
|
+
return 0.0 # no text, no fluency reward
|
|
39
|
+
|
|
40
|
+
actual_length = len(completion.split())
|
|
41
|
+
deviation = abs(actual_length - target_length) / target_length
|
|
42
|
+
|
|
43
|
+
if deviation <= tolerance:
|
|
44
|
+
return 1.0
|
|
45
|
+
else:
|
|
46
|
+
# Linear decay beyond tolerance
|
|
47
|
+
return max(0.0, 1.0 - (deviation - tolerance) / (1.0 - tolerance))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@reward
|
|
51
|
+
def keyword_reward(
|
|
52
|
+
prompt: str,
|
|
53
|
+
completion: str,
|
|
54
|
+
example: Dict[str, Any],
|
|
55
|
+
keywords: Optional[List[str]] = None,
|
|
56
|
+
bonus_multiplier: float = 1.0,
|
|
57
|
+
**kwargs
|
|
58
|
+
) -> float:
|
|
59
|
+
"""
|
|
60
|
+
Pure function rewarding keyword usage.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
prompt: Input prompt (analyzed for required keywords)
|
|
64
|
+
completion: Generated response text
|
|
65
|
+
example: Example data context (may contain keywords if not provided)
|
|
66
|
+
keywords: Keywords to encourage (can be None to use from example)
|
|
67
|
+
bonus_multiplier: Multiplier for bonus points
|
|
68
|
+
**kwargs: Additional parameters
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Reward between 0.0 and potentially > 1.0 with bonuses
|
|
72
|
+
"""
|
|
73
|
+
if not completion.strip():
|
|
74
|
+
return 0.0
|
|
75
|
+
|
|
76
|
+
# Get keywords from parameter or example context
|
|
77
|
+
if keywords is None:
|
|
78
|
+
keywords = example.get('keywords', [])
|
|
79
|
+
|
|
80
|
+
if not keywords:
|
|
81
|
+
return 0.0
|
|
82
|
+
|
|
83
|
+
completion_lower = completion.lower()
|
|
84
|
+
|
|
85
|
+
# Count keyword matches
|
|
86
|
+
matches = sum(1 for kw in keywords if kw.lower() in completion_lower)
|
|
87
|
+
base_reward = matches / len(keywords) if keywords else 0.0
|
|
88
|
+
|
|
89
|
+
# Bonus for using keywords not in prompt
|
|
90
|
+
prompt_lower = prompt.lower()
|
|
91
|
+
bonus_keywords = [kw for kw in keywords if kw.lower() not in prompt_lower]
|
|
92
|
+
bonus_matches = sum(1 for kw in bonus_keywords if kw.lower() in completion_lower)
|
|
93
|
+
bonus_reward = (bonus_matches / len(bonus_keywords) if bonus_keywords else 0.0) * bonus_multiplier
|
|
94
|
+
|
|
95
|
+
return min(1.0, base_reward + bonus_reward)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@reward
|
|
99
|
+
def perplexity_reward(
|
|
100
|
+
prompt: str,
|
|
101
|
+
completion: str,
|
|
102
|
+
example: Dict[str, Any],
|
|
103
|
+
model = None, # Optional MLX model for perplexity computation
|
|
104
|
+
max_perplexity: float = 100.0,
|
|
105
|
+
**kwargs
|
|
106
|
+
) -> float:
|
|
107
|
+
"""
|
|
108
|
+
Pure function rewarding low perplexity (high fluency).
|
|
109
|
+
|
|
110
|
+
If no model provided, uses simple heuristics.
|
|
111
|
+
With model, computes actual perplexity using MLX.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
prompt: Input prompt
|
|
115
|
+
completion: Generated response text
|
|
116
|
+
example: Example data context (may contain model reference)
|
|
117
|
+
model: Optional MLX model for perplexity computation
|
|
118
|
+
max_perplexity: Maximum perplexity for normalization
|
|
119
|
+
**kwargs: Additional parameters
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Reward between 0.0 and 1.0 (higher = more fluent)
|
|
123
|
+
"""
|
|
124
|
+
if not completion.strip():
|
|
125
|
+
return 0.0
|
|
126
|
+
|
|
127
|
+
# Use model from example if not provided
|
|
128
|
+
if model is None:
|
|
129
|
+
model = example.get('model')
|
|
130
|
+
|
|
131
|
+
if model is not None:
|
|
132
|
+
# MLX-based perplexity computation is not yet implemented.
|
|
133
|
+
# Remove model to ensure heuristic fallback is used.
|
|
134
|
+
model = None
|
|
135
|
+
|
|
136
|
+
# Fallback: simple heuristics for fluency
|
|
137
|
+
words = completion.split()
|
|
138
|
+
|
|
139
|
+
# Penalize very short responses (minimum heuristic fluency)
|
|
140
|
+
if len(words) < 3:
|
|
141
|
+
return 0.2
|
|
142
|
+
|
|
143
|
+
# Penalize repetition
|
|
144
|
+
unique_words = len(set(words))
|
|
145
|
+
repetition_penalty = unique_words / len(words)
|
|
146
|
+
|
|
147
|
+
# Penalize very long words (might be gibberish)
|
|
148
|
+
avg_word_length = sum(len(word) for word in words) / len(words)
|
|
149
|
+
length_penalty = 1.0 if avg_word_length <= 8 else max(0.3, 1.0 - (avg_word_length - 8) * 0.1)
|
|
150
|
+
|
|
151
|
+
# Penalize lack of punctuation in longer responses
|
|
152
|
+
has_punctuation = any(char in completion for char in '.!?')
|
|
153
|
+
punct_penalty = 1.0 if len(words) < 10 or has_punctuation else 0.8
|
|
154
|
+
|
|
155
|
+
fluency_score = repetition_penalty * length_penalty * punct_penalty
|
|
156
|
+
return min(1.0, fluency_score)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@reward
|
|
160
|
+
def accuracy_reward(
|
|
161
|
+
prompt: str,
|
|
162
|
+
completion: str,
|
|
163
|
+
example: Dict[str, Any],
|
|
164
|
+
ground_truth: Optional[str] = None,
|
|
165
|
+
similarity_threshold: float = 0.7,
|
|
166
|
+
**kwargs
|
|
167
|
+
) -> float:
|
|
168
|
+
"""
|
|
169
|
+
Pure function rewarding factual accuracy.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
prompt: Input prompt
|
|
173
|
+
completion: Generated response text
|
|
174
|
+
example: Example data context (may contain ground truth)
|
|
175
|
+
ground_truth: Optional ground truth for comparison
|
|
176
|
+
similarity_threshold: Threshold for similarity matching
|
|
177
|
+
**kwargs: Additional parameters
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Reward between 0.0 and 1.0
|
|
181
|
+
"""
|
|
182
|
+
if not completion.strip():
|
|
183
|
+
return 0.0
|
|
184
|
+
|
|
185
|
+
# Get ground truth from parameter or example context
|
|
186
|
+
if ground_truth is None:
|
|
187
|
+
ground_truth = example.get('ground_truth') or example.get('label')
|
|
188
|
+
|
|
189
|
+
if ground_truth is None:
|
|
190
|
+
# Without ground truth, use simple fact-checking heuristics
|
|
191
|
+
# Penalize uncertain language in factual contexts
|
|
192
|
+
uncertain_phrases = [
|
|
193
|
+
'i think', 'maybe', 'perhaps', 'possibly', 'not sure',
|
|
194
|
+
'might be', 'could be', 'i believe', 'seems like'
|
|
195
|
+
]
|
|
196
|
+
|
|
197
|
+
uncertainty_penalty = sum(1 for phrase in uncertain_phrases
|
|
198
|
+
if phrase in completion.lower())
|
|
199
|
+
confidence_score = max(0.3, 1.0 - uncertainty_penalty * 0.2)
|
|
200
|
+
|
|
201
|
+
return confidence_score
|
|
202
|
+
|
|
203
|
+
# With ground truth, compute similarity
|
|
204
|
+
# Simple word overlap similarity (could be enhanced with embeddings)
|
|
205
|
+
completion_words = set(completion.lower().split())
|
|
206
|
+
truth_words = set(ground_truth.lower().split())
|
|
207
|
+
|
|
208
|
+
if not truth_words:
|
|
209
|
+
return 0.5
|
|
210
|
+
|
|
211
|
+
overlap = len(completion_words & truth_words)
|
|
212
|
+
similarity = overlap / len(truth_words)
|
|
213
|
+
|
|
214
|
+
return 1.0 if similarity >= similarity_threshold else similarity / similarity_threshold
|