data-forager 0.1.6__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,117 @@
1
+ """
2
+ Common interfaces and utilities for sample generators.
3
+
4
+ This module provides the base protocol and data structures used by all sample generators.
5
+ """
6
+
7
+ from typing import List, Protocol, TYPE_CHECKING
8
+
9
+ from dataclasses import dataclass
10
+
11
+ if TYPE_CHECKING:
12
+ from data_forager.sample_generators.schema import SampleSchema
13
+
14
+
15
+ @dataclass
16
+ class SampleData:
17
+ """
18
+ Data class containing sample bytes and file location.
19
+
20
+ :param sample_bytes: The sample data as bytes.
21
+ :param file_path: Path to the file where the sample is stored.
22
+ """
23
+
24
+ sample_bytes: bytes
25
+ file_path: str
26
+
27
+
28
+ class SampleGeneratorInterface(Protocol):
29
+ """
30
+ Protocol defining the interface for sample generators.
31
+
32
+ Sample generators transform input data (e.g., text lines) into samples
33
+ that can be indexed and stored for random access during training.
34
+ """
35
+
36
+ def prepare(self, text_file_path: str):
37
+ """
38
+ Prepare sample generation from a new input text file.
39
+
40
+ :param text_file_path: Path to the input text file.
41
+ """
42
+ ...
43
+
44
+ def create_samples(self, text_line: bytes) -> List[SampleData]:
45
+ """
46
+ Create one or more samples from the given text_line and store them.
47
+
48
+ IMPORTANT: It is assumed that each sample returned is stored in a file
49
+ sequentially in the same order. This must also hold over multiple function
50
+ calls. This is important because the byte offset of a sample is derived
51
+ from the order the samples are returned.
52
+
53
+ :param text_line: Text line in bytes from text_file_path, provided in the
54
+ prepare phase. The function needs to choose a text encoding itself.
55
+
56
+ :return: List of SampleData objects. For each created sample:
57
+ - Its representation in bytes, as used to store the sample
58
+ - The file path to where the sample is stored
59
+ """
60
+ ...
61
+
62
+ def finish(self, is_last_file: bool):
63
+ """
64
+ Finish generation of samples from the current input file.
65
+
66
+ Called after all text lines from the input file have been processed.
67
+
68
+ :param is_last_file: Indicates if the input text file was the last file
69
+ to be processed.
70
+ """
71
+ ...
72
+
73
+ def get_sample_schema(self) -> "SampleSchema | None":
74
+ """
75
+ Return schema describing sample structure, or None for unstructured samples.
76
+
77
+ The schema describes how sample bytes should be interpreted (e.g., array
78
+ names, dtypes, and offsets for multi-array samples).
79
+
80
+ :return: SampleSchema if samples have structured format, None otherwise.
81
+ """
82
+ ...
83
+
84
+
85
+ class NOOPSampleGenerator:
86
+ """
87
+ A no-operation sample generator that returns input lines unchanged.
88
+
89
+ Used as the default generator when no transformation is needed.
90
+ """
91
+
92
+ def __init__(self):
93
+ self._current_text_file = None
94
+
95
+ def prepare(self, text_file_path: str):
96
+ self._current_text_file = text_file_path
97
+
98
+ def create_samples(self, text_line: bytes) -> List[SampleData]:
99
+ return [SampleData(text_line, self._current_text_file)]
100
+
101
+ def finish(self, is_last_file: bool):
102
+ self._current_text_file = None
103
+
104
+ def get_sample_schema(self) -> "SampleSchema | None":
105
+ return None
106
+
107
+
108
+ def noop_sample_processing(text_line: bytes, text_file_path: str) -> List[SampleData]:
109
+ """
110
+ Simple function that wraps a text line as a sample without transformation.
111
+
112
+ :param text_line: The input text line as bytes.
113
+ :param text_file_path: The file path to associate with the sample.
114
+
115
+ :return: List containing a single SampleData with the unchanged text line.
116
+ """
117
+ return [SampleData(text_line, text_file_path)]
@@ -0,0 +1,54 @@
1
+ """
2
+ Schema definitions for describing sample structure.
3
+
4
+ This module contains dataclasses that describe how sample bytes should be
5
+ interpreted, particularly for samples containing multiple arrays (e.g., tokens
6
+ and loss masks).
7
+ """
8
+
9
+ from typing import List, Optional
10
+
11
+ from dataclasses import dataclass, field
12
+
13
+
14
+ @dataclass
15
+ class ArraySpec:
16
+ """
17
+ Specification for a single array within a sample.
18
+
19
+ :param name: Name of the array (e.g., "tokens", "loss_mask").
20
+ :param dtype: NumPy dtype string (e.g., "uint32", "uint8").
21
+ :param offset: Byte offset within the sample where this array starts.
22
+ """
23
+
24
+ name: str
25
+ dtype: str
26
+ offset: int
27
+
28
+
29
+ @dataclass
30
+ class SampleSchema:
31
+ """
32
+ Schema describing the structure of samples.
33
+
34
+ Used when samples contain multiple arrays (e.g., tokens + auxiliary data).
35
+ Each array has a name, dtype, and byte offset within the sample.
36
+
37
+ :param sample_size: Number of elements per array (e.g., context length).
38
+ :param arrays: List of ArraySpec describing each array in the sample.
39
+ :param total_bytes_per_sample: Total size of each sample in bytes.
40
+ If None, computed automatically from arrays.
41
+ """
42
+
43
+ sample_size: int
44
+ arrays: List[ArraySpec] = field(default_factory=list)
45
+ total_bytes_per_sample: Optional[int] = None
46
+
47
+ def __post_init__(self):
48
+ """Compute total_bytes_per_sample if not provided."""
49
+ if self.total_bytes_per_sample is None and self.arrays:
50
+ import numpy as np
51
+ total = 0
52
+ for arr in self.arrays:
53
+ total += self.sample_size * np.dtype(arr.dtype).itemsize
54
+ self.total_bytes_per_sample = total
@@ -0,0 +1,210 @@
1
+ """
2
+ Tokenized sample generator for converting text into fixed-length token samples.
3
+
4
+ This module provides TokenizedSampleGenerator which tokenizes text and produces
5
+ samples of a fixed context length, suitable for language model training.
6
+ """
7
+
8
+ from typing import Callable, List, Optional
9
+
10
+ import os
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+
15
+ from basics.base import Base
16
+
17
+ from data_forager.sample_generators.common import SampleData, SampleGeneratorInterface
18
+
19
+
20
+ TokenizerFunc = Callable[[str], List[int]]
21
+ ProcessTextLineFunc = Callable[[bytes], str]
22
+
23
+
24
+ class TokenizedSampleGenerator(Base, SampleGeneratorInterface):
25
+ """
26
+ Tokenizes text into fixed-length samples for language model training.
27
+
28
+ This generator:
29
+ 1. Processes input text lines using a configurable function
30
+ 2. Tokenizes the text using a provided tokenizer
31
+ 3. Splits tokens into fixed-length samples
32
+ 4. Handles document boundaries with EOS tokens
33
+ 5. Carries over remainder tokens across documents
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ process_text_line_func: ProcessTextLineFunc,
39
+ tokenizer_func: TokenizerFunc,
40
+ eos_idx: int,
41
+ token_dtype: np.dtype = np.uint16,
42
+ sample_size: Optional[int] = None,
43
+ base_output_path: str = None,
44
+ file_name_postfix: str = "tokenized-samples",
45
+ name: Optional[str] = None
46
+ ):
47
+ """
48
+ Initialize the tokenized sample generator.
49
+
50
+ Tokenizes and indexes text into fixed length (`sample_size` not None) samples or
51
+ samples of variable size, depending on the text (`sample_size` is None).
52
+
53
+ This callable performs the following steps:
54
+
55
+ ## prepare ##
56
+ * In the preparation step, create file to store tokenized samples, based on the
57
+ input `text_file_path` and the given `base_output_path` and `file_name_postfix`
58
+ * If the `base_output_path` is not given the `text_file_path` will be used +
59
+ "/tokenized-samples"
60
+
61
+ ## create_samples ##
62
+ * To create tokenized text samples, processes incoming text line using
63
+ `process_text_line_func`, e.g. convert JSONL into dict and retrieve the sample
64
+ text from it.
65
+ * The resulting text is tokenized using `tokenizer_func`.
66
+ * If a `sample_size` is given:
67
+ The tokenized text is split into samples of length `sample_size` and stored in
68
+ the file opened in the prepare step. Here `token_dtype` is used.
69
+ - Trailing tokens will be combined with samples of a next text line
70
+ - Tokens from different text samples will be separated by `eos_idx`
71
+ * If a `sample_size` is not given:
72
+ The tokenized text is immediately stored as is, in the file opened in the
73
+ prepare step. Here `token_dtype` is used.
74
+
75
+ ## finish ##
76
+ * After all text lines are processed, the file holding the tokenized text samples
77
+ is closed. When `sample_size` is not None: Any final trailing tokens will be
78
+ discarded, but only when the last input text file was processed.
79
+
80
+ :param process_text_line_func: Function to extract text from input bytes.
81
+ :param tokenizer_func: Function to tokenize text into token IDs.
82
+ :param eos_idx: End-of-sequence token ID for document boundaries.
83
+ :param token_dtype: NumPy dtype for storing tokens (default: uint16).
84
+ :param sample_size: Fixed sample length, or None for variable-length samples.
85
+ :param base_output_path: Base path for output files.
86
+ :param file_name_postfix: Postfix for output file names.
87
+ :param name: Name for logging purposes.
88
+ """
89
+ super().__init__(pybase_logger_name=name)
90
+
91
+ if sample_size is None:
92
+ self._log.info("Tokenized text will NOT be broken into samples of fixed length.")
93
+
94
+ self._process_text_line_func = process_text_line_func
95
+ self._tokenizer_func = tokenizer_func
96
+ self._eos_idx = eos_idx
97
+ self._token_dtype = token_dtype
98
+ self._sample_size = sample_size
99
+ self._base_output_path = base_output_path
100
+ self._file_name_postfix = file_name_postfix
101
+
102
+ self._current_samples_path = None
103
+ self._current_samples_file = None
104
+
105
+ self._rest_tokens = None
106
+
107
+ def prepare(self, text_file_path: str):
108
+ """
109
+ Prepare for processing a new input file.
110
+
111
+ Creates the output file for storing tokenized samples based on the input
112
+ `text_file_path` and the configured `base_output_path` and `file_name_postfix`.
113
+
114
+ :param text_file_path: Path to the input text file.
115
+ """
116
+ input_file_path = os.path.dirname(text_file_path)
117
+ input_file_name = Path(text_file_path).stem
118
+ output_file_name = f"{input_file_name}-{self._file_name_postfix}.bin"
119
+
120
+ output_path = self._base_output_path
121
+ if self._base_output_path is None:
122
+ output_path = os.path.join(input_file_path, "tokenized-samples")
123
+
124
+ os.makedirs(output_path, exist_ok=True)
125
+
126
+ output_file_path = os.path.join(output_path, output_file_name)
127
+ if os.path.exists(output_file_path):
128
+ raise FileExistsError(f"Tokenized samples file already exists: \n{output_file_path}")
129
+
130
+ self._current_samples_path = output_file_path
131
+ self._current_samples_file = open(output_file_path, "wb")
132
+
133
+ self._log.debug(f"Tokenized samples file opened: \n{output_file_path}")
134
+
135
+ def create_samples(self, text_line: bytes) -> List[SampleData]:
136
+ """
137
+ Create tokenized samples from a text line.
138
+
139
+ Processes the input text line, tokenizes it, and splits into fixed-length
140
+ samples (if sample_size is set). Handles document boundaries with EOS tokens
141
+ and carries over remainder tokens.
142
+
143
+ :param text_line: JSONL text line as bytes.
144
+
145
+ :return: List of SampleData objects containing the tokenized samples.
146
+ """
147
+ input_text = self._process_text_line_func(text_line)
148
+ tokenized_text = self._tokenizer_func(input_text)
149
+
150
+ if self._sample_size is not None:
151
+ # Always append EOS after each document to mark document boundary
152
+ tokenized_text = tokenized_text + [self._eos_idx]
153
+
154
+ # Prepend any leftover tokens from previous document
155
+ if self._rest_tokens is not None:
156
+ tokenized_text = self._rest_tokens + tokenized_text
157
+ self._rest_tokens = None
158
+
159
+ num_tokens = len(tokenized_text)
160
+ num_samples = num_tokens // self._sample_size
161
+ num_rest_tokens = num_tokens % self._sample_size
162
+
163
+ if num_rest_tokens > 0:
164
+ # Store remainder tokens (includes EOS from this document)
165
+ self._rest_tokens = tokenized_text[-num_rest_tokens:]
166
+ tokenized_text = tokenized_text[:num_samples * self._sample_size]
167
+
168
+ tokenized_samples = np.array(tokenized_text, dtype=self._token_dtype)
169
+ tokenized_samples = tokenized_samples.reshape(-1, self._sample_size)
170
+ else:
171
+ tokenized_samples = np.array([tokenized_text], dtype=self._token_dtype)
172
+
173
+ # Store tokenized_samples
174
+ sample_data = []
175
+ for sample_idx in range(tokenized_samples.shape[0]):
176
+ sample_bytes = tokenized_samples[sample_idx, :].tobytes()
177
+ sample_data.append(SampleData(
178
+ sample_bytes, self._current_samples_path,
179
+ ))
180
+
181
+ self._current_samples_file.write(sample_bytes)
182
+
183
+ return sample_data
184
+
185
+ def finish(self, is_last_file: bool):
186
+ """
187
+ Finish processing the current input file.
188
+
189
+ Closes the current samples file. When `sample_size` is not None, any final
190
+ trailing tokens will be discarded only when the last input file was processed.
191
+
192
+ :param is_last_file: Whether this is the last input file.
193
+ """
194
+ self._close_current_samples_file()
195
+
196
+ if is_last_file and self._rest_tokens is not None:
197
+ self._log.debug(f"Cut off {len(self._rest_tokens)} unused tokens")
198
+
199
+ def get_sample_schema(self):
200
+ """Return None as this generator produces simple token arrays."""
201
+ return None
202
+
203
+ def _close_current_samples_file(self):
204
+ if self._current_samples_file:
205
+ self._log.debug(f"Closing tokenized samples file: \n{self._current_samples_path}")
206
+ self._current_samples_file.close()
207
+ self._current_samples_file = None
208
+
209
+ def __del__(self):
210
+ self._close_current_samples_file()
@@ -0,0 +1,250 @@
1
+ """
2
+ Tokenized sample generator with auxiliary data support.
3
+
4
+ This module provides TokenizedSampleWithAuxGenerator which tokenizes structured
5
+ samples (with typed parts) and generates auxiliary data (e.g., loss masks)
6
+ alongside the tokens.
7
+ """
8
+
9
+ from typing import Callable, Dict, List, Optional
10
+
11
+ import os
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+
16
+ from basics.base import Base
17
+
18
+ from data_forager.sample_generators.common import SampleData, SampleGeneratorInterface
19
+ from data_forager.sample_generators.schema import SampleSchema, ArraySpec
20
+ from data_forager.sample_generators.aux.common import AuxDataGenerator, Part
21
+
22
+
23
+ ProcessPartsFunc = Callable[[bytes], List[Part]]
24
+ TokenizerFunc = Callable[[str], List[int]]
25
+
26
+
27
+ class TokenizedSampleWithAuxGenerator(Base, SampleGeneratorInterface):
28
+ """
29
+ Tokenizes structured samples and generates auxiliary data.
30
+
31
+ This generator handles samples with typed parts (e.g., prompt, response)
32
+ and produces both tokens and auxiliary data (e.g., loss masks) in a single
33
+ concatenated output.
34
+
35
+ Input format: JSONL with {"parts": [{"type": "...", "text": "..."}, ...]}
36
+ Output format: Concatenated bytes [tokens][aux1][aux2]...
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ process_parts_func: ProcessPartsFunc,
42
+ tokenizer_func: TokenizerFunc,
43
+ eos_idx: int,
44
+ aux_generators: Dict[str, AuxDataGenerator],
45
+ token_dtype: np.dtype = np.uint32,
46
+ sample_size: int = 4096,
47
+ base_output_path: Optional[str] = None,
48
+ file_name_postfix: str = "tokenized-samples",
49
+ name: Optional[str] = None,
50
+ ):
51
+ """
52
+ Initialize the tokenized sample generator with auxiliary data support.
53
+
54
+ :param process_parts_func: Function to extract typed parts from input bytes.
55
+ Takes JSONL bytes and returns List[Part].
56
+ :param tokenizer_func: Function to tokenize text into token IDs.
57
+ :param eos_idx: End-of-sequence token ID for document boundaries.
58
+ :param aux_generators: Dict mapping names to AuxDataGenerator instances.
59
+ Example: {"loss_mask": LossMaskGenerator()}
60
+ :param token_dtype: NumPy dtype for storing tokens (default: uint32).
61
+ :param sample_size: Fixed sample length in tokens.
62
+ :param base_output_path: Base path for output files.
63
+ :param file_name_postfix: Postfix for output file names.
64
+ :param name: Name for logging purposes.
65
+ """
66
+ super().__init__(pybase_logger_name=name)
67
+
68
+ if not aux_generators:
69
+ raise ValueError("aux_generators must not be empty")
70
+
71
+ self._process_parts_func = process_parts_func
72
+ self._tokenizer_func = tokenizer_func
73
+ self._eos_idx = eos_idx
74
+ self._aux_generators = aux_generators
75
+ self._token_dtype = np.dtype(token_dtype)
76
+ self._sample_size = sample_size
77
+ self._base_output_path = base_output_path
78
+ self._file_name_postfix = file_name_postfix
79
+
80
+ # Build schema
81
+ self._sample_schema = self._build_schema()
82
+
83
+ # Current output state
84
+ self._current_samples_path: Optional[str] = None
85
+ self._current_samples_file = None
86
+
87
+ # Remainder from previous document
88
+ self._rest_tokens: Optional[List[int]] = None
89
+ self._rest_aux: Optional[Dict[str, List[int]]] = None
90
+
91
+ def _build_schema(self) -> SampleSchema:
92
+ """Build the sample schema from token dtype and aux generators."""
93
+ arrays = []
94
+ offset = 0
95
+
96
+ # Tokens array first
97
+ token_bytes = self._sample_size * self._token_dtype.itemsize
98
+ arrays.append(ArraySpec(name="tokens", dtype=str(self._token_dtype), offset=offset))
99
+ offset += token_bytes
100
+
101
+ # Auxiliary arrays in sorted order (for deterministic output)
102
+ for name in sorted(self._aux_generators.keys()):
103
+ gen = self._aux_generators[name]
104
+ aux_bytes = self._sample_size * gen.dtype.itemsize
105
+ arrays.append(ArraySpec(name=name, dtype=str(gen.dtype), offset=offset))
106
+ offset += aux_bytes
107
+
108
+ return SampleSchema(
109
+ sample_size=self._sample_size,
110
+ arrays=arrays,
111
+ total_bytes_per_sample=offset,
112
+ )
113
+
114
+ def get_sample_schema(self) -> SampleSchema:
115
+ """Return the sample schema describing the output structure."""
116
+ return self._sample_schema
117
+
118
+ def prepare(self, text_file_path: str):
119
+ """
120
+ Prepare for processing a new input file.
121
+
122
+ Creates the output file for storing tokenized samples with aux data.
123
+
124
+ :param text_file_path: Path to the input text file.
125
+ """
126
+ input_file_path = os.path.dirname(text_file_path)
127
+ input_file_name = Path(text_file_path).stem
128
+ output_file_name = f"{input_file_name}-{self._file_name_postfix}.bin"
129
+
130
+ output_path = self._base_output_path
131
+ if self._base_output_path is None:
132
+ output_path = os.path.join(input_file_path, "tokenized-samples")
133
+
134
+ os.makedirs(output_path, exist_ok=True)
135
+
136
+ output_file_path = os.path.join(output_path, output_file_name)
137
+ if os.path.exists(output_file_path):
138
+ raise FileExistsError(f"Tokenized samples file already exists: \n{output_file_path}")
139
+
140
+ self._current_samples_path = output_file_path
141
+ self._current_samples_file = open(output_file_path, "wb")
142
+
143
+ self._log.debug(f"Tokenized samples file opened: \n{output_file_path}")
144
+
145
+ def create_samples(self, text_line: bytes) -> List[SampleData]:
146
+ """
147
+ Create tokenized samples with auxiliary data from a structured input.
148
+
149
+ Processes the input parts, tokenizes each part, generates auxiliary data,
150
+ and splits into fixed-length samples.
151
+
152
+ :param text_line: JSONL line as bytes containing structured parts.
153
+
154
+ :return: List of SampleData objects containing the tokenized samples.
155
+ """
156
+ parts = self._process_parts_func(text_line)
157
+
158
+ # Tokenize all parts and generate aux data
159
+ doc_tokens: List[int] = []
160
+ doc_aux: Dict[str, List[int]] = {name: [] for name in self._aux_generators}
161
+
162
+ for part in parts:
163
+ part_tokens = self._tokenizer_func(part.text)
164
+ doc_tokens.extend(part_tokens)
165
+
166
+ for name, gen in self._aux_generators.items():
167
+ aux_values = gen.generate(
168
+ part_type=part.type,
169
+ num_tokens=len(part_tokens),
170
+ part_tokens=part_tokens,
171
+ )
172
+ doc_aux[name].extend(aux_values)
173
+
174
+ # Add EOS token and aux values
175
+ doc_tokens.append(self._eos_idx)
176
+ for name, gen in self._aux_generators.items():
177
+ doc_aux[name].append(gen.generate_for_eos())
178
+
179
+ # Prepend any leftover from previous document
180
+ # Invariant: rest_tokens and all rest_aux[*] have the same length
181
+ if self._rest_tokens is not None:
182
+ assert all(
183
+ len(self._rest_aux[name]) == len(self._rest_tokens)
184
+ for name in self._aux_generators
185
+ ), "Invariant violated: rest_tokens and rest_aux must have same length"
186
+
187
+ doc_tokens = self._rest_tokens + doc_tokens
188
+ for name in self._aux_generators:
189
+ doc_aux[name] = self._rest_aux[name] + doc_aux[name]
190
+ self._rest_tokens = None
191
+ self._rest_aux = None
192
+
193
+ # Split into samples
194
+ num_tokens = len(doc_tokens)
195
+ num_samples = num_tokens // self._sample_size
196
+ num_rest_tokens = num_tokens % self._sample_size
197
+
198
+ if num_rest_tokens > 0:
199
+ # Store remainder
200
+ self._rest_tokens = doc_tokens[-num_rest_tokens:]
201
+ self._rest_aux = {name: doc_aux[name][-num_rest_tokens:] for name in self._aux_generators}
202
+ doc_tokens = doc_tokens[:num_samples * self._sample_size]
203
+ for name in self._aux_generators:
204
+ doc_aux[name] = doc_aux[name][:num_samples * self._sample_size]
205
+
206
+ # Write samples
207
+ sample_data = []
208
+ for i in range(num_samples):
209
+ start = i * self._sample_size
210
+ end = start + self._sample_size
211
+
212
+ # Build concatenated sample bytes
213
+ sample_tokens = doc_tokens[start:end]
214
+ sample_bytes = np.array(sample_tokens, dtype=self._token_dtype).tobytes()
215
+
216
+ # Add aux data in sorted order
217
+ for name in sorted(self._aux_generators.keys()):
218
+ gen = self._aux_generators[name]
219
+ aux_values = doc_aux[name][start:end]
220
+ sample_bytes += np.array(aux_values, dtype=gen.dtype).tobytes()
221
+
222
+ sample_data.append(SampleData(sample_bytes, self._current_samples_path))
223
+ self._current_samples_file.write(sample_bytes)
224
+
225
+ return sample_data
226
+
227
+ def finish(self, is_last_file: bool):
228
+ """
229
+ Finish processing the current input file.
230
+
231
+ Closes the current samples file. When is_last_file is True, any
232
+ remaining tokens are discarded.
233
+
234
+ :param is_last_file: Whether this is the last input file.
235
+ """
236
+ self._close_current_samples_file()
237
+
238
+ if is_last_file and self._rest_tokens is not None:
239
+ self._log.debug(f"Cut off {len(self._rest_tokens)} unused tokens")
240
+ self._rest_tokens = None
241
+ self._rest_aux = None
242
+
243
+ def _close_current_samples_file(self):
244
+ if self._current_samples_file:
245
+ self._log.debug(f"Closing tokenized samples file: \n{self._current_samples_path}")
246
+ self._current_samples_file.close()
247
+ self._current_samples_file = None
248
+
249
+ def __del__(self):
250
+ self._close_current_samples_file()
@@ -1,12 +1,29 @@
1
- from typing import List, Optional, Dict, Protocol
1
+ """
2
+ Sample index data structures for random access to samples.
2
3
 
3
- from dataclasses import dataclass
4
+ This module provides the core data structures for indexing samples:
5
+ - SampleLocation: Location of a single sample (file path, byte offset, size)
6
+ - SampleIndex: Index containing all sample locations and optional schema
7
+ """
8
+
9
+ from typing import List, Optional
10
+
11
+ from dataclasses import dataclass, field
4
12
 
5
13
  import numpy as np
6
14
 
15
+ from data_forager.sample_generators.schema import SampleSchema
16
+
7
17
 
8
18
  @dataclass
9
19
  class SampleLocation:
20
+ """
21
+ Location information for a single sample.
22
+
23
+ :param file_path: Path to the file containing the sample.
24
+ :param byte_offset: Byte offset within the file where the sample starts.
25
+ :param num_bytes: Size of the sample in bytes.
26
+ """
10
27
 
11
28
  file_path: str
12
29
  byte_offset: int
@@ -20,6 +37,18 @@ class SampleLocation:
20
37
 
21
38
  @dataclass
22
39
  class SampleIndex:
40
+ """
41
+ Index for random access to samples stored on disk.
42
+
43
+ Contains file locations and byte offsets for all samples, enabling O(1)
44
+ random access via seek operations.
45
+
46
+ :param file_locations: List of file paths containing samples.
47
+ :param sample_locations: 2D numpy array of shape (num_samples, 3) where
48
+ each row contains (file_index, byte_offset, num_bytes).
49
+ :param sample_schema: Optional schema describing sample structure for
50
+ samples with auxiliary data (e.g., tokens + loss_mask).
51
+ """
23
52
 
24
53
  file_locations: List[str]
25
54
 
@@ -30,6 +59,9 @@ class SampleIndex:
30
59
  # num_bytes : uint64
31
60
  sample_locations: np.ndarray
32
61
 
62
+ # Optional schema for structured samples (e.g., tokens + auxiliary data)
63
+ sample_schema: Optional[SampleSchema] = field(default=None)
64
+
33
65
  def __len__(self) -> int:
34
66
  return self.sample_locations.shape[0]
35
67