data-forager 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_forager/datasets/tokens_with_aux.py +91 -0
- data_forager/index_stores/fs_based.py +91 -8
- data_forager/indexers/text_lines.py +28 -73
- data_forager/indexers/tokenization_indexer.py +158 -191
- data_forager/sample_generators/__init__.py +30 -0
- data_forager/sample_generators/aux/__init__.py +18 -0
- data_forager/sample_generators/aux/common.py +77 -0
- data_forager/sample_generators/aux/loss_mask.py +78 -0
- data_forager/sample_generators/common.py +117 -0
- data_forager/sample_generators/schema.py +54 -0
- data_forager/sample_generators/tokenization.py +210 -0
- data_forager/sample_generators/tokenization_with_aux.py +250 -0
- data_forager/sample_index.py +34 -2
- {data_forager-0.1.5.dist-info → data_forager-0.2.0.dist-info}/METADATA +1 -1
- data_forager-0.2.0.dist-info/RECORD +29 -0
- {data_forager-0.1.5.dist-info → data_forager-0.2.0.dist-info}/WHEEL +1 -1
- data_forager-0.1.5.dist-info/RECORD +0 -20
- {data_forager-0.1.5.dist-info → data_forager-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {data_forager-0.1.5.dist-info → data_forager-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,28 +1,46 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Factory function for creating tokenization and indexing pipelines.
|
|
2
3
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
4
|
+
This module provides a convenience function for setting up the complete pipeline
|
|
5
|
+
to tokenize JSONL text files and create an index for random access.
|
|
6
|
+
"""
|
|
6
7
|
|
|
7
|
-
import
|
|
8
|
+
from typing import Callable, Dict, List, Optional
|
|
8
9
|
|
|
9
|
-
import
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
10
13
|
|
|
11
|
-
from basics.base import Base
|
|
12
14
|
from basics.logging import get_logger
|
|
13
15
|
|
|
14
|
-
module_logger = get_logger(os.path.basename(__file__))
|
|
15
|
-
|
|
16
16
|
from data_forager.index_stores.common import IndexStoreInterface
|
|
17
17
|
from data_forager.index_stores.fs_based import IndexStore as FSBasedIndexStore
|
|
18
|
-
from data_forager.indexers.text_lines import
|
|
18
|
+
from data_forager.indexers.text_lines import FileTextLinesIndexer
|
|
19
|
+
from data_forager.sample_generators.tokenization import (
|
|
20
|
+
TokenizedSampleGenerator,
|
|
21
|
+
TokenizerFunc,
|
|
22
|
+
ProcessTextLineFunc,
|
|
23
|
+
)
|
|
24
|
+
from data_forager.sample_generators.aux.common import Part, AuxDataGenerator
|
|
19
25
|
from data_forager.utils import find_files_recursive, natural_sort
|
|
20
26
|
|
|
21
|
-
|
|
22
|
-
|
|
27
|
+
|
|
28
|
+
ProcessPartsFunc = Callable[[bytes], List[Part]]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
module_logger = get_logger(os.path.basename(__file__))
|
|
23
32
|
|
|
24
33
|
|
|
25
34
|
def get_text_from_jsonl(jsonl_bytes: bytes, text_key: str = "text", text_encoding: str = "utf-8") -> str:
|
|
35
|
+
"""
|
|
36
|
+
Extract text from a JSONL line.
|
|
37
|
+
|
|
38
|
+
:param jsonl_bytes: Raw bytes of the JSONL line.
|
|
39
|
+
:param text_key: Key in the JSON object containing the text.
|
|
40
|
+
:param text_encoding: Text encoding to use for decoding.
|
|
41
|
+
|
|
42
|
+
:return: The extracted text string.
|
|
43
|
+
"""
|
|
26
44
|
jsonl_text = jsonl_bytes.decode(text_encoding)
|
|
27
45
|
data = json.loads(jsonl_text)
|
|
28
46
|
return data[text_key]
|
|
@@ -44,9 +62,9 @@ def create_tokenize_and_index_jsonl_text_func(
|
|
|
44
62
|
Create a pipeline to tokenize text from JSONL files and create an index for random access.
|
|
45
63
|
|
|
46
64
|
The pipeline:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
65
|
+
* Tokenizes text from input JSONL objects
|
|
66
|
+
* Stores the token data in bin files under "tokenized-samples" folder
|
|
67
|
+
* Stores index data under "index" folder
|
|
50
68
|
|
|
51
69
|
Usage:
|
|
52
70
|
```python
|
|
@@ -165,179 +183,128 @@ def create_tokenize_and_index_jsonl_text_func(
|
|
|
165
183
|
)
|
|
166
184
|
|
|
167
185
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
self._rest_tokens = None
|
|
294
|
-
|
|
295
|
-
num_tokens = len(tokenized_text)
|
|
296
|
-
num_samples = num_tokens // self._sample_size
|
|
297
|
-
num_rest_tokens = num_tokens % self._sample_size
|
|
298
|
-
|
|
299
|
-
if num_rest_tokens > 0:
|
|
300
|
-
# Store remainder tokens (includes EOS from this document)
|
|
301
|
-
self._rest_tokens = tokenized_text[-num_rest_tokens:]
|
|
302
|
-
tokenized_text = tokenized_text[:num_samples * self._sample_size]
|
|
303
|
-
|
|
304
|
-
tokenized_samples = np.array(tokenized_text, dtype=self._token_dtype)
|
|
305
|
-
tokenized_samples = tokenized_samples.reshape(-1, self._sample_size)
|
|
306
|
-
else:
|
|
307
|
-
tokenized_samples = np.array([tokenized_text], dtype=self._token_dtype)
|
|
308
|
-
|
|
309
|
-
# Store tokenized_samples
|
|
310
|
-
sample_data = []
|
|
311
|
-
for sample_idx in range(tokenized_samples.shape[0]):
|
|
312
|
-
sample_bytes = tokenized_samples[sample_idx, :].tobytes()
|
|
313
|
-
sample_data.append(SampleData(
|
|
314
|
-
sample_bytes, self._current_samples_path,
|
|
315
|
-
))
|
|
316
|
-
|
|
317
|
-
self._current_samples_file.write(sample_bytes)
|
|
318
|
-
|
|
319
|
-
return sample_data
|
|
320
|
-
|
|
321
|
-
def finish(self, is_last_file: bool):
|
|
322
|
-
"""
|
|
323
|
-
## finish ##
|
|
324
|
-
* After all text lines are processed, the file holding the tokenized text samples is closed.
|
|
325
|
-
When `sample_size` not None: Any final trailing tokens will be discarded, but only when the last
|
|
326
|
-
input text file was processed.
|
|
327
|
-
|
|
328
|
-
:param is_last_file:
|
|
329
|
-
:return:
|
|
330
|
-
"""
|
|
331
|
-
self._close_current_samples_file()
|
|
332
|
-
|
|
333
|
-
if is_last_file and self._rest_tokens is not None:
|
|
334
|
-
self._log.debug(f"Cut off {len(self._rest_tokens)} unused tokens")
|
|
335
|
-
|
|
336
|
-
def _close_current_samples_file(self):
|
|
337
|
-
if self._current_samples_file:
|
|
338
|
-
self._log.debug(f"Closing tokenized samples file: \n{self._current_samples_path}")
|
|
339
|
-
self._current_samples_file.close()
|
|
340
|
-
self._current_samples_file = None
|
|
341
|
-
|
|
342
|
-
def __del__(self):
|
|
343
|
-
self._close_current_samples_file()
|
|
186
|
+
def create_tokenize_and_index_with_aux_func(
|
|
187
|
+
process_parts_func: ProcessPartsFunc,
|
|
188
|
+
tokenizer_func: TokenizerFunc,
|
|
189
|
+
eos_idx: int,
|
|
190
|
+
aux_generators: Dict[str, AuxDataGenerator],
|
|
191
|
+
input_base_path: Optional[str] = None,
|
|
192
|
+
input_file_paths: Optional[List[str]] = None,
|
|
193
|
+
output_base_path: Optional[str] = None,
|
|
194
|
+
index_store: Optional[IndexStoreInterface] = None,
|
|
195
|
+
logger: Optional[logging.Logger] = None,
|
|
196
|
+
name: Optional[str] = None,
|
|
197
|
+
**sample_generator_kwargs,
|
|
198
|
+
) -> FileTextLinesIndexer:
|
|
199
|
+
"""
|
|
200
|
+
Create a pipeline to tokenize structured samples with auxiliary data.
|
|
201
|
+
|
|
202
|
+
This function creates a pipeline that:
|
|
203
|
+
* Processes structured input (parts with types) from JSONL files
|
|
204
|
+
* Tokenizes each part and generates auxiliary data (e.g., loss masks)
|
|
205
|
+
* Stores concatenated token + aux data in bin files
|
|
206
|
+
* Creates an index with schema for random access
|
|
207
|
+
|
|
208
|
+
Usage:
|
|
209
|
+
```python
|
|
210
|
+
from data_forager.sample_generators.aux import Part, LossMaskGenerator
|
|
211
|
+
|
|
212
|
+
def parse_parts(line_bytes: bytes) -> List[Part]:
|
|
213
|
+
data = json.loads(line_bytes.decode('utf-8'))
|
|
214
|
+
return [Part(type=p['type'], text=p['text']) for p in data['parts']]
|
|
215
|
+
|
|
216
|
+
indexer = create_tokenize_and_index_with_aux_func(
|
|
217
|
+
process_parts_func=parse_parts,
|
|
218
|
+
tokenizer_func=tokenizer.encode,
|
|
219
|
+
eos_idx=tokenizer.eos_token_id,
|
|
220
|
+
aux_generators={'loss_mask': LossMaskGenerator()},
|
|
221
|
+
input_base_path='./data',
|
|
222
|
+
sample_size=4096,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
indexer()
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
:param process_parts_func: Function to extract typed parts from input bytes.
|
|
229
|
+
Takes JSONL bytes and returns List[Part].
|
|
230
|
+
:param tokenizer_func: Function used to tokenize text.
|
|
231
|
+
:param eos_idx: EOS token index, known by the used Tokenizer.
|
|
232
|
+
:param aux_generators: Dict mapping names to AuxDataGenerator instances.
|
|
233
|
+
Example: {'loss_mask': LossMaskGenerator()}
|
|
234
|
+
:param input_base_path: Path to directory containing JSONL files.
|
|
235
|
+
:param input_file_paths: List of file paths to process.
|
|
236
|
+
:param output_base_path: Base path for output (index and tokenized samples).
|
|
237
|
+
:param index_store: Index store to use. Must support set_sample_schema().
|
|
238
|
+
:param logger: Logger to use.
|
|
239
|
+
:param name: Name of the indexer for logging.
|
|
240
|
+
:param sample_generator_kwargs: Other kwargs passed to TokenizedSampleWithAuxGenerator
|
|
241
|
+
(e.g., sample_size, token_dtype).
|
|
242
|
+
|
|
243
|
+
:raises ValueError: If both input_base_path and input_file_paths are None.
|
|
244
|
+
:raises ValueError: If output destination cannot be determined.
|
|
245
|
+
|
|
246
|
+
:return: FileTextLinesIndexer instance that can be called to run the pipeline.
|
|
247
|
+
"""
|
|
248
|
+
# Import here to avoid circular dependency
|
|
249
|
+
from data_forager.sample_generators.tokenization_with_aux import (
|
|
250
|
+
TokenizedSampleWithAuxGenerator,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if logger is None:
|
|
254
|
+
logger = module_logger
|
|
255
|
+
|
|
256
|
+
# Validate input source
|
|
257
|
+
if input_base_path is None and input_file_paths is None:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
"Either input_base_path or input_file_paths must be provided"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Determine output base path
|
|
263
|
+
effective_output_base_path = output_base_path or input_base_path
|
|
264
|
+
|
|
265
|
+
# Validate output destination
|
|
266
|
+
if index_store is None and effective_output_base_path is None:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
"Either index_store, output_base_path, or input_base_path must be provided "
|
|
269
|
+
"to determine where to store the index"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
logger.info(f"Output base path: {effective_output_base_path}")
|
|
273
|
+
|
|
274
|
+
if index_store is None:
|
|
275
|
+
index_store = FSBasedIndexStore(
|
|
276
|
+
base_path=effective_output_base_path,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if input_file_paths is None:
|
|
280
|
+
logger.info(f"Scanning for JSONL files in: {input_base_path}")
|
|
281
|
+
input_file_paths = find_files_recursive(
|
|
282
|
+
input_base_path,
|
|
283
|
+
extension_patterns=['*.jsonl', '*.JSONL']
|
|
284
|
+
)
|
|
285
|
+
input_file_paths = natural_sort(input_file_paths)
|
|
286
|
+
logger.info(f"Found {len(input_file_paths)} JSONL file(s)")
|
|
287
|
+
|
|
288
|
+
# Set default base_output_path for tokenized samples if not provided
|
|
289
|
+
if 'base_output_path' not in sample_generator_kwargs:
|
|
290
|
+
default_base_output_path = os.path.join(
|
|
291
|
+
effective_output_base_path, "tokenized-samples"
|
|
292
|
+
)
|
|
293
|
+
logger.info(f"Tokenized samples output path: {default_base_output_path}")
|
|
294
|
+
sample_generator_kwargs['base_output_path'] = default_base_output_path
|
|
295
|
+
|
|
296
|
+
sample_generator = TokenizedSampleWithAuxGenerator(
|
|
297
|
+
process_parts_func=process_parts_func,
|
|
298
|
+
tokenizer_func=tokenizer_func,
|
|
299
|
+
eos_idx=eos_idx,
|
|
300
|
+
aux_generators=aux_generators,
|
|
301
|
+
**sample_generator_kwargs
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
return FileTextLinesIndexer(
|
|
305
|
+
input_file_paths=input_file_paths,
|
|
306
|
+
index_store=index_store,
|
|
307
|
+
sample_generator=sample_generator,
|
|
308
|
+
description="Tokenizing with aux data and indexing",
|
|
309
|
+
name=name,
|
|
310
|
+
)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sample generators for transforming input data into samples.
|
|
3
|
+
|
|
4
|
+
This package contains:
|
|
5
|
+
- SampleGeneratorInterface: Protocol for sample generators
|
|
6
|
+
- SampleData: Data class for sample information
|
|
7
|
+
- SampleSchema, ArraySpec: Schema classes for structured samples
|
|
8
|
+
- TokenizedSampleGenerator: Tokenizes text into fixed-length samples
|
|
9
|
+
- TokenizedSampleWithAuxGenerator: Tokenizes with auxiliary data (loss masks, etc.)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from data_forager.sample_generators.common import (
|
|
13
|
+
SampleData,
|
|
14
|
+
SampleGeneratorInterface,
|
|
15
|
+
NOOPSampleGenerator,
|
|
16
|
+
noop_sample_processing,
|
|
17
|
+
)
|
|
18
|
+
from data_forager.sample_generators.schema import (
|
|
19
|
+
ArraySpec,
|
|
20
|
+
SampleSchema,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ArraySpec",
|
|
25
|
+
"NOOPSampleGenerator",
|
|
26
|
+
"SampleData",
|
|
27
|
+
"SampleGeneratorInterface",
|
|
28
|
+
"SampleSchema",
|
|
29
|
+
"noop_sample_processing",
|
|
30
|
+
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Auxiliary data generators for sample generators.
|
|
3
|
+
|
|
4
|
+
This subpackage provides generators for auxiliary data that accompanies
|
|
5
|
+
tokenized samples, such as loss masks for selective training.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from data_forager.sample_generators.aux.common import (
|
|
9
|
+
AuxDataGenerator,
|
|
10
|
+
Part,
|
|
11
|
+
)
|
|
12
|
+
from data_forager.sample_generators.aux.loss_mask import LossMaskGenerator
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"AuxDataGenerator",
|
|
16
|
+
"LossMaskGenerator",
|
|
17
|
+
"Part",
|
|
18
|
+
]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common interfaces for auxiliary data generators.
|
|
3
|
+
|
|
4
|
+
This module provides the protocol and data structures for generating auxiliary
|
|
5
|
+
data (e.g., loss masks) alongside tokenized samples.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Protocol
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Part:
|
|
17
|
+
"""
|
|
18
|
+
A typed part of a structured sample.
|
|
19
|
+
|
|
20
|
+
Used to represent segments of text with semantic types (e.g., prompt,
|
|
21
|
+
response, system) that determine how auxiliary data is generated.
|
|
22
|
+
|
|
23
|
+
:param type: Semantic type of this part (e.g., "system", "prompt",
|
|
24
|
+
"response", "thinking", "text").
|
|
25
|
+
:param text: The text content of this part.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
type: str
|
|
29
|
+
text: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AuxDataGenerator(Protocol):
|
|
33
|
+
"""
|
|
34
|
+
Protocol for generating auxiliary data for tokenized samples.
|
|
35
|
+
|
|
36
|
+
Auxiliary data generators produce per-token data (e.g., loss masks)
|
|
37
|
+
based on the semantic type of each part in a structured sample.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def dtype(self) -> np.dtype:
|
|
42
|
+
"""
|
|
43
|
+
Return the NumPy dtype for this auxiliary data.
|
|
44
|
+
|
|
45
|
+
:return: NumPy dtype (e.g., np.uint8 for loss masks).
|
|
46
|
+
"""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
def generate(
|
|
50
|
+
self,
|
|
51
|
+
part_type: str,
|
|
52
|
+
num_tokens: int,
|
|
53
|
+
*,
|
|
54
|
+
part_tokens: List[int] | None = None,
|
|
55
|
+
) -> List[int]:
|
|
56
|
+
"""
|
|
57
|
+
Generate auxiliary data values for a tokenized part.
|
|
58
|
+
|
|
59
|
+
:param part_type: Semantic type of the part (e.g., "prompt", "response").
|
|
60
|
+
:param num_tokens: Number of tokens in this part.
|
|
61
|
+
:param part_tokens: Optional list of actual token IDs, for generators
|
|
62
|
+
that need token-level information.
|
|
63
|
+
|
|
64
|
+
:return: List of auxiliary data values, one per token.
|
|
65
|
+
"""
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
def generate_for_eos(self) -> int:
|
|
69
|
+
"""
|
|
70
|
+
Generate the auxiliary data value for an EOS token.
|
|
71
|
+
|
|
72
|
+
EOS tokens are inserted between documents. This method returns the
|
|
73
|
+
value to use for these boundary tokens.
|
|
74
|
+
|
|
75
|
+
:return: Single auxiliary data value for the EOS token.
|
|
76
|
+
"""
|
|
77
|
+
...
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss mask generator for selective training.
|
|
3
|
+
|
|
4
|
+
This module provides LossMaskGenerator which creates loss masks based on
|
|
5
|
+
part types, enabling selective training on specific parts of samples
|
|
6
|
+
(e.g., train on responses but not prompts).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import List, Set
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from data_forager.sample_generators.aux.common import AuxDataGenerator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LossMaskGenerator(AuxDataGenerator):
|
|
17
|
+
"""
|
|
18
|
+
Generates loss masks based on part types.
|
|
19
|
+
|
|
20
|
+
Loss mask semantics:
|
|
21
|
+
- mask=1: Excluded from loss (masked out, don't train)
|
|
22
|
+
- mask=0: Included in loss (train on these tokens)
|
|
23
|
+
|
|
24
|
+
Default masked types: "system", "prompt"
|
|
25
|
+
Default unmasked types: "text", "response", "thinking"
|
|
26
|
+
|
|
27
|
+
:param masked_types: Set of part types to mask (exclude from loss).
|
|
28
|
+
:param mask_eos: Whether to mask EOS tokens. Default False (train on EOS).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
masked_types: Set[str] | None = None,
|
|
34
|
+
mask_eos: bool = False,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the loss mask generator.
|
|
38
|
+
|
|
39
|
+
:param masked_types: Set of part types to mask. Defaults to {"system", "prompt"}.
|
|
40
|
+
:param mask_eos: Whether to mask EOS tokens. Default False.
|
|
41
|
+
"""
|
|
42
|
+
if masked_types is None:
|
|
43
|
+
masked_types = {"system", "prompt"}
|
|
44
|
+
|
|
45
|
+
self._masked_types = masked_types
|
|
46
|
+
self._mask_eos = mask_eos
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def dtype(self) -> np.dtype:
|
|
50
|
+
"""Return uint8 dtype for loss masks."""
|
|
51
|
+
return np.dtype(np.uint8)
|
|
52
|
+
|
|
53
|
+
def generate(
|
|
54
|
+
self,
|
|
55
|
+
part_type: str,
|
|
56
|
+
num_tokens: int,
|
|
57
|
+
*,
|
|
58
|
+
part_tokens: List[int] | None = None,
|
|
59
|
+
) -> List[int]:
|
|
60
|
+
"""
|
|
61
|
+
Generate loss mask values for a tokenized part.
|
|
62
|
+
|
|
63
|
+
:param part_type: Semantic type of the part.
|
|
64
|
+
:param num_tokens: Number of tokens in this part.
|
|
65
|
+
:param part_tokens: Unused, accepted for protocol compatibility.
|
|
66
|
+
|
|
67
|
+
:return: List of mask values (1=masked, 0=train).
|
|
68
|
+
"""
|
|
69
|
+
mask_value = 1 if part_type in self._masked_types else 0
|
|
70
|
+
return [mask_value] * num_tokens
|
|
71
|
+
|
|
72
|
+
def generate_for_eos(self) -> int:
|
|
73
|
+
"""
|
|
74
|
+
Generate the loss mask value for an EOS token.
|
|
75
|
+
|
|
76
|
+
:return: 1 if mask_eos is True, 0 otherwise.
|
|
77
|
+
"""
|
|
78
|
+
return 1 if self._mask_eos else 0
|