eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,264 @@
1
+ """
2
+ CLI command for running agent evaluations using the ForkableResource framework.
3
+ """
4
+
5
+ import asyncio
6
+
7
+ try:
8
+ import yaml
9
+ except ImportError:
10
+ import sys
11
+ import types
12
+
13
+ # Create a stub module if yaml is not installed
14
+ yaml = types.ModuleType("yaml")
15
+
16
+ def dummy_safe_load(x):
17
+ return None
18
+
19
+ def dummy_dump(x, **kwargs):
20
+ return None
21
+
22
+ yaml.safe_load = dummy_safe_load # type: ignore[assignment]
23
+ yaml.dump = dummy_dump # type: ignore[assignment]
24
+
25
+ import json # Fallback or for explicit JSON files
26
+ import logging # For logger instance
27
+ import os # For environment variables
28
+ from pathlib import Path
29
+
30
+ from pydantic import ValidationError
31
+
32
+ from eval_protocol.agent import Orchestrator
33
+ from eval_protocol.agent.task_manager import TaskManager
34
+ from eval_protocol.models import TaskDefinitionModel # Import the new Pydantic model
35
+
36
+ # setup_logging is already called in cli.py's main, but good for standalone use if any
37
+ # from .common import setup_logging
38
+
39
+
40
+ def agent_eval_command(args):
41
+ """
42
+ Run agent evaluation using the Orchestrator and ForkableResource framework.
43
+ """
44
+ logger = logging.getLogger("agent_eval")
45
+ logger.info("Starting agent-eval command.")
46
+
47
+ task_manager = TaskManager()
48
+
49
+ if not args.task_def:
50
+ logger.error("Error: --task-def (path to task definition YAML file or directory) is required.")
51
+ return 1
52
+
53
+ task_def_path = Path(args.task_def)
54
+
55
+ registered_task_ids = []
56
+ if task_def_path.is_file():
57
+ task_def = task_manager._load_task_from_file(str(task_def_path))
58
+ if task_def:
59
+ task_id = task_manager.register_task(task_def)
60
+ registered_task_ids.append(task_id)
61
+ else:
62
+ logger.error(f"Failed to load task definition from {task_def_path}")
63
+ return 1
64
+ elif task_def_path.is_dir():
65
+ registered_task_ids = task_manager.register_tasks_from_directory(str(task_def_path))
66
+ if not registered_task_ids:
67
+ logger.error(f"No valid task definitions found in directory: {task_def_path}")
68
+ return 1
69
+ else:
70
+ logger.error(f"Task definition path not found or invalid: {task_def_path}")
71
+ return 1
72
+
73
+ logger.info(f"Registered {len(registered_task_ids)} tasks: {registered_task_ids}")
74
+
75
+ try:
76
+
77
+ async def main_flow():
78
+ if getattr(args, "model", None):
79
+ original_model = os.environ.get("MODEL_AGENT")
80
+ os.environ["MODEL_AGENT"] = args.model
81
+ logger.info(f"Model overridden to: {args.model}")
82
+
83
+ parallel = getattr(args, "parallel", False)
84
+ max_concurrency = getattr(args, "max_concurrency", 3)
85
+ filter_tasks = getattr(args, "filter", None)
86
+
87
+ tasks_to_run = registered_task_ids
88
+ if filter_tasks:
89
+ tasks_to_run = [tid for tid in registered_task_ids if tid in filter_tasks]
90
+ if not tasks_to_run:
91
+ logger.warning(f"No tasks match the specified filter: {filter_tasks}")
92
+ return
93
+
94
+ try:
95
+ num_rollouts_override = getattr(args, "num_rollouts", None)
96
+ results = await task_manager.execute_tasks(
97
+ task_ids=tasks_to_run,
98
+ parallel=parallel,
99
+ max_concurrency=max_concurrency,
100
+ num_rollouts_override=num_rollouts_override,
101
+ )
102
+
103
+ logger.info(f"Execution completed for {len(results)} tasks")
104
+ for task_id, result in results.items():
105
+ if isinstance(result, dict) and "error" in result:
106
+ logger.error(f"Task '{task_id}' failed: {result['error']}")
107
+ elif isinstance(result, dict) and result.get("aggregated", False):
108
+ # Handle aggregated results from multiple rollouts
109
+ logger.info(f"Task '{task_id}' batch results:")
110
+ logger.info(
111
+ f" - Rollouts: {result['successful_rollouts']}/{result['num_rollouts']} successful ({result.get('failed_rollouts', 0)} failed)"
112
+ )
113
+ logger.info(f" - Success rate: {result['success_rate']:.2%}")
114
+ logger.info(f" - Average score: {result['avg_score']:.4f}")
115
+ logger.info(f" - Standard deviation: {result.get('std_dev', 0.0):.4f}")
116
+ logger.info(f" - Score range: {result['min_score']:.4f} - {result['max_score']:.4f}")
117
+ if "aggregated_metrics" in result:
118
+ logger.info(f" - Aggregated metrics:")
119
+ for metric_name, metric_data in result["aggregated_metrics"].items():
120
+ logger.info(
121
+ f" * {metric_name}: avg={metric_data['avg_score']:.4f}, range={metric_data['min_score']:.4f}-{metric_data['max_score']:.4f}"
122
+ )
123
+
124
+ # Log path to detailed results file
125
+ if result.get("timestamp"):
126
+ timestamp = (
127
+ result["timestamp"].replace(":", "").replace("-", "").replace("T", "_").split(".")[0]
128
+ )
129
+ # Use the trajectory filename format that matches TaskManager
130
+ trajectory_file = f"trajectory_{task_id}_{timestamp}.jsonl"
131
+ logger.info(f" - Trajectory data saved to: {trajectory_file}")
132
+ elif isinstance(result, dict) and "score" in result:
133
+ logger.info(f"Task '{task_id}' score: {result['score']}")
134
+ else:
135
+ logger.info(f"Task '{task_id}' completed")
136
+ finally:
137
+ await task_manager.cleanup()
138
+
139
+ asyncio.run(main_flow())
140
+ logger.info("agent-eval command finished successfully.")
141
+ return 0
142
+ except Exception as e:
143
+ logger.error(f"Error during agent-eval execution: {e}")
144
+ import traceback
145
+
146
+ logger.debug(traceback.format_exc())
147
+ return 1
148
+
149
+
150
+ def bfcl_eval_command(args):
151
+ """
152
+ Run BFCL agent evaluations using the refactored framework.
153
+ This command specifically manages BFCL task evaluation.
154
+ """
155
+ logger = logging.getLogger("bfcl_eval")
156
+ logger.info("Starting BFCL evaluation command.")
157
+
158
+ task_manager = TaskManager()
159
+
160
+ task_dir = Path(args.task_dir)
161
+ if not task_dir.is_dir():
162
+ logger.error(f"Task directory not found: {task_dir}")
163
+ return 1
164
+
165
+ logger.info(f"Registering BFCL tasks from {task_dir}")
166
+
167
+ try:
168
+ registered_task_ids = []
169
+
170
+ if args.task_id:
171
+ task_path = task_dir / f"{args.task_id}.yaml"
172
+ if not task_path.exists():
173
+ logger.error(f"Task file not found: {task_path}")
174
+ return 1
175
+
176
+ task_def = task_manager._load_task_from_file(str(task_path))
177
+ if task_def:
178
+ task_id = task_manager.register_task(task_def)
179
+ registered_task_ids.append(task_id)
180
+ logger.info(f"Registered task: {task_id}")
181
+ else:
182
+ logger.error(f"Failed to load task from {task_path}")
183
+ return 1
184
+ else:
185
+ registered_task_ids = task_manager.register_tasks_from_directory(str(task_dir))
186
+ if not registered_task_ids:
187
+ logger.error(f"No valid BFCL tasks found in directory: {task_dir}")
188
+ return 1
189
+ logger.info(f"Registered {len(registered_task_ids)} BFCL tasks")
190
+
191
+ async def main_flow():
192
+ if args.model:
193
+ original_model = os.environ.get("MODEL_AGENT")
194
+ os.environ["MODEL_AGENT"] = args.model
195
+ logger.info(f"Model overridden to: {args.model}")
196
+
197
+ if args.output_dir:
198
+ output_path = Path(args.output_dir)
199
+ output_path.mkdir(parents=True, exist_ok=True)
200
+ logger.info(f"Results will be saved to {output_path}")
201
+
202
+ try:
203
+ results = await task_manager.execute_tasks(
204
+ task_ids=registered_task_ids,
205
+ parallel=args.parallel,
206
+ max_concurrency=args.max_concurrency,
207
+ )
208
+
209
+ logger.info(f"BFCL evaluation completed for {len(results)} tasks")
210
+ for task_id, result in results.items():
211
+ if isinstance(result, dict) and "error" in result:
212
+ logger.error(f"Task '{task_id}' failed: {result['error']}")
213
+ elif isinstance(result, dict) and "score" in result:
214
+ logger.info(f"Task '{task_id}' score: {result['score']}")
215
+
216
+ # More detailed results for BFCL
217
+ if "format_score" in result:
218
+ logger.info(f"Task '{task_id}' format score: {result['format_score']}")
219
+ if "state_match" in result:
220
+ logger.info(f"Task '{task_id}' state match: {result['state_match']}")
221
+ else:
222
+ logger.info(f"Task '{task_id}' completed with result: {result}")
223
+
224
+ if args.output_dir:
225
+ results_file = Path(args.output_dir) / "bfcl_results.json"
226
+
227
+ # Convert results to JSON-serializable format
228
+ serializable_results = {}
229
+ for task_id, result in results.items():
230
+ if hasattr(result, "dict"):
231
+ # Handle Pydantic models
232
+ serializable_results[task_id] = result.dict()
233
+ elif isinstance(result, dict):
234
+ # Handle dictionaries with potentially non-serializable values
235
+ serializable_dict = {}
236
+ for k, v in result.items():
237
+ if hasattr(v, "dict"):
238
+ serializable_dict[k] = v.dict()
239
+ elif hasattr(v, "__dict__"):
240
+ serializable_dict[k] = str(v)
241
+ else:
242
+ serializable_dict[k] = v
243
+ serializable_results[task_id] = serializable_dict
244
+ else:
245
+ # Handle other objects by converting to string
246
+ serializable_results[task_id] = str(result)
247
+
248
+ with open(results_file, "w") as f:
249
+ json.dump(serializable_results, f, indent=2)
250
+ logger.info(f"Results saved to {results_file}")
251
+
252
+ finally:
253
+ await task_manager.cleanup()
254
+
255
+ asyncio.run(main_flow())
256
+ logger.info("BFCL evaluation completed successfully.")
257
+ return 0
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error during BFCL evaluation: {e}")
261
+ import traceback
262
+
263
+ logger.debug(traceback.format_exc())
264
+ return 1
@@ -0,0 +1,242 @@
1
+ """
2
+ Common utility functions for the Reward Kit CLI.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from typing import Any, Dict, Iterator, List, Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def setup_logging(verbose=False, debug=False):
14
+ """Setup logging configuration"""
15
+ if debug:
16
+ log_level = logging.DEBUG
17
+ # More detailed format for debug
18
+ format_str = "[%(asctime)s][%(name)s][%(levelname)s] - %(pathname)s:%(lineno)d - %(message)s"
19
+ elif verbose: # --verbose flag
20
+ log_level = logging.INFO
21
+ # Consistent format, similar to user's logs but with name
22
+ format_str = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s"
23
+ else: # Default (neither --verbose nor --debug)
24
+ log_level = logging.INFO # Changed from WARNING to INFO
25
+ # Use the same format as verbose for default INFO level
26
+ format_str = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s"
27
+
28
+ logging.basicConfig(level=log_level, format=format_str, datefmt="%Y-%m-%d %H:%M:%S")
29
+
30
+ # Set higher levels for noisy libraries unless in full debug mode
31
+ if not debug:
32
+ noisy_loggers = ["httpx", "mcp", "urllib3", "asyncio", "hpack", "httpcore"]
33
+ for logger_name in noisy_loggers:
34
+ logging.getLogger(logger_name).setLevel(logging.WARNING)
35
+
36
+ # Ensure eval_protocol's own loggers respect the overall log_level,
37
+ # overriding any specific DEBUG settings in submodules unless --debug is used.
38
+ # If log_level is WARNING (default), eval_protocol INFO and DEBUG logs will be suppressed.
39
+ # If log_level is INFO (--verbose), eval_protocol DEBUG logs will be suppressed.
40
+ # If log_level is DEBUG (--debug), all eval_protocol logs show.
41
+ logging.getLogger("eval_protocol").setLevel(log_level)
42
+
43
+
44
+ def check_environment():
45
+ """Check if required environment variables are set for general commands."""
46
+ if not os.environ.get("FIREWORKS_API_KEY"):
47
+ logger.warning("FIREWORKS_API_KEY environment variable is not set.")
48
+ logger.warning("This is required for API calls. Set this variable before running the command.")
49
+ logger.warning("Example: FIREWORKS_API_KEY=$DEV_FIREWORKS_API_KEY reward-kit [command]")
50
+ return False
51
+ return True
52
+
53
+
54
+ def check_agent_environment(test_mode=False):
55
+ """Check if required environment variables are set for agent evaluation commands."""
56
+ missing_vars = []
57
+ if not os.environ.get("MODEL_AGENT"):
58
+ missing_vars.append("MODEL_AGENT")
59
+
60
+ if test_mode:
61
+ if missing_vars:
62
+ logger.info(f"Note: The following environment variables are not set: {', '.join(missing_vars)}")
63
+ logger.info("Since you're running in test mode, these are not strictly required for all operations.")
64
+ return True
65
+
66
+ if missing_vars:
67
+ logger.warning(f"The following environment variables are not set: {', '.join(missing_vars)}")
68
+ logger.warning(
69
+ "These are typically required for full agent evaluation. Set these variables for full functionality."
70
+ )
71
+ logger.warning("Example: MODEL_AGENT=openai/gpt-4o-mini reward-kit agent-eval [args]")
72
+ logger.warning("Alternatively, use --test-mode for certain validation tasks without requiring all API keys.")
73
+ return False
74
+ return True
75
+
76
+
77
+ # --- Sample Loading Helper Functions ---
78
+
79
+
80
+ def _validate_sample_messages(messages: Any, sample_index: int, line_number: int) -> bool:
81
+ """Helper to validate the 'messages' field of a sample."""
82
+ if not isinstance(messages, list):
83
+ logger.warning(f"Sample {sample_index} (line {line_number}): 'messages' field is not a list. Skipping sample.")
84
+ return False
85
+ if not messages:
86
+ logger.warning(f"Sample {sample_index} (line {line_number}): 'messages' list is empty. Skipping sample.")
87
+ return False
88
+ for i, msg in enumerate(messages):
89
+ if not isinstance(msg, dict):
90
+ logger.warning(
91
+ f"Sample {sample_index} (line {line_number}): message item {i} is not a dictionary. Skipping sample."
92
+ )
93
+ return False
94
+ role = msg.get("role")
95
+ content = msg.get("content")
96
+ if not isinstance(role, str) or not isinstance(content, str):
97
+ logger.warning(
98
+ f"Sample {sample_index} (line {line_number}): message item {i} missing 'role' or 'content' string fields. Skipping sample."
99
+ )
100
+ return False
101
+ return True
102
+
103
+
104
+ def load_samples_from_file(filepath: str, max_samples: int) -> Iterator[Dict[str, Any]]:
105
+ """
106
+ Loads samples from a JSONL file.
107
+ Each line should be a JSON object.
108
+ Each sample must contain a 'messages' key with a list of message dicts (each having 'role' and 'content').
109
+ Yields valid sample dictionaries up to max_samples.
110
+ """
111
+ count = 0
112
+ line_number = 0
113
+ try:
114
+ with open(filepath, "r", encoding="utf-8") as f:
115
+ for line in f:
116
+ line_number += 1
117
+ if count >= max_samples:
118
+ logger.info(f"Reached max_samples ({max_samples}). Stopping sample loading from {filepath}.")
119
+ break
120
+ line_content = line.strip()
121
+ if not line_content:
122
+ continue
123
+ try:
124
+ sample = json.loads(line_content)
125
+ except json.JSONDecodeError:
126
+ logger.warning(f"Line {line_number}: Invalid JSON. Skipping line: {line_content[:100]}...")
127
+ continue
128
+ if not isinstance(sample, dict):
129
+ logger.warning(f"Line {line_number}: Content is not a JSON object. Skipping line.")
130
+ continue
131
+ messages = sample.get("messages")
132
+ if messages is None:
133
+ logger.warning(f"Sample (line {line_number}): 'messages' field is missing. Skipping sample.")
134
+ continue
135
+ if not _validate_sample_messages(messages, count + 1, line_number):
136
+ continue
137
+ yield sample
138
+ count += 1
139
+ except FileNotFoundError:
140
+ logger.error(f"Sample file not found: {filepath}")
141
+ except Exception as e:
142
+ logger.error(f"Error reading or processing sample file {filepath}: {e}")
143
+ if count == 0:
144
+ logger.info(f"No valid samples loaded from {filepath} after processing {line_number} lines.")
145
+
146
+
147
+ def load_samples_from_huggingface(
148
+ dataset_name: str,
149
+ split: str,
150
+ prompt_key: str,
151
+ response_key: str,
152
+ key_map: Optional[Dict[str, str]],
153
+ max_samples: int,
154
+ ) -> Iterator[Dict[str, Any]]:
155
+ """
156
+ Loads samples from a HuggingFace dataset using the 'datasets' library.
157
+ Constructs 'messages' from prompt_key and response_key.
158
+ Uses key_map to map other dataset fields to custom keys in the output sample.
159
+ Yields valid sample dictionaries up to max_samples.
160
+ """
161
+ try:
162
+ from datasets import (
163
+ Dataset,
164
+ DatasetDict,
165
+ IterableDataset,
166
+ IterableDatasetDict,
167
+ load_dataset,
168
+ )
169
+
170
+ # Also consider specific exceptions from datasets like DatasetNotFoundError
171
+ except ImportError:
172
+ logger.error(
173
+ "The 'datasets' library is required to load samples from HuggingFace. "
174
+ "Please install it with 'pip install datasets'."
175
+ )
176
+ return
177
+
178
+ count = 0
179
+ processed_records = 0
180
+ try:
181
+ hf_dataset = load_dataset(dataset_name, split=split, streaming=True) # Use streaming
182
+ except Exception as e: # Broad exception for now, can be more specific
183
+ logger.error(f"Error loading HuggingFace dataset '{dataset_name}' (split: {split}): {e}")
184
+ return
185
+
186
+ if not isinstance(
187
+ hf_dataset, (DatasetDict, Dataset, IterableDatasetDict, IterableDataset)
188
+ ): # Should be IterableDataset due to streaming=True
189
+ logger.error(f"Loaded HuggingFace dataset '{dataset_name}' is not a recognized Dataset type.")
190
+ return
191
+
192
+ logger.info(f"Streaming samples from HuggingFace dataset '{dataset_name}' (split: {split}).")
193
+ for record in hf_dataset:
194
+ processed_records += 1
195
+ if count >= max_samples:
196
+ logger.info(f"Reached max_samples ({max_samples}). Stopping HuggingFace sample loading.")
197
+ break
198
+
199
+ if not isinstance(record, dict):
200
+ logger.warning(f"HuggingFace dataset record {processed_records} is not a dictionary. Skipping.")
201
+ continue
202
+
203
+ prompt = record.get(prompt_key)
204
+ response_content = record.get(response_key)
205
+
206
+ if not isinstance(prompt, str):
207
+ logger.warning(
208
+ f"HuggingFace record {processed_records}: Prompt key '{prompt_key}' (value: {str(prompt)[:50]}...) did not yield a string. Skipping sample."
209
+ )
210
+ continue
211
+ if not isinstance(response_content, str):
212
+ logger.warning(
213
+ f"HuggingFace record {processed_records}: Response key '{response_key}' (value: {str(response_content)[:50]}...) did not yield a string. Skipping sample."
214
+ )
215
+ continue
216
+
217
+ messages = [
218
+ {"role": "user", "content": prompt},
219
+ {"role": "assistant", "content": response_content},
220
+ ]
221
+
222
+ if not _validate_sample_messages(messages, count + 1, processed_records):
223
+ continue
224
+
225
+ sample_output: Dict[str, Any] = {"messages": messages}
226
+
227
+ if key_map:
228
+ for source_key_in_record, target_key_in_sample in key_map.items():
229
+ if source_key_in_record in record:
230
+ sample_output[target_key_in_sample] = record[source_key_in_record]
231
+ else:
232
+ logger.warning(
233
+ f"HuggingFace record {processed_records}: Key '{source_key_in_record}' from key_map not found. It will be omitted."
234
+ )
235
+
236
+ yield sample_output
237
+ count += 1
238
+
239
+ if count == 0:
240
+ logger.info(
241
+ f"No valid samples loaded from HuggingFace dataset '{dataset_name}' (split: {split}) after processing {processed_records} records."
242
+ )