camel-ai 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +151 -4
- camel/datasets/__init__.py +2 -6
- camel/datasets/base.py +269 -801
- camel/environments/base.py +3 -5
- {camel_ai-0.2.28.dist-info → camel_ai-0.2.29.dist-info}/METADATA +2 -2
- {camel_ai-0.2.28.dist-info → camel_ai-0.2.29.dist-info}/RECORD +9 -9
- {camel_ai-0.2.28.dist-info → camel_ai-0.2.29.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.28.dist-info → camel_ai-0.2.29.dist-info}/licenses/LICENSE +0 -0
camel/datasets/base.py
CHANGED
|
@@ -13,28 +13,26 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
|
-
import os
|
|
17
16
|
import random
|
|
17
|
+
from datetime import datetime
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from typing import (
|
|
20
20
|
Any,
|
|
21
|
-
Callable,
|
|
22
21
|
Dict,
|
|
23
22
|
List,
|
|
24
23
|
Optional,
|
|
25
24
|
Sized,
|
|
26
|
-
TypeVar,
|
|
27
25
|
Union,
|
|
28
26
|
)
|
|
29
27
|
|
|
30
|
-
import torch
|
|
31
28
|
from datasets import Dataset as HFDataset
|
|
32
29
|
from pydantic import BaseModel, Field, ValidationError
|
|
33
|
-
from torch.utils.data import
|
|
30
|
+
from torch.utils.data import Dataset
|
|
34
31
|
|
|
35
32
|
from camel.agents import ChatAgent
|
|
36
33
|
from camel.logger import get_logger
|
|
37
34
|
from camel.verifiers import BaseVerifier
|
|
35
|
+
from camel.verifiers.models import VerifierInput
|
|
38
36
|
|
|
39
37
|
logger = get_logger(__name__)
|
|
40
38
|
|
|
@@ -44,13 +42,9 @@ class DataPoint(BaseModel):
|
|
|
44
42
|
|
|
45
43
|
Attributes:
|
|
46
44
|
question (str): The primary question or issue to be addressed.
|
|
47
|
-
rationale (str): Logical reasoning or explanation behind the
|
|
48
|
-
answer.
|
|
45
|
+
rationale (Optional[str]): Logical reasoning or explanation behind the
|
|
46
|
+
answer. (default: :obj:`None`)
|
|
49
47
|
final_answer (str): The final answer.
|
|
50
|
-
raw_markdown (Optional[str]): Raw markdown content for generating
|
|
51
|
-
rewards/hints. (default: :obj:`None`)
|
|
52
|
-
difficulty (Optional[str]): Difficulty level of the question.
|
|
53
|
-
(default: :obj:`None`)
|
|
54
48
|
metadata Optional[Dict[str, Any]]: Additional metadata about the data
|
|
55
49
|
point. (default: :obj:`None`)
|
|
56
50
|
"""
|
|
@@ -58,13 +52,12 @@ class DataPoint(BaseModel):
|
|
|
58
52
|
question: str = Field(
|
|
59
53
|
..., description="The primary question or issue to be addressed."
|
|
60
54
|
)
|
|
61
|
-
rationale: str = Field(
|
|
62
|
-
|
|
55
|
+
rationale: Optional[str] = Field(
|
|
56
|
+
default=None,
|
|
57
|
+
description="Logical reasoning or explanation behind the answer.",
|
|
63
58
|
)
|
|
64
59
|
final_answer: str = Field(..., description="The final answer.")
|
|
65
|
-
|
|
66
|
-
None, description="Difficulty level of the question."
|
|
67
|
-
)
|
|
60
|
+
|
|
68
61
|
metadata: Optional[Dict[str, Any]] = Field(
|
|
69
62
|
default=None, description="Additional metadata about the data point."
|
|
70
63
|
)
|
|
@@ -90,244 +83,11 @@ class DataPoint(BaseModel):
|
|
|
90
83
|
return cls(**data)
|
|
91
84
|
|
|
92
85
|
|
|
93
|
-
class
|
|
94
|
-
r"""A dataset
|
|
95
|
-
It can be either static (e.g., MATH dataset) or generative
|
|
96
|
-
(using an LLM to generate questions).
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
def __init__(
|
|
100
|
-
self,
|
|
101
|
-
data: List[Dict[str, str]],
|
|
102
|
-
cache_dir: Optional[str] = None,
|
|
103
|
-
**kwargs,
|
|
104
|
-
):
|
|
105
|
-
r"""Initialize the dataset.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
data (List[Dict[str, str]]): List of dictionary items to
|
|
109
|
-
create the dataset from.
|
|
110
|
-
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
111
|
-
(default: :obj:`None`)
|
|
112
|
-
**kwargs: Additional dataset parameters.
|
|
113
|
-
|
|
114
|
-
Note:
|
|
115
|
-
The dataset must be initialized by calling setup() before use.
|
|
116
|
-
"""
|
|
117
|
-
self._is_setup = False
|
|
118
|
-
self._raw_data: List[Dict[str, str]] = data if data is not None else []
|
|
119
|
-
self._cache_dir = str(cache_dir) if cache_dir is not None else None
|
|
120
|
-
|
|
121
|
-
# Store all parameters in metadata dict for compatibility
|
|
122
|
-
self._metadata = {
|
|
123
|
-
'cache_dir': self._cache_dir,
|
|
124
|
-
**kwargs,
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
self.data: List[DataPoint] = [] # Will be populated in setup()
|
|
128
|
-
|
|
129
|
-
async def setup(self) -> None:
|
|
130
|
-
r"""Set up the dataset with necessary resources.
|
|
131
|
-
|
|
132
|
-
This method:
|
|
133
|
-
1. Creates cache directory if needed
|
|
134
|
-
2. Processes raw data into DataPoint objects using vectorized
|
|
135
|
-
operations
|
|
136
|
-
3. Validates dataset integrity
|
|
137
|
-
|
|
138
|
-
Raises:
|
|
139
|
-
OSError: If cache directory creation fails.
|
|
140
|
-
Exception: If dataset initialization fails.
|
|
141
|
-
"""
|
|
142
|
-
if self._is_setup:
|
|
143
|
-
logger.debug(f"{self.__class__.__name__} already initialized")
|
|
144
|
-
return
|
|
145
|
-
|
|
146
|
-
try:
|
|
147
|
-
# Create cache directory if specified
|
|
148
|
-
if self._cache_dir:
|
|
149
|
-
try:
|
|
150
|
-
os.makedirs(self._cache_dir, exist_ok=True)
|
|
151
|
-
logger.debug(f"Created cache directory: {self._cache_dir}")
|
|
152
|
-
except OSError as e:
|
|
153
|
-
logger.error(
|
|
154
|
-
f"Failed to create cache directory "
|
|
155
|
-
f"{self._cache_dir}: {e}"
|
|
156
|
-
)
|
|
157
|
-
raise
|
|
158
|
-
|
|
159
|
-
# Process raw data into DataPoint objects using vectorized
|
|
160
|
-
# operations
|
|
161
|
-
if not self._raw_data:
|
|
162
|
-
self.data = []
|
|
163
|
-
logger.debug("No raw data to process")
|
|
164
|
-
else:
|
|
165
|
-
try:
|
|
166
|
-
# Helper function for validation that can be used with map
|
|
167
|
-
def create_datapoint(item, idx=None):
|
|
168
|
-
try:
|
|
169
|
-
return DataPoint(
|
|
170
|
-
question=item.get('question', ''),
|
|
171
|
-
rationale=item.get('rationale', ''),
|
|
172
|
-
final_answer=item.get('final_answer', ''),
|
|
173
|
-
metadata=item.get('metadata', {})
|
|
174
|
-
if isinstance(item.get('metadata'), dict)
|
|
175
|
-
else {},
|
|
176
|
-
raw_markdown='',
|
|
177
|
-
difficulty='',
|
|
178
|
-
)
|
|
179
|
-
except ValidationError as e:
|
|
180
|
-
idx_str = (
|
|
181
|
-
f" at index {idx}" if idx is not None else ""
|
|
182
|
-
)
|
|
183
|
-
error_msg = (
|
|
184
|
-
f"Sample{idx_str} validation error: {e}"
|
|
185
|
-
)
|
|
186
|
-
logger.error(error_msg)
|
|
187
|
-
raise ValueError(error_msg)
|
|
188
|
-
|
|
189
|
-
# If raw_data is already a HF dataset, use its map function
|
|
190
|
-
if hasattr(self._raw_data, 'map') and callable(
|
|
191
|
-
self._raw_data.map
|
|
192
|
-
):
|
|
193
|
-
# Using HF dataset's map for vectorized processing
|
|
194
|
-
processed_data = self._raw_data.map(
|
|
195
|
-
lambda example, idx: {
|
|
196
|
-
'datapoint': create_datapoint(example, idx)
|
|
197
|
-
},
|
|
198
|
-
with_indices=True,
|
|
199
|
-
)
|
|
200
|
-
self.data = [
|
|
201
|
-
item['datapoint'] for item in processed_data
|
|
202
|
-
]
|
|
203
|
-
else:
|
|
204
|
-
# Bulk create datapoints
|
|
205
|
-
self.data = [
|
|
206
|
-
create_datapoint(item, i)
|
|
207
|
-
for i, item in enumerate(self._raw_data)
|
|
208
|
-
]
|
|
209
|
-
|
|
210
|
-
logger.debug(f"Processed {len(self.data)} data points")
|
|
211
|
-
except Exception as e:
|
|
212
|
-
logger.error(f"Error processing data: {e}")
|
|
213
|
-
raise
|
|
214
|
-
|
|
215
|
-
self._is_setup = True
|
|
216
|
-
logger.info(f"{self.__class__.__name__} initialized successfully")
|
|
217
|
-
|
|
218
|
-
except Exception as e:
|
|
219
|
-
logger.error(f"Error during {self.__class__.__name__} setup: {e}")
|
|
220
|
-
await self.cleanup()
|
|
221
|
-
raise
|
|
222
|
-
|
|
223
|
-
async def cleanup(self) -> None:
|
|
224
|
-
r"""Clean up dataset resources.
|
|
225
|
-
|
|
226
|
-
This method handles cleanup of resources and resets the dataset state.
|
|
227
|
-
"""
|
|
228
|
-
if not self._is_setup:
|
|
229
|
-
return
|
|
230
|
-
|
|
231
|
-
try:
|
|
232
|
-
# Clear metadata while preserving init config
|
|
233
|
-
init_config = {
|
|
234
|
-
'cache_dir': self._cache_dir,
|
|
235
|
-
}
|
|
236
|
-
self._metadata = init_config
|
|
237
|
-
|
|
238
|
-
logger.info(f"{self.__class__.__name__} cleaned up successfully")
|
|
239
|
-
|
|
240
|
-
except Exception as e:
|
|
241
|
-
logger.error(
|
|
242
|
-
f"Error during {self.__class__.__name__} cleanup: {e}"
|
|
243
|
-
)
|
|
244
|
-
raise
|
|
245
|
-
|
|
246
|
-
finally:
|
|
247
|
-
# Always mark as uninitialized, even if cleanup fails
|
|
248
|
-
self._is_setup = False
|
|
249
|
-
|
|
250
|
-
def sample(self) -> DataPoint:
|
|
251
|
-
r"""Sample a random datapoint from the dataset.
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
DataPoint: A randomly sampled DataPoint.
|
|
255
|
-
|
|
256
|
-
Raises:
|
|
257
|
-
RuntimeError: If dataset is not initialized.
|
|
258
|
-
"""
|
|
259
|
-
if not self._is_setup:
|
|
260
|
-
raise RuntimeError(
|
|
261
|
-
f"{self.__class__.__name__} must be initialized "
|
|
262
|
-
"before sampling"
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
idx = random.randint(0, len(self) - 1)
|
|
266
|
-
return self[idx]
|
|
267
|
-
|
|
268
|
-
def __len__(self) -> int:
|
|
269
|
-
r"""Return the size of the dataset."""
|
|
270
|
-
return len(self.data)
|
|
271
|
-
|
|
272
|
-
def __getitem__(self, idx: int) -> DataPoint:
|
|
273
|
-
r"""Get an item from the dataset.
|
|
274
|
-
|
|
275
|
-
Args:
|
|
276
|
-
idx (int): Index of the item to get.
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
DataPoint: DataPoint from the dataset with the given index.
|
|
280
|
-
|
|
281
|
-
Raises:
|
|
282
|
-
IndexError: If idx is out of bounds.
|
|
283
|
-
"""
|
|
284
|
-
if idx < 0 or idx >= len(self):
|
|
285
|
-
raise IndexError(
|
|
286
|
-
f"Index {idx} out of bounds for dataset of size {len(self)}"
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
return self.data[idx]
|
|
290
|
-
|
|
291
|
-
@property
|
|
292
|
-
def metadata(self) -> Dict[str, Any]:
|
|
293
|
-
r"""Get dataset metadata."""
|
|
294
|
-
return self._metadata.copy()
|
|
295
|
-
|
|
296
|
-
def to_pytorch_dataset(
|
|
297
|
-
self,
|
|
298
|
-
transform: Optional[Callable] = None,
|
|
299
|
-
target_transform: Optional[Callable] = None,
|
|
300
|
-
batch_size: Optional[int] = None,
|
|
301
|
-
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
302
|
-
r"""Convert to a PyTorch dataset or DataLoader.
|
|
303
|
-
|
|
304
|
-
Args:
|
|
305
|
-
transform (Optional[Callable]): Transform to apply to samples.
|
|
306
|
-
target_transform (Optional[Callable]): Transform to apply to
|
|
307
|
-
targets.
|
|
308
|
-
batch_size (Optional[int]): If provided, returns a DataLoader with
|
|
309
|
-
the specified batch size instead of a PyTorchDataset.
|
|
310
|
-
(default: :obj:`None`)
|
|
311
|
-
|
|
312
|
-
Returns:
|
|
313
|
-
Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
|
|
314
|
-
PyTorch format or DataLoader if batch_size is provided.
|
|
315
|
-
"""
|
|
316
|
-
dataset = PyTorchDataset.from_datapoints(
|
|
317
|
-
self.data,
|
|
318
|
-
transform=transform,
|
|
319
|
-
target_transform=target_transform,
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
if batch_size is not None:
|
|
323
|
-
return dataset.get_dataloader(batch_size=batch_size)
|
|
324
|
-
|
|
325
|
-
return dataset
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
class SeedDataset(BaseDataset):
|
|
329
|
-
r"""A dataset containing validated seed examples for data generation.
|
|
86
|
+
class StaticDataset(Dataset):
|
|
87
|
+
r"""A static dataset containing a list of datapoints.
|
|
330
88
|
Ensures that all items adhere to the DataPoint schema.
|
|
89
|
+
This dataset extends :obj:`Dataset` from PyTorch and should
|
|
90
|
+
be used when its size is fixed at runtime.
|
|
331
91
|
|
|
332
92
|
This class can initialize from Hugging Face Datasets,
|
|
333
93
|
PyTorch Datasets, JSON file paths, or lists of dictionaries,
|
|
@@ -337,46 +97,46 @@ class SeedDataset(BaseDataset):
|
|
|
337
97
|
def __init__(
|
|
338
98
|
self,
|
|
339
99
|
data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]],
|
|
340
|
-
|
|
341
|
-
seed: Optional[int] = None,
|
|
100
|
+
seed: int = 42,
|
|
342
101
|
min_samples: int = 1,
|
|
343
102
|
strict: bool = False,
|
|
344
103
|
**kwargs,
|
|
345
104
|
):
|
|
346
|
-
r"""Initialize the
|
|
105
|
+
r"""Initialize the static dataset and validate integrity.
|
|
347
106
|
|
|
348
107
|
Args:
|
|
349
|
-
data (Union[HFDataset, Dataset,
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
- A
|
|
353
|
-
- A
|
|
354
|
-
- A
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
108
|
+
data (:obj:`Union[HFDataset, Dataset,
|
|
109
|
+
Path, List[Dict[str, Any]]]`):
|
|
110
|
+
Input data, which can be one of the following:
|
|
111
|
+
- A Hugging Face Dataset (:obj:`HFDataset`).
|
|
112
|
+
- A PyTorch Dataset (:obj:`torch.utils.data.Dataset`).
|
|
113
|
+
- A :obj:`Path` object representing a JSON file.
|
|
114
|
+
- A list of dictionaries with :obj:`DataPoint`-compatible
|
|
115
|
+
fields.
|
|
116
|
+
seed (:obj:`int`): Random seed for reproducibility.
|
|
117
|
+
Default is :obj:`42`.
|
|
118
|
+
min_samples (:obj:`int`): Minimum required number of samples.
|
|
119
|
+
Default is :obj:`1`.
|
|
120
|
+
strict (:obj:`bool`): Whether to raise an error on invalid
|
|
121
|
+
datapoints (:obj:`True`) or skip/filter them (:obj:`False`).
|
|
122
|
+
Default is :obj:`False`.
|
|
361
123
|
**kwargs: Additional dataset parameters.
|
|
362
124
|
|
|
363
125
|
Raises:
|
|
364
|
-
TypeError: If the data type is
|
|
365
|
-
ValueError: If dataset
|
|
366
|
-
|
|
367
|
-
FileNotFoundError: If the JSON file path
|
|
368
|
-
json.JSONDecodeError: If the JSON file
|
|
126
|
+
TypeError: If the input data type is unsupported.
|
|
127
|
+
ValueError: If the dataset contains fewer than :obj:`min_samples`
|
|
128
|
+
datapoints or if validation fails.
|
|
129
|
+
FileNotFoundError: If the specified JSON file path does not exist.
|
|
130
|
+
json.JSONDecodeError: If the JSON file contains invalid formatting.
|
|
369
131
|
"""
|
|
370
|
-
# Initialize BaseDataset with empty data, we'll populate it ourselves
|
|
371
|
-
super().__init__(data=[], cache_dir=cache_dir, **kwargs)
|
|
372
132
|
|
|
133
|
+
# Store all parameters in metadata dict for compatibility
|
|
134
|
+
self._metadata = {
|
|
135
|
+
**kwargs,
|
|
136
|
+
}
|
|
373
137
|
self._rng = random.Random(seed)
|
|
374
138
|
self._strict = strict
|
|
375
139
|
|
|
376
|
-
# Type checking and conversion into list of dicts to have a
|
|
377
|
-
# consistent internal format. Since Seed Dataset should be
|
|
378
|
-
# small, we can load it entirely into memory
|
|
379
|
-
|
|
380
140
|
self.data: List[DataPoint] = self._init_data(data)
|
|
381
141
|
self._length = len(self.data)
|
|
382
142
|
|
|
@@ -389,6 +149,26 @@ class SeedDataset(BaseDataset):
|
|
|
389
149
|
def _init_data(
|
|
390
150
|
self, data: Union[HFDataset, Dataset, Path, List[Dict[str, Any]]]
|
|
391
151
|
) -> List[DataPoint]:
|
|
152
|
+
r"""Convert input data from various formats into a list of
|
|
153
|
+
:obj:`DataPoint` instances.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
data (:obj:`Union[
|
|
157
|
+
HFDataset,
|
|
158
|
+
Dataset,
|
|
159
|
+
Path,
|
|
160
|
+
List[Dict[str, Any]]
|
|
161
|
+
]`):
|
|
162
|
+
Input dataset in one of the supported formats.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
:obj:`List[DataPoint]`: A list of validated :obj:`DataPoint`
|
|
166
|
+
instances.
|
|
167
|
+
|
|
168
|
+
Raises:
|
|
169
|
+
TypeError: If the input data type is unsupported.
|
|
170
|
+
"""
|
|
171
|
+
|
|
392
172
|
if isinstance(data, HFDataset):
|
|
393
173
|
raw_data = self._init_from_hf_dataset(data)
|
|
394
174
|
elif isinstance(data, Dataset):
|
|
@@ -450,7 +230,6 @@ class SeedDataset(BaseDataset):
|
|
|
450
230
|
rationale=rationale,
|
|
451
231
|
final_answer=final_answer,
|
|
452
232
|
metadata=item.get('metadata'),
|
|
453
|
-
difficulty=item.get('difficulty'),
|
|
454
233
|
)
|
|
455
234
|
except ValidationError as e:
|
|
456
235
|
if self._strict:
|
|
@@ -474,17 +253,19 @@ class SeedDataset(BaseDataset):
|
|
|
474
253
|
return self._length
|
|
475
254
|
|
|
476
255
|
def __getitem__(self, idx: int) -> DataPoint:
|
|
477
|
-
r"""
|
|
256
|
+
r"""Retrieve a datapoint by index.
|
|
478
257
|
|
|
479
258
|
Args:
|
|
480
|
-
idx (int): Index of the
|
|
259
|
+
idx (:obj:`int`): Index of the datapoint.
|
|
481
260
|
|
|
482
261
|
Returns:
|
|
483
|
-
|
|
262
|
+
:obj:`DataPoint`: The datapoint corresponding to the given index.
|
|
484
263
|
|
|
485
264
|
Raises:
|
|
486
|
-
IndexError: If idx is out of bounds
|
|
265
|
+
IndexError: If :obj:`idx` is out of bounds (negative or greater
|
|
266
|
+
than dataset length - 1).
|
|
487
267
|
"""
|
|
268
|
+
|
|
488
269
|
if idx < 0 or idx >= self._length:
|
|
489
270
|
raise IndexError(
|
|
490
271
|
f"Index {idx} out of bounds for dataset of size {self._length}"
|
|
@@ -495,11 +276,12 @@ class SeedDataset(BaseDataset):
|
|
|
495
276
|
r"""Sample a random datapoint from the dataset.
|
|
496
277
|
|
|
497
278
|
Returns:
|
|
498
|
-
DataPoint
|
|
279
|
+
:obj:`DataPoint`: A randomly sampled :obj:`DataPoint`.
|
|
499
280
|
|
|
500
281
|
Raises:
|
|
501
|
-
RuntimeError: If the dataset is empty.
|
|
282
|
+
RuntimeError: If the dataset is empty and no samples can be drawn.
|
|
502
283
|
"""
|
|
284
|
+
|
|
503
285
|
if self._length == 0:
|
|
504
286
|
raise RuntimeError("Dataset is empty, cannot sample.")
|
|
505
287
|
idx = self._rng.randint(0, self._length - 1)
|
|
@@ -507,15 +289,42 @@ class SeedDataset(BaseDataset):
|
|
|
507
289
|
|
|
508
290
|
@property
|
|
509
291
|
def metadata(self) -> Dict[str, Any]:
|
|
510
|
-
r"""
|
|
292
|
+
r"""Retrieve dataset metadata.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
:obj:`Dict[str, Any]`: A copy of the dataset metadata dictionary.
|
|
296
|
+
"""
|
|
297
|
+
|
|
511
298
|
return self._metadata.copy()
|
|
512
299
|
|
|
513
300
|
def _init_from_hf_dataset(self, data: HFDataset) -> List[Dict[str, Any]]:
|
|
301
|
+
r"""Convert a Hugging Face dataset into a list of dictionaries.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
data (:obj:`HFDataset`): A Hugging Face dataset.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
:obj:`List[Dict[str, Any]]`: A list of dictionaries representing
|
|
308
|
+
the dataset, where each dictionary corresponds to a datapoint.
|
|
309
|
+
"""
|
|
514
310
|
return [dict(item) for item in data]
|
|
515
311
|
|
|
516
312
|
def _init_from_pytorch_dataset(
|
|
517
313
|
self, data: Dataset
|
|
518
314
|
) -> List[Dict[str, Any]]:
|
|
315
|
+
r"""Convert a PyTorch dataset into a list of dictionaries.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
data (:obj:`Dataset`): A PyTorch dataset.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
:obj:`List[Dict[str, Any]]`: A list of dictionaries representing
|
|
322
|
+
the dataset.
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
TypeError: If the dataset does not implement :obj:`__len__()`
|
|
326
|
+
or contains non-dictionary elements.
|
|
327
|
+
"""
|
|
519
328
|
if not isinstance(data, Sized):
|
|
520
329
|
raise TypeError(
|
|
521
330
|
f"{type(data).__name__} does not implement `__len__()`."
|
|
@@ -533,6 +342,20 @@ class SeedDataset(BaseDataset):
|
|
|
533
342
|
return raw_data
|
|
534
343
|
|
|
535
344
|
def _init_from_json_path(self, data: Path) -> List[Dict[str, Any]]:
|
|
345
|
+
r"""Load and parse a dataset from a JSON file.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
data (:obj:`Path`): Path to the JSON file.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
:obj:`List[Dict[str, Any]]`: A list of datapoint dictionaries.
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
FileNotFoundError: If the specified JSON file does not exist.
|
|
355
|
+
ValueError: If the JSON content is not a list of dictionaries.
|
|
356
|
+
json.JSONDecodeError: If the JSON file has invalid formatting.
|
|
357
|
+
"""
|
|
358
|
+
|
|
536
359
|
if not data.exists():
|
|
537
360
|
raise FileNotFoundError(f"JSON file not found: {data}")
|
|
538
361
|
try:
|
|
@@ -557,6 +380,18 @@ class SeedDataset(BaseDataset):
|
|
|
557
380
|
def _init_from_list(
|
|
558
381
|
self, data: List[Dict[str, Any]]
|
|
559
382
|
) -> List[Dict[str, Any]]:
|
|
383
|
+
r"""Validate and convert a list of dictionaries into a dataset.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
data (:obj:`List[Dict[str, Any]]`): A list of dictionaries where
|
|
387
|
+
each dictionary must be a valid :obj:`DataPoint`.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
:obj:`List[Dict[str, Any]]`: The validated list of dictionaries.
|
|
391
|
+
|
|
392
|
+
Raises:
|
|
393
|
+
ValueError: If any item in the list is not a dictionary.
|
|
394
|
+
"""
|
|
560
395
|
for i, item in enumerate(data):
|
|
561
396
|
if not isinstance(item, dict):
|
|
562
397
|
raise ValueError(
|
|
@@ -566,111 +401,7 @@ class SeedDataset(BaseDataset):
|
|
|
566
401
|
return data
|
|
567
402
|
|
|
568
403
|
|
|
569
|
-
class
|
|
570
|
-
r"""A dataset for storing synthetically generated data points.
|
|
571
|
-
|
|
572
|
-
This class is used to store datapoints that are generated through
|
|
573
|
-
a generative process, such as using an agent.
|
|
574
|
-
"""
|
|
575
|
-
|
|
576
|
-
def __init__(
|
|
577
|
-
self,
|
|
578
|
-
data: Optional[List[Dict[str, str]]] = None,
|
|
579
|
-
cache_dir: Optional[str] = None,
|
|
580
|
-
**kwargs,
|
|
581
|
-
):
|
|
582
|
-
r"""Initialize the synthetic dataset.
|
|
583
|
-
|
|
584
|
-
Args:
|
|
585
|
-
data (Optional[List[Dict[str, str]]]): List of dictionary items to
|
|
586
|
-
create the dataset from. (default: :obj:`None`)
|
|
587
|
-
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
588
|
-
(default: :obj:`None`)
|
|
589
|
-
**kwargs: Additional dataset parameters.
|
|
590
|
-
"""
|
|
591
|
-
super().__init__(
|
|
592
|
-
data=data if data is not None else [],
|
|
593
|
-
cache_dir=cache_dir,
|
|
594
|
-
**kwargs,
|
|
595
|
-
)
|
|
596
|
-
self.data: List[DataPoint] = []
|
|
597
|
-
|
|
598
|
-
def add(self, item: DataPoint) -> None:
|
|
599
|
-
r"""Add a new data point to the dataset.
|
|
600
|
-
|
|
601
|
-
Args:
|
|
602
|
-
item (DataPoint): The datapoint to add to the dataset.
|
|
603
|
-
"""
|
|
604
|
-
self.data.append(item)
|
|
605
|
-
|
|
606
|
-
def add_batch(self, items: List[DataPoint]) -> None:
|
|
607
|
-
r"""Add multiple data points to the dataset.
|
|
608
|
-
|
|
609
|
-
Args:
|
|
610
|
-
items (List[DataPoint]): The datapoints to add to the dataset.
|
|
611
|
-
"""
|
|
612
|
-
self.data.extend(items)
|
|
613
|
-
|
|
614
|
-
def to_pytorch_dataset(
|
|
615
|
-
self,
|
|
616
|
-
transform: Optional[Callable] = None,
|
|
617
|
-
target_transform: Optional[Callable] = None,
|
|
618
|
-
batch_size: Optional[int] = None,
|
|
619
|
-
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
620
|
-
r"""Convert to a PyTorch dataset or DataLoader.
|
|
621
|
-
|
|
622
|
-
Args:
|
|
623
|
-
transform (Optional[Callable]): Transform to apply to samples.
|
|
624
|
-
target_transform (Optional[Callable]): Transform to apply to
|
|
625
|
-
targets.
|
|
626
|
-
batch_size (Optional[int]): If provided, returns a DataLoader with
|
|
627
|
-
the specified batch size instead of a PyTorchDataset.
|
|
628
|
-
(default: :obj:`None`)
|
|
629
|
-
|
|
630
|
-
Returns:
|
|
631
|
-
Union[PyTorchDataset, torch.utils.data.DataLoader]: Dataset in
|
|
632
|
-
PyTorch format or DataLoader if batch_size is provided.
|
|
633
|
-
"""
|
|
634
|
-
return convert_synthetic_to_pytorch(
|
|
635
|
-
self,
|
|
636
|
-
transform=transform,
|
|
637
|
-
target_transform=target_transform,
|
|
638
|
-
batch_size=batch_size,
|
|
639
|
-
)
|
|
640
|
-
|
|
641
|
-
def save_pytorch_format(self, path: str, compression: bool = True) -> None:
|
|
642
|
-
r"""Save the dataset to disk in PyTorch format.
|
|
643
|
-
|
|
644
|
-
Args:
|
|
645
|
-
path (str): Path to save the dataset to.
|
|
646
|
-
compression (bool): Whether to use compression to reduce file size.
|
|
647
|
-
(default: :obj:`True`)
|
|
648
|
-
"""
|
|
649
|
-
save_synthetic_dataset(self, path, compression=compression)
|
|
650
|
-
|
|
651
|
-
def filter(
|
|
652
|
-
self, predicate: Callable[[DataPoint], bool]
|
|
653
|
-
) -> 'SyntheticDataset':
|
|
654
|
-
r"""Filter the dataset using a predicate function.
|
|
655
|
-
|
|
656
|
-
Args:
|
|
657
|
-
predicate (Callable[[DataPoint], bool]): Function that takes a
|
|
658
|
-
DataPoint and returns True if it should be kept, False
|
|
659
|
-
otherwise.
|
|
660
|
-
|
|
661
|
-
Returns:
|
|
662
|
-
SyntheticDataset: A new dataset containing only the filtered items.
|
|
663
|
-
"""
|
|
664
|
-
filtered_data = [dp for dp in self.data if predicate(dp)]
|
|
665
|
-
|
|
666
|
-
# Create a new dataset with the filtered data
|
|
667
|
-
new_dataset = SyntheticDataset()
|
|
668
|
-
new_dataset.add_batch(filtered_data)
|
|
669
|
-
|
|
670
|
-
return new_dataset
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
class GenerativeDataset(BaseDataset):
|
|
404
|
+
class GenerativeDataset(Dataset):
|
|
674
405
|
r"""A dataset for generating synthetic datapoints using external agents and
|
|
675
406
|
verifiers.
|
|
676
407
|
|
|
@@ -680,30 +411,22 @@ class GenerativeDataset(BaseDataset):
|
|
|
680
411
|
|
|
681
412
|
def __init__(
|
|
682
413
|
self,
|
|
683
|
-
seed_dataset:
|
|
414
|
+
seed_dataset: StaticDataset,
|
|
684
415
|
verifier: BaseVerifier,
|
|
685
416
|
agent: ChatAgent,
|
|
686
|
-
cache_dir: Optional[str] = None,
|
|
687
417
|
seed: int = 42,
|
|
688
418
|
**kwargs,
|
|
689
419
|
):
|
|
690
420
|
r"""Initialize the generative dataset.
|
|
691
421
|
|
|
692
422
|
Args:
|
|
693
|
-
seed_dataset (
|
|
423
|
+
seed_dataset (StaticDataset): Validated static dataset to
|
|
424
|
+
use for examples.
|
|
694
425
|
verifier (BaseVerifier): Verifier to validate generated content.
|
|
695
426
|
agent (ChatAgent): Agent to generate new datapoints.
|
|
696
|
-
cache_dir (Optional[str]): Directory to cache dataset files.
|
|
697
|
-
(default: :obj:`None`)
|
|
698
427
|
seed (int): Random seed for reproducibility. (default: :obj:`42`)
|
|
699
428
|
**kwargs: Additional dataset parameters.
|
|
700
429
|
"""
|
|
701
|
-
# Initialize with empty data since we'll generate content dynamically
|
|
702
|
-
super().__init__(
|
|
703
|
-
data=[],
|
|
704
|
-
cache_dir=cache_dir,
|
|
705
|
-
**kwargs,
|
|
706
|
-
)
|
|
707
430
|
|
|
708
431
|
self.seed_dataset = seed_dataset
|
|
709
432
|
self.verifier = verifier
|
|
@@ -712,8 +435,11 @@ class GenerativeDataset(BaseDataset):
|
|
|
712
435
|
self.seed = seed
|
|
713
436
|
random.seed(self.seed)
|
|
714
437
|
|
|
438
|
+
self._data: List[DataPoint] = []
|
|
439
|
+
|
|
715
440
|
def _construct_prompt(self, examples: List[DataPoint]) -> str:
|
|
716
|
-
r"""Construct a prompt for generating new datapoints
|
|
441
|
+
r"""Construct a prompt for generating new datapoints
|
|
442
|
+
using a fixed sample of 3 examples from the seed dataset.
|
|
717
443
|
|
|
718
444
|
Args:
|
|
719
445
|
examples (List[DataPoint]): Examples to include in the prompt.
|
|
@@ -732,440 +458,182 @@ class GenerativeDataset(BaseDataset):
|
|
|
732
458
|
prompt += "New datapoint:"
|
|
733
459
|
return prompt
|
|
734
460
|
|
|
735
|
-
async def generate_new(
|
|
736
|
-
|
|
461
|
+
async def generate_new(
|
|
462
|
+
self, n: int, max_retries: int = 10
|
|
463
|
+
) -> List[DataPoint]:
|
|
464
|
+
r"""Generates and validates `n` new datapoints through
|
|
465
|
+
few-shot prompting, with a retry limit.
|
|
466
|
+
|
|
467
|
+
Steps:
|
|
468
|
+
1. Samples 3 examples from the seed dataset.
|
|
469
|
+
2. Constructs a prompt using the selected examples.
|
|
470
|
+
3. Uses an agent to generate a new datapoint.
|
|
471
|
+
4. Verifies the datapoint using a verifier.
|
|
472
|
+
5. Stores valid datapoints in memory.
|
|
737
473
|
|
|
738
474
|
Args:
|
|
739
475
|
n (int): Number of valid datapoints to generate.
|
|
476
|
+
max_retries (int): Maximum number of retries before stopping.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
List[DataPoint]: A list of newly generated valid datapoints.
|
|
740
480
|
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
481
|
+
Raises:
|
|
482
|
+
TypeError: If the agent's output is not a dictionary (or does not
|
|
483
|
+
match the expected format).
|
|
484
|
+
KeyError: If required keys are missing from the response.
|
|
485
|
+
AttributeError: If the verifier response lacks attributes.
|
|
486
|
+
ValidationError: If a datapoint fails schema validation.
|
|
487
|
+
RuntimeError: If retries are exhausted before `n` valid datapoints
|
|
488
|
+
are generated.
|
|
489
|
+
|
|
490
|
+
Notes:
|
|
491
|
+
- Retries on validation failures until `n` valid datapoints exist
|
|
492
|
+
or `max_retries` is reached, whichever comes first.
|
|
493
|
+
- If retries are exhausted before reaching `n`, a `RuntimeError`
|
|
494
|
+
is raised.
|
|
495
|
+
- Metadata includes a timestamp for tracking datapoint creation.
|
|
496
|
+
- This method can be overridden to implement custom generation.
|
|
747
497
|
"""
|
|
748
498
|
valid_data_points: List[DataPoint] = []
|
|
499
|
+
retries = 0
|
|
749
500
|
|
|
750
|
-
while len(valid_data_points) < n:
|
|
501
|
+
while len(valid_data_points) < n and retries < max_retries:
|
|
751
502
|
try:
|
|
752
|
-
|
|
753
|
-
examples = [self.seed_dataset[i] for i in indices]
|
|
503
|
+
examples = [self.seed_dataset.sample() for _ in range(3)]
|
|
754
504
|
prompt = self._construct_prompt(examples)
|
|
755
505
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
)
|
|
762
|
-
|
|
763
|
-
if not isinstance(agent_output, dict):
|
|
764
|
-
raise TypeError("Agent output must be a dictionary")
|
|
765
|
-
if (
|
|
766
|
-
'question' not in agent_output
|
|
767
|
-
or 'rationale' not in agent_output
|
|
768
|
-
):
|
|
769
|
-
raise KeyError(
|
|
770
|
-
"Agent output missing required keys: "
|
|
771
|
-
"'question' or 'rationale'"
|
|
506
|
+
try:
|
|
507
|
+
agent_output = (
|
|
508
|
+
self.agent.step(prompt, response_format=DataPoint)
|
|
509
|
+
.msgs[0]
|
|
510
|
+
.parsed
|
|
772
511
|
)
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
512
|
+
if not isinstance(agent_output, dict):
|
|
513
|
+
raise TypeError("Agent output must be a dictionary")
|
|
514
|
+
if (
|
|
515
|
+
"question" not in agent_output
|
|
516
|
+
or "rationale" not in agent_output
|
|
517
|
+
):
|
|
518
|
+
raise KeyError(
|
|
519
|
+
"Missing 'question' or 'rationale' in agent output"
|
|
520
|
+
)
|
|
521
|
+
except (TypeError, KeyError) as e:
|
|
522
|
+
logger.warning(
|
|
523
|
+
f"Agent output issue: {e}, retrying... "
|
|
524
|
+
f"({retries + 1}/{max_retries})"
|
|
781
525
|
)
|
|
782
|
-
|
|
783
|
-
if not verifier_response.result:
|
|
526
|
+
retries += 1
|
|
784
527
|
continue
|
|
785
528
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
# Create and validate the new datapoint
|
|
789
|
-
new_datapoint = {
|
|
790
|
-
'question': agent_output['question'],
|
|
791
|
-
'rationale': rationale,
|
|
792
|
-
'final_answer': final_answer,
|
|
793
|
-
}
|
|
794
|
-
|
|
795
|
-
datapoint = DataPoint(**new_datapoint)
|
|
796
|
-
valid_data_points.append(datapoint)
|
|
797
|
-
|
|
798
|
-
except (TypeError, KeyError, AttributeError, ValidationError) as e:
|
|
799
|
-
logger.warning(
|
|
800
|
-
f"Error encountered during generation: {e}, retrying..."
|
|
801
|
-
)
|
|
802
|
-
|
|
803
|
-
# Add all valid datapoints to the dataset
|
|
804
|
-
for datapoint in valid_data_points:
|
|
805
|
-
self.data.append(datapoint)
|
|
806
|
-
logger.debug("Added new datapoint to dataset")
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
# Define a type variable for return type flexibility
|
|
810
|
-
T_co = TypeVar('T_co', covariant=True)
|
|
811
|
-
|
|
529
|
+
rationale = agent_output["rationale"]
|
|
812
530
|
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
sample and returns a transformed version for features.
|
|
832
|
-
(default: :obj:`None`)
|
|
833
|
-
target_transform (Optional[Callable]): A function/transform that
|
|
834
|
-
takes a target and returns a transformed version. (default:
|
|
835
|
-
:obj:`None`)
|
|
836
|
-
validate (bool): Whether to validate data points using DataPoint
|
|
837
|
-
schema. (default: :obj:`True`)
|
|
838
|
-
|
|
839
|
-
Raises:
|
|
840
|
-
ValidationError: If validation is enabled and data doesn't match
|
|
841
|
-
DataPoint schema.
|
|
842
|
-
"""
|
|
843
|
-
self.transform = transform
|
|
844
|
-
self.target_transform = target_transform
|
|
845
|
-
|
|
846
|
-
# Validate and store data
|
|
847
|
-
self._raw_data = data
|
|
848
|
-
self.data = []
|
|
531
|
+
try:
|
|
532
|
+
verifier_response = await self.verifier.verify(
|
|
533
|
+
VerifierInput(
|
|
534
|
+
llm_response=rationale, ground_truth=None
|
|
535
|
+
)
|
|
536
|
+
)
|
|
537
|
+
if not verifier_response or not verifier_response.result:
|
|
538
|
+
raise ValueError(
|
|
539
|
+
"Verifier unsuccessful, response: "
|
|
540
|
+
f"{verifier_response}"
|
|
541
|
+
)
|
|
542
|
+
except (ValueError, AttributeError) as e:
|
|
543
|
+
logger.warning(
|
|
544
|
+
f"Verifier issue: {e}, "
|
|
545
|
+
f"retrying... ({retries + 1}/{max_retries})"
|
|
546
|
+
)
|
|
547
|
+
retries += 1
|
|
548
|
+
continue
|
|
849
549
|
|
|
850
|
-
if validate:
|
|
851
|
-
for i, item in enumerate(self._raw_data):
|
|
852
550
|
try:
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
551
|
+
new_datapoint = DataPoint(
|
|
552
|
+
question=agent_output["question"],
|
|
553
|
+
rationale=rationale,
|
|
554
|
+
final_answer=verifier_response.result,
|
|
555
|
+
metadata={
|
|
556
|
+
"synthetic": str(True),
|
|
557
|
+
"created": datetime.now().isoformat(),
|
|
558
|
+
},
|
|
559
|
+
)
|
|
856
560
|
except ValidationError as e:
|
|
857
|
-
logger.
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
def __getitem__(self, index: int) -> T_co:
|
|
864
|
-
r"""Get an item from the dataset.
|
|
561
|
+
logger.warning(
|
|
562
|
+
f"Datapoint validation failed: {e}, "
|
|
563
|
+
f"retrying... ({retries + 1}/{max_retries})"
|
|
564
|
+
)
|
|
565
|
+
retries += 1
|
|
566
|
+
continue
|
|
865
567
|
|
|
866
|
-
|
|
867
|
-
index (int): Index of the item to get.
|
|
568
|
+
valid_data_points.append(new_datapoint)
|
|
868
569
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
570
|
+
except Exception as e:
|
|
571
|
+
logger.warning(
|
|
572
|
+
f"Unexpected error: {e}, retrying..."
|
|
573
|
+
f" ({retries + 1}/{max_retries})"
|
|
574
|
+
)
|
|
575
|
+
retries += 1
|
|
873
576
|
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
577
|
+
if len(valid_data_points) < n:
|
|
578
|
+
raise RuntimeError(
|
|
579
|
+
f"Failed to generate {n} valid datapoints "
|
|
580
|
+
f"after {max_retries} retries."
|
|
581
|
+
)
|
|
877
582
|
|
|
878
|
-
|
|
583
|
+
self._data.extend(valid_data_points)
|
|
584
|
+
return valid_data_points
|
|
879
585
|
|
|
880
586
|
def __len__(self) -> int:
|
|
881
|
-
r"""Return the size of the dataset.
|
|
882
|
-
|
|
883
|
-
Returns:
|
|
884
|
-
int: Number of samples in the dataset.
|
|
885
|
-
"""
|
|
886
|
-
return len(self.data)
|
|
587
|
+
r"""Return the size of the dataset."""
|
|
588
|
+
return len(self._data)
|
|
887
589
|
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
cls, datapoints: List[DataPoint], **kwargs
|
|
891
|
-
) -> 'PyTorchDataset':
|
|
892
|
-
r"""Create a PyTorchDataset from a list of DataPoints.
|
|
590
|
+
def __getitem__(self, idx: int) -> DataPoint:
|
|
591
|
+
r"""Retrieve a datapoint by index.
|
|
893
592
|
|
|
894
593
|
Args:
|
|
895
|
-
|
|
896
|
-
**kwargs: Additional arguments to pass to the constructor.
|
|
594
|
+
idx (int): Index of the datapoint.
|
|
897
595
|
|
|
898
596
|
Returns:
|
|
899
|
-
|
|
900
|
-
"""
|
|
901
|
-
data = [dp.to_dict() for dp in datapoints]
|
|
902
|
-
# We can skip validation since datapoints are already validated
|
|
903
|
-
return cls(data, validate=False, **kwargs)
|
|
904
|
-
|
|
905
|
-
def to_hf_dataset(self) -> HFDataset:
|
|
906
|
-
r"""Convert to a HuggingFace dataset.
|
|
907
|
-
|
|
908
|
-
Returns:
|
|
909
|
-
HFDataset: Dataset in HuggingFace format.
|
|
910
|
-
"""
|
|
911
|
-
return HFDataset.from_list(self.data)
|
|
597
|
+
DataPoint: The datapoint corresponding to the given index.
|
|
912
598
|
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
Args:
|
|
917
|
-
path (str): Path to save the dataset to.
|
|
599
|
+
Raises:
|
|
600
|
+
IndexError: If idx is out of bounds.
|
|
918
601
|
"""
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
transform: Optional[Callable] = None,
|
|
926
|
-
target_transform: Optional[Callable] = None,
|
|
927
|
-
) -> 'PyTorchDataset':
|
|
928
|
-
r"""Load a dataset from disk.
|
|
929
|
-
|
|
930
|
-
Args:
|
|
931
|
-
path (str): Path to load the dataset from.
|
|
932
|
-
transform (Optional[Callable]): Transform to apply to samples.
|
|
933
|
-
(default: :obj:`None`)
|
|
934
|
-
target_transform (Optional[Callable]): Transform to apply to
|
|
935
|
-
targets. (default: :obj:`None`)
|
|
602
|
+
if idx < 0 or idx >= len(self._data):
|
|
603
|
+
raise IndexError(
|
|
604
|
+
f"Index {idx} out of bounds for dataset of "
|
|
605
|
+
f"size {len(self._data)}"
|
|
606
|
+
)
|
|
607
|
+
return self._data[idx]
|
|
936
608
|
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
"""
|
|
940
|
-
data = torch.load(path)
|
|
941
|
-
return cls(
|
|
942
|
-
data,
|
|
943
|
-
transform=transform,
|
|
944
|
-
target_transform=target_transform,
|
|
945
|
-
validate=False,
|
|
946
|
-
)
|
|
609
|
+
def save_to_jsonl(self, file_path: Union[str, Path]) -> None:
|
|
610
|
+
r"""Saves the dataset to a JSONL (JSON Lines) file.
|
|
947
611
|
|
|
948
|
-
|
|
949
|
-
def collate_fn(
|
|
950
|
-
batch: List[Dict[str, Any]],
|
|
951
|
-
) -> Dict[str, Union[List[Any], torch.Tensor]]:
|
|
952
|
-
r"""Collate function for PyTorch DataLoader.
|
|
612
|
+
Each datapoint is stored as a separate JSON object on a new line.
|
|
953
613
|
|
|
954
614
|
Args:
|
|
955
|
-
|
|
615
|
+
file_path (Union[str, Path]): Path to save the JSONL file.
|
|
956
616
|
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
"""
|
|
961
|
-
if not batch:
|
|
962
|
-
return {}
|
|
963
|
-
|
|
964
|
-
# Initialize result dictionary with keys from first item - start with
|
|
965
|
-
# lists only
|
|
966
|
-
result: Dict[str, List[Any]] = {k: [] for k in batch[0].keys()}
|
|
967
|
-
|
|
968
|
-
# Collect values by key
|
|
969
|
-
for item in batch:
|
|
970
|
-
for k, v in item.items():
|
|
971
|
-
result[k].append(v)
|
|
972
|
-
|
|
973
|
-
# Convert numeric/boolean lists to tensors where possible
|
|
974
|
-
result_with_tensors: Dict[str, Union[List[Any], torch.Tensor]] = {}
|
|
975
|
-
for k, v in result.items():
|
|
976
|
-
if all(isinstance(x, (int, float, bool)) for x in v):
|
|
977
|
-
try:
|
|
978
|
-
result_with_tensors[k] = torch.tensor(v)
|
|
979
|
-
except (ValueError, TypeError):
|
|
980
|
-
# Keep as list if tensor conversion fails
|
|
981
|
-
result_with_tensors[k] = v
|
|
982
|
-
else:
|
|
983
|
-
result_with_tensors[k] = v
|
|
984
|
-
|
|
985
|
-
return result_with_tensors
|
|
986
|
-
|
|
987
|
-
def get_dataloader(
|
|
988
|
-
self,
|
|
989
|
-
batch_size: int = 32,
|
|
990
|
-
shuffle: bool = True,
|
|
991
|
-
num_workers: int = 0,
|
|
992
|
-
pin_memory: bool = False,
|
|
993
|
-
**kwargs,
|
|
994
|
-
) -> "DataLoader":
|
|
995
|
-
r"""Create a PyTorch DataLoader for this dataset.
|
|
996
|
-
|
|
997
|
-
Args:
|
|
998
|
-
batch_size (int): Batch size. (default: :obj:`32`)
|
|
999
|
-
shuffle (bool): Whether to shuffle the dataset. (default:
|
|
1000
|
-
:obj:`True`)
|
|
1001
|
-
num_workers (int): Number of workers for data loading. (default:
|
|
1002
|
-
:obj:`0`)
|
|
1003
|
-
pin_memory (bool): Whether to pin memory for faster GPU transfer.
|
|
1004
|
-
(default: :obj:`False`)
|
|
1005
|
-
**kwargs: Additional arguments to pass to DataLoader.
|
|
617
|
+
Raises:
|
|
618
|
+
ValueError: If the dataset is empty.
|
|
619
|
+
IOError: If there is an issue writing to the file.
|
|
1006
620
|
|
|
1007
|
-
|
|
1008
|
-
|
|
621
|
+
Notes:
|
|
622
|
+
- Uses `self._data`, which contains the generated datapoints.
|
|
623
|
+
- Overwrites the file if it already exists.
|
|
624
|
+
- Ensures compatibility with large datasets by using JSONL format.
|
|
1009
625
|
"""
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
return DataLoader(
|
|
1013
|
-
self,
|
|
1014
|
-
batch_size=batch_size,
|
|
1015
|
-
shuffle=shuffle,
|
|
1016
|
-
num_workers=num_workers,
|
|
1017
|
-
collate_fn=self.collate_fn,
|
|
1018
|
-
pin_memory=pin_memory,
|
|
1019
|
-
**kwargs,
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
def convert_hf_to_pytorch(
|
|
1024
|
-
hf_dataset: HFDataset,
|
|
1025
|
-
transform: Optional[Callable] = None,
|
|
1026
|
-
target_transform: Optional[Callable] = None,
|
|
1027
|
-
column_mapping: Optional[Dict[str, str]] = None,
|
|
1028
|
-
validate: bool = True,
|
|
1029
|
-
batch_size: Optional[int] = None,
|
|
1030
|
-
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
1031
|
-
r"""Convert a HuggingFace dataset to a PyTorchDataset or DataLoader.
|
|
1032
|
-
|
|
1033
|
-
This function maps HuggingFace dataset columns to the expected DataPoint
|
|
1034
|
-
format, validates the data, and creates a PyTorchDataset or DataLoader.
|
|
1035
|
-
|
|
1036
|
-
Args:
|
|
1037
|
-
hf_dataset (HFDataset): HuggingFace dataset to convert.
|
|
1038
|
-
transform (Optional[Callable]): Transform to apply to samples.
|
|
1039
|
-
target_transform (Optional[Callable]): Transform to apply to targets.
|
|
1040
|
-
column_mapping (Optional[Dict[str, str]]): Mapping from HuggingFace
|
|
1041
|
-
column names to DataPoint field names. If None, assumes columns
|
|
1042
|
-
already match DataPoint fields.
|
|
1043
|
-
validate (bool): Whether to validate data points using DataPoint
|
|
1044
|
-
schema. (default: :obj:`True`)
|
|
1045
|
-
batch_size (Optional[int]): If provided, returns a DataLoader with the
|
|
1046
|
-
specified batch size instead of a PyTorchDataset. (default:
|
|
1047
|
-
:obj:`None`)
|
|
1048
|
-
|
|
1049
|
-
Returns:
|
|
1050
|
-
Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
|
|
1051
|
-
or DataLoader if batch_size is provided.
|
|
1052
|
-
"""
|
|
1053
|
-
# Convert HuggingFace dataset to list of dicts more efficiently
|
|
1054
|
-
mapped_dataset = []
|
|
1055
|
-
|
|
1056
|
-
for i in range(len(hf_dataset)):
|
|
1057
|
-
item = hf_dataset[i]
|
|
1058
|
-
if column_mapping is not None:
|
|
1059
|
-
# Apply column mapping if provided
|
|
1060
|
-
mapped_item = {}
|
|
1061
|
-
for hf_col, dp_field in column_mapping.items():
|
|
1062
|
-
if hf_col in item:
|
|
1063
|
-
mapped_item[dp_field] = item[hf_col]
|
|
1064
|
-
mapped_dataset.append(mapped_item)
|
|
1065
|
-
else:
|
|
1066
|
-
# Otherwise use item directly
|
|
1067
|
-
mapped_dataset.append(dict(item))
|
|
1068
|
-
|
|
1069
|
-
# Create PyTorchDataset
|
|
1070
|
-
dataset: PyTorchDataset = PyTorchDataset(
|
|
1071
|
-
mapped_dataset,
|
|
1072
|
-
transform=transform,
|
|
1073
|
-
target_transform=target_transform,
|
|
1074
|
-
validate=validate,
|
|
1075
|
-
)
|
|
1076
|
-
|
|
1077
|
-
# Return DataLoader if batch_size is provided
|
|
1078
|
-
if batch_size is not None:
|
|
1079
|
-
return dataset.get_dataloader(batch_size=batch_size)
|
|
626
|
+
if not self._data:
|
|
627
|
+
raise ValueError("Dataset is empty. No data to save.")
|
|
1080
628
|
|
|
1081
|
-
|
|
629
|
+
file_path = Path(file_path)
|
|
1082
630
|
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
Args:
|
|
1093
|
-
synthetic_dataset (SyntheticDataset): Synthetic dataset to convert.
|
|
1094
|
-
transform (Optional[Callable]): Transform to apply to samples.
|
|
1095
|
-
target_transform (Optional[Callable]): Transform to apply to targets.
|
|
1096
|
-
batch_size (Optional[int]): If provided, returns a DataLoader with the
|
|
1097
|
-
specified batch size instead of a PyTorchDataset. (default:
|
|
1098
|
-
:obj:`None`)
|
|
1099
|
-
|
|
1100
|
-
Returns:
|
|
1101
|
-
Union[PyTorchDataset, torch.utils.data.DataLoader]: Converted dataset
|
|
1102
|
-
or DataLoader if batch_size is provided.
|
|
1103
|
-
"""
|
|
1104
|
-
dataset = PyTorchDataset.from_datapoints(
|
|
1105
|
-
synthetic_dataset.data,
|
|
1106
|
-
transform=transform,
|
|
1107
|
-
target_transform=target_transform,
|
|
1108
|
-
)
|
|
1109
|
-
|
|
1110
|
-
# Return DataLoader if batch_size is provided
|
|
1111
|
-
if batch_size is not None:
|
|
1112
|
-
return dataset.get_dataloader(batch_size=batch_size)
|
|
1113
|
-
|
|
1114
|
-
return dataset
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
def save_synthetic_dataset(
|
|
1118
|
-
synthetic_dataset: 'SyntheticDataset',
|
|
1119
|
-
path: str,
|
|
1120
|
-
compression: bool = True,
|
|
1121
|
-
) -> None:
|
|
1122
|
-
r"""Save a synthetic dataset to disk using PyTorch format.
|
|
1123
|
-
|
|
1124
|
-
Args:
|
|
1125
|
-
synthetic_dataset (SyntheticDataset): Dataset to save.
|
|
1126
|
-
path (str): Path to save the dataset to.
|
|
1127
|
-
compression (bool): Whether to use compression to reduce file size.
|
|
1128
|
-
(default: :obj:`True`)
|
|
1129
|
-
"""
|
|
1130
|
-
pytorch_dataset = convert_synthetic_to_pytorch(synthetic_dataset)
|
|
1131
|
-
|
|
1132
|
-
# Save with compression if enabled (uses less disk space)
|
|
1133
|
-
if compression:
|
|
1134
|
-
torch.save(
|
|
1135
|
-
pytorch_dataset.data, # type: ignore[union-attr]
|
|
1136
|
-
path,
|
|
1137
|
-
_use_new_zipfile_serialization=True,
|
|
1138
|
-
)
|
|
1139
|
-
else:
|
|
1140
|
-
pytorch_dataset.save_to_disk(path) # type: ignore[union-attr]
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
def load_pytorch_dataset(
|
|
1144
|
-
path: str,
|
|
1145
|
-
transform: Optional[Callable] = None,
|
|
1146
|
-
target_transform: Optional[Callable] = None,
|
|
1147
|
-
batch_size: Optional[int] = None,
|
|
1148
|
-
) -> Union["PyTorchDataset", "DataLoader"]:
|
|
1149
|
-
r"""Load a PyTorchDataset from disk.
|
|
1150
|
-
|
|
1151
|
-
Args:
|
|
1152
|
-
path (str): Path to load the dataset from.
|
|
1153
|
-
transform (Optional[Callable]): Transform to apply to samples.
|
|
1154
|
-
target_transform (Optional[Callable]): Transform to apply to targets.
|
|
1155
|
-
batch_size (Optional[int]): If provided, returns a DataLoader with the
|
|
1156
|
-
specified batch size instead of a PyTorchDataset. (default:
|
|
1157
|
-
:obj:`None`)
|
|
1158
|
-
|
|
1159
|
-
Returns:
|
|
1160
|
-
Union[PyTorchDataset, torch.utils.data.DataLoader]: Loaded dataset or
|
|
1161
|
-
DataLoader if batch_size is provided.
|
|
1162
|
-
"""
|
|
1163
|
-
dataset = PyTorchDataset.load_from_disk(
|
|
1164
|
-
path, transform=transform, target_transform=target_transform
|
|
1165
|
-
)
|
|
1166
|
-
|
|
1167
|
-
# Return DataLoader if batch_size is provided
|
|
1168
|
-
if batch_size is not None:
|
|
1169
|
-
return dataset.get_dataloader(batch_size=batch_size)
|
|
1170
|
-
|
|
1171
|
-
return dataset
|
|
631
|
+
try:
|
|
632
|
+
with file_path.open("w", encoding="utf-8") as f:
|
|
633
|
+
for datapoint in self._data:
|
|
634
|
+
json.dump(datapoint.to_dict(), f)
|
|
635
|
+
f.write("\n") # Ensure each entry is on a new line
|
|
636
|
+
logger.info(f"Dataset saved successfully to {file_path}")
|
|
637
|
+
except IOError as e:
|
|
638
|
+
logger.error(f"Error writing to file {file_path}: {e}")
|
|
639
|
+
raise
|