eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reward functions for evaluating repetition in model responses.
|
|
3
|
+
|
|
4
|
+
This module provides reward functions that penalize repetitive text in model responses,
|
|
5
|
+
encouraging more diverse and information-rich outputs.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
10
|
+
|
|
11
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
12
|
+
from ..typed_interface import reward_function
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_ngrams(text: str, n: int, language: str = "en") -> Tuple[List[Tuple[str, ...]], int]:
|
|
16
|
+
"""
|
|
17
|
+
Extract n-grams from text based on language.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
text: The text to extract n-grams from
|
|
21
|
+
n: Size of the n-grams
|
|
22
|
+
language: Language of the text (affects tokenization)
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tuple of (list of n-grams, total n-gram count)
|
|
26
|
+
"""
|
|
27
|
+
if language == "en":
|
|
28
|
+
words = text.lower().split()
|
|
29
|
+
elif language == "zh":
|
|
30
|
+
try:
|
|
31
|
+
import jieba
|
|
32
|
+
|
|
33
|
+
words = list(jieba.cut(text))
|
|
34
|
+
except ImportError:
|
|
35
|
+
words = list(text)
|
|
36
|
+
else:
|
|
37
|
+
words = text.lower().split()
|
|
38
|
+
|
|
39
|
+
ngrams = []
|
|
40
|
+
for i in range(len(words) - n + 1):
|
|
41
|
+
ngrams.append(tuple(words[i : i + n]))
|
|
42
|
+
|
|
43
|
+
return ngrams, len(ngrams)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@reward_function
|
|
47
|
+
def repetition_penalty_reward(
|
|
48
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
49
|
+
ground_truth: Optional[
|
|
50
|
+
Union[List[Message], List[Dict[str, Any]]]
|
|
51
|
+
] = None, # Not used by this function but part of standard signature
|
|
52
|
+
ngram_size: int = 3,
|
|
53
|
+
max_penalty: float = 0.5,
|
|
54
|
+
language: str = "en",
|
|
55
|
+
**kwargs: Any,
|
|
56
|
+
) -> EvaluateResult:
|
|
57
|
+
"""
|
|
58
|
+
Reward function that penalizes repetitive text in model responses.
|
|
59
|
+
The model's response is assumed to be the last message in the `messages` list.
|
|
60
|
+
|
|
61
|
+
This function computes repetition by examining unique n-grams in the response
|
|
62
|
+
and penalizes texts with a high proportion of repeated phrases.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
messages: List of conversation messages, where `messages[-1]` is the model's response.
|
|
66
|
+
ground_truth: Optional. Expected assistant response trajectory. Not directly used by this reward.
|
|
67
|
+
ngram_size: Size of n-grams to check for repetition.
|
|
68
|
+
max_penalty: Maximum penalty to apply for repetitive text.
|
|
69
|
+
language: Language of the text (affects tokenization).
|
|
70
|
+
**kwargs: Additional arguments.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
EvaluateResult with score penalizing repetition
|
|
74
|
+
"""
|
|
75
|
+
if not messages or len(messages) == 0:
|
|
76
|
+
return EvaluateResult(
|
|
77
|
+
score=0.0,
|
|
78
|
+
reason="No messages provided",
|
|
79
|
+
metrics={"repetition": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
response = messages[-1]
|
|
83
|
+
|
|
84
|
+
if isinstance(response, Message):
|
|
85
|
+
if response.role != "assistant":
|
|
86
|
+
return EvaluateResult(
|
|
87
|
+
score=0.0,
|
|
88
|
+
reason="No assistant response found",
|
|
89
|
+
metrics={
|
|
90
|
+
"repetition": MetricResult(
|
|
91
|
+
score=0.0,
|
|
92
|
+
is_score_valid=False,
|
|
93
|
+
reason="Message not from assistant",
|
|
94
|
+
)
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
text = response.content or ""
|
|
98
|
+
elif isinstance(response, dict):
|
|
99
|
+
if response.get("role") != "assistant":
|
|
100
|
+
return EvaluateResult(
|
|
101
|
+
score=0.0,
|
|
102
|
+
reason="No assistant response found",
|
|
103
|
+
metrics={
|
|
104
|
+
"repetition": MetricResult(
|
|
105
|
+
score=0.0,
|
|
106
|
+
is_score_valid=False,
|
|
107
|
+
reason="Message not from assistant",
|
|
108
|
+
)
|
|
109
|
+
},
|
|
110
|
+
)
|
|
111
|
+
text = response.get("content", "")
|
|
112
|
+
else:
|
|
113
|
+
return EvaluateResult(
|
|
114
|
+
score=0.0,
|
|
115
|
+
reason="Last message is of unexpected type.",
|
|
116
|
+
metrics={
|
|
117
|
+
"repetition": MetricResult(
|
|
118
|
+
score=0.0,
|
|
119
|
+
is_score_valid=False,
|
|
120
|
+
reason="Invalid message type in messages.",
|
|
121
|
+
)
|
|
122
|
+
},
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if not text.strip():
|
|
126
|
+
return EvaluateResult(
|
|
127
|
+
score=1.0,
|
|
128
|
+
reason="Empty response, no repetition to penalize",
|
|
129
|
+
metrics={
|
|
130
|
+
"repetition": MetricResult(
|
|
131
|
+
score=1.0,
|
|
132
|
+
is_score_valid=True,
|
|
133
|
+
reason="Empty response",
|
|
134
|
+
),
|
|
135
|
+
"unique_ngram_ratio": MetricResult(
|
|
136
|
+
score=1.0,
|
|
137
|
+
is_score_valid=True,
|
|
138
|
+
reason="Empty response",
|
|
139
|
+
),
|
|
140
|
+
"repetition_penalty": MetricResult(
|
|
141
|
+
score=1.0,
|
|
142
|
+
is_score_valid=True,
|
|
143
|
+
reason="No penalty applied to empty response",
|
|
144
|
+
),
|
|
145
|
+
},
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
ngrams, total = get_ngrams(text, ngram_size, language)
|
|
149
|
+
|
|
150
|
+
if total < 1:
|
|
151
|
+
return EvaluateResult(
|
|
152
|
+
score=1.0,
|
|
153
|
+
reason=f"Text too short for {ngram_size}-gram analysis",
|
|
154
|
+
metrics={
|
|
155
|
+
"repetition": MetricResult(
|
|
156
|
+
score=1.0,
|
|
157
|
+
is_score_valid=True,
|
|
158
|
+
reason=f"Text too short for {ngram_size}-gram analysis",
|
|
159
|
+
)
|
|
160
|
+
},
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
unique_ngrams = len(set(ngrams))
|
|
164
|
+
repetition_ratio = 1.0 - (unique_ngrams / total)
|
|
165
|
+
penalty = repetition_ratio * max_penalty
|
|
166
|
+
score = max(0.0, 1.0 - penalty)
|
|
167
|
+
success = repetition_ratio < 0.2
|
|
168
|
+
|
|
169
|
+
reason = f"Repetition ratio: {repetition_ratio:.2f}, Unique {ngram_size}-grams: {unique_ngrams}/{total}"
|
|
170
|
+
metrics = {
|
|
171
|
+
"repetition": MetricResult(score=score, is_score_valid=success, reason=reason),
|
|
172
|
+
"unique_ngram_ratio": MetricResult(
|
|
173
|
+
score=1.0 - repetition_ratio,
|
|
174
|
+
is_score_valid=success,
|
|
175
|
+
reason=f"Unique {ngram_size}-gram ratio: {1.0 - repetition_ratio:.2f}",
|
|
176
|
+
),
|
|
177
|
+
"repetition_penalty": MetricResult(
|
|
178
|
+
score=1.0 - penalty,
|
|
179
|
+
is_score_valid=success,
|
|
180
|
+
reason=f"Applied repetition penalty: {penalty:.2f}",
|
|
181
|
+
),
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
return EvaluateResult(score=score, reason=reason, metrics=metrics, is_score_valid=score > 0.0)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@reward_function
|
|
188
|
+
def diversity_reward(
|
|
189
|
+
messages: Union[List[Message], List[Dict[str, Any]]],
|
|
190
|
+
ground_truth: Optional[
|
|
191
|
+
Union[List[Message], List[Dict[str, Any]]]
|
|
192
|
+
] = None, # Not used by this function but part of standard signature
|
|
193
|
+
ngram_sizes: List[int] = [1, 2, 3],
|
|
194
|
+
weights: Optional[List[float]] = None,
|
|
195
|
+
language: str = "en",
|
|
196
|
+
**kwargs: Any,
|
|
197
|
+
) -> EvaluateResult:
|
|
198
|
+
"""
|
|
199
|
+
Reward function that measures lexical diversity in model responses.
|
|
200
|
+
The model's response is assumed to be the last message in the `messages` list.
|
|
201
|
+
|
|
202
|
+
This function computes diversity across multiple n-gram sizes and combines them
|
|
203
|
+
into a weighted score to encourage varied vocabulary and phrasing.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
messages: List of conversation messages, where `messages[-1]` is the model's response.
|
|
207
|
+
ground_truth: Optional. Expected assistant response trajectory. Not directly used by this reward.
|
|
208
|
+
ngram_sizes: List of n-gram sizes to evaluate.
|
|
209
|
+
weights: Optional list of weights for each n-gram size (normalized if provided).
|
|
210
|
+
language: Language of the text (affects tokenization).
|
|
211
|
+
**kwargs: Additional arguments.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
EvaluateResult with score based on lexical diversity
|
|
215
|
+
"""
|
|
216
|
+
if not messages or len(messages) == 0:
|
|
217
|
+
return EvaluateResult(
|
|
218
|
+
score=0.0,
|
|
219
|
+
reason="No messages provided",
|
|
220
|
+
metrics={"diversity": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
response = messages[-1]
|
|
224
|
+
|
|
225
|
+
if isinstance(response, Message):
|
|
226
|
+
if response.role != "assistant":
|
|
227
|
+
return EvaluateResult(
|
|
228
|
+
score=0.0,
|
|
229
|
+
reason="No assistant response found",
|
|
230
|
+
metrics={
|
|
231
|
+
"diversity": MetricResult(
|
|
232
|
+
score=0.0,
|
|
233
|
+
is_score_valid=False,
|
|
234
|
+
reason="Message not from assistant",
|
|
235
|
+
)
|
|
236
|
+
},
|
|
237
|
+
)
|
|
238
|
+
text = response.content or ""
|
|
239
|
+
elif isinstance(response, dict):
|
|
240
|
+
if response.get("role") != "assistant":
|
|
241
|
+
return EvaluateResult(
|
|
242
|
+
score=0.0,
|
|
243
|
+
reason="No assistant response found",
|
|
244
|
+
metrics={
|
|
245
|
+
"diversity": MetricResult(
|
|
246
|
+
score=0.0,
|
|
247
|
+
is_score_valid=False,
|
|
248
|
+
reason="Message not from assistant",
|
|
249
|
+
)
|
|
250
|
+
},
|
|
251
|
+
)
|
|
252
|
+
text = response.get("content", "")
|
|
253
|
+
else:
|
|
254
|
+
return EvaluateResult(
|
|
255
|
+
score=0.0,
|
|
256
|
+
reason="Last message is of unexpected type.",
|
|
257
|
+
metrics={
|
|
258
|
+
"diversity": MetricResult(
|
|
259
|
+
score=0.0,
|
|
260
|
+
is_score_valid=False,
|
|
261
|
+
reason="Invalid message type in messages.",
|
|
262
|
+
)
|
|
263
|
+
},
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if not text.strip():
|
|
267
|
+
return EvaluateResult(
|
|
268
|
+
score=0.0,
|
|
269
|
+
reason="Empty response",
|
|
270
|
+
metrics={
|
|
271
|
+
"diversity": MetricResult(
|
|
272
|
+
score=0.0,
|
|
273
|
+
is_score_valid=False,
|
|
274
|
+
reason="Empty response",
|
|
275
|
+
)
|
|
276
|
+
},
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if weights is None:
|
|
280
|
+
weights = [0.2, 0.3, 0.5][: len(ngram_sizes)]
|
|
281
|
+
|
|
282
|
+
if len(weights) != len(ngram_sizes):
|
|
283
|
+
if len(weights) > len(ngram_sizes):
|
|
284
|
+
weights = weights[: len(ngram_sizes)]
|
|
285
|
+
else:
|
|
286
|
+
missing_weight = (
|
|
287
|
+
(1.0 - sum(weights)) / (len(ngram_sizes) - len(weights))
|
|
288
|
+
if (len(ngram_sizes) - len(weights)) > 0
|
|
289
|
+
else 0
|
|
290
|
+
)
|
|
291
|
+
weights.extend([missing_weight] * (len(ngram_sizes) - len(weights)))
|
|
292
|
+
|
|
293
|
+
total_weight = sum(weights)
|
|
294
|
+
if total_weight != 1.0 and total_weight > 0: # Avoid division by zero if total_weight is 0
|
|
295
|
+
weights = [w / total_weight for w in weights]
|
|
296
|
+
elif total_weight == 0 and len(weights) > 0: # If all weights are zero, distribute equally
|
|
297
|
+
weights = [1.0 / len(weights)] * len(weights)
|
|
298
|
+
|
|
299
|
+
diversity_scores = {}
|
|
300
|
+
ratios = {}
|
|
301
|
+
|
|
302
|
+
for size, weight in zip(ngram_sizes, weights):
|
|
303
|
+
ngrams, total = get_ngrams(text, size, language)
|
|
304
|
+
|
|
305
|
+
if total < 1:
|
|
306
|
+
diversity_scores[f"ngram_{size}"] = 1.0
|
|
307
|
+
ratios[f"ngram_{size}"] = 1.0
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
unique_count = len(set(ngrams))
|
|
311
|
+
ratio = unique_count / total
|
|
312
|
+
diversity_scores[f"ngram_{size}"] = ratio * weight
|
|
313
|
+
ratios[f"ngram_{size}"] = ratio
|
|
314
|
+
|
|
315
|
+
final_score = sum(diversity_scores.values())
|
|
316
|
+
success = final_score > 0.6
|
|
317
|
+
|
|
318
|
+
size_metric_items: List[Tuple[str, MetricResult]] = []
|
|
319
|
+
for size_key, ratio_val in ratios.items():
|
|
320
|
+
metric_for_size = MetricResult(
|
|
321
|
+
score=ratio_val,
|
|
322
|
+
is_score_valid=ratio_val > 0.7,
|
|
323
|
+
reason=f"Diversity ratio for {size_key}: {ratio_val:.2f}",
|
|
324
|
+
)
|
|
325
|
+
size_metric_items.append((size_key, metric_for_size))
|
|
326
|
+
|
|
327
|
+
size_metrics: Dict[str, MetricResult] = dict(size_metric_items)
|
|
328
|
+
|
|
329
|
+
metrics: Dict[str, MetricResult] = {
|
|
330
|
+
"diversity": MetricResult(
|
|
331
|
+
score=final_score,
|
|
332
|
+
is_score_valid=success,
|
|
333
|
+
reason=f"Overall weighted diversity score: {final_score:.2f}",
|
|
334
|
+
),
|
|
335
|
+
**size_metrics,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
return EvaluateResult(
|
|
339
|
+
score=final_score,
|
|
340
|
+
reason=f"Lexical diversity score: {final_score:.2f}",
|
|
341
|
+
metrics=metrics,
|
|
342
|
+
)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reward functions for counting tags in text responses.
|
|
3
|
+
|
|
4
|
+
This module provides reward functions that evaluate if responses contain
|
|
5
|
+
specified XML/HTML-like tags in correct quantities.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from typing import Any, Dict, List, Set, Union
|
|
10
|
+
|
|
11
|
+
from ..models import EvaluateResult, Message, MetricResult
|
|
12
|
+
from ..typed_interface import reward_function
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@reward_function # type: ignore[arg-type]
|
|
16
|
+
def tag_count_reward(
|
|
17
|
+
messages: List[Message],
|
|
18
|
+
*, # Make subsequent parameters keyword-only
|
|
19
|
+
required_tags: List[str],
|
|
20
|
+
score_per_tag: float = 0.25,
|
|
21
|
+
require_balanced: bool = True,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
) -> EvaluateResult:
|
|
24
|
+
"""
|
|
25
|
+
Reward function that checks for presence of specific tags in response.
|
|
26
|
+
|
|
27
|
+
For each tag found in required_tags, adds score_per_tag to the score.
|
|
28
|
+
Optionally requires tags to be balanced (equal opening and closing tags).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
messages: List of conversation messages
|
|
32
|
+
required_tags: List of tag names to check for (without < > brackets)
|
|
33
|
+
score_per_tag: Score to award per correctly found tag (default: 0.25)
|
|
34
|
+
require_balanced: If True, requires equal opening and closing tags
|
|
35
|
+
**kwargs: Additional arguments
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
EvaluateResult with score based on tags found
|
|
39
|
+
"""
|
|
40
|
+
if not messages or len(messages) == 0:
|
|
41
|
+
return EvaluateResult(
|
|
42
|
+
score=0.0,
|
|
43
|
+
reason="No messages provided",
|
|
44
|
+
metrics={"tag_count": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
response = messages[-1]
|
|
48
|
+
|
|
49
|
+
if response.role != "assistant" or not response.content:
|
|
50
|
+
return EvaluateResult(
|
|
51
|
+
score=0.0,
|
|
52
|
+
reason="No assistant response found or response has no content",
|
|
53
|
+
metrics={
|
|
54
|
+
"tag_count": MetricResult(
|
|
55
|
+
score=0.0,
|
|
56
|
+
is_score_valid=False,
|
|
57
|
+
reason="Message not from assistant or has no content",
|
|
58
|
+
)
|
|
59
|
+
},
|
|
60
|
+
)
|
|
61
|
+
text: str = response.content
|
|
62
|
+
|
|
63
|
+
tag_metrics = {}
|
|
64
|
+
found_tags: Set[str] = set()
|
|
65
|
+
mismatched_tags: Set[str] = set()
|
|
66
|
+
total_found = 0
|
|
67
|
+
|
|
68
|
+
for tag in required_tags:
|
|
69
|
+
opening_pattern = f"<{tag}[^>]*>"
|
|
70
|
+
closing_pattern = f"</{tag}>"
|
|
71
|
+
|
|
72
|
+
opening_count = len(re.findall(opening_pattern, text))
|
|
73
|
+
closing_count = len(re.findall(closing_pattern, text))
|
|
74
|
+
|
|
75
|
+
if require_balanced:
|
|
76
|
+
is_found = opening_count > 0 and closing_count > 0 and opening_count == closing_count
|
|
77
|
+
else:
|
|
78
|
+
is_found = opening_count > 0 or closing_count > 0
|
|
79
|
+
|
|
80
|
+
is_balanced = opening_count == closing_count
|
|
81
|
+
|
|
82
|
+
if is_found:
|
|
83
|
+
found_tags.add(tag)
|
|
84
|
+
total_found += 1
|
|
85
|
+
|
|
86
|
+
if require_balanced and not is_balanced and (opening_count > 0 or closing_count > 0):
|
|
87
|
+
mismatched_tags.add(tag)
|
|
88
|
+
|
|
89
|
+
if require_balanced:
|
|
90
|
+
tag_score = 1.0 if (opening_count > 0 and closing_count > 0 and is_balanced) else 0.0
|
|
91
|
+
tag_success = opening_count > 0 and closing_count > 0 and is_balanced
|
|
92
|
+
else:
|
|
93
|
+
has_tags = opening_count > 0 or closing_count > 0
|
|
94
|
+
tag_score = 1.0 if has_tags else 0.0
|
|
95
|
+
tag_success = opening_count > 0 or closing_count > 0
|
|
96
|
+
|
|
97
|
+
tag_metrics[f"tag_{tag}"] = MetricResult(
|
|
98
|
+
score=tag_score,
|
|
99
|
+
is_score_valid=tag_success,
|
|
100
|
+
reason=_get_tag_reason(tag, opening_count, closing_count, require_balanced),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
total_score = min(1.0, len(found_tags) * score_per_tag)
|
|
104
|
+
|
|
105
|
+
if require_balanced and mismatched_tags:
|
|
106
|
+
penalty = len(mismatched_tags) * score_per_tag
|
|
107
|
+
total_score = max(0.0, total_score - penalty)
|
|
108
|
+
|
|
109
|
+
success = len(found_tags) == len(required_tags) and (not require_balanced or not mismatched_tags)
|
|
110
|
+
|
|
111
|
+
reason = _get_overall_reason(required_tags, found_tags, mismatched_tags, require_balanced)
|
|
112
|
+
tag_metrics["overall"] = MetricResult(score=total_score, is_score_valid=success, reason=reason)
|
|
113
|
+
|
|
114
|
+
return EvaluateResult(score=total_score, reason=reason, metrics=tag_metrics, is_score_valid=success)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _get_tag_reason(tag: str, opening_count: int, closing_count: int, require_balanced: bool) -> str:
|
|
118
|
+
"""Generate a descriptive reason for a tag's evaluation."""
|
|
119
|
+
if opening_count == 0 and closing_count == 0:
|
|
120
|
+
return f"Tag '{tag}' not found in response"
|
|
121
|
+
elif opening_count > 0 and closing_count == 0:
|
|
122
|
+
return f"Found {opening_count} opening <{tag}> tag(s) but no closing"
|
|
123
|
+
elif opening_count == 0 and closing_count > 0:
|
|
124
|
+
return f"Found {closing_count} closing </{tag}> tag(s) but no opening"
|
|
125
|
+
elif opening_count == closing_count:
|
|
126
|
+
return f"Found {opening_count} balanced '{tag}' tag(s)"
|
|
127
|
+
else:
|
|
128
|
+
if require_balanced:
|
|
129
|
+
return f"Unbalanced tags: {opening_count} opening vs " f"{closing_count} closing '{tag}' tags"
|
|
130
|
+
else:
|
|
131
|
+
return f"Found '{tag}' tags (unbalanced: {opening_count} opening, " f"{closing_count} closing)"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _get_overall_reason(
|
|
135
|
+
required_tags: List[str],
|
|
136
|
+
found_tags: Set[str],
|
|
137
|
+
mismatched_tags: Set[str],
|
|
138
|
+
require_balanced: bool,
|
|
139
|
+
) -> str:
|
|
140
|
+
"""Generate an overall reason for the evaluation."""
|
|
141
|
+
if not found_tags:
|
|
142
|
+
return "No required tags found in response"
|
|
143
|
+
|
|
144
|
+
missing_tags = set(required_tags) - found_tags
|
|
145
|
+
|
|
146
|
+
if not missing_tags and not mismatched_tags:
|
|
147
|
+
return f"All {len(required_tags)} required tags found and balanced"
|
|
148
|
+
|
|
149
|
+
reason_parts = []
|
|
150
|
+
|
|
151
|
+
if found_tags:
|
|
152
|
+
reason_parts.append(f"Found {len(found_tags)}/{len(required_tags)} required tags")
|
|
153
|
+
|
|
154
|
+
if missing_tags:
|
|
155
|
+
tags_str = ", ".join([f"'{tag}'" for tag in missing_tags])
|
|
156
|
+
reason_parts.append(f"Missing tags: {tags_str}")
|
|
157
|
+
|
|
158
|
+
if require_balanced and mismatched_tags:
|
|
159
|
+
tags_str = ", ".join([f"'{tag}'" for tag in mismatched_tags])
|
|
160
|
+
reason_parts.append(f"Unbalanced tags: {tags_str}")
|
|
161
|
+
|
|
162
|
+
return ". ".join(reason_parts)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from eval_protocol.agent.models import StepData # Internal StepData model
|
|
4
|
+
|
|
5
|
+
# Assuming models are structured as planned
|
|
6
|
+
from eval_protocol.models import EvaluateResult # Extended EvaluateResult
|
|
7
|
+
|
|
8
|
+
# Placeholder for actual Message type if needed for type hinting complex observation_data
|
|
9
|
+
# from eval_protocol.models import Message
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RLDataAligner:
|
|
13
|
+
"""
|
|
14
|
+
Component responsible for aligning outputs from user reward functions
|
|
15
|
+
(EvaluateResult containing scores and/or base_rewards per step)
|
|
16
|
+
with the system's internal StepData representation. This prepares
|
|
17
|
+
the data for subsequent GiGPO (or other RL algorithm) advantage calculations.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def align_data_for_rl_processing(
|
|
21
|
+
self,
|
|
22
|
+
current_eval_result: EvaluateResult,
|
|
23
|
+
current_step_data_list: List[StepData],
|
|
24
|
+
rollout_id: str, # For logging or if needed
|
|
25
|
+
) -> List[StepData]:
|
|
26
|
+
"""
|
|
27
|
+
Aligns the EvaluateResult (from user's reward function) with the
|
|
28
|
+
internal list of StepData for a single rollout.
|
|
29
|
+
|
|
30
|
+
Populates `StepData.base_reward` from `EvaluateResult.step_outputs.base_reward`.
|
|
31
|
+
Associates `EvaluateResult.score` with the rollout for GiGPO A_E calculation.
|
|
32
|
+
(Association of final_score might happen by returning it alongside, or
|
|
33
|
+
by the caller managing it). For now, this function focuses on base_reward.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
current_eval_result: The EvaluateResult from the user's reward function for this rollout.
|
|
37
|
+
current_step_data_list: The list of StepData objects collected by RLRolloutWorker.
|
|
38
|
+
rollout_id: Identifier for the current rollout.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The list of StepData objects, with `base_reward` populated.
|
|
42
|
+
The `final_score` from current_eval_result should be handled by the caller
|
|
43
|
+
for GiGPO A_E calculation.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# Store final_score (for GiGPO A_E) - The caller will handle this.
|
|
47
|
+
# This function's primary job is to populate base_rewards in StepData.
|
|
48
|
+
|
|
49
|
+
if current_eval_result.step_outputs:
|
|
50
|
+
# Create a dictionary for quick lookup of user-defined step rewards
|
|
51
|
+
user_step_rewards_map: Dict[Union[int, str], float] = {
|
|
52
|
+
step_out.step_index: step_out.base_reward for step_out in current_eval_result.step_outputs
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
for s_data in current_step_data_list:
|
|
56
|
+
# --- Critical Mapping Logic ---
|
|
57
|
+
# Strategy: Use 'assistant_turn_index' stored in StepData.step_info
|
|
58
|
+
# by RLRolloutWorker. User's StepOutput.step_index should match this.
|
|
59
|
+
# This assumes RLRolloutWorker adds this info.
|
|
60
|
+
user_defined_step_idx = s_data.step_info.get("assistant_turn_index")
|
|
61
|
+
|
|
62
|
+
if user_defined_step_idx is not None:
|
|
63
|
+
if user_defined_step_idx in user_step_rewards_map:
|
|
64
|
+
s_data.base_reward = user_step_rewards_map[user_defined_step_idx]
|
|
65
|
+
else:
|
|
66
|
+
# No base reward provided by user for this specific system step.
|
|
67
|
+
# s_data.base_reward remains None (or could be a default).
|
|
68
|
+
pass
|
|
69
|
+
else:
|
|
70
|
+
# RLRolloutWorker did not provide 'assistant_turn_index' for this StepData,
|
|
71
|
+
# or the mapping key in step_info is different.
|
|
72
|
+
# This indicates a potential issue in RLRolloutWorker or mapping strategy.
|
|
73
|
+
pass
|
|
74
|
+
else:
|
|
75
|
+
# No step_outputs provided by the user. Base rewards will remain None.
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
return current_step_data_list
|
|
79
|
+
|
|
80
|
+
# TODO (Future): Consider a batch version if performance becomes an issue,
|
|
81
|
+
# but the core logic per rollout remains the same.
|
|
82
|
+
# def align_batch_data_for_rl_processing(...)
|