eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,342 @@
1
+ """
2
+ Reward functions for evaluating repetition in model responses.
3
+
4
+ This module provides reward functions that penalize repetitive text in model responses,
5
+ encouraging more diverse and information-rich outputs.
6
+ """
7
+
8
+ import re
9
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
10
+
11
+ from ..models import EvaluateResult, Message, MetricResult
12
+ from ..typed_interface import reward_function
13
+
14
+
15
+ def get_ngrams(text: str, n: int, language: str = "en") -> Tuple[List[Tuple[str, ...]], int]:
16
+ """
17
+ Extract n-grams from text based on language.
18
+
19
+ Args:
20
+ text: The text to extract n-grams from
21
+ n: Size of the n-grams
22
+ language: Language of the text (affects tokenization)
23
+
24
+ Returns:
25
+ Tuple of (list of n-grams, total n-gram count)
26
+ """
27
+ if language == "en":
28
+ words = text.lower().split()
29
+ elif language == "zh":
30
+ try:
31
+ import jieba
32
+
33
+ words = list(jieba.cut(text))
34
+ except ImportError:
35
+ words = list(text)
36
+ else:
37
+ words = text.lower().split()
38
+
39
+ ngrams = []
40
+ for i in range(len(words) - n + 1):
41
+ ngrams.append(tuple(words[i : i + n]))
42
+
43
+ return ngrams, len(ngrams)
44
+
45
+
46
+ @reward_function
47
+ def repetition_penalty_reward(
48
+ messages: Union[List[Message], List[Dict[str, Any]]],
49
+ ground_truth: Optional[
50
+ Union[List[Message], List[Dict[str, Any]]]
51
+ ] = None, # Not used by this function but part of standard signature
52
+ ngram_size: int = 3,
53
+ max_penalty: float = 0.5,
54
+ language: str = "en",
55
+ **kwargs: Any,
56
+ ) -> EvaluateResult:
57
+ """
58
+ Reward function that penalizes repetitive text in model responses.
59
+ The model's response is assumed to be the last message in the `messages` list.
60
+
61
+ This function computes repetition by examining unique n-grams in the response
62
+ and penalizes texts with a high proportion of repeated phrases.
63
+
64
+ Args:
65
+ messages: List of conversation messages, where `messages[-1]` is the model's response.
66
+ ground_truth: Optional. Expected assistant response trajectory. Not directly used by this reward.
67
+ ngram_size: Size of n-grams to check for repetition.
68
+ max_penalty: Maximum penalty to apply for repetitive text.
69
+ language: Language of the text (affects tokenization).
70
+ **kwargs: Additional arguments.
71
+
72
+ Returns:
73
+ EvaluateResult with score penalizing repetition
74
+ """
75
+ if not messages or len(messages) == 0:
76
+ return EvaluateResult(
77
+ score=0.0,
78
+ reason="No messages provided",
79
+ metrics={"repetition": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
80
+ )
81
+
82
+ response = messages[-1]
83
+
84
+ if isinstance(response, Message):
85
+ if response.role != "assistant":
86
+ return EvaluateResult(
87
+ score=0.0,
88
+ reason="No assistant response found",
89
+ metrics={
90
+ "repetition": MetricResult(
91
+ score=0.0,
92
+ is_score_valid=False,
93
+ reason="Message not from assistant",
94
+ )
95
+ },
96
+ )
97
+ text = response.content or ""
98
+ elif isinstance(response, dict):
99
+ if response.get("role") != "assistant":
100
+ return EvaluateResult(
101
+ score=0.0,
102
+ reason="No assistant response found",
103
+ metrics={
104
+ "repetition": MetricResult(
105
+ score=0.0,
106
+ is_score_valid=False,
107
+ reason="Message not from assistant",
108
+ )
109
+ },
110
+ )
111
+ text = response.get("content", "")
112
+ else:
113
+ return EvaluateResult(
114
+ score=0.0,
115
+ reason="Last message is of unexpected type.",
116
+ metrics={
117
+ "repetition": MetricResult(
118
+ score=0.0,
119
+ is_score_valid=False,
120
+ reason="Invalid message type in messages.",
121
+ )
122
+ },
123
+ )
124
+
125
+ if not text.strip():
126
+ return EvaluateResult(
127
+ score=1.0,
128
+ reason="Empty response, no repetition to penalize",
129
+ metrics={
130
+ "repetition": MetricResult(
131
+ score=1.0,
132
+ is_score_valid=True,
133
+ reason="Empty response",
134
+ ),
135
+ "unique_ngram_ratio": MetricResult(
136
+ score=1.0,
137
+ is_score_valid=True,
138
+ reason="Empty response",
139
+ ),
140
+ "repetition_penalty": MetricResult(
141
+ score=1.0,
142
+ is_score_valid=True,
143
+ reason="No penalty applied to empty response",
144
+ ),
145
+ },
146
+ )
147
+
148
+ ngrams, total = get_ngrams(text, ngram_size, language)
149
+
150
+ if total < 1:
151
+ return EvaluateResult(
152
+ score=1.0,
153
+ reason=f"Text too short for {ngram_size}-gram analysis",
154
+ metrics={
155
+ "repetition": MetricResult(
156
+ score=1.0,
157
+ is_score_valid=True,
158
+ reason=f"Text too short for {ngram_size}-gram analysis",
159
+ )
160
+ },
161
+ )
162
+
163
+ unique_ngrams = len(set(ngrams))
164
+ repetition_ratio = 1.0 - (unique_ngrams / total)
165
+ penalty = repetition_ratio * max_penalty
166
+ score = max(0.0, 1.0 - penalty)
167
+ success = repetition_ratio < 0.2
168
+
169
+ reason = f"Repetition ratio: {repetition_ratio:.2f}, Unique {ngram_size}-grams: {unique_ngrams}/{total}"
170
+ metrics = {
171
+ "repetition": MetricResult(score=score, is_score_valid=success, reason=reason),
172
+ "unique_ngram_ratio": MetricResult(
173
+ score=1.0 - repetition_ratio,
174
+ is_score_valid=success,
175
+ reason=f"Unique {ngram_size}-gram ratio: {1.0 - repetition_ratio:.2f}",
176
+ ),
177
+ "repetition_penalty": MetricResult(
178
+ score=1.0 - penalty,
179
+ is_score_valid=success,
180
+ reason=f"Applied repetition penalty: {penalty:.2f}",
181
+ ),
182
+ }
183
+
184
+ return EvaluateResult(score=score, reason=reason, metrics=metrics, is_score_valid=score > 0.0)
185
+
186
+
187
+ @reward_function
188
+ def diversity_reward(
189
+ messages: Union[List[Message], List[Dict[str, Any]]],
190
+ ground_truth: Optional[
191
+ Union[List[Message], List[Dict[str, Any]]]
192
+ ] = None, # Not used by this function but part of standard signature
193
+ ngram_sizes: List[int] = [1, 2, 3],
194
+ weights: Optional[List[float]] = None,
195
+ language: str = "en",
196
+ **kwargs: Any,
197
+ ) -> EvaluateResult:
198
+ """
199
+ Reward function that measures lexical diversity in model responses.
200
+ The model's response is assumed to be the last message in the `messages` list.
201
+
202
+ This function computes diversity across multiple n-gram sizes and combines them
203
+ into a weighted score to encourage varied vocabulary and phrasing.
204
+
205
+ Args:
206
+ messages: List of conversation messages, where `messages[-1]` is the model's response.
207
+ ground_truth: Optional. Expected assistant response trajectory. Not directly used by this reward.
208
+ ngram_sizes: List of n-gram sizes to evaluate.
209
+ weights: Optional list of weights for each n-gram size (normalized if provided).
210
+ language: Language of the text (affects tokenization).
211
+ **kwargs: Additional arguments.
212
+
213
+ Returns:
214
+ EvaluateResult with score based on lexical diversity
215
+ """
216
+ if not messages or len(messages) == 0:
217
+ return EvaluateResult(
218
+ score=0.0,
219
+ reason="No messages provided",
220
+ metrics={"diversity": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
221
+ )
222
+
223
+ response = messages[-1]
224
+
225
+ if isinstance(response, Message):
226
+ if response.role != "assistant":
227
+ return EvaluateResult(
228
+ score=0.0,
229
+ reason="No assistant response found",
230
+ metrics={
231
+ "diversity": MetricResult(
232
+ score=0.0,
233
+ is_score_valid=False,
234
+ reason="Message not from assistant",
235
+ )
236
+ },
237
+ )
238
+ text = response.content or ""
239
+ elif isinstance(response, dict):
240
+ if response.get("role") != "assistant":
241
+ return EvaluateResult(
242
+ score=0.0,
243
+ reason="No assistant response found",
244
+ metrics={
245
+ "diversity": MetricResult(
246
+ score=0.0,
247
+ is_score_valid=False,
248
+ reason="Message not from assistant",
249
+ )
250
+ },
251
+ )
252
+ text = response.get("content", "")
253
+ else:
254
+ return EvaluateResult(
255
+ score=0.0,
256
+ reason="Last message is of unexpected type.",
257
+ metrics={
258
+ "diversity": MetricResult(
259
+ score=0.0,
260
+ is_score_valid=False,
261
+ reason="Invalid message type in messages.",
262
+ )
263
+ },
264
+ )
265
+
266
+ if not text.strip():
267
+ return EvaluateResult(
268
+ score=0.0,
269
+ reason="Empty response",
270
+ metrics={
271
+ "diversity": MetricResult(
272
+ score=0.0,
273
+ is_score_valid=False,
274
+ reason="Empty response",
275
+ )
276
+ },
277
+ )
278
+
279
+ if weights is None:
280
+ weights = [0.2, 0.3, 0.5][: len(ngram_sizes)]
281
+
282
+ if len(weights) != len(ngram_sizes):
283
+ if len(weights) > len(ngram_sizes):
284
+ weights = weights[: len(ngram_sizes)]
285
+ else:
286
+ missing_weight = (
287
+ (1.0 - sum(weights)) / (len(ngram_sizes) - len(weights))
288
+ if (len(ngram_sizes) - len(weights)) > 0
289
+ else 0
290
+ )
291
+ weights.extend([missing_weight] * (len(ngram_sizes) - len(weights)))
292
+
293
+ total_weight = sum(weights)
294
+ if total_weight != 1.0 and total_weight > 0: # Avoid division by zero if total_weight is 0
295
+ weights = [w / total_weight for w in weights]
296
+ elif total_weight == 0 and len(weights) > 0: # If all weights are zero, distribute equally
297
+ weights = [1.0 / len(weights)] * len(weights)
298
+
299
+ diversity_scores = {}
300
+ ratios = {}
301
+
302
+ for size, weight in zip(ngram_sizes, weights):
303
+ ngrams, total = get_ngrams(text, size, language)
304
+
305
+ if total < 1:
306
+ diversity_scores[f"ngram_{size}"] = 1.0
307
+ ratios[f"ngram_{size}"] = 1.0
308
+ continue
309
+
310
+ unique_count = len(set(ngrams))
311
+ ratio = unique_count / total
312
+ diversity_scores[f"ngram_{size}"] = ratio * weight
313
+ ratios[f"ngram_{size}"] = ratio
314
+
315
+ final_score = sum(diversity_scores.values())
316
+ success = final_score > 0.6
317
+
318
+ size_metric_items: List[Tuple[str, MetricResult]] = []
319
+ for size_key, ratio_val in ratios.items():
320
+ metric_for_size = MetricResult(
321
+ score=ratio_val,
322
+ is_score_valid=ratio_val > 0.7,
323
+ reason=f"Diversity ratio for {size_key}: {ratio_val:.2f}",
324
+ )
325
+ size_metric_items.append((size_key, metric_for_size))
326
+
327
+ size_metrics: Dict[str, MetricResult] = dict(size_metric_items)
328
+
329
+ metrics: Dict[str, MetricResult] = {
330
+ "diversity": MetricResult(
331
+ score=final_score,
332
+ is_score_valid=success,
333
+ reason=f"Overall weighted diversity score: {final_score:.2f}",
334
+ ),
335
+ **size_metrics,
336
+ }
337
+
338
+ return EvaluateResult(
339
+ score=final_score,
340
+ reason=f"Lexical diversity score: {final_score:.2f}",
341
+ metrics=metrics,
342
+ )
@@ -0,0 +1,162 @@
1
+ """
2
+ Reward functions for counting tags in text responses.
3
+
4
+ This module provides reward functions that evaluate if responses contain
5
+ specified XML/HTML-like tags in correct quantities.
6
+ """
7
+
8
+ import re
9
+ from typing import Any, Dict, List, Set, Union
10
+
11
+ from ..models import EvaluateResult, Message, MetricResult
12
+ from ..typed_interface import reward_function
13
+
14
+
15
+ @reward_function # type: ignore[arg-type]
16
+ def tag_count_reward(
17
+ messages: List[Message],
18
+ *, # Make subsequent parameters keyword-only
19
+ required_tags: List[str],
20
+ score_per_tag: float = 0.25,
21
+ require_balanced: bool = True,
22
+ **kwargs: Any,
23
+ ) -> EvaluateResult:
24
+ """
25
+ Reward function that checks for presence of specific tags in response.
26
+
27
+ For each tag found in required_tags, adds score_per_tag to the score.
28
+ Optionally requires tags to be balanced (equal opening and closing tags).
29
+
30
+ Args:
31
+ messages: List of conversation messages
32
+ required_tags: List of tag names to check for (without < > brackets)
33
+ score_per_tag: Score to award per correctly found tag (default: 0.25)
34
+ require_balanced: If True, requires equal opening and closing tags
35
+ **kwargs: Additional arguments
36
+
37
+ Returns:
38
+ EvaluateResult with score based on tags found
39
+ """
40
+ if not messages or len(messages) == 0:
41
+ return EvaluateResult(
42
+ score=0.0,
43
+ reason="No messages provided",
44
+ metrics={"tag_count": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided")},
45
+ )
46
+
47
+ response = messages[-1]
48
+
49
+ if response.role != "assistant" or not response.content:
50
+ return EvaluateResult(
51
+ score=0.0,
52
+ reason="No assistant response found or response has no content",
53
+ metrics={
54
+ "tag_count": MetricResult(
55
+ score=0.0,
56
+ is_score_valid=False,
57
+ reason="Message not from assistant or has no content",
58
+ )
59
+ },
60
+ )
61
+ text: str = response.content
62
+
63
+ tag_metrics = {}
64
+ found_tags: Set[str] = set()
65
+ mismatched_tags: Set[str] = set()
66
+ total_found = 0
67
+
68
+ for tag in required_tags:
69
+ opening_pattern = f"<{tag}[^>]*>"
70
+ closing_pattern = f"</{tag}>"
71
+
72
+ opening_count = len(re.findall(opening_pattern, text))
73
+ closing_count = len(re.findall(closing_pattern, text))
74
+
75
+ if require_balanced:
76
+ is_found = opening_count > 0 and closing_count > 0 and opening_count == closing_count
77
+ else:
78
+ is_found = opening_count > 0 or closing_count > 0
79
+
80
+ is_balanced = opening_count == closing_count
81
+
82
+ if is_found:
83
+ found_tags.add(tag)
84
+ total_found += 1
85
+
86
+ if require_balanced and not is_balanced and (opening_count > 0 or closing_count > 0):
87
+ mismatched_tags.add(tag)
88
+
89
+ if require_balanced:
90
+ tag_score = 1.0 if (opening_count > 0 and closing_count > 0 and is_balanced) else 0.0
91
+ tag_success = opening_count > 0 and closing_count > 0 and is_balanced
92
+ else:
93
+ has_tags = opening_count > 0 or closing_count > 0
94
+ tag_score = 1.0 if has_tags else 0.0
95
+ tag_success = opening_count > 0 or closing_count > 0
96
+
97
+ tag_metrics[f"tag_{tag}"] = MetricResult(
98
+ score=tag_score,
99
+ is_score_valid=tag_success,
100
+ reason=_get_tag_reason(tag, opening_count, closing_count, require_balanced),
101
+ )
102
+
103
+ total_score = min(1.0, len(found_tags) * score_per_tag)
104
+
105
+ if require_balanced and mismatched_tags:
106
+ penalty = len(mismatched_tags) * score_per_tag
107
+ total_score = max(0.0, total_score - penalty)
108
+
109
+ success = len(found_tags) == len(required_tags) and (not require_balanced or not mismatched_tags)
110
+
111
+ reason = _get_overall_reason(required_tags, found_tags, mismatched_tags, require_balanced)
112
+ tag_metrics["overall"] = MetricResult(score=total_score, is_score_valid=success, reason=reason)
113
+
114
+ return EvaluateResult(score=total_score, reason=reason, metrics=tag_metrics, is_score_valid=success)
115
+
116
+
117
+ def _get_tag_reason(tag: str, opening_count: int, closing_count: int, require_balanced: bool) -> str:
118
+ """Generate a descriptive reason for a tag's evaluation."""
119
+ if opening_count == 0 and closing_count == 0:
120
+ return f"Tag '{tag}' not found in response"
121
+ elif opening_count > 0 and closing_count == 0:
122
+ return f"Found {opening_count} opening <{tag}> tag(s) but no closing"
123
+ elif opening_count == 0 and closing_count > 0:
124
+ return f"Found {closing_count} closing </{tag}> tag(s) but no opening"
125
+ elif opening_count == closing_count:
126
+ return f"Found {opening_count} balanced '{tag}' tag(s)"
127
+ else:
128
+ if require_balanced:
129
+ return f"Unbalanced tags: {opening_count} opening vs " f"{closing_count} closing '{tag}' tags"
130
+ else:
131
+ return f"Found '{tag}' tags (unbalanced: {opening_count} opening, " f"{closing_count} closing)"
132
+
133
+
134
+ def _get_overall_reason(
135
+ required_tags: List[str],
136
+ found_tags: Set[str],
137
+ mismatched_tags: Set[str],
138
+ require_balanced: bool,
139
+ ) -> str:
140
+ """Generate an overall reason for the evaluation."""
141
+ if not found_tags:
142
+ return "No required tags found in response"
143
+
144
+ missing_tags = set(required_tags) - found_tags
145
+
146
+ if not missing_tags and not mismatched_tags:
147
+ return f"All {len(required_tags)} required tags found and balanced"
148
+
149
+ reason_parts = []
150
+
151
+ if found_tags:
152
+ reason_parts.append(f"Found {len(found_tags)}/{len(required_tags)} required tags")
153
+
154
+ if missing_tags:
155
+ tags_str = ", ".join([f"'{tag}'" for tag in missing_tags])
156
+ reason_parts.append(f"Missing tags: {tags_str}")
157
+
158
+ if require_balanced and mismatched_tags:
159
+ tags_str = ", ".join([f"'{tag}'" for tag in mismatched_tags])
160
+ reason_parts.append(f"Unbalanced tags: {tags_str}")
161
+
162
+ return ". ".join(reason_parts)
@@ -0,0 +1,82 @@
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ from eval_protocol.agent.models import StepData # Internal StepData model
4
+
5
+ # Assuming models are structured as planned
6
+ from eval_protocol.models import EvaluateResult # Extended EvaluateResult
7
+
8
+ # Placeholder for actual Message type if needed for type hinting complex observation_data
9
+ # from eval_protocol.models import Message
10
+
11
+
12
+ class RLDataAligner:
13
+ """
14
+ Component responsible for aligning outputs from user reward functions
15
+ (EvaluateResult containing scores and/or base_rewards per step)
16
+ with the system's internal StepData representation. This prepares
17
+ the data for subsequent GiGPO (or other RL algorithm) advantage calculations.
18
+ """
19
+
20
+ def align_data_for_rl_processing(
21
+ self,
22
+ current_eval_result: EvaluateResult,
23
+ current_step_data_list: List[StepData],
24
+ rollout_id: str, # For logging or if needed
25
+ ) -> List[StepData]:
26
+ """
27
+ Aligns the EvaluateResult (from user's reward function) with the
28
+ internal list of StepData for a single rollout.
29
+
30
+ Populates `StepData.base_reward` from `EvaluateResult.step_outputs.base_reward`.
31
+ Associates `EvaluateResult.score` with the rollout for GiGPO A_E calculation.
32
+ (Association of final_score might happen by returning it alongside, or
33
+ by the caller managing it). For now, this function focuses on base_reward.
34
+
35
+ Args:
36
+ current_eval_result: The EvaluateResult from the user's reward function for this rollout.
37
+ current_step_data_list: The list of StepData objects collected by RLRolloutWorker.
38
+ rollout_id: Identifier for the current rollout.
39
+
40
+ Returns:
41
+ The list of StepData objects, with `base_reward` populated.
42
+ The `final_score` from current_eval_result should be handled by the caller
43
+ for GiGPO A_E calculation.
44
+ """
45
+
46
+ # Store final_score (for GiGPO A_E) - The caller will handle this.
47
+ # This function's primary job is to populate base_rewards in StepData.
48
+
49
+ if current_eval_result.step_outputs:
50
+ # Create a dictionary for quick lookup of user-defined step rewards
51
+ user_step_rewards_map: Dict[Union[int, str], float] = {
52
+ step_out.step_index: step_out.base_reward for step_out in current_eval_result.step_outputs
53
+ }
54
+
55
+ for s_data in current_step_data_list:
56
+ # --- Critical Mapping Logic ---
57
+ # Strategy: Use 'assistant_turn_index' stored in StepData.step_info
58
+ # by RLRolloutWorker. User's StepOutput.step_index should match this.
59
+ # This assumes RLRolloutWorker adds this info.
60
+ user_defined_step_idx = s_data.step_info.get("assistant_turn_index")
61
+
62
+ if user_defined_step_idx is not None:
63
+ if user_defined_step_idx in user_step_rewards_map:
64
+ s_data.base_reward = user_step_rewards_map[user_defined_step_idx]
65
+ else:
66
+ # No base reward provided by user for this specific system step.
67
+ # s_data.base_reward remains None (or could be a default).
68
+ pass
69
+ else:
70
+ # RLRolloutWorker did not provide 'assistant_turn_index' for this StepData,
71
+ # or the mapping key in step_info is different.
72
+ # This indicates a potential issue in RLRolloutWorker or mapping strategy.
73
+ pass
74
+ else:
75
+ # No step_outputs provided by the user. Base rewards will remain None.
76
+ pass
77
+
78
+ return current_step_data_list
79
+
80
+ # TODO (Future): Consider a batch version if performance becomes an issue,
81
+ # but the core logic per rollout remains the same.
82
+ # def align_batch_data_for_rl_processing(...)