textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,387 @@
1
+ # textpolicy/rewards/adapters.py
2
+ """
3
+ Core configuration models and utility functions for MLX-optimized reward system.
4
+
5
+ This module provides essential patterns for building modular reward systems:
6
+ - Configuration models for reward system setup
7
+ - Sample types for data handling
8
+ - Math utilities for text processing
9
+ - Async patterns for external model integration
10
+ """
11
+
12
+ import mlx.core as mx
13
+ import asyncio
14
+ from typing import Dict, List, Any, Optional, Union, Callable
15
+ from dataclasses import dataclass, field
16
+ from enum import Enum
17
+ import re
18
+
19
+ # Optional dependencies
20
+ try:
21
+ import aiohttp # type: ignore
22
+ HAS_AIOHTTP = True
23
+ except ImportError:
24
+ HAS_AIOHTTP = False
25
+ aiohttp = None
26
+
27
+ try:
28
+ from pydantic import BaseModel, Field
29
+ HAS_PYDANTIC = True
30
+ except ImportError:
31
+ HAS_PYDANTIC = False
32
+ # Fallback simple config class
33
+ class BaseModel:
34
+ def __init__(self, **kwargs):
35
+ for k, v in kwargs.items():
36
+ setattr(self, k, v)
37
+
38
+ def Field(default=None, **kwargs):
39
+ return default
40
+
41
+
42
+ # ==========================================
43
+ # CONFIGURATION MODELS
44
+ # ==========================================
45
+
46
+ class MLXRewardConfig(BaseModel):
47
+ """Configuration for individual reward functions."""
48
+ weight: float = Field(1.0, description="Weight of this reward function")
49
+ params: Dict[str, Any] = Field(default_factory=dict, description="Parameters for the reward function")
50
+ verifiers: Optional[List[str]] = Field(None, description="List of verifier names")
51
+ verifier_penalty: float = Field(0.0, description="Penalty if verifiers fail")
52
+ enable_mlx_compilation: bool = Field(True, description="Enable MLX compilation")
53
+
54
+
55
+ class MLXRewardSystemConfig(BaseModel):
56
+ """Configuration for complete reward system setup."""
57
+ reward_configs: Dict[str, MLXRewardConfig] = Field(
58
+ default_factory=dict,
59
+ description="MLX-optimized reward configurations"
60
+ )
61
+ batch_size: int = Field(32, description="Batch size for MLX processing")
62
+ max_workers: int = Field(4, description="Max workers for async processing")
63
+ enable_external_rewards: bool = Field(False, description="Enable external reward models")
64
+
65
+
66
+ # ==========================================
67
+ # SAMPLE TYPES
68
+ # ==========================================
69
+
70
+ @dataclass
71
+ class MLXSample:
72
+ """Lightweight sample type for text generation data."""
73
+ prompt: str = ""
74
+ response: str = ""
75
+ label: Optional[str] = None
76
+ reward: Optional[float] = None
77
+ metadata: Dict[str, Any] = field(default_factory=dict)
78
+
79
+ class Status(Enum):
80
+ PENDING = "pending"
81
+ COMPLETED = "completed"
82
+ FAILED = "failed"
83
+
84
+ status: Status = Status.PENDING
85
+
86
+ def to_dict(self) -> Dict[str, Any]:
87
+ """Convert to dictionary for processing."""
88
+ return {
89
+ "prompt": self.prompt,
90
+ "completion": self.response, # Map to our naming convention
91
+ "example": {
92
+ "label": self.label,
93
+ "metadata": self.metadata,
94
+ **self.metadata # Flatten metadata
95
+ }
96
+ }
97
+
98
+
99
+ # ==========================================
100
+ # MATH UTILITIES
101
+ # ==========================================
102
+
103
+ def extract_boxed_answer(text: str) -> Optional[str]:
104
+ """Extract LaTeX boxed answer from text."""
105
+ idx = text.rfind("\\boxed{")
106
+ if idx < 0:
107
+ return None
108
+
109
+ i = idx
110
+ brace_count = 0
111
+ while i < len(text):
112
+ if text[i] == "{":
113
+ brace_count += 1
114
+ elif text[i] == "}":
115
+ brace_count -= 1
116
+ if brace_count == 0:
117
+ return text[idx:i+1]
118
+ i += 1
119
+ return None
120
+
121
+
122
+ def normalize_math_answer(answer: str) -> str:
123
+ """Normalize mathematical answer for comparison."""
124
+ answer = answer.split("=")[-1].strip()
125
+
126
+ # Remove common expressions
127
+ removals = ["square", "dollars", "units", "\\text{}", "^\\circ"]
128
+ for removal in removals:
129
+ answer = answer.replace(removal, "")
130
+
131
+ # Normalize fractions and roots
132
+ answer = re.sub(r"(frac)([^{])(.)", r"frac{\2}{\3}", answer)
133
+ answer = re.sub(r"(sqrt)([^{])", r"sqrt{\2}", answer)
134
+ answer = answer.replace("$", "").strip()
135
+
136
+ return answer
137
+
138
+
139
+ def compute_f1_score(prediction: str, ground_truth: str) -> float:
140
+ """Compute F1 score between prediction and ground truth."""
141
+ pred_tokens = prediction.lower().split()
142
+ truth_tokens = ground_truth.lower().split()
143
+
144
+ if not truth_tokens:
145
+ return 1.0 if not pred_tokens else 0.0
146
+
147
+ common = set(pred_tokens) & set(truth_tokens)
148
+ if not common:
149
+ return 0.0
150
+
151
+ precision = len(common) / len(pred_tokens) if pred_tokens else 0.0
152
+ recall = len(common) / len(truth_tokens)
153
+
154
+ if precision + recall == 0:
155
+ return 0.0
156
+
157
+ return 2 * (precision * recall) / (precision + recall)
158
+
159
+
160
+ # ==========================================
161
+ # EXTERNAL REWARD MODELS
162
+ # ==========================================
163
+
164
+ class MLXExternalRewardModel:
165
+ """Async client for external reward model APIs."""
166
+
167
+ def __init__(self, url: str, timeout: float = 30.0):
168
+ self.url = url
169
+ self.timeout = timeout
170
+
171
+ async def get_reward(self, sample: MLXSample) -> float:
172
+ """Get reward from external model."""
173
+ if not HAS_AIOHTTP or aiohttp is None:
174
+ raise ImportError("aiohttp is required for external reward models. Install with: uv add aiohttp")
175
+
176
+ payload = {
177
+ "prompt": sample.prompt,
178
+ "response": sample.response,
179
+ "label": sample.label,
180
+ "metadata": sample.metadata
181
+ }
182
+
183
+ try:
184
+ async with aiohttp.ClientSession() as session:
185
+ async with session.post(
186
+ self.url,
187
+ json=payload,
188
+ timeout=aiohttp.ClientTimeout(total=self.timeout)
189
+ ) as resp:
190
+ if resp.status == 200:
191
+ result = await resp.json()
192
+ return float(result.get("reward", 0.0))
193
+ else:
194
+ return 0.0
195
+ except Exception:
196
+ return 0.0
197
+
198
+ async def get_batch_rewards(self, samples: List[MLXSample]) -> List[float]:
199
+ """Get rewards for batch of samples."""
200
+ tasks = [self.get_reward(sample) for sample in samples]
201
+ results = await asyncio.gather(*tasks, return_exceptions=False)
202
+ return list(results)
203
+
204
+
205
+ # ==========================================
206
+ # MLX-NATIVE BRIDGE ADAPTERS
207
+ # ==========================================
208
+
209
+ def create_mlx_batch_adapter(
210
+ reward_configs: Dict[str, MLXRewardConfig],
211
+ external_models: Optional[Dict[str, MLXExternalRewardModel]] = None
212
+ ) -> Union[Callable[[List[str], List[str], List[Dict[str, Any]]], mx.array],
213
+ Callable[[List[str], List[str], List[Dict[str, Any]]], Any]]:
214
+ """
215
+ Create MLX-native batch adapter for processing reward configurations.
216
+
217
+ This function takes configuration patterns and creates MLX-optimized processors.
218
+ """
219
+ from .registry import create_configured_reward_function, RewardConfig
220
+
221
+ # Convert MLX configs to our internal format
222
+ internal_configs = []
223
+ for name, config in reward_configs.items():
224
+ internal_config = RewardConfig(
225
+ name=name,
226
+ weight=config.weight,
227
+ params=config.params,
228
+ verifiers=config.verifiers or [],
229
+ verifier_penalty=config.verifier_penalty
230
+ )
231
+ internal_configs.append(internal_config)
232
+
233
+ # Create base MLX processor
234
+ from .mlx_batch_processor import create_mlx_optimized_batch_processor
235
+ base_processor = create_mlx_optimized_batch_processor(internal_configs)
236
+
237
+ # Add external model support if needed
238
+ if external_models:
239
+ async def async_processor(
240
+ prompts: List[str],
241
+ completions: List[str],
242
+ examples: List[Dict[str, Any]]
243
+ ) -> mx.array:
244
+ # Process local rewards
245
+ local_rewards = base_processor(prompts, completions, examples)
246
+
247
+ # Process external rewards
248
+ external_rewards = []
249
+ samples = [
250
+ MLXSample(prompt=p, response=c, metadata=e)
251
+ for p, c, e in zip(prompts, completions, examples)
252
+ ]
253
+
254
+ for name, model in external_models.items():
255
+ batch_rewards = await model.get_batch_rewards(samples)
256
+ external_rewards.append(mx.array(batch_rewards))
257
+
258
+ # Combine all rewards
259
+ if external_rewards:
260
+ all_rewards = mx.stack([local_rewards] + external_rewards, axis=1)
261
+ return mx.mean(all_rewards, axis=1) # Simple average
262
+ else:
263
+ return local_rewards
264
+
265
+ return async_processor
266
+ else:
267
+ return base_processor
268
+
269
+
270
+ def samples_to_mlx_format(samples: List[MLXSample]) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
271
+ """Convert MLXSample list to our MLX processing format."""
272
+ prompts = [s.prompt for s in samples]
273
+ completions = [s.response for s in samples]
274
+ examples = [s.to_dict()["example"] for s in samples]
275
+ return prompts, completions, examples
276
+
277
+
278
+ def mlx_format_to_samples(
279
+ prompts: List[str],
280
+ completions: List[str],
281
+ examples: List[Dict[str, Any]],
282
+ rewards: mx.array
283
+ ) -> List[MLXSample]:
284
+ """Convert MLX format back to MLXSample list."""
285
+ samples = []
286
+ for i, (prompt, completion, example) in enumerate(zip(prompts, completions, examples)):
287
+ sample = MLXSample(
288
+ prompt=prompt,
289
+ response=completion,
290
+ label=example.get("label"),
291
+ reward=float(rewards[i]),
292
+ metadata=example.get("metadata", {}),
293
+ status=MLXSample.Status.COMPLETED
294
+ )
295
+ samples.append(sample)
296
+ return samples
297
+
298
+
299
+ # ==========================================
300
+ # ESSENTIAL MATH REWARD FUNCTIONS
301
+ # ==========================================
302
+
303
+ from .registry import reward
304
+
305
+ @reward(name="math_accuracy")
306
+ def math_accuracy_reward(
307
+ prompt: str,
308
+ completion: str,
309
+ example: Dict[str, Any],
310
+ extract_boxed: bool = True,
311
+ **kwargs
312
+ ) -> float:
313
+ """Math accuracy reward using boxed answer extraction."""
314
+ ground_truth = example.get("label") or example.get("ground_truth")
315
+ if not ground_truth:
316
+ return 0.0
317
+
318
+ # Extract answer if needed
319
+ prediction = completion
320
+ if extract_boxed:
321
+ boxed = extract_boxed_answer(completion)
322
+ if boxed:
323
+ # Remove \boxed{} wrapper
324
+ prediction = boxed[7:-1] if boxed.startswith("\\boxed{") and boxed.endswith("}") else boxed
325
+
326
+ # Normalize both answers
327
+ pred_normalized = normalize_math_answer(prediction)
328
+ truth_normalized = normalize_math_answer(ground_truth)
329
+
330
+ # Exact match
331
+ if pred_normalized == truth_normalized:
332
+ return 1.0
333
+
334
+ # Fallback to F1 score
335
+ return compute_f1_score(pred_normalized, truth_normalized)
336
+
337
+
338
+ @reward(name="f1_score")
339
+ def f1_score_reward(
340
+ prompt: str,
341
+ completion: str,
342
+ example: Dict[str, Any],
343
+ **kwargs
344
+ ) -> float:
345
+ """F1 score reward for text overlap measurement."""
346
+ ground_truth = example.get("label") or example.get("ground_truth", "")
347
+ return compute_f1_score(completion, ground_truth)
348
+
349
+
350
+ # ==========================================
351
+ # CONFIGURATION HELPERS
352
+ # ==========================================
353
+
354
+ def create_mlx_system_from_config(config_dict: Dict[str, Any]) -> Callable:
355
+ """
356
+ Create complete MLX reward system from configuration dictionary.
357
+
358
+ This function creates MLX-optimized reward processors from configuration.
359
+ """
360
+ # Parse configuration
361
+ reward_configs = {}
362
+ external_models = {}
363
+
364
+ for name, config_data in config_dict.items():
365
+ if config_data.get("external_url"):
366
+ # External model
367
+ external_models[name] = MLXExternalRewardModel(
368
+ url=config_data["external_url"],
369
+ timeout=config_data.get("timeout", 30.0)
370
+ )
371
+ else:
372
+ # Local MLX reward
373
+ reward_configs[name] = MLXRewardConfig(**config_data)
374
+
375
+ # Create adapter
376
+ return create_mlx_batch_adapter(reward_configs, external_models if external_models else None)
377
+
378
+
379
+ def get_available_adapters() -> Dict[str, List[str]]:
380
+ """List available adapter components."""
381
+ return {
382
+ "reward_functions": ["math_accuracy", "f1_score"],
383
+ "math_utilities": ["extract_boxed_answer", "normalize_math_answer", "compute_f1_score"],
384
+ "external_models": ["MLXExternalRewardModel"],
385
+ "config_models": ["MLXRewardConfig", "MLXRewardSystemConfig"],
386
+ "sample_types": ["MLXSample"]
387
+ }
@@ -0,0 +1,214 @@
1
+ # textpolicy/rewards/basic.py
2
+ """
3
+ Basic pure reward functions for text generation following retrain's patterns.
4
+
5
+ All functions follow the signature: (prompt: str, completion: str, example: Dict[str, Any], **kwargs) -> float
6
+ All functions are pure - no side effects, deterministic output.
7
+ All functions are MLX compilation compatible and registered via decorators.
8
+ """
9
+
10
+ from typing import List, Optional, Dict, Any
11
+ from .registry import reward
12
+
13
+
14
+ @reward
15
+ def length_reward(
16
+ prompt: str,
17
+ completion: str,
18
+ example: Dict[str, Any],
19
+ target_length: int = 50,
20
+ tolerance: float = 0.2,
21
+ **kwargs
22
+ ) -> float:
23
+ """
24
+ Pure function rewarding responses close to target length.
25
+
26
+ Args:
27
+ prompt: Input prompt (not used but kept for signature consistency)
28
+ completion: Generated response text
29
+ example: Example data context (not used here)
30
+ target_length: Target length in words
31
+ tolerance: Tolerance for length deviation (0.2 = 20%)
32
+ **kwargs: Additional parameters
33
+
34
+ Returns:
35
+ Reward between 0.0 and 1.0
36
+ """
37
+ if not completion.strip():
38
+ return 0.0 # no text, no fluency reward
39
+
40
+ actual_length = len(completion.split())
41
+ deviation = abs(actual_length - target_length) / target_length
42
+
43
+ if deviation <= tolerance:
44
+ return 1.0
45
+ else:
46
+ # Linear decay beyond tolerance
47
+ return max(0.0, 1.0 - (deviation - tolerance) / (1.0 - tolerance))
48
+
49
+
50
+ @reward
51
+ def keyword_reward(
52
+ prompt: str,
53
+ completion: str,
54
+ example: Dict[str, Any],
55
+ keywords: Optional[List[str]] = None,
56
+ bonus_multiplier: float = 1.0,
57
+ **kwargs
58
+ ) -> float:
59
+ """
60
+ Pure function rewarding keyword usage.
61
+
62
+ Args:
63
+ prompt: Input prompt (analyzed for required keywords)
64
+ completion: Generated response text
65
+ example: Example data context (may contain keywords if not provided)
66
+ keywords: Keywords to encourage (can be None to use from example)
67
+ bonus_multiplier: Multiplier for bonus points
68
+ **kwargs: Additional parameters
69
+
70
+ Returns:
71
+ Reward between 0.0 and potentially > 1.0 with bonuses
72
+ """
73
+ if not completion.strip():
74
+ return 0.0
75
+
76
+ # Get keywords from parameter or example context
77
+ if keywords is None:
78
+ keywords = example.get('keywords', [])
79
+
80
+ if not keywords:
81
+ return 0.0
82
+
83
+ completion_lower = completion.lower()
84
+
85
+ # Count keyword matches
86
+ matches = sum(1 for kw in keywords if kw.lower() in completion_lower)
87
+ base_reward = matches / len(keywords) if keywords else 0.0
88
+
89
+ # Bonus for using keywords not in prompt
90
+ prompt_lower = prompt.lower()
91
+ bonus_keywords = [kw for kw in keywords if kw.lower() not in prompt_lower]
92
+ bonus_matches = sum(1 for kw in bonus_keywords if kw.lower() in completion_lower)
93
+ bonus_reward = (bonus_matches / len(bonus_keywords) if bonus_keywords else 0.0) * bonus_multiplier
94
+
95
+ return min(1.0, base_reward + bonus_reward)
96
+
97
+
98
+ @reward
99
+ def perplexity_reward(
100
+ prompt: str,
101
+ completion: str,
102
+ example: Dict[str, Any],
103
+ model = None, # Optional MLX model for perplexity computation
104
+ max_perplexity: float = 100.0,
105
+ **kwargs
106
+ ) -> float:
107
+ """
108
+ Pure function rewarding low perplexity (high fluency).
109
+
110
+ If no model provided, uses simple heuristics.
111
+ With model, computes actual perplexity using MLX.
112
+
113
+ Args:
114
+ prompt: Input prompt
115
+ completion: Generated response text
116
+ example: Example data context (may contain model reference)
117
+ model: Optional MLX model for perplexity computation
118
+ max_perplexity: Maximum perplexity for normalization
119
+ **kwargs: Additional parameters
120
+
121
+ Returns:
122
+ Reward between 0.0 and 1.0 (higher = more fluent)
123
+ """
124
+ if not completion.strip():
125
+ return 0.0
126
+
127
+ # Use model from example if not provided
128
+ if model is None:
129
+ model = example.get('model')
130
+
131
+ if model is not None:
132
+ # MLX-based perplexity computation is not yet implemented.
133
+ # Remove model to ensure heuristic fallback is used.
134
+ model = None
135
+
136
+ # Fallback: simple heuristics for fluency
137
+ words = completion.split()
138
+
139
+ # Penalize very short responses (minimum heuristic fluency)
140
+ if len(words) < 3:
141
+ return 0.2
142
+
143
+ # Penalize repetition
144
+ unique_words = len(set(words))
145
+ repetition_penalty = unique_words / len(words)
146
+
147
+ # Penalize very long words (might be gibberish)
148
+ avg_word_length = sum(len(word) for word in words) / len(words)
149
+ length_penalty = 1.0 if avg_word_length <= 8 else max(0.3, 1.0 - (avg_word_length - 8) * 0.1)
150
+
151
+ # Penalize lack of punctuation in longer responses
152
+ has_punctuation = any(char in completion for char in '.!?')
153
+ punct_penalty = 1.0 if len(words) < 10 or has_punctuation else 0.8
154
+
155
+ fluency_score = repetition_penalty * length_penalty * punct_penalty
156
+ return min(1.0, fluency_score)
157
+
158
+
159
+ @reward
160
+ def accuracy_reward(
161
+ prompt: str,
162
+ completion: str,
163
+ example: Dict[str, Any],
164
+ ground_truth: Optional[str] = None,
165
+ similarity_threshold: float = 0.7,
166
+ **kwargs
167
+ ) -> float:
168
+ """
169
+ Pure function rewarding factual accuracy.
170
+
171
+ Args:
172
+ prompt: Input prompt
173
+ completion: Generated response text
174
+ example: Example data context (may contain ground truth)
175
+ ground_truth: Optional ground truth for comparison
176
+ similarity_threshold: Threshold for similarity matching
177
+ **kwargs: Additional parameters
178
+
179
+ Returns:
180
+ Reward between 0.0 and 1.0
181
+ """
182
+ if not completion.strip():
183
+ return 0.0
184
+
185
+ # Get ground truth from parameter or example context
186
+ if ground_truth is None:
187
+ ground_truth = example.get('ground_truth') or example.get('label')
188
+
189
+ if ground_truth is None:
190
+ # Without ground truth, use simple fact-checking heuristics
191
+ # Penalize uncertain language in factual contexts
192
+ uncertain_phrases = [
193
+ 'i think', 'maybe', 'perhaps', 'possibly', 'not sure',
194
+ 'might be', 'could be', 'i believe', 'seems like'
195
+ ]
196
+
197
+ uncertainty_penalty = sum(1 for phrase in uncertain_phrases
198
+ if phrase in completion.lower())
199
+ confidence_score = max(0.3, 1.0 - uncertainty_penalty * 0.2)
200
+
201
+ return confidence_score
202
+
203
+ # With ground truth, compute similarity
204
+ # Simple word overlap similarity (could be enhanced with embeddings)
205
+ completion_words = set(completion.lower().split())
206
+ truth_words = set(ground_truth.lower().split())
207
+
208
+ if not truth_words:
209
+ return 0.5
210
+
211
+ overlap = len(completion_words & truth_words)
212
+ similarity = overlap / len(truth_words)
213
+
214
+ return 1.0 if similarity >= similarity_threshold else similarity / similarity_threshold