eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,217 @@
1
+ """
2
+ Utilities for running batch evaluation on transformed N-variant data.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ from typing import Any, Dict, List, Optional, Union
8
+
9
+ from ..models import EvaluateResult
10
+ from ..utils.module_loader import load_function as load_reward_function
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def run_batch_evaluation(
16
+ batch_jsonl_path: str,
17
+ reward_function_path: str,
18
+ output_path: str,
19
+ reward_function_kwargs: Optional[Dict[str, Any]] = None,
20
+ ) -> List[Dict[str, Any]]:
21
+ """
22
+ Run batch evaluation on transformed N-variant data.
23
+
24
+ Args:
25
+ batch_jsonl_path: Path to the batch evaluation JSONL file
26
+ reward_function_path: Path to the batch reward function (e.g., "module.function")
27
+ output_path: Path to write the batch evaluation results
28
+ reward_function_kwargs: Additional kwargs for the reward function
29
+
30
+ Returns:
31
+ List of batch evaluation results
32
+ """
33
+ if reward_function_kwargs is None:
34
+ reward_function_kwargs = {}
35
+
36
+ # Load the batch reward function
37
+ reward_function = load_reward_function(reward_function_path)
38
+
39
+ # Verify it's a batch mode function
40
+ if not hasattr(reward_function, "_reward_function_mode"):
41
+ logger.warning(f"Reward function {reward_function_path} doesn't have mode metadata. Assuming batch mode.")
42
+ elif getattr(reward_function, "_reward_function_mode") != "batch":
43
+ raise ValueError(
44
+ f"Reward function {reward_function_path} is not configured for batch mode. Expected mode='batch'."
45
+ )
46
+
47
+ results = []
48
+
49
+ try:
50
+ with open(batch_jsonl_path, "r", encoding="utf-8") as f:
51
+ for line_num, line in enumerate(f, 1):
52
+ try:
53
+ data = json.loads(line.strip())
54
+
55
+ # Extract required fields
56
+ request_id = data.get("request_id")
57
+ rollouts_messages = data.get("rollouts_messages")
58
+
59
+ if not request_id:
60
+ logger.error(f"Line {line_num}: Missing request_id")
61
+ continue
62
+
63
+ if not rollouts_messages or not isinstance(rollouts_messages, list):
64
+ logger.error(f"Line {line_num}: Missing or invalid rollouts_messages")
65
+ continue
66
+
67
+ # Prepare kwargs for the batch reward function
68
+ batch_kwargs = dict(reward_function_kwargs)
69
+
70
+ # Add other fields from the data as kwargs (excluding the main inputs)
71
+ excluded_fields = {
72
+ "request_id",
73
+ "rollouts_messages",
74
+ "num_variants",
75
+ "response_ids",
76
+ }
77
+ for key, value in data.items():
78
+ if key not in excluded_fields:
79
+ batch_kwargs[key] = value
80
+
81
+ # Call the batch reward function
82
+ try:
83
+ batch_results = reward_function(rollouts_messages=rollouts_messages, **batch_kwargs)
84
+
85
+ # Validate results
86
+ if not isinstance(batch_results, list):
87
+ raise ValueError(f"Batch reward function must return a list, got {type(batch_results)}")
88
+
89
+ if len(batch_results) != len(rollouts_messages):
90
+ raise ValueError(
91
+ f"Batch reward function returned {len(batch_results)} results "
92
+ f"but expected {len(rollouts_messages)} (one per rollout)"
93
+ )
94
+
95
+ # Create result entries
96
+ response_ids = data.get("response_ids", list(range(len(rollouts_messages))))
97
+
98
+ for i, (response_id, eval_result) in enumerate(zip(response_ids, batch_results)):
99
+ if not isinstance(eval_result, EvaluateResult):
100
+ logger.error(f"Result {i} is not an EvaluateResult: {type(eval_result)}")
101
+ continue
102
+
103
+ result_entry = {
104
+ "request_id": request_id,
105
+ "response_id": response_id,
106
+ "rollout_index": i,
107
+ "evaluation_score": eval_result.score,
108
+ "evaluation_reason": eval_result.reason,
109
+ "is_score_valid": eval_result.is_score_valid,
110
+ "evaluation_metrics": (
111
+ {k: v.model_dump() for k, v in eval_result.metrics.items()}
112
+ if eval_result.metrics
113
+ else {}
114
+ ),
115
+ # Include original metadata
116
+ **{k: v for k, v in data.items() if k not in excluded_fields},
117
+ }
118
+
119
+ results.append(result_entry)
120
+
121
+ except Exception as e:
122
+ logger.error(f"Error calling batch reward function for request {request_id}: {e}")
123
+ # Create error entries for each expected result
124
+ response_ids = data.get("response_ids", list(range(len(rollouts_messages))))
125
+ for i, response_id in enumerate(response_ids):
126
+ error_entry = {
127
+ "request_id": request_id,
128
+ "response_id": response_id,
129
+ "rollout_index": i,
130
+ "error": f"Batch evaluation failed: {str(e)}",
131
+ "evaluation_score": 0.0,
132
+ "is_score_valid": False,
133
+ }
134
+ results.append(error_entry)
135
+
136
+ except json.JSONDecodeError as e:
137
+ logger.error(f"Invalid JSON on line {line_num}: {e}")
138
+ continue
139
+ except Exception as e:
140
+ logger.error(f"Error processing line {line_num}: {e}")
141
+ continue
142
+
143
+ except FileNotFoundError:
144
+ raise FileNotFoundError(f"Batch JSONL file not found: {batch_jsonl_path}")
145
+
146
+ # Write results
147
+ with open(output_path, "w", encoding="utf-8") as f:
148
+ for result in results:
149
+ f.write(json.dumps(result) + "\n")
150
+
151
+ logger.info(f"Batch evaluation completed. {len(results)} results written to {output_path}")
152
+ return results
153
+
154
+
155
+ def create_sample_batch_reward_function():
156
+ """
157
+ Create a sample batch reward function for testing.
158
+ This is a simple function that compares all variants and scores them relative to each other.
159
+ """
160
+ from ..models import EvaluateResult, Message
161
+ from ..typed_interface import reward_function
162
+
163
+ @reward_function(mode="batch")
164
+ def sample_batch_reward(
165
+ rollouts_messages: List[List[Message]],
166
+ ground_truth_for_eval: Optional[str] = None,
167
+ **kwargs: Any,
168
+ ) -> List[EvaluateResult]:
169
+ """
170
+ Sample batch reward function that scores variants relative to each other.
171
+
172
+ This function demonstrates how to process multiple rollouts (variants) together
173
+ and return comparative scores.
174
+ """
175
+ from ..models import MetricResult
176
+
177
+ results = []
178
+
179
+ # Extract the assistant responses from each rollout
180
+ assistant_responses = []
181
+ for rollout in rollouts_messages:
182
+ assistant_msg = None
183
+ for msg in reversed(rollout): # Find the last assistant message
184
+ if msg.role == "assistant":
185
+ assistant_msg = msg.content
186
+ break
187
+ assistant_responses.append(assistant_msg or "")
188
+
189
+ # Simple scoring: longer responses get higher scores (just for demonstration)
190
+ response_lengths = [len(response) for response in assistant_responses]
191
+ max_length = max(response_lengths) if response_lengths else 1
192
+
193
+ for i, (response, length) in enumerate(zip(assistant_responses, response_lengths)):
194
+ # Normalize score based on length (0.1 to 1.0)
195
+ base_score = 0.1 + 0.9 * (length / max_length)
196
+
197
+ # Add some variation based on position (earlier variants get slight bonus)
198
+ position_bonus = 0.1 * (1 - i / len(assistant_responses))
199
+ final_score = min(1.0, base_score + position_bonus)
200
+
201
+ result = EvaluateResult(
202
+ score=final_score,
203
+ reason=f"Variant {i}: Length={length}, Base={base_score:.2f}, Position bonus={position_bonus:.2f}",
204
+ is_score_valid=True,
205
+ metrics={
206
+ "response_length": MetricResult(
207
+ score=min(1.0, length / 100.0), # Normalize length score
208
+ reason=f"Response contains {length} characters",
209
+ is_score_valid=True,
210
+ )
211
+ },
212
+ )
213
+ results.append(result)
214
+
215
+ return results
216
+
217
+ return sample_batch_reward
@@ -0,0 +1,205 @@
1
+ """
2
+ Utilities for transforming N-variant generation results into batch evaluation format.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ from collections import defaultdict
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def transform_n_variant_jsonl_to_batch_format(
14
+ input_file_path: str,
15
+ output_file_path: Optional[str] = None,
16
+ request_id_field: str = "request_id",
17
+ response_id_field: str = "response_id",
18
+ messages_field: str = "full_conversation_history",
19
+ fallback_messages_fields: List[str] = None,
20
+ ) -> List[Dict[str, Any]]:
21
+ """
22
+ Transform N-variant generation JSONL output into batch evaluation format.
23
+
24
+ This function groups variants by request_id and creates rollouts_messages
25
+ containing all variant conversations for each original request.
26
+
27
+ Args:
28
+ input_file_path: Path to the N-variant generation JSONL file
29
+ output_file_path: Optional path to write the transformed data (if None, returns data only)
30
+ request_id_field: Field name containing the original request ID (default: "request_id")
31
+ response_id_field: Field name containing the variant response ID (default: "response_id")
32
+ messages_field: Primary field containing conversation messages (default: "full_conversation_history")
33
+ fallback_messages_fields: Fallback fields to construct messages if primary field is missing
34
+
35
+ Returns:
36
+ List of batch evaluation entries, each containing rollouts_messages and other metadata
37
+
38
+ Raises:
39
+ FileNotFoundError: If input file doesn't exist
40
+ ValueError: If required fields are missing or data format is invalid
41
+ """
42
+ if fallback_messages_fields is None:
43
+ fallback_messages_fields = ["user_query", "system_prompt", "assistant_response"]
44
+
45
+ # Group variants by request_id
46
+ grouped_variants = defaultdict(list)
47
+
48
+ try:
49
+ with open(input_file_path, "r", encoding="utf-8") as f:
50
+ for line_num, line in enumerate(f, 1):
51
+ try:
52
+ data = json.loads(line.strip())
53
+
54
+ # Skip lines with errors
55
+ if "error" in data:
56
+ logger.warning(f"Skipping line {line_num} due to error: {data.get('error')}")
57
+ continue
58
+
59
+ # Validate required fields
60
+ if request_id_field not in data:
61
+ raise ValueError(f"Line {line_num}: Missing required field '{request_id_field}'")
62
+
63
+ if response_id_field not in data:
64
+ raise ValueError(f"Line {line_num}: Missing required field '{response_id_field}'")
65
+
66
+ request_id = data[request_id_field]
67
+ response_id = data[response_id_field]
68
+
69
+ # Extract messages
70
+ messages = _extract_messages_from_data(data, messages_field, fallback_messages_fields, line_num)
71
+
72
+ # Store variant data
73
+ variant_data = {
74
+ "response_id": response_id,
75
+ "messages": messages,
76
+ "original_data": data, # Keep original data for metadata extraction
77
+ }
78
+
79
+ grouped_variants[request_id].append(variant_data)
80
+
81
+ except json.JSONDecodeError as e:
82
+ logger.error(f"Invalid JSON on line {line_num}: {e}")
83
+ continue
84
+ except Exception as e:
85
+ logger.error(f"Error processing line {line_num}: {e}")
86
+ continue
87
+
88
+ except FileNotFoundError:
89
+ raise FileNotFoundError(f"Input file not found: {input_file_path}")
90
+
91
+ # Transform grouped variants into batch format
92
+ batch_entries = []
93
+
94
+ for request_id, variants in grouped_variants.items():
95
+ # Sort variants by response_id to ensure consistent ordering
96
+ variants.sort(key=lambda x: x["response_id"])
97
+
98
+ # Extract rollouts_messages (list of message lists)
99
+ rollouts_messages = [variant["messages"] for variant in variants]
100
+
101
+ # Extract common metadata from the first variant (assuming it's consistent across variants)
102
+ first_variant = variants[0]["original_data"]
103
+
104
+ # Create batch entry
105
+ batch_entry = {
106
+ "request_id": request_id,
107
+ "rollouts_messages": rollouts_messages,
108
+ "num_variants": len(variants),
109
+ "response_ids": [variant["response_id"] for variant in variants],
110
+ }
111
+
112
+ # Add common fields as kwargs (excluding variant-specific fields)
113
+ excluded_fields = {
114
+ "id",
115
+ request_id_field,
116
+ response_id_field,
117
+ messages_field,
118
+ "full_conversation_history",
119
+ "assistant_response",
120
+ "evaluation_score",
121
+ "evaluation_reason",
122
+ "evaluation_metrics",
123
+ "executed_tool_calls",
124
+ "discovered_tools",
125
+ "final_mcp_state_captured",
126
+ }
127
+
128
+ for key, value in first_variant.items():
129
+ if key not in excluded_fields:
130
+ batch_entry[key] = value
131
+
132
+ batch_entries.append(batch_entry)
133
+
134
+ # Write output file if specified
135
+ if output_file_path:
136
+ with open(output_file_path, "w", encoding="utf-8") as f:
137
+ for entry in batch_entries:
138
+ f.write(json.dumps(entry) + "\n")
139
+ logger.info(f"Transformed {len(batch_entries)} batch entries written to {output_file_path}")
140
+
141
+ return batch_entries
142
+
143
+
144
+ def _extract_messages_from_data(
145
+ data: Dict[str, Any], primary_field: str, fallback_fields: List[str], line_num: int
146
+ ) -> List[Dict[str, Any]]:
147
+ """
148
+ Extract conversation messages from variant data.
149
+
150
+ Args:
151
+ data: Variant data dictionary
152
+ primary_field: Primary field containing messages
153
+ fallback_fields: Fallback fields to construct messages
154
+ line_num: Line number for error reporting
155
+
156
+ Returns:
157
+ List of message dictionaries
158
+ """
159
+ # Try primary field first
160
+ if primary_field in data and data[primary_field]:
161
+ messages = data[primary_field]
162
+ if isinstance(messages, list):
163
+ return messages
164
+ else:
165
+ logger.warning(f"Line {line_num}: {primary_field} is not a list, trying fallback")
166
+
167
+ # Try to construct messages from fallback fields
168
+ messages = []
169
+
170
+ # Add system message if available
171
+ if "system_prompt" in data and data["system_prompt"]:
172
+ messages.append({"role": "system", "content": data["system_prompt"]})
173
+
174
+ # Add user message if available
175
+ if "user_query" in data and data["user_query"]:
176
+ messages.append({"role": "user", "content": data["user_query"]})
177
+
178
+ # Add assistant message if available
179
+ if "assistant_response" in data and data["assistant_response"]:
180
+ messages.append({"role": "assistant", "content": data["assistant_response"]})
181
+
182
+ if not messages:
183
+ raise ValueError(f"Line {line_num}: Could not extract messages from any available fields")
184
+
185
+ return messages
186
+
187
+
188
+ def create_batch_evaluation_dataset(n_variant_jsonl_path: str, output_jsonl_path: str, **transform_kwargs) -> str:
189
+ """
190
+ Convenience function to create a batch evaluation dataset from N-variant generation output.
191
+
192
+ Args:
193
+ n_variant_jsonl_path: Path to N-variant generation JSONL file
194
+ output_jsonl_path: Path for the batch evaluation JSONL file
195
+ **transform_kwargs: Additional arguments for transform_n_variant_jsonl_to_batch_format
196
+
197
+ Returns:
198
+ Path to the created batch evaluation dataset
199
+ """
200
+ transform_n_variant_jsonl_to_batch_format(
201
+ input_file_path=n_variant_jsonl_path,
202
+ output_file_path=output_jsonl_path,
203
+ **transform_kwargs,
204
+ )
205
+ return output_jsonl_path
@@ -0,0 +1,112 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
3
+
4
+ if TYPE_CHECKING:
5
+ from datasets import Dataset
6
+
7
+ try:
8
+ from datasets import Dataset
9
+
10
+ HAS_DATASETS_LIB = True
11
+ except ImportError:
12
+ HAS_DATASETS_LIB = False
13
+ if not TYPE_CHECKING:
14
+
15
+ class Dataset:
16
+ """Placeholder for HuggingFace Dataset for when the library is not installed."""
17
+
18
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
19
+ pass
20
+
21
+
22
+ def load_jsonl_to_hf_dataset(
23
+ dataset_path: str,
24
+ transform_fn: Optional[Callable[[Dict[str, Any]], Optional[Dict[str, Any]]]] = None,
25
+ prompt_column: str = "prompt",
26
+ required_columns: Optional[List[str]] = None,
27
+ dataset_filter_fn: Optional[Callable[[Dict[str, Any]], bool]] = None,
28
+ ) -> Optional["Dataset"]:
29
+ """
30
+ Loads a JSONL file into a HuggingFace Dataset, optionally applying a
31
+ transformation to each sample and ensuring required columns are present.
32
+
33
+ Args:
34
+ dataset_path: Path to the JSONL file.
35
+ transform_fn: An optional function to apply to each raw dictionary (sample)
36
+ from the JSONL file. It should take a dict and return a dict
37
+ (or None to skip the sample).
38
+ prompt_column: The name of the column expected to contain the prompt for TRL.
39
+ Defaults to "prompt".
40
+ required_columns: A list of column names that must be present in the
41
+ final dataset (after transformation).
42
+ dataset_filter_fn: An optional function to filter samples from the dataset
43
+ after transformation. It should take a dict and return
44
+ True to keep the sample, False to discard.
45
+
46
+ Returns:
47
+ A HuggingFace Dataset object, or None if the datasets library is not installed
48
+ or if an error occurs.
49
+ """
50
+ if not HAS_DATASETS_LIB:
51
+ print("The 'datasets' library is not installed. Please install it with 'pip install datasets'.")
52
+ return None
53
+
54
+ processed_samples: List[Dict[str, Any]] = []
55
+ try:
56
+ with open(dataset_path, "r", encoding="utf-8") as f:
57
+ for line_number, line in enumerate(f, 1):
58
+ try:
59
+ raw_sample = json.loads(line.strip())
60
+
61
+ transformed_sample: Optional[Dict[str, Any]]
62
+ if transform_fn:
63
+ transformed_sample = transform_fn(raw_sample)
64
+ else:
65
+ transformed_sample = raw_sample
66
+
67
+ if transformed_sample is None:
68
+ continue
69
+
70
+ if dataset_filter_fn and not dataset_filter_fn(transformed_sample):
71
+ continue
72
+
73
+ processed_samples.append(transformed_sample)
74
+
75
+ except json.JSONDecodeError as e:
76
+ print(f"Warning: Skipping line {line_number} in {dataset_path} due to JSON decode error: {e}")
77
+ except Exception as e:
78
+ print(
79
+ f"Warning: Skipping line {line_number} in {dataset_path} due to error in transform_fn or filter_fn: {e}"
80
+ )
81
+
82
+ if not processed_samples:
83
+ print(f"Warning: No samples were processed from {dataset_path}.")
84
+ return Dataset.from_list([])
85
+
86
+ hf_dataset = Dataset.from_list(processed_samples)
87
+
88
+ if prompt_column not in hf_dataset.column_names:
89
+ raise ValueError(
90
+ f"Dataset from {dataset_path} must contain a '{prompt_column}' column after transformation."
91
+ )
92
+
93
+ final_required_columns = set(required_columns or [])
94
+ final_required_columns.add(prompt_column)
95
+
96
+ for col in final_required_columns:
97
+ if col not in hf_dataset.column_names:
98
+ raise ValueError(
99
+ f"Dataset from {dataset_path} must contain a '{col}' column after transformation for the reward function/TRL."
100
+ )
101
+
102
+ return hf_dataset
103
+
104
+ except FileNotFoundError:
105
+ print(f"Error: Dataset file not found at {dataset_path}")
106
+ return None
107
+ except ValueError as ve:
108
+ print(f"Error: {ve}")
109
+ return None
110
+ except Exception as e:
111
+ print(f"An unexpected error occurred while loading dataset {dataset_path}: {e}")
112
+ return None
@@ -0,0 +1,56 @@
1
+ """
2
+ Utility for dynamically loading modules and functions.
3
+ """
4
+
5
+ import importlib
6
+ import logging
7
+ from typing import Any, Callable
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def load_function(import_path: str) -> Callable[..., Any]:
13
+ """
14
+ Dynamically loads a function given its full import path.
15
+ Example: "my_package.my_module.my_function"
16
+ """
17
+ try:
18
+ module_path, function_name = import_path.rsplit(".", 1)
19
+ module = importlib.import_module(module_path)
20
+ func = getattr(module, function_name)
21
+ if not callable(func):
22
+ raise AttributeError(f"'{function_name}' in module '{module_path}' is not callable.")
23
+ logger.info(f"Successfully loaded function '{function_name}' from '{module_path}'.")
24
+ return func
25
+ except ImportError as e:
26
+ logger.error(f"Failed to import module from path '{import_path}': {e}")
27
+ raise
28
+ except AttributeError as e:
29
+ logger.error(f"Failed to find or access function in path '{import_path}': {e}")
30
+ raise
31
+ except Exception as e:
32
+ logger.error(f"An unexpected error occurred while loading function from '{import_path}': {e}")
33
+ raise
34
+
35
+
36
+ # Example usage:
37
+ # if __name__ == '__main__':
38
+ # try:
39
+ # # Assuming you have a eval_protocol.rewards.math module with math_reward function
40
+ # math_reward_func = load_function("eval_protocol.rewards.math.math_reward")
41
+ # print(f"Loaded: {math_reward_func}")
42
+ # # You could then call it, e.g., if it took simple args: math_reward_func(arg1="test")
43
+ # except Exception as e:
44
+ # print(f"Test loading failed: {e}")
45
+
46
+ # try:
47
+ # # Test with a non-existent function
48
+ # load_function("eval_protocol.rewards.math.non_existent_function")
49
+ # except Exception as e:
50
+ # print(f"Test loading non-existent function failed as expected: {e}")
51
+
52
+ # try:
53
+ # # Test with a non-existent module
54
+ # load_function("non_existent_package.non_existent_module.some_function")
55
+ # except Exception as e:
56
+ # print(f"Test loading non-existent module failed as expected: {e}")