eval-protocol 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. development/__init__.py +1 -0
  2. development/normalize_sandbox_fusion.py +628 -0
  3. development/utils/__init__.py +1 -0
  4. development/utils/generate_api_key.py +31 -0
  5. development/utils/subprocess_manager.py +481 -0
  6. eval_protocol/__init__.py +86 -0
  7. eval_protocol/__main__.py +10 -0
  8. eval_protocol/_version.py +21 -0
  9. eval_protocol/adapters/__init__.py +1 -0
  10. eval_protocol/adapters/braintrust.py +8 -0
  11. eval_protocol/adapters/trl.py +8 -0
  12. eval_protocol/agent/__init__.py +29 -0
  13. eval_protocol/agent/models.py +69 -0
  14. eval_protocol/agent/orchestrator.py +893 -0
  15. eval_protocol/agent/resource_abc.py +89 -0
  16. eval_protocol/agent/resource_pool.py +184 -0
  17. eval_protocol/agent/resources/__init__.py +44 -0
  18. eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
  19. eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
  20. eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
  21. eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
  22. eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
  23. eval_protocol/agent/resources/docker_resource.py +479 -0
  24. eval_protocol/agent/resources/filesystem_resource.py +371 -0
  25. eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
  26. eval_protocol/agent/resources/http_rollout_resource.py +325 -0
  27. eval_protocol/agent/resources/python_state_resource.py +170 -0
  28. eval_protocol/agent/resources/sql_resource.py +271 -0
  29. eval_protocol/agent/task_manager.py +1064 -0
  30. eval_protocol/agent/tool_registry.py +111 -0
  31. eval_protocol/auth.py +156 -0
  32. eval_protocol/cli.py +425 -0
  33. eval_protocol/cli_commands/__init__.py +1 -0
  34. eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
  35. eval_protocol/cli_commands/common.py +242 -0
  36. eval_protocol/cli_commands/deploy.py +486 -0
  37. eval_protocol/cli_commands/deploy_mcp.py +287 -0
  38. eval_protocol/cli_commands/preview.py +186 -0
  39. eval_protocol/cli_commands/run_eval_cmd.py +202 -0
  40. eval_protocol/common_utils.py +36 -0
  41. eval_protocol/config.py +180 -0
  42. eval_protocol/datasets/__init__.py +1 -0
  43. eval_protocol/datasets/loader.py +521 -0
  44. eval_protocol/evaluation.py +1045 -0
  45. eval_protocol/execution/__init__.py +1 -0
  46. eval_protocol/execution/pipeline.py +920 -0
  47. eval_protocol/gcp_tools.py +484 -0
  48. eval_protocol/generation/cache.py +141 -0
  49. eval_protocol/generation/clients/base.py +67 -0
  50. eval_protocol/generation/clients.py +248 -0
  51. eval_protocol/generic_server.py +165 -0
  52. eval_protocol/integrations/__init__.py +12 -0
  53. eval_protocol/integrations/braintrust.py +51 -0
  54. eval_protocol/integrations/deepeval.py +106 -0
  55. eval_protocol/integrations/openeval.py +40 -0
  56. eval_protocol/integrations/trl.py +187 -0
  57. eval_protocol/mcp/__init__.py +48 -0
  58. eval_protocol/mcp/adapter.py +131 -0
  59. eval_protocol/mcp/client/__init__.py +12 -0
  60. eval_protocol/mcp/client/connection.py +499 -0
  61. eval_protocol/mcp/clients.py +195 -0
  62. eval_protocol/mcp/execution/__init__.py +23 -0
  63. eval_protocol/mcp/execution/base_policy.py +227 -0
  64. eval_protocol/mcp/execution/fireworks_policy.py +209 -0
  65. eval_protocol/mcp/execution/manager.py +506 -0
  66. eval_protocol/mcp/execution/policy.py +421 -0
  67. eval_protocol/mcp/grid_renderer.py +54 -0
  68. eval_protocol/mcp/mcpgym.py +637 -0
  69. eval_protocol/mcp/process_manager.py +177 -0
  70. eval_protocol/mcp/session/__init__.py +11 -0
  71. eval_protocol/mcp/session/manager.py +228 -0
  72. eval_protocol/mcp/simple_process_manager.py +291 -0
  73. eval_protocol/mcp/simulation_server.py +458 -0
  74. eval_protocol/mcp/types.py +80 -0
  75. eval_protocol/mcp_agent/__init__.py +1 -0
  76. eval_protocol/mcp_agent/config.py +147 -0
  77. eval_protocol/mcp_agent/intermediary_server.py +542 -0
  78. eval_protocol/mcp_agent/main.py +210 -0
  79. eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
  80. eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
  81. eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
  82. eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
  83. eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
  84. eval_protocol/mcp_agent/session.py +79 -0
  85. eval_protocol/mcp_env.py +304 -0
  86. eval_protocol/models.py +366 -0
  87. eval_protocol/packaging.py +219 -0
  88. eval_protocol/platform_api.py +360 -0
  89. eval_protocol/playback_policy.py +396 -0
  90. eval_protocol/resources.py +128 -0
  91. eval_protocol/reward_function.py +410 -0
  92. eval_protocol/rewards/__init__.py +94 -0
  93. eval_protocol/rewards/accuracy.py +454 -0
  94. eval_protocol/rewards/accuracy_length.py +173 -0
  95. eval_protocol/rewards/apps_coding_reward.py +331 -0
  96. eval_protocol/rewards/apps_execution_utils.py +149 -0
  97. eval_protocol/rewards/apps_testing_util.py +559 -0
  98. eval_protocol/rewards/bfcl_reward.py +313 -0
  99. eval_protocol/rewards/code_execution.py +1620 -0
  100. eval_protocol/rewards/code_execution_utils.py +72 -0
  101. eval_protocol/rewards/cpp_code.py +861 -0
  102. eval_protocol/rewards/deepcoder_reward.py +161 -0
  103. eval_protocol/rewards/format.py +129 -0
  104. eval_protocol/rewards/function_calling.py +541 -0
  105. eval_protocol/rewards/json_schema.py +422 -0
  106. eval_protocol/rewards/language_consistency.py +700 -0
  107. eval_protocol/rewards/lean_prover.py +479 -0
  108. eval_protocol/rewards/length.py +375 -0
  109. eval_protocol/rewards/list_comparison_math_reward.py +221 -0
  110. eval_protocol/rewards/math.py +762 -0
  111. eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
  112. eval_protocol/rewards/reasoning_steps.py +249 -0
  113. eval_protocol/rewards/repetition.py +342 -0
  114. eval_protocol/rewards/tag_count.py +162 -0
  115. eval_protocol/rl_processing.py +82 -0
  116. eval_protocol/server.py +271 -0
  117. eval_protocol/typed_interface.py +260 -0
  118. eval_protocol/utils/__init__.py +8 -0
  119. eval_protocol/utils/batch_evaluation.py +217 -0
  120. eval_protocol/utils/batch_transformation.py +205 -0
  121. eval_protocol/utils/dataset_helpers.py +112 -0
  122. eval_protocol/utils/module_loader.py +56 -0
  123. eval_protocol/utils/packaging_utils.py +108 -0
  124. eval_protocol/utils/static_policy.py +305 -0
  125. eval_protocol-0.0.3.dist-info/METADATA +635 -0
  126. eval_protocol-0.0.3.dist-info/RECORD +130 -0
  127. eval_protocol-0.0.3.dist-info/WHEEL +5 -0
  128. eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
  129. eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
  130. eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,521 @@
1
+ """
2
+ Hydra-based dataset loading and processing.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import datasets
11
+ from datasets import Dataset, DatasetDict
12
+ from hydra import compose, initialize_config_dir
13
+ from hydra.utils import instantiate
14
+ from omegaconf import DictConfig, OmegaConf
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ import importlib # Added for dynamic function import
19
+
20
+ # Placeholder for Fireworks API client if needed in the future
21
+ # from ..fireworks_client import FireworksClient # Example
22
+
23
+ # --- Preprocessing Functions ---
24
+ # These can be moved to a separate processors.py if they grow numerous.
25
+
26
+
27
+ def transform_codeparrot_apps_sample(example: Dict[str, Any]) -> Dict[str, Any]:
28
+ """
29
+ Transforms a single sample from codeparrot/apps dataset to include
30
+ a 'transformed_ground_truth' field compatible with apps_coding_reward.
31
+ """
32
+ gt_dict = {}
33
+ # fn_name can be None or missing for some APPS problems (standard input based)
34
+ if example.get("fn_name"):
35
+ gt_dict["fn_name"] = example["fn_name"]
36
+
37
+ input_output_str = example.get("input_output")
38
+ if input_output_str:
39
+ try:
40
+ parsed_io = json.loads(input_output_str)
41
+ # Ensure 'inputs' and 'outputs' keys exist in the parsed JSON
42
+ # and are lists, as expected by apps_testing_util.py
43
+ gt_dict["inputs"] = parsed_io.get("inputs", [])
44
+ gt_dict["outputs"] = parsed_io.get("outputs", [])
45
+ if not isinstance(gt_dict["inputs"], list) or not isinstance(gt_dict["outputs"], list):
46
+ logger.warning(
47
+ f"Parsed input_output for problem_id {example.get('problem_id', 'Unknown')} "
48
+ f"does not contain 'inputs'/'outputs' as lists. IO: {input_output_str}"
49
+ )
50
+ # Fallback to empty lists if types are wrong to prevent downstream errors
51
+ gt_dict["inputs"] = [] if not isinstance(gt_dict["inputs"], list) else gt_dict["inputs"]
52
+ gt_dict["outputs"] = [] if not isinstance(gt_dict["outputs"], list) else gt_dict["outputs"]
53
+
54
+ except json.JSONDecodeError:
55
+ logger.warning(
56
+ f"Failed to parse input_output JSON for problem_id {example.get('problem_id', 'Unknown')}. "
57
+ f"Content: {input_output_str}"
58
+ )
59
+ # Initialize to empty lists to prevent downstream errors if JSON is malformed
60
+ gt_dict["inputs"] = []
61
+ gt_dict["outputs"] = []
62
+ else:
63
+ # If input_output field is missing or empty, provide empty lists
64
+ gt_dict["inputs"] = []
65
+ gt_dict["outputs"] = []
66
+ logger.warning(f"Missing or empty input_output field for problem_id {example.get('problem_id', 'Unknown')}.")
67
+
68
+ example["transformed_ground_truth"] = json.dumps(gt_dict)
69
+ return example
70
+
71
+
72
+ # --- End Preprocessing Functions ---
73
+
74
+
75
+ def load_jsonl_file(file_path: str) -> List[Dict[str, Any]]:
76
+ """Loads a JSONL file into a list of dictionaries."""
77
+ data = []
78
+ if not os.path.exists(file_path):
79
+ raise FileNotFoundError(f"JSONL file not found: {file_path}")
80
+ with open(file_path, "r", encoding="utf-8") as f:
81
+ for line in f:
82
+ try:
83
+ data.append(json.loads(line))
84
+ except json.JSONDecodeError as e:
85
+ raise ValueError(f"Error decoding JSON in file {file_path}: {e} on line: {line.strip()}")
86
+ return data
87
+
88
+
89
+ def load_and_process_dataset(
90
+ source_type: str,
91
+ path_or_name: str,
92
+ split: Optional[str] = None,
93
+ config_name: Optional[str] = None,
94
+ data_files: Optional[Union[str, List[str], Dict[str, Union[str, List[str]]]]] = None,
95
+ max_samples: Optional[int] = None,
96
+ # column_mapping: Optional[Dict[str, str]] = None, # To be used for processing
97
+ # preprocessing_steps: Optional[List[str]] = None, # To be implemented
98
+ hf_extra_load_params: Optional[Dict[str, Any]] = None,
99
+ **kwargs: Any, # Catch-all for other params
100
+ ) -> Union[Dataset, DatasetDict, List[Dict[str, Any]]]:
101
+ """
102
+ Loads a dataset from the specified source.
103
+
104
+ Args:
105
+ source_type: Type of dataset source ("huggingface", "jsonl", "fireworks").
106
+ path_or_name: Path to file or Hugging Face dataset name/ID.
107
+ split: Dataset split (e.g., "train", "test"). For HF, this is passed to load_dataset.
108
+ For jsonl loaded via HF, this is also passed.
109
+ config_name: Specific configuration of a Hugging Face dataset (its 'name').
110
+ data_files: Path(s) to local data files for Hugging Face's load_dataset
111
+ (e.g., for loading local jsonl, csv into HF Dataset).
112
+ max_samples: Maximum number of samples to load.
113
+ hf_extra_load_params: Extra kwargs for Hugging Face's `datasets.load_dataset()`.
114
+ kwargs: Additional arguments.
115
+
116
+ Returns:
117
+ Loaded dataset, typically as Hugging Face Dataset or DatasetDict.
118
+ """
119
+ loaded_dataset: Union[Dataset, DatasetDict, List[Dict[str, Any]]]
120
+
121
+ # Prepare kwargs for datasets.load_dataset, separating out custom ones
122
+ load_kwargs_for_hf = hf_extra_load_params.copy() if hf_extra_load_params else {}
123
+
124
+ # Pop custom parameters from kwargs before they are merged
125
+ column_mapping_from_kwargs = kwargs.pop("column_mapping", None)
126
+ preprocessing_steps_from_kwargs = kwargs.pop("preprocessing_steps", None)
127
+ dataset_description = kwargs.pop("description", "No description provided.")
128
+
129
+ # Pop all reward-kit specific metadata fields not intended for datasets.load_dataset
130
+ eval_protocol_specific_keys = [
131
+ "dataset_name",
132
+ "pretty_name",
133
+ "final_columns",
134
+ "column_transformations",
135
+ "output_columns_creation",
136
+ "preprocess_functions",
137
+ "postprocess_functions",
138
+ "_target_",
139
+ "dataset_type",
140
+ ]
141
+
142
+ for key in eval_protocol_specific_keys:
143
+ if key in kwargs:
144
+ logger.debug(f"Filtering out reward-kit specific config key: {key}")
145
+ kwargs.pop(key, None)
146
+
147
+ logger.info(f"Dataset description: {dataset_description}")
148
+
149
+ load_kwargs_for_hf.update(kwargs) # Merge remaining kwargs (actual HF load_dataset params)
150
+
151
+ if source_type == "huggingface":
152
+ if config_name: # config_name is a standard HF param
153
+ load_kwargs_for_hf["name"] = config_name
154
+ # The 'split' argument for datasets.load_dataset can be complex.
155
+ # If data_files is a dict mapping splits to files, 'split' might not be needed here,
156
+ # as load_dataset will return a DatasetDict.
157
+ # If data_files is a single file/list, or path_or_name is a hub ID, 'split' is used.
158
+ if split and not (isinstance(data_files, dict) and split in data_files):
159
+ load_kwargs_for_hf["split"] = split
160
+
161
+ # trust_remote_code will be handled by HF_DATASETS_TRUST_REMOTE_CODE=1 env var
162
+
163
+ loaded_dataset = datasets.load_dataset(
164
+ path_or_name,
165
+ data_files=data_files,
166
+ # trust_remote_code removed, rely on env var
167
+ **load_kwargs_for_hf, # Remaining kwargs (e.g. download_mode if re-added)
168
+ )
169
+ elif source_type == "jsonl":
170
+ # Using Hugging Face's 'json' loader for consistency and features.
171
+ # trust_remote_code will be handled by HF_DATASETS_TRUST_REMOTE_CODE=1 env var
172
+ # path_or_name can be a direct path to a .jsonl file for single file loading.
173
+ # data_files can be used for more complex setups (multiple files, multiple splits).
174
+
175
+ effective_data_files = data_files
176
+ if not effective_data_files and path_or_name:
177
+ if not path_or_name.endswith(".jsonl"):
178
+ raise ValueError(
179
+ f"For source_type 'jsonl' without 'data_files', 'path_or_name' must be a .jsonl file. Got: {path_or_name}"
180
+ )
181
+ # If path_or_name is a single jsonl file, use it as data_files for the specified split or default 'train'
182
+ effective_data_files = {split if split else "train": path_or_name}
183
+
184
+ if not effective_data_files:
185
+ raise ValueError(
186
+ "For source_type 'jsonl', either 'path_or_name' to a .jsonl file or 'data_files' must be provided."
187
+ )
188
+
189
+ # The 'split' kwarg to load_dataset for local files behaves such that if data_files is a dict,
190
+ # it returns a DatasetDict, and then you select the split. If data_files is a single path/list,
191
+ # 'split' selects that split.
192
+ hf_split_param = split
193
+ if isinstance(effective_data_files, dict) and split:
194
+ hf_split_param = None
195
+
196
+ loaded_dataset = datasets.load_dataset(
197
+ "json",
198
+ data_files=effective_data_files,
199
+ split=hf_split_param,
200
+ # trust_remote_code removed, rely on env var
201
+ **load_kwargs_for_hf,
202
+ )
203
+
204
+ if split and isinstance(loaded_dataset, DatasetDict):
205
+ if split not in loaded_dataset:
206
+ raise ValueError(
207
+ f"Split '{split}' not found in loaded jsonl DatasetDict. Available splits: {list(loaded_dataset.keys())}"
208
+ )
209
+ loaded_dataset = loaded_dataset[split]
210
+ elif split and not isinstance(loaded_dataset, DatasetDict) and hf_split_param == split:
211
+ pass
212
+ elif not split and isinstance(loaded_dataset, DatasetDict):
213
+ logger.info(
214
+ f"Loaded multiple splits from JSONL: {list(loaded_dataset.keys())}. No specific split requested via 'split' arg."
215
+ )
216
+
217
+ elif source_type == "fireworks":
218
+ # Placeholder for Fireworks dataset loading.
219
+ # This would likely involve an API call to download a JSONL, then load it.
220
+ # For now, it's not implemented.
221
+ # Example:
222
+ # client = FireworksClient() # Assuming a client exists
223
+ # downloaded_file_path = client.download_dataset(path_or_name) # path_or_name is Fireworks dataset ID
224
+ # loaded_dataset = datasets.load_dataset("json", data_files=downloaded_file_path, split=split, **load_kwargs)
225
+ # os.remove(downloaded_file_path) # Clean up temp file
226
+ raise NotImplementedError(
227
+ "Fireworks dataset loading (source_type='fireworks') is not yet implemented. "
228
+ "If you have a JSONL file from Fireworks, use source_type='jsonl'."
229
+ )
230
+ else:
231
+ raise ValueError(f"Unsupported source_type: '{source_type}'. Must be 'huggingface', 'jsonl', or 'fireworks'.")
232
+
233
+ if max_samples is not None and max_samples > 0:
234
+ if isinstance(loaded_dataset, Dataset):
235
+ if len(loaded_dataset) > max_samples:
236
+ loaded_dataset = loaded_dataset.select(range(max_samples))
237
+ elif isinstance(loaded_dataset, DatasetDict):
238
+ for s_name in loaded_dataset.keys():
239
+ if len(loaded_dataset[s_name]) > max_samples:
240
+ loaded_dataset[s_name] = loaded_dataset[s_name].select(range(max_samples))
241
+ elif isinstance(loaded_dataset, list): # Should not happen if always converting to HF Dataset
242
+ if len(loaded_dataset) > max_samples:
243
+ loaded_dataset = loaded_dataset[:max_samples]
244
+
245
+ # Apply column mapping if provided
246
+ if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
247
+ logger.info(f"Applying column mapping: {column_mapping_from_kwargs}")
248
+ # Note: Column mapping should happen *after* preprocessing if preprocessors add new columns
249
+ # that are then mapped. Or, mapping happens first, and preprocessors use the new names.
250
+ # Current Hugging Face `map` function adds new columns, doesn't modify in place by default,
251
+ # so preprocessors creating 'transformed_ground_truth' is fine before mapping it.
252
+ # Let's assume mapping is done *after* preprocessing for now.
253
+ pass # Deferred until after preprocessing
254
+
255
+ # Apply preprocessing steps
256
+ if preprocessing_steps_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
257
+ logger.info(f"Applying preprocessing steps: {preprocessing_steps_from_kwargs}")
258
+ for step_path in preprocessing_steps_from_kwargs:
259
+ try:
260
+ module_path, func_name = step_path.rsplit(".", 1)
261
+ module = importlib.import_module(module_path)
262
+ preprocessor_func = getattr(module, func_name)
263
+
264
+ if isinstance(loaded_dataset, Dataset):
265
+ # Pass existing column names to avoid issues if map tries to remove them by default
266
+ # and they are needed by subsequent steps or final output.
267
+ # However, if the preprocessor is designed to remove columns, this might interfere.
268
+ # For now, assume preprocessors add/modify columns.
269
+ # `batched=False` is default for `map` but can be specified by preprocessor if needed.
270
+ loaded_dataset = loaded_dataset.map(preprocessor_func)
271
+ elif isinstance(loaded_dataset, DatasetDict):
272
+ for s_name in loaded_dataset.keys():
273
+ logger.info(f"Applying preprocessor {func_name} to split '{s_name}'")
274
+ loaded_dataset[s_name] = loaded_dataset[s_name].map(preprocessor_func)
275
+ logger.info(f"Successfully applied preprocessor: {step_path}")
276
+ except Exception as e:
277
+ logger.error(
278
+ f"Failed to apply preprocessing step {step_path}: {e}",
279
+ exc_info=True,
280
+ )
281
+ raise # Re-raise to halt execution if a preprocessor fails
282
+
283
+ # Apply column mapping (now after preprocessing)
284
+ if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
285
+ logger.info(f"Applying column mapping (post-preprocessing): {column_mapping_from_kwargs}")
286
+ if isinstance(loaded_dataset, Dataset):
287
+ # Filter out mappings where the old name is null/empty or doesn't exist
288
+ # column_mapping_from_kwargs format: {new_name: old_name}
289
+ valid_mapping = {
290
+ old_name: new_name
291
+ for new_name, old_name in column_mapping_from_kwargs.items()
292
+ if old_name and old_name in loaded_dataset.column_names
293
+ }
294
+ if valid_mapping:
295
+ # Ensure no attempt to rename to an existing column not part of this specific mapping op
296
+ # This is complex; rename_columns handles conflicts by appending '_'.
297
+ # For safety, let's check if a 'new' name is already a column and not the 'old' one.
298
+ final_mapping = {}
299
+ for old_name, new_name in valid_mapping.items():
300
+ if new_name in loaded_dataset.column_names and new_name != old_name:
301
+ logger.warning(
302
+ f"Attempting to map column '{old_name}' to '{new_name}', but '{new_name}' already exists and is not '{old_name}'. This may lead to unexpected behavior or errors. Skipping this specific rename."
303
+ )
304
+ else:
305
+ final_mapping[old_name] = new_name
306
+
307
+ if final_mapping:
308
+ loaded_dataset = loaded_dataset.rename_columns(final_mapping)
309
+ else:
310
+ logger.info("Column mapping resulted in no columns to rename after validation.")
311
+ else:
312
+ logger.warning(
313
+ "Column mapping provided but resulted in no valid columns to rename (original columns not found or new names empty)."
314
+ )
315
+
316
+ elif isinstance(loaded_dataset, DatasetDict):
317
+ for s_name in loaded_dataset.keys():
318
+ current_split_dataset = loaded_dataset[s_name]
319
+ valid_mapping = {
320
+ old_name: new_name
321
+ for new_name, old_name in column_mapping_from_kwargs.items()
322
+ if old_name and old_name in current_split_dataset.column_names
323
+ }
324
+ if valid_mapping:
325
+ final_mapping = {}
326
+ for old_name, new_name in valid_mapping.items():
327
+ if new_name in current_split_dataset.column_names and new_name != old_name:
328
+ logger.warning(
329
+ f"For split '{s_name}', attempting to map column '{old_name}' to '{new_name}', but '{new_name}' already exists and is not '{old_name}'. Skipping this specific rename for the split."
330
+ )
331
+ else:
332
+ final_mapping[old_name] = new_name
333
+
334
+ if final_mapping:
335
+ loaded_dataset[s_name] = current_split_dataset.rename_columns(final_mapping)
336
+ else:
337
+ logger.info(
338
+ f"Column mapping for split '{s_name}' resulted in no columns to rename after validation."
339
+ )
340
+ else:
341
+ logger.warning(f"Column mapping for split '{s_name}' resulted in no valid columns to rename.")
342
+
343
+ return loaded_dataset
344
+
345
+
346
+ def apply_column_mapping(dataset: Dataset, column_mapping: Dict[str, str]) -> Dataset:
347
+ """
348
+ Apply column mapping to rename dataset columns.
349
+
350
+ Args:
351
+ dataset: The dataset to rename columns in
352
+ column_mapping: Dict mapping new names to existing column names
353
+
354
+ Returns:
355
+ Dataset with renamed columns
356
+ """
357
+ # Filter out null mappings and reverse the mapping (old_name -> new_name)
358
+ rename_mapping = {}
359
+ for new_name, old_name in column_mapping.items():
360
+ if old_name is not None and old_name in dataset.column_names:
361
+ rename_mapping[old_name] = new_name
362
+
363
+ if rename_mapping:
364
+ dataset = dataset.rename_columns(rename_mapping)
365
+
366
+ return dataset
367
+
368
+
369
+ def convert_to_evaluation_format(
370
+ dataset: Dataset,
371
+ system_prompt: Optional[str] = None,
372
+ query_column: str = "query",
373
+ ground_truth_column: str = "ground_truth",
374
+ ) -> Dataset:
375
+ """
376
+ Convert dataset to evaluation format with user_query and ground_truth_for_eval.
377
+
378
+ Args:
379
+ dataset: Input dataset
380
+ system_prompt: Optional system prompt to prepend to queries
381
+ query_column: Name of the query/question column
382
+ ground_truth_column: Name of the ground truth/answer column
383
+
384
+ Returns:
385
+ Dataset in evaluation format
386
+ """
387
+
388
+ def transform_example(example):
389
+ # Keep user query separate from system prompt
390
+ user_query = example.get(query_column, "")
391
+
392
+ # Extract ground truth
393
+ ground_truth = example.get(ground_truth_column, "")
394
+
395
+ # Create evaluation format with separate system prompt
396
+ result = {"user_query": user_query, "ground_truth_for_eval": ground_truth}
397
+ if system_prompt:
398
+ result["system_prompt"] = system_prompt
399
+
400
+ # Preserve id if it exists
401
+ if "id" in example:
402
+ result["id"] = example["id"]
403
+ elif query_column in example:
404
+ # Generate a simple id from the query if no id exists
405
+ result["id"] = str(hash(example[query_column]))[1:8] # Simple hash-based id
406
+
407
+ return result
408
+
409
+ return dataset.map(transform_example)
410
+
411
+
412
+ def load_derived_dataset(
413
+ base_dataset: Union[str, DictConfig],
414
+ system_prompt: Optional[str] = None,
415
+ output_format: str = "evaluation_format",
416
+ transformations: Optional[List[str]] = None,
417
+ derived_column_mapping: Optional[Dict[str, str]] = None,
418
+ derived_max_samples: Optional[int] = None,
419
+ **kwargs: Any,
420
+ ) -> Dataset:
421
+ """
422
+ Load a derived dataset that references a base dataset and applies transformations.
423
+
424
+ Args:
425
+ base_dataset: Either a string name of a dataset config or a DictConfig
426
+ system_prompt: Optional system prompt to add to queries
427
+ output_format: Format to convert the dataset to
428
+ transformations: List of additional transformations to apply
429
+ derived_column_mapping: Column mapping for the derived dataset
430
+ derived_max_samples: Maximum samples for the derived dataset
431
+ kwargs: Additional arguments
432
+
433
+ Returns:
434
+ Transformed dataset
435
+ """
436
+ # Load base dataset
437
+ if isinstance(base_dataset, str):
438
+ # Load base dataset configuration by name
439
+ # Try to find the config in the current Hydra config search path first
440
+ try:
441
+ from hydra.core.global_hydra import GlobalHydra
442
+
443
+ # Check if Hydra is already initialized
444
+ if GlobalHydra.instance().is_initialized():
445
+ # Try to use existing Hydra context first
446
+ try:
447
+ base_cfg = compose(config_name=f"dataset/{base_dataset}")
448
+ except Exception:
449
+ # If that fails, try using the project root config directory
450
+ config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../conf"))
451
+ if os.path.exists(config_dir):
452
+ with initialize_config_dir(config_dir=config_dir, version_base="1.3"):
453
+ base_cfg = compose(config_name=f"dataset/{base_dataset}")
454
+ else:
455
+ raise FileNotFoundError(f"Config directory not found: {config_dir}")
456
+ else:
457
+ # Try to initialize with the project root config path
458
+ config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../conf"))
459
+ if os.path.exists(config_dir):
460
+ with initialize_config_dir(config_dir=config_dir, version_base="1.3"):
461
+ base_cfg = compose(config_name=f"dataset/{base_dataset}")
462
+ else:
463
+ raise FileNotFoundError(f"Config directory not found: {config_dir}")
464
+ except Exception as e:
465
+ raise ValueError(f"Failed to load base dataset config '{base_dataset}': {e}")
466
+
467
+ # The compose() returns a config with nested 'dataset' key if it's a full config
468
+ if "dataset" in base_cfg:
469
+ base_dataset_cfg = base_cfg.dataset
470
+ else:
471
+ base_dataset_cfg = base_cfg
472
+
473
+ # Instantiate the base dataset
474
+ base_loaded_dataset = instantiate(base_dataset_cfg)
475
+ elif isinstance(base_dataset, DictConfig):
476
+ # Base dataset is already a config object
477
+ base_loaded_dataset = instantiate(base_dataset)
478
+ else:
479
+ raise ValueError(f"base_dataset must be a string or DictConfig, got {type(base_dataset)}")
480
+
481
+ # Ensure we have a Dataset (not DatasetDict)
482
+ if isinstance(base_loaded_dataset, DatasetDict):
483
+ # Use the first available split or 'train' if available
484
+ if "train" in base_loaded_dataset:
485
+ dataset = base_loaded_dataset["train"]
486
+ else:
487
+ dataset = list(base_loaded_dataset.values())[0]
488
+ else:
489
+ dataset = base_loaded_dataset
490
+
491
+ # Apply derived column mapping if provided
492
+ if derived_column_mapping:
493
+ dataset = apply_column_mapping(dataset, derived_column_mapping)
494
+
495
+ # Apply max samples if specified
496
+ if derived_max_samples is not None and derived_max_samples > 0:
497
+ if len(dataset) > derived_max_samples:
498
+ dataset = dataset.select(range(derived_max_samples))
499
+
500
+ # Apply format conversion
501
+ if output_format == "evaluation_format":
502
+ dataset = convert_to_evaluation_format(
503
+ dataset,
504
+ system_prompt=system_prompt,
505
+ query_column="query",
506
+ ground_truth_column="ground_truth",
507
+ )
508
+ elif output_format == "conversation_format":
509
+ # TODO: Implement conversation format conversion if needed
510
+ raise NotImplementedError("conversation_format not yet implemented")
511
+ elif output_format == "jsonl":
512
+ # Keep as-is, already in a compatible format
513
+ pass
514
+ else:
515
+ raise ValueError(f"Unsupported output_format: {output_format}")
516
+
517
+ # TODO: Apply additional transformations if specified
518
+ if transformations:
519
+ raise NotImplementedError("Custom transformations not yet implemented")
520
+
521
+ return dataset