eval-protocol 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- development/__init__.py +1 -0
- development/normalize_sandbox_fusion.py +628 -0
- development/utils/__init__.py +1 -0
- development/utils/generate_api_key.py +31 -0
- development/utils/subprocess_manager.py +481 -0
- eval_protocol/__init__.py +86 -0
- eval_protocol/__main__.py +10 -0
- eval_protocol/_version.py +21 -0
- eval_protocol/adapters/__init__.py +1 -0
- eval_protocol/adapters/braintrust.py +8 -0
- eval_protocol/adapters/trl.py +8 -0
- eval_protocol/agent/__init__.py +29 -0
- eval_protocol/agent/models.py +69 -0
- eval_protocol/agent/orchestrator.py +893 -0
- eval_protocol/agent/resource_abc.py +89 -0
- eval_protocol/agent/resource_pool.py +184 -0
- eval_protocol/agent/resources/__init__.py +44 -0
- eval_protocol/agent/resources/bfcl_envs/__init__.py +1 -0
- eval_protocol/agent/resources/bfcl_envs/gorilla_file_system.py +342 -0
- eval_protocol/agent/resources/bfcl_envs/math_api.py +40 -0
- eval_protocol/agent/resources/bfcl_envs/posting_api.py +157 -0
- eval_protocol/agent/resources/bfcl_sim_api_resource.py +314 -0
- eval_protocol/agent/resources/docker_resource.py +479 -0
- eval_protocol/agent/resources/filesystem_resource.py +371 -0
- eval_protocol/agent/resources/http_rollout_protocol.py +85 -0
- eval_protocol/agent/resources/http_rollout_resource.py +325 -0
- eval_protocol/agent/resources/python_state_resource.py +170 -0
- eval_protocol/agent/resources/sql_resource.py +271 -0
- eval_protocol/agent/task_manager.py +1064 -0
- eval_protocol/agent/tool_registry.py +111 -0
- eval_protocol/auth.py +156 -0
- eval_protocol/cli.py +425 -0
- eval_protocol/cli_commands/__init__.py +1 -0
- eval_protocol/cli_commands/agent_eval_cmd.py +264 -0
- eval_protocol/cli_commands/common.py +242 -0
- eval_protocol/cli_commands/deploy.py +486 -0
- eval_protocol/cli_commands/deploy_mcp.py +287 -0
- eval_protocol/cli_commands/preview.py +186 -0
- eval_protocol/cli_commands/run_eval_cmd.py +202 -0
- eval_protocol/common_utils.py +36 -0
- eval_protocol/config.py +180 -0
- eval_protocol/datasets/__init__.py +1 -0
- eval_protocol/datasets/loader.py +521 -0
- eval_protocol/evaluation.py +1045 -0
- eval_protocol/execution/__init__.py +1 -0
- eval_protocol/execution/pipeline.py +920 -0
- eval_protocol/gcp_tools.py +484 -0
- eval_protocol/generation/cache.py +141 -0
- eval_protocol/generation/clients/base.py +67 -0
- eval_protocol/generation/clients.py +248 -0
- eval_protocol/generic_server.py +165 -0
- eval_protocol/integrations/__init__.py +12 -0
- eval_protocol/integrations/braintrust.py +51 -0
- eval_protocol/integrations/deepeval.py +106 -0
- eval_protocol/integrations/openeval.py +40 -0
- eval_protocol/integrations/trl.py +187 -0
- eval_protocol/mcp/__init__.py +48 -0
- eval_protocol/mcp/adapter.py +131 -0
- eval_protocol/mcp/client/__init__.py +12 -0
- eval_protocol/mcp/client/connection.py +499 -0
- eval_protocol/mcp/clients.py +195 -0
- eval_protocol/mcp/execution/__init__.py +23 -0
- eval_protocol/mcp/execution/base_policy.py +227 -0
- eval_protocol/mcp/execution/fireworks_policy.py +209 -0
- eval_protocol/mcp/execution/manager.py +506 -0
- eval_protocol/mcp/execution/policy.py +421 -0
- eval_protocol/mcp/grid_renderer.py +54 -0
- eval_protocol/mcp/mcpgym.py +637 -0
- eval_protocol/mcp/process_manager.py +177 -0
- eval_protocol/mcp/session/__init__.py +11 -0
- eval_protocol/mcp/session/manager.py +228 -0
- eval_protocol/mcp/simple_process_manager.py +291 -0
- eval_protocol/mcp/simulation_server.py +458 -0
- eval_protocol/mcp/types.py +80 -0
- eval_protocol/mcp_agent/__init__.py +1 -0
- eval_protocol/mcp_agent/config.py +147 -0
- eval_protocol/mcp_agent/intermediary_server.py +542 -0
- eval_protocol/mcp_agent/main.py +210 -0
- eval_protocol/mcp_agent/orchestration/__init__.py +1 -0
- eval_protocol/mcp_agent/orchestration/base_client.py +132 -0
- eval_protocol/mcp_agent/orchestration/local_docker_client.py +702 -0
- eval_protocol/mcp_agent/orchestration/remote_http_client.py +304 -0
- eval_protocol/mcp_agent/orchestration/stdio_mcp_client_helper.py +3 -0
- eval_protocol/mcp_agent/session.py +79 -0
- eval_protocol/mcp_env.py +304 -0
- eval_protocol/models.py +366 -0
- eval_protocol/packaging.py +219 -0
- eval_protocol/platform_api.py +360 -0
- eval_protocol/playback_policy.py +396 -0
- eval_protocol/resources.py +128 -0
- eval_protocol/reward_function.py +410 -0
- eval_protocol/rewards/__init__.py +94 -0
- eval_protocol/rewards/accuracy.py +454 -0
- eval_protocol/rewards/accuracy_length.py +173 -0
- eval_protocol/rewards/apps_coding_reward.py +331 -0
- eval_protocol/rewards/apps_execution_utils.py +149 -0
- eval_protocol/rewards/apps_testing_util.py +559 -0
- eval_protocol/rewards/bfcl_reward.py +313 -0
- eval_protocol/rewards/code_execution.py +1620 -0
- eval_protocol/rewards/code_execution_utils.py +72 -0
- eval_protocol/rewards/cpp_code.py +861 -0
- eval_protocol/rewards/deepcoder_reward.py +161 -0
- eval_protocol/rewards/format.py +129 -0
- eval_protocol/rewards/function_calling.py +541 -0
- eval_protocol/rewards/json_schema.py +422 -0
- eval_protocol/rewards/language_consistency.py +700 -0
- eval_protocol/rewards/lean_prover.py +479 -0
- eval_protocol/rewards/length.py +375 -0
- eval_protocol/rewards/list_comparison_math_reward.py +221 -0
- eval_protocol/rewards/math.py +762 -0
- eval_protocol/rewards/multiple_choice_math_reward.py +232 -0
- eval_protocol/rewards/reasoning_steps.py +249 -0
- eval_protocol/rewards/repetition.py +342 -0
- eval_protocol/rewards/tag_count.py +162 -0
- eval_protocol/rl_processing.py +82 -0
- eval_protocol/server.py +271 -0
- eval_protocol/typed_interface.py +260 -0
- eval_protocol/utils/__init__.py +8 -0
- eval_protocol/utils/batch_evaluation.py +217 -0
- eval_protocol/utils/batch_transformation.py +205 -0
- eval_protocol/utils/dataset_helpers.py +112 -0
- eval_protocol/utils/module_loader.py +56 -0
- eval_protocol/utils/packaging_utils.py +108 -0
- eval_protocol/utils/static_policy.py +305 -0
- eval_protocol-0.0.3.dist-info/METADATA +635 -0
- eval_protocol-0.0.3.dist-info/RECORD +130 -0
- eval_protocol-0.0.3.dist-info/WHEEL +5 -0
- eval_protocol-0.0.3.dist-info/entry_points.txt +4 -0
- eval_protocol-0.0.3.dist-info/licenses/LICENSE +201 -0
- eval_protocol-0.0.3.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hydra-based dataset loading and processing.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
import datasets
|
|
11
|
+
from datasets import Dataset, DatasetDict
|
|
12
|
+
from hydra import compose, initialize_config_dir
|
|
13
|
+
from hydra.utils import instantiate
|
|
14
|
+
from omegaconf import DictConfig, OmegaConf
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
import importlib # Added for dynamic function import
|
|
19
|
+
|
|
20
|
+
# Placeholder for Fireworks API client if needed in the future
|
|
21
|
+
# from ..fireworks_client import FireworksClient # Example
|
|
22
|
+
|
|
23
|
+
# --- Preprocessing Functions ---
|
|
24
|
+
# These can be moved to a separate processors.py if they grow numerous.
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def transform_codeparrot_apps_sample(example: Dict[str, Any]) -> Dict[str, Any]:
|
|
28
|
+
"""
|
|
29
|
+
Transforms a single sample from codeparrot/apps dataset to include
|
|
30
|
+
a 'transformed_ground_truth' field compatible with apps_coding_reward.
|
|
31
|
+
"""
|
|
32
|
+
gt_dict = {}
|
|
33
|
+
# fn_name can be None or missing for some APPS problems (standard input based)
|
|
34
|
+
if example.get("fn_name"):
|
|
35
|
+
gt_dict["fn_name"] = example["fn_name"]
|
|
36
|
+
|
|
37
|
+
input_output_str = example.get("input_output")
|
|
38
|
+
if input_output_str:
|
|
39
|
+
try:
|
|
40
|
+
parsed_io = json.loads(input_output_str)
|
|
41
|
+
# Ensure 'inputs' and 'outputs' keys exist in the parsed JSON
|
|
42
|
+
# and are lists, as expected by apps_testing_util.py
|
|
43
|
+
gt_dict["inputs"] = parsed_io.get("inputs", [])
|
|
44
|
+
gt_dict["outputs"] = parsed_io.get("outputs", [])
|
|
45
|
+
if not isinstance(gt_dict["inputs"], list) or not isinstance(gt_dict["outputs"], list):
|
|
46
|
+
logger.warning(
|
|
47
|
+
f"Parsed input_output for problem_id {example.get('problem_id', 'Unknown')} "
|
|
48
|
+
f"does not contain 'inputs'/'outputs' as lists. IO: {input_output_str}"
|
|
49
|
+
)
|
|
50
|
+
# Fallback to empty lists if types are wrong to prevent downstream errors
|
|
51
|
+
gt_dict["inputs"] = [] if not isinstance(gt_dict["inputs"], list) else gt_dict["inputs"]
|
|
52
|
+
gt_dict["outputs"] = [] if not isinstance(gt_dict["outputs"], list) else gt_dict["outputs"]
|
|
53
|
+
|
|
54
|
+
except json.JSONDecodeError:
|
|
55
|
+
logger.warning(
|
|
56
|
+
f"Failed to parse input_output JSON for problem_id {example.get('problem_id', 'Unknown')}. "
|
|
57
|
+
f"Content: {input_output_str}"
|
|
58
|
+
)
|
|
59
|
+
# Initialize to empty lists to prevent downstream errors if JSON is malformed
|
|
60
|
+
gt_dict["inputs"] = []
|
|
61
|
+
gt_dict["outputs"] = []
|
|
62
|
+
else:
|
|
63
|
+
# If input_output field is missing or empty, provide empty lists
|
|
64
|
+
gt_dict["inputs"] = []
|
|
65
|
+
gt_dict["outputs"] = []
|
|
66
|
+
logger.warning(f"Missing or empty input_output field for problem_id {example.get('problem_id', 'Unknown')}.")
|
|
67
|
+
|
|
68
|
+
example["transformed_ground_truth"] = json.dumps(gt_dict)
|
|
69
|
+
return example
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# --- End Preprocessing Functions ---
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def load_jsonl_file(file_path: str) -> List[Dict[str, Any]]:
|
|
76
|
+
"""Loads a JSONL file into a list of dictionaries."""
|
|
77
|
+
data = []
|
|
78
|
+
if not os.path.exists(file_path):
|
|
79
|
+
raise FileNotFoundError(f"JSONL file not found: {file_path}")
|
|
80
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
81
|
+
for line in f:
|
|
82
|
+
try:
|
|
83
|
+
data.append(json.loads(line))
|
|
84
|
+
except json.JSONDecodeError as e:
|
|
85
|
+
raise ValueError(f"Error decoding JSON in file {file_path}: {e} on line: {line.strip()}")
|
|
86
|
+
return data
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_and_process_dataset(
|
|
90
|
+
source_type: str,
|
|
91
|
+
path_or_name: str,
|
|
92
|
+
split: Optional[str] = None,
|
|
93
|
+
config_name: Optional[str] = None,
|
|
94
|
+
data_files: Optional[Union[str, List[str], Dict[str, Union[str, List[str]]]]] = None,
|
|
95
|
+
max_samples: Optional[int] = None,
|
|
96
|
+
# column_mapping: Optional[Dict[str, str]] = None, # To be used for processing
|
|
97
|
+
# preprocessing_steps: Optional[List[str]] = None, # To be implemented
|
|
98
|
+
hf_extra_load_params: Optional[Dict[str, Any]] = None,
|
|
99
|
+
**kwargs: Any, # Catch-all for other params
|
|
100
|
+
) -> Union[Dataset, DatasetDict, List[Dict[str, Any]]]:
|
|
101
|
+
"""
|
|
102
|
+
Loads a dataset from the specified source.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
source_type: Type of dataset source ("huggingface", "jsonl", "fireworks").
|
|
106
|
+
path_or_name: Path to file or Hugging Face dataset name/ID.
|
|
107
|
+
split: Dataset split (e.g., "train", "test"). For HF, this is passed to load_dataset.
|
|
108
|
+
For jsonl loaded via HF, this is also passed.
|
|
109
|
+
config_name: Specific configuration of a Hugging Face dataset (its 'name').
|
|
110
|
+
data_files: Path(s) to local data files for Hugging Face's load_dataset
|
|
111
|
+
(e.g., for loading local jsonl, csv into HF Dataset).
|
|
112
|
+
max_samples: Maximum number of samples to load.
|
|
113
|
+
hf_extra_load_params: Extra kwargs for Hugging Face's `datasets.load_dataset()`.
|
|
114
|
+
kwargs: Additional arguments.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Loaded dataset, typically as Hugging Face Dataset or DatasetDict.
|
|
118
|
+
"""
|
|
119
|
+
loaded_dataset: Union[Dataset, DatasetDict, List[Dict[str, Any]]]
|
|
120
|
+
|
|
121
|
+
# Prepare kwargs for datasets.load_dataset, separating out custom ones
|
|
122
|
+
load_kwargs_for_hf = hf_extra_load_params.copy() if hf_extra_load_params else {}
|
|
123
|
+
|
|
124
|
+
# Pop custom parameters from kwargs before they are merged
|
|
125
|
+
column_mapping_from_kwargs = kwargs.pop("column_mapping", None)
|
|
126
|
+
preprocessing_steps_from_kwargs = kwargs.pop("preprocessing_steps", None)
|
|
127
|
+
dataset_description = kwargs.pop("description", "No description provided.")
|
|
128
|
+
|
|
129
|
+
# Pop all reward-kit specific metadata fields not intended for datasets.load_dataset
|
|
130
|
+
eval_protocol_specific_keys = [
|
|
131
|
+
"dataset_name",
|
|
132
|
+
"pretty_name",
|
|
133
|
+
"final_columns",
|
|
134
|
+
"column_transformations",
|
|
135
|
+
"output_columns_creation",
|
|
136
|
+
"preprocess_functions",
|
|
137
|
+
"postprocess_functions",
|
|
138
|
+
"_target_",
|
|
139
|
+
"dataset_type",
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
for key in eval_protocol_specific_keys:
|
|
143
|
+
if key in kwargs:
|
|
144
|
+
logger.debug(f"Filtering out reward-kit specific config key: {key}")
|
|
145
|
+
kwargs.pop(key, None)
|
|
146
|
+
|
|
147
|
+
logger.info(f"Dataset description: {dataset_description}")
|
|
148
|
+
|
|
149
|
+
load_kwargs_for_hf.update(kwargs) # Merge remaining kwargs (actual HF load_dataset params)
|
|
150
|
+
|
|
151
|
+
if source_type == "huggingface":
|
|
152
|
+
if config_name: # config_name is a standard HF param
|
|
153
|
+
load_kwargs_for_hf["name"] = config_name
|
|
154
|
+
# The 'split' argument for datasets.load_dataset can be complex.
|
|
155
|
+
# If data_files is a dict mapping splits to files, 'split' might not be needed here,
|
|
156
|
+
# as load_dataset will return a DatasetDict.
|
|
157
|
+
# If data_files is a single file/list, or path_or_name is a hub ID, 'split' is used.
|
|
158
|
+
if split and not (isinstance(data_files, dict) and split in data_files):
|
|
159
|
+
load_kwargs_for_hf["split"] = split
|
|
160
|
+
|
|
161
|
+
# trust_remote_code will be handled by HF_DATASETS_TRUST_REMOTE_CODE=1 env var
|
|
162
|
+
|
|
163
|
+
loaded_dataset = datasets.load_dataset(
|
|
164
|
+
path_or_name,
|
|
165
|
+
data_files=data_files,
|
|
166
|
+
# trust_remote_code removed, rely on env var
|
|
167
|
+
**load_kwargs_for_hf, # Remaining kwargs (e.g. download_mode if re-added)
|
|
168
|
+
)
|
|
169
|
+
elif source_type == "jsonl":
|
|
170
|
+
# Using Hugging Face's 'json' loader for consistency and features.
|
|
171
|
+
# trust_remote_code will be handled by HF_DATASETS_TRUST_REMOTE_CODE=1 env var
|
|
172
|
+
# path_or_name can be a direct path to a .jsonl file for single file loading.
|
|
173
|
+
# data_files can be used for more complex setups (multiple files, multiple splits).
|
|
174
|
+
|
|
175
|
+
effective_data_files = data_files
|
|
176
|
+
if not effective_data_files and path_or_name:
|
|
177
|
+
if not path_or_name.endswith(".jsonl"):
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"For source_type 'jsonl' without 'data_files', 'path_or_name' must be a .jsonl file. Got: {path_or_name}"
|
|
180
|
+
)
|
|
181
|
+
# If path_or_name is a single jsonl file, use it as data_files for the specified split or default 'train'
|
|
182
|
+
effective_data_files = {split if split else "train": path_or_name}
|
|
183
|
+
|
|
184
|
+
if not effective_data_files:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"For source_type 'jsonl', either 'path_or_name' to a .jsonl file or 'data_files' must be provided."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# The 'split' kwarg to load_dataset for local files behaves such that if data_files is a dict,
|
|
190
|
+
# it returns a DatasetDict, and then you select the split. If data_files is a single path/list,
|
|
191
|
+
# 'split' selects that split.
|
|
192
|
+
hf_split_param = split
|
|
193
|
+
if isinstance(effective_data_files, dict) and split:
|
|
194
|
+
hf_split_param = None
|
|
195
|
+
|
|
196
|
+
loaded_dataset = datasets.load_dataset(
|
|
197
|
+
"json",
|
|
198
|
+
data_files=effective_data_files,
|
|
199
|
+
split=hf_split_param,
|
|
200
|
+
# trust_remote_code removed, rely on env var
|
|
201
|
+
**load_kwargs_for_hf,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if split and isinstance(loaded_dataset, DatasetDict):
|
|
205
|
+
if split not in loaded_dataset:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Split '{split}' not found in loaded jsonl DatasetDict. Available splits: {list(loaded_dataset.keys())}"
|
|
208
|
+
)
|
|
209
|
+
loaded_dataset = loaded_dataset[split]
|
|
210
|
+
elif split and not isinstance(loaded_dataset, DatasetDict) and hf_split_param == split:
|
|
211
|
+
pass
|
|
212
|
+
elif not split and isinstance(loaded_dataset, DatasetDict):
|
|
213
|
+
logger.info(
|
|
214
|
+
f"Loaded multiple splits from JSONL: {list(loaded_dataset.keys())}. No specific split requested via 'split' arg."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
elif source_type == "fireworks":
|
|
218
|
+
# Placeholder for Fireworks dataset loading.
|
|
219
|
+
# This would likely involve an API call to download a JSONL, then load it.
|
|
220
|
+
# For now, it's not implemented.
|
|
221
|
+
# Example:
|
|
222
|
+
# client = FireworksClient() # Assuming a client exists
|
|
223
|
+
# downloaded_file_path = client.download_dataset(path_or_name) # path_or_name is Fireworks dataset ID
|
|
224
|
+
# loaded_dataset = datasets.load_dataset("json", data_files=downloaded_file_path, split=split, **load_kwargs)
|
|
225
|
+
# os.remove(downloaded_file_path) # Clean up temp file
|
|
226
|
+
raise NotImplementedError(
|
|
227
|
+
"Fireworks dataset loading (source_type='fireworks') is not yet implemented. "
|
|
228
|
+
"If you have a JSONL file from Fireworks, use source_type='jsonl'."
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(f"Unsupported source_type: '{source_type}'. Must be 'huggingface', 'jsonl', or 'fireworks'.")
|
|
232
|
+
|
|
233
|
+
if max_samples is not None and max_samples > 0:
|
|
234
|
+
if isinstance(loaded_dataset, Dataset):
|
|
235
|
+
if len(loaded_dataset) > max_samples:
|
|
236
|
+
loaded_dataset = loaded_dataset.select(range(max_samples))
|
|
237
|
+
elif isinstance(loaded_dataset, DatasetDict):
|
|
238
|
+
for s_name in loaded_dataset.keys():
|
|
239
|
+
if len(loaded_dataset[s_name]) > max_samples:
|
|
240
|
+
loaded_dataset[s_name] = loaded_dataset[s_name].select(range(max_samples))
|
|
241
|
+
elif isinstance(loaded_dataset, list): # Should not happen if always converting to HF Dataset
|
|
242
|
+
if len(loaded_dataset) > max_samples:
|
|
243
|
+
loaded_dataset = loaded_dataset[:max_samples]
|
|
244
|
+
|
|
245
|
+
# Apply column mapping if provided
|
|
246
|
+
if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
|
|
247
|
+
logger.info(f"Applying column mapping: {column_mapping_from_kwargs}")
|
|
248
|
+
# Note: Column mapping should happen *after* preprocessing if preprocessors add new columns
|
|
249
|
+
# that are then mapped. Or, mapping happens first, and preprocessors use the new names.
|
|
250
|
+
# Current Hugging Face `map` function adds new columns, doesn't modify in place by default,
|
|
251
|
+
# so preprocessors creating 'transformed_ground_truth' is fine before mapping it.
|
|
252
|
+
# Let's assume mapping is done *after* preprocessing for now.
|
|
253
|
+
pass # Deferred until after preprocessing
|
|
254
|
+
|
|
255
|
+
# Apply preprocessing steps
|
|
256
|
+
if preprocessing_steps_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
|
|
257
|
+
logger.info(f"Applying preprocessing steps: {preprocessing_steps_from_kwargs}")
|
|
258
|
+
for step_path in preprocessing_steps_from_kwargs:
|
|
259
|
+
try:
|
|
260
|
+
module_path, func_name = step_path.rsplit(".", 1)
|
|
261
|
+
module = importlib.import_module(module_path)
|
|
262
|
+
preprocessor_func = getattr(module, func_name)
|
|
263
|
+
|
|
264
|
+
if isinstance(loaded_dataset, Dataset):
|
|
265
|
+
# Pass existing column names to avoid issues if map tries to remove them by default
|
|
266
|
+
# and they are needed by subsequent steps or final output.
|
|
267
|
+
# However, if the preprocessor is designed to remove columns, this might interfere.
|
|
268
|
+
# For now, assume preprocessors add/modify columns.
|
|
269
|
+
# `batched=False` is default for `map` but can be specified by preprocessor if needed.
|
|
270
|
+
loaded_dataset = loaded_dataset.map(preprocessor_func)
|
|
271
|
+
elif isinstance(loaded_dataset, DatasetDict):
|
|
272
|
+
for s_name in loaded_dataset.keys():
|
|
273
|
+
logger.info(f"Applying preprocessor {func_name} to split '{s_name}'")
|
|
274
|
+
loaded_dataset[s_name] = loaded_dataset[s_name].map(preprocessor_func)
|
|
275
|
+
logger.info(f"Successfully applied preprocessor: {step_path}")
|
|
276
|
+
except Exception as e:
|
|
277
|
+
logger.error(
|
|
278
|
+
f"Failed to apply preprocessing step {step_path}: {e}",
|
|
279
|
+
exc_info=True,
|
|
280
|
+
)
|
|
281
|
+
raise # Re-raise to halt execution if a preprocessor fails
|
|
282
|
+
|
|
283
|
+
# Apply column mapping (now after preprocessing)
|
|
284
|
+
if column_mapping_from_kwargs and isinstance(loaded_dataset, (Dataset, DatasetDict)):
|
|
285
|
+
logger.info(f"Applying column mapping (post-preprocessing): {column_mapping_from_kwargs}")
|
|
286
|
+
if isinstance(loaded_dataset, Dataset):
|
|
287
|
+
# Filter out mappings where the old name is null/empty or doesn't exist
|
|
288
|
+
# column_mapping_from_kwargs format: {new_name: old_name}
|
|
289
|
+
valid_mapping = {
|
|
290
|
+
old_name: new_name
|
|
291
|
+
for new_name, old_name in column_mapping_from_kwargs.items()
|
|
292
|
+
if old_name and old_name in loaded_dataset.column_names
|
|
293
|
+
}
|
|
294
|
+
if valid_mapping:
|
|
295
|
+
# Ensure no attempt to rename to an existing column not part of this specific mapping op
|
|
296
|
+
# This is complex; rename_columns handles conflicts by appending '_'.
|
|
297
|
+
# For safety, let's check if a 'new' name is already a column and not the 'old' one.
|
|
298
|
+
final_mapping = {}
|
|
299
|
+
for old_name, new_name in valid_mapping.items():
|
|
300
|
+
if new_name in loaded_dataset.column_names and new_name != old_name:
|
|
301
|
+
logger.warning(
|
|
302
|
+
f"Attempting to map column '{old_name}' to '{new_name}', but '{new_name}' already exists and is not '{old_name}'. This may lead to unexpected behavior or errors. Skipping this specific rename."
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
final_mapping[old_name] = new_name
|
|
306
|
+
|
|
307
|
+
if final_mapping:
|
|
308
|
+
loaded_dataset = loaded_dataset.rename_columns(final_mapping)
|
|
309
|
+
else:
|
|
310
|
+
logger.info("Column mapping resulted in no columns to rename after validation.")
|
|
311
|
+
else:
|
|
312
|
+
logger.warning(
|
|
313
|
+
"Column mapping provided but resulted in no valid columns to rename (original columns not found or new names empty)."
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
elif isinstance(loaded_dataset, DatasetDict):
|
|
317
|
+
for s_name in loaded_dataset.keys():
|
|
318
|
+
current_split_dataset = loaded_dataset[s_name]
|
|
319
|
+
valid_mapping = {
|
|
320
|
+
old_name: new_name
|
|
321
|
+
for new_name, old_name in column_mapping_from_kwargs.items()
|
|
322
|
+
if old_name and old_name in current_split_dataset.column_names
|
|
323
|
+
}
|
|
324
|
+
if valid_mapping:
|
|
325
|
+
final_mapping = {}
|
|
326
|
+
for old_name, new_name in valid_mapping.items():
|
|
327
|
+
if new_name in current_split_dataset.column_names and new_name != old_name:
|
|
328
|
+
logger.warning(
|
|
329
|
+
f"For split '{s_name}', attempting to map column '{old_name}' to '{new_name}', but '{new_name}' already exists and is not '{old_name}'. Skipping this specific rename for the split."
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
final_mapping[old_name] = new_name
|
|
333
|
+
|
|
334
|
+
if final_mapping:
|
|
335
|
+
loaded_dataset[s_name] = current_split_dataset.rename_columns(final_mapping)
|
|
336
|
+
else:
|
|
337
|
+
logger.info(
|
|
338
|
+
f"Column mapping for split '{s_name}' resulted in no columns to rename after validation."
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(f"Column mapping for split '{s_name}' resulted in no valid columns to rename.")
|
|
342
|
+
|
|
343
|
+
return loaded_dataset
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def apply_column_mapping(dataset: Dataset, column_mapping: Dict[str, str]) -> Dataset:
|
|
347
|
+
"""
|
|
348
|
+
Apply column mapping to rename dataset columns.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
dataset: The dataset to rename columns in
|
|
352
|
+
column_mapping: Dict mapping new names to existing column names
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
Dataset with renamed columns
|
|
356
|
+
"""
|
|
357
|
+
# Filter out null mappings and reverse the mapping (old_name -> new_name)
|
|
358
|
+
rename_mapping = {}
|
|
359
|
+
for new_name, old_name in column_mapping.items():
|
|
360
|
+
if old_name is not None and old_name in dataset.column_names:
|
|
361
|
+
rename_mapping[old_name] = new_name
|
|
362
|
+
|
|
363
|
+
if rename_mapping:
|
|
364
|
+
dataset = dataset.rename_columns(rename_mapping)
|
|
365
|
+
|
|
366
|
+
return dataset
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def convert_to_evaluation_format(
|
|
370
|
+
dataset: Dataset,
|
|
371
|
+
system_prompt: Optional[str] = None,
|
|
372
|
+
query_column: str = "query",
|
|
373
|
+
ground_truth_column: str = "ground_truth",
|
|
374
|
+
) -> Dataset:
|
|
375
|
+
"""
|
|
376
|
+
Convert dataset to evaluation format with user_query and ground_truth_for_eval.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
dataset: Input dataset
|
|
380
|
+
system_prompt: Optional system prompt to prepend to queries
|
|
381
|
+
query_column: Name of the query/question column
|
|
382
|
+
ground_truth_column: Name of the ground truth/answer column
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Dataset in evaluation format
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
def transform_example(example):
|
|
389
|
+
# Keep user query separate from system prompt
|
|
390
|
+
user_query = example.get(query_column, "")
|
|
391
|
+
|
|
392
|
+
# Extract ground truth
|
|
393
|
+
ground_truth = example.get(ground_truth_column, "")
|
|
394
|
+
|
|
395
|
+
# Create evaluation format with separate system prompt
|
|
396
|
+
result = {"user_query": user_query, "ground_truth_for_eval": ground_truth}
|
|
397
|
+
if system_prompt:
|
|
398
|
+
result["system_prompt"] = system_prompt
|
|
399
|
+
|
|
400
|
+
# Preserve id if it exists
|
|
401
|
+
if "id" in example:
|
|
402
|
+
result["id"] = example["id"]
|
|
403
|
+
elif query_column in example:
|
|
404
|
+
# Generate a simple id from the query if no id exists
|
|
405
|
+
result["id"] = str(hash(example[query_column]))[1:8] # Simple hash-based id
|
|
406
|
+
|
|
407
|
+
return result
|
|
408
|
+
|
|
409
|
+
return dataset.map(transform_example)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def load_derived_dataset(
|
|
413
|
+
base_dataset: Union[str, DictConfig],
|
|
414
|
+
system_prompt: Optional[str] = None,
|
|
415
|
+
output_format: str = "evaluation_format",
|
|
416
|
+
transformations: Optional[List[str]] = None,
|
|
417
|
+
derived_column_mapping: Optional[Dict[str, str]] = None,
|
|
418
|
+
derived_max_samples: Optional[int] = None,
|
|
419
|
+
**kwargs: Any,
|
|
420
|
+
) -> Dataset:
|
|
421
|
+
"""
|
|
422
|
+
Load a derived dataset that references a base dataset and applies transformations.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
base_dataset: Either a string name of a dataset config or a DictConfig
|
|
426
|
+
system_prompt: Optional system prompt to add to queries
|
|
427
|
+
output_format: Format to convert the dataset to
|
|
428
|
+
transformations: List of additional transformations to apply
|
|
429
|
+
derived_column_mapping: Column mapping for the derived dataset
|
|
430
|
+
derived_max_samples: Maximum samples for the derived dataset
|
|
431
|
+
kwargs: Additional arguments
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
Transformed dataset
|
|
435
|
+
"""
|
|
436
|
+
# Load base dataset
|
|
437
|
+
if isinstance(base_dataset, str):
|
|
438
|
+
# Load base dataset configuration by name
|
|
439
|
+
# Try to find the config in the current Hydra config search path first
|
|
440
|
+
try:
|
|
441
|
+
from hydra.core.global_hydra import GlobalHydra
|
|
442
|
+
|
|
443
|
+
# Check if Hydra is already initialized
|
|
444
|
+
if GlobalHydra.instance().is_initialized():
|
|
445
|
+
# Try to use existing Hydra context first
|
|
446
|
+
try:
|
|
447
|
+
base_cfg = compose(config_name=f"dataset/{base_dataset}")
|
|
448
|
+
except Exception:
|
|
449
|
+
# If that fails, try using the project root config directory
|
|
450
|
+
config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../conf"))
|
|
451
|
+
if os.path.exists(config_dir):
|
|
452
|
+
with initialize_config_dir(config_dir=config_dir, version_base="1.3"):
|
|
453
|
+
base_cfg = compose(config_name=f"dataset/{base_dataset}")
|
|
454
|
+
else:
|
|
455
|
+
raise FileNotFoundError(f"Config directory not found: {config_dir}")
|
|
456
|
+
else:
|
|
457
|
+
# Try to initialize with the project root config path
|
|
458
|
+
config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../conf"))
|
|
459
|
+
if os.path.exists(config_dir):
|
|
460
|
+
with initialize_config_dir(config_dir=config_dir, version_base="1.3"):
|
|
461
|
+
base_cfg = compose(config_name=f"dataset/{base_dataset}")
|
|
462
|
+
else:
|
|
463
|
+
raise FileNotFoundError(f"Config directory not found: {config_dir}")
|
|
464
|
+
except Exception as e:
|
|
465
|
+
raise ValueError(f"Failed to load base dataset config '{base_dataset}': {e}")
|
|
466
|
+
|
|
467
|
+
# The compose() returns a config with nested 'dataset' key if it's a full config
|
|
468
|
+
if "dataset" in base_cfg:
|
|
469
|
+
base_dataset_cfg = base_cfg.dataset
|
|
470
|
+
else:
|
|
471
|
+
base_dataset_cfg = base_cfg
|
|
472
|
+
|
|
473
|
+
# Instantiate the base dataset
|
|
474
|
+
base_loaded_dataset = instantiate(base_dataset_cfg)
|
|
475
|
+
elif isinstance(base_dataset, DictConfig):
|
|
476
|
+
# Base dataset is already a config object
|
|
477
|
+
base_loaded_dataset = instantiate(base_dataset)
|
|
478
|
+
else:
|
|
479
|
+
raise ValueError(f"base_dataset must be a string or DictConfig, got {type(base_dataset)}")
|
|
480
|
+
|
|
481
|
+
# Ensure we have a Dataset (not DatasetDict)
|
|
482
|
+
if isinstance(base_loaded_dataset, DatasetDict):
|
|
483
|
+
# Use the first available split or 'train' if available
|
|
484
|
+
if "train" in base_loaded_dataset:
|
|
485
|
+
dataset = base_loaded_dataset["train"]
|
|
486
|
+
else:
|
|
487
|
+
dataset = list(base_loaded_dataset.values())[0]
|
|
488
|
+
else:
|
|
489
|
+
dataset = base_loaded_dataset
|
|
490
|
+
|
|
491
|
+
# Apply derived column mapping if provided
|
|
492
|
+
if derived_column_mapping:
|
|
493
|
+
dataset = apply_column_mapping(dataset, derived_column_mapping)
|
|
494
|
+
|
|
495
|
+
# Apply max samples if specified
|
|
496
|
+
if derived_max_samples is not None and derived_max_samples > 0:
|
|
497
|
+
if len(dataset) > derived_max_samples:
|
|
498
|
+
dataset = dataset.select(range(derived_max_samples))
|
|
499
|
+
|
|
500
|
+
# Apply format conversion
|
|
501
|
+
if output_format == "evaluation_format":
|
|
502
|
+
dataset = convert_to_evaluation_format(
|
|
503
|
+
dataset,
|
|
504
|
+
system_prompt=system_prompt,
|
|
505
|
+
query_column="query",
|
|
506
|
+
ground_truth_column="ground_truth",
|
|
507
|
+
)
|
|
508
|
+
elif output_format == "conversation_format":
|
|
509
|
+
# TODO: Implement conversation format conversion if needed
|
|
510
|
+
raise NotImplementedError("conversation_format not yet implemented")
|
|
511
|
+
elif output_format == "jsonl":
|
|
512
|
+
# Keep as-is, already in a compatible format
|
|
513
|
+
pass
|
|
514
|
+
else:
|
|
515
|
+
raise ValueError(f"Unsupported output_format: {output_format}")
|
|
516
|
+
|
|
517
|
+
# TODO: Apply additional transformations if specified
|
|
518
|
+
if transformations:
|
|
519
|
+
raise NotImplementedError("Custom transformations not yet implemented")
|
|
520
|
+
|
|
521
|
+
return dataset
|