nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/__init__.py CHANGED
@@ -12,18 +12,18 @@ Quick Start
12
12
  -----------
13
13
  >>> from nextrec.basic.features import DenseFeature, SparseFeature
14
14
  >>> from nextrec.models.ranking.deepfm import DeepFM
15
- >>>
15
+ >>>
16
16
  >>> # Define features
17
17
  >>> dense_features = [DenseFeature('age')]
18
18
  >>> sparse_features = [SparseFeature('category', vocab_size=100, embedding_dim=16)]
19
- >>>
19
+ >>>
20
20
  >>> # Build model
21
21
  >>> model = DeepFM(
22
22
  ... dense_features=dense_features,
23
23
  ... sparse_features=sparse_features,
24
24
  ... targets=['label']
25
25
  ... )
26
- >>>
26
+ >>>
27
27
  >>> # Train model
28
28
  >>> model.fit(train_data=df_train, valid_data=df_valid)
29
29
  """
@@ -31,7 +31,7 @@ Quick Start
31
31
  from nextrec.__version__ import __version__
32
32
 
33
33
  __all__ = [
34
- '__version__',
34
+ "__version__",
35
35
  ]
36
36
 
37
37
  # Package metadata
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.1"
1
+ __version__ = "0.1.2"
@@ -14,40 +14,41 @@ class Dice(nn.Module):
14
14
  """
15
15
  Dice activation function from the paper:
16
16
  "Deep Interest Network for Click-Through Rate Prediction" (Zhou et al., 2018)
17
-
17
+
18
18
  Dice(x) = p(x) * x + (1 - p(x)) * alpha * x
19
19
  where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
20
20
  """
21
+
21
22
  def __init__(self, emb_size: int, epsilon: float = 1e-9):
22
23
  super(Dice, self).__init__()
23
24
  self.epsilon = epsilon
24
25
  self.alpha = nn.Parameter(torch.zeros(emb_size))
25
26
  self.bn = nn.BatchNorm1d(emb_size)
26
-
27
+
27
28
  def forward(self, x):
28
29
  # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
29
30
  original_shape = x.shape
30
-
31
+
31
32
  if x.dim() == 3:
32
33
  # For 3D input (batch_size, seq_len, emb_size), reshape to 2D
33
34
  batch_size, seq_len, emb_size = x.shape
34
35
  x = x.view(-1, emb_size)
35
-
36
+
36
37
  x_norm = self.bn(x)
37
38
  p = torch.sigmoid(x_norm)
38
39
  output = p * x + (1 - p) * self.alpha * x
39
-
40
+
40
41
  if len(original_shape) == 3:
41
42
  output = output.view(original_shape)
42
-
43
+
43
44
  return output
44
45
 
45
46
 
46
47
  def activation_layer(activation: str, emb_size: int | None = None):
47
48
  """Create an activation layer based on the given activation name."""
48
-
49
+
49
50
  activation = activation.lower()
50
-
51
+
51
52
  if activation == "dice":
52
53
  if emb_size is None:
53
54
  raise ValueError("emb_size is required for Dice activation")
@@ -89,4 +90,4 @@ def activation_layer(activation: str, emb_size: int | None = None):
89
90
  elif activation in ["none", "linear", "identity"]:
90
91
  return nn.Identity()
91
92
  else:
92
- raise ValueError(f"Unsupported activation function: {activation}")
93
+ raise ValueError(f"Unsupported activation function: {activation}")
nextrec/basic/callback.py CHANGED
@@ -8,6 +8,7 @@ Author:
8
8
 
9
9
  import copy
10
10
 
11
+
11
12
  class EarlyStopper(object):
12
13
  def __init__(self, patience: int = 20, mode: str = "max"):
13
14
  self.patience = patience
@@ -30,16 +30,18 @@ class FileDataset(IterableDataset):
30
30
  Iterable dataset for reading multiple files in batches.
31
31
  Supports CSV and Parquet files with chunk-based reading.
32
32
  """
33
-
34
- def __init__(self,
35
- file_paths: list[str], # file paths to read, containing CSV or Parquet files
36
- dense_features: list[DenseFeature], # dense feature definitions
37
- sparse_features: list[SparseFeature], # sparse feature definitions
38
- sequence_features: list[SequenceFeature], # sequence feature definitions
39
- target_columns: list[str], # target column names
40
- chunk_size: int = 10000,
41
- file_type: Literal['csv', 'parquet'] = 'csv',
42
- processor: Optional['DataProcessor'] = None): # optional DataProcessor for transformation
33
+
34
+ def __init__(
35
+ self,
36
+ file_paths: list[str], # file paths to read, containing CSV or Parquet files
37
+ dense_features: list[DenseFeature], # dense feature definitions
38
+ sparse_features: list[SparseFeature], # sparse feature definitions
39
+ sequence_features: list[SequenceFeature], # sequence feature definitions
40
+ target_columns: list[str], # target column names
41
+ chunk_size: int = 10000,
42
+ file_type: Literal["csv", "parquet"] = "csv",
43
+ processor: Optional["DataProcessor"] = None,
44
+ ): # optional DataProcessor for transformation
43
45
 
44
46
  self.file_paths = file_paths
45
47
  self.dense_features = dense_features
@@ -49,30 +51,30 @@ class FileDataset(IterableDataset):
49
51
  self.chunk_size = chunk_size
50
52
  self.file_type = file_type
51
53
  self.processor = processor
52
-
54
+
53
55
  self.all_features = dense_features + sparse_features + sequence_features
54
56
  self.feature_names = [f.name for f in self.all_features]
55
57
  self.current_file_index = 0
56
58
  self.total_files = len(file_paths)
57
-
59
+
58
60
  def __iter__(self) -> Iterator[tuple]:
59
61
  self.current_file_index = 0
60
62
  self._file_pbar = None
61
-
63
+
62
64
  # Create progress bar for file processing when multiple files
63
65
  if self.total_files > 1:
64
66
  self._file_pbar = tqdm.tqdm(
65
- total=self.total_files,
66
- desc="Files",
67
+ total=self.total_files,
68
+ desc="Files",
67
69
  unit="file",
68
70
  position=0,
69
71
  leave=True,
70
- bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
72
+ bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
71
73
  )
72
-
74
+
73
75
  for file_path in self.file_paths:
74
76
  self.current_file_index += 1
75
-
77
+
76
78
  # Update file progress bar
77
79
  if self._file_pbar is not None:
78
80
  self._file_pbar.update(1)
@@ -80,48 +82,51 @@ class FileDataset(IterableDataset):
80
82
  # For single file, log the file name
81
83
  file_name = os.path.basename(file_path)
82
84
  logging.info(colorize(f"Processing file: {file_name}", color="cyan"))
83
-
84
- if self.file_type == 'csv':
85
+
86
+ if self.file_type == "csv":
85
87
  yield from self._read_csv_chunks(file_path)
86
- elif self.file_type == 'parquet':
88
+ elif self.file_type == "parquet":
87
89
  yield from self._read_parquet_chunks(file_path)
88
-
90
+
89
91
  # Close file progress bar
90
92
  if self._file_pbar is not None:
91
93
  self._file_pbar.close()
92
-
94
+
93
95
  def _read_csv_chunks(self, file_path: str) -> Iterator[tuple]:
94
96
  chunk_iterator = pd.read_csv(file_path, chunksize=self.chunk_size)
95
-
97
+
96
98
  for chunk in chunk_iterator:
97
99
  tensors = self._dataframe_to_tensors(chunk)
98
100
  if tensors:
99
101
  yield tensors
100
-
102
+
101
103
  def _read_parquet_chunks(self, file_path: str) -> Iterator[tuple]:
102
104
  """
103
105
  Read parquet file in chunks to reduce memory footprint.
104
106
  Uses pyarrow's batch reading for true streaming.
105
107
  """
106
108
  import pyarrow.parquet as pq
109
+
107
110
  parquet_file = pq.ParquetFile(file_path)
108
111
  for batch in parquet_file.iter_batches(batch_size=self.chunk_size):
109
- chunk = batch.to_pandas()
112
+ chunk = batch.to_pandas()
110
113
  tensors = self._dataframe_to_tensors(chunk)
111
114
  if tensors:
112
115
  yield tensors
113
116
  del chunk
114
-
117
+
115
118
  def _dataframe_to_tensors(self, df: pd.DataFrame) -> tuple | None:
116
119
  if self.processor is not None:
117
120
  if not self.processor.is_fitted:
118
- raise ValueError("DataProcessor must be fitted before using in streaming mode")
121
+ raise ValueError(
122
+ "DataProcessor must be fitted before using in streaming mode"
123
+ )
119
124
  transformed_data = self.processor.transform(df, return_dict=True)
120
125
  else:
121
126
  transformed_data = df
122
-
127
+
123
128
  tensors = []
124
-
129
+
125
130
  # Process features
126
131
  for feature in self.all_features:
127
132
  if self.processor is not None:
@@ -131,10 +136,15 @@ class FileDataset(IterableDataset):
131
136
  else:
132
137
  # Get data from original dataframe
133
138
  if feature.name not in df.columns:
134
- logging.warning(colorize(f"Feature column '{feature.name}' not found in DataFrame", "yellow"))
139
+ logging.warning(
140
+ colorize(
141
+ f"Feature column '{feature.name}' not found in DataFrame",
142
+ "yellow",
143
+ )
144
+ )
135
145
  continue
136
146
  column_data = df[feature.name].values
137
-
147
+
138
148
  # Handle sequence features: convert to 2D array of shape (batch_size, seq_length)
139
149
  if isinstance(feature, SequenceFeature):
140
150
  if isinstance(column_data, np.ndarray) and column_data.dtype == object:
@@ -143,25 +153,29 @@ class FileDataset(IterableDataset):
143
153
  except (ValueError, TypeError) as e:
144
154
  # Fallback: handle variable-length sequences by padding
145
155
  sequences = []
146
- max_len = feature.max_len if hasattr(feature, 'max_len') else 0
156
+ max_len = feature.max_len if hasattr(feature, "max_len") else 0
147
157
  for seq in column_data:
148
158
  if isinstance(seq, (list, tuple, np.ndarray)):
149
159
  seq_arr = np.asarray(seq, dtype=np.int64)
150
160
  else:
151
161
  seq_arr = np.array([], dtype=np.int64)
152
162
  sequences.append(seq_arr)
153
-
163
+
154
164
  # Pad sequences to same length
155
165
  if max_len == 0:
156
- max_len = max(len(seq) for seq in sequences) if sequences else 1
157
-
166
+ max_len = (
167
+ max(len(seq) for seq in sequences) if sequences else 1
168
+ )
169
+
158
170
  padded = []
159
171
  for seq in sequences:
160
172
  if len(seq) > max_len:
161
173
  padded.append(seq[:max_len])
162
174
  else:
163
175
  pad_width = max_len - len(seq)
164
- padded.append(np.pad(seq, (0, pad_width), constant_values=0))
176
+ padded.append(
177
+ np.pad(seq, (0, pad_width), constant_values=0)
178
+ )
165
179
  column_data = np.stack(padded)
166
180
  else:
167
181
  column_data = np.asarray(column_data, dtype=np.int64)
@@ -170,43 +184,43 @@ class FileDataset(IterableDataset):
170
184
  tensor = torch.from_numpy(np.asarray(column_data, dtype=np.float32))
171
185
  else: # SparseFeature
172
186
  tensor = torch.from_numpy(np.asarray(column_data, dtype=np.int64))
173
-
187
+
174
188
  tensors.append(tensor)
175
-
189
+
176
190
  # Process targets
177
191
  target_tensors = []
178
192
  for target_name in self.target_columns:
179
193
  if self.processor is not None:
180
194
  target_data = transformed_data.get(target_name)
181
- if target_data is None:
195
+ if target_data is None:
182
196
  continue
183
197
  else:
184
198
  if target_name not in df.columns:
185
199
  continue
186
200
  target_data = df[target_name].values
187
-
201
+
188
202
  target_tensor = torch.from_numpy(np.asarray(target_data, dtype=np.float32))
189
-
203
+
190
204
  if target_tensor.dim() == 1:
191
205
  target_tensor = target_tensor.view(-1, 1)
192
-
206
+
193
207
  target_tensors.append(target_tensor)
194
-
208
+
195
209
  # Combine target tensors
196
210
  if target_tensors:
197
211
  if len(target_tensors) == 1 and target_tensors[0].shape[1] > 1:
198
212
  y_tensor = target_tensors[0]
199
213
  else:
200
214
  y_tensor = torch.cat(target_tensors, dim=1)
201
-
215
+
202
216
  if y_tensor.shape[1] == 1:
203
217
  y_tensor = y_tensor.squeeze(1)
204
-
218
+
205
219
  tensors.append(y_tensor)
206
-
220
+
207
221
  if not tensors:
208
222
  return None
209
-
223
+
210
224
  return tuple(tensors)
211
225
 
212
226
 
@@ -226,13 +240,15 @@ class RecDataLoader:
226
240
  >>> processor=processor
227
241
  >>> )
228
242
  """
229
-
230
- def __init__(self,
231
- dense_features: list[DenseFeature] | None = None,
232
- sparse_features: list[SparseFeature] | None = None,
233
- sequence_features: list[SequenceFeature] | None = None,
234
- target: list[str] | None | str = None,
235
- processor: Optional['DataProcessor'] = None):
243
+
244
+ def __init__(
245
+ self,
246
+ dense_features: list[DenseFeature] | None = None,
247
+ sparse_features: list[SparseFeature] | None = None,
248
+ sequence_features: list[SequenceFeature] | None = None,
249
+ target: list[str] | None | str = None,
250
+ processor: Optional["DataProcessor"] = None,
251
+ ):
236
252
 
237
253
  self.dense_features = dense_features if dense_features else []
238
254
  self.sparse_features = sparse_features if sparse_features else []
@@ -244,41 +260,48 @@ class RecDataLoader:
244
260
  else:
245
261
  self.target_columns = []
246
262
  self.processor = processor
247
-
248
- self.all_features = self.dense_features + self.sparse_features + self.sequence_features
249
-
250
- def create_dataloader(self,
251
- data: Union[dict, pd.DataFrame, str, DataLoader],
252
- batch_size: int = 32,
253
- shuffle: bool = True,
254
- load_full: bool = True,
255
- chunk_size: int = 10000) -> DataLoader:
263
+
264
+ self.all_features = (
265
+ self.dense_features + self.sparse_features + self.sequence_features
266
+ )
267
+
268
+ def create_dataloader(
269
+ self,
270
+ data: Union[dict, pd.DataFrame, str, DataLoader],
271
+ batch_size: int = 32,
272
+ shuffle: bool = True,
273
+ load_full: bool = True,
274
+ chunk_size: int = 10000,
275
+ ) -> DataLoader:
256
276
  """
257
277
  Create DataLoader from various data sources.
258
278
  """
259
279
  if isinstance(data, DataLoader):
260
280
  return data
261
-
281
+
262
282
  if isinstance(data, (str, os.PathLike)):
263
- return self._create_from_path(data, batch_size, shuffle, load_full, chunk_size)
264
-
283
+ return self._create_from_path(
284
+ data, batch_size, shuffle, load_full, chunk_size
285
+ )
286
+
265
287
  if isinstance(data, (dict, pd.DataFrame)):
266
288
  return self._create_from_memory(data, batch_size, shuffle)
267
289
 
268
290
  raise ValueError(f"Unsupported data type: {type(data)}")
269
-
270
- def _create_from_memory(self,
271
- data: Union[dict, pd.DataFrame],
272
- batch_size: int,
273
- shuffle: bool) -> DataLoader:
291
+
292
+ def _create_from_memory(
293
+ self, data: Union[dict, pd.DataFrame], batch_size: int, shuffle: bool
294
+ ) -> DataLoader:
274
295
 
275
296
  if self.processor is not None:
276
297
  if not self.processor.is_fitted:
277
- raise ValueError("DataProcessor must be fitted before using in RecDataLoader")
298
+ raise ValueError(
299
+ "DataProcessor must be fitted before using in RecDataLoader"
300
+ )
278
301
  data = self.processor.transform(data, return_dict=True)
279
-
302
+
280
303
  tensors = []
281
-
304
+
282
305
  # Process features
283
306
  for feature in self.all_features:
284
307
  column = get_column_data(data, feature.name)
@@ -288,97 +311,111 @@ class RecDataLoader:
288
311
  if isinstance(feature, SequenceFeature):
289
312
  if isinstance(column, pd.Series):
290
313
  column = column.values
291
-
314
+
292
315
  # Handle different input formats for sequence features
293
316
  if isinstance(column, np.ndarray):
294
317
  # Check if elements are actually sequences (not just object dtype scalars)
295
- if column.dtype == object and len(column) > 0 and isinstance(column[0], (list, tuple, np.ndarray)):
318
+ if (
319
+ column.dtype == object
320
+ and len(column) > 0
321
+ and isinstance(column[0], (list, tuple, np.ndarray))
322
+ ):
296
323
  # Each element is a sequence (array/list), stack them into 2D array
297
324
  try:
298
325
  column = np.stack([np.asarray(seq, dtype=np.int64) for seq in column]) # type: ignore
299
326
  except (ValueError, TypeError) as e:
300
327
  # Fallback: handle variable-length sequences by padding
301
328
  sequences = []
302
- max_len = feature.max_len if hasattr(feature, 'max_len') else 0
329
+ max_len = (
330
+ feature.max_len if hasattr(feature, "max_len") else 0
331
+ )
303
332
  for seq in column:
304
333
  if isinstance(seq, (list, tuple, np.ndarray)):
305
334
  seq_arr = np.asarray(seq, dtype=np.int64)
306
335
  else:
307
336
  seq_arr = np.array([], dtype=np.int64)
308
337
  sequences.append(seq_arr)
309
-
338
+
310
339
  # Pad sequences to same length
311
340
  if max_len == 0:
312
- max_len = max(len(seq) for seq in sequences) if sequences else 1
313
-
341
+ max_len = (
342
+ max(len(seq) for seq in sequences)
343
+ if sequences
344
+ else 1
345
+ )
346
+
314
347
  padded = []
315
348
  for seq in sequences:
316
349
  if len(seq) > max_len:
317
350
  padded.append(seq[:max_len])
318
351
  else:
319
352
  pad_width = max_len - len(seq)
320
- padded.append(np.pad(seq, (0, pad_width), constant_values=0))
353
+ padded.append(
354
+ np.pad(seq, (0, pad_width), constant_values=0)
355
+ )
321
356
  column = np.stack(padded)
322
357
  elif column.ndim == 1:
323
358
  # 1D array, need to reshape or handle appropriately
324
359
  # Assuming each element should be treated as a single-item sequence
325
360
  column = column.reshape(-1, 1)
326
361
  # else: already a 2D array
327
-
362
+
328
363
  column = np.asarray(column, dtype=np.int64)
329
364
  tensor = torch.from_numpy(column)
330
-
365
+
331
366
  elif isinstance(feature, DenseFeature):
332
367
  tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
333
368
  else: # SparseFeature
334
369
  tensor = torch.from_numpy(np.asarray(column, dtype=np.int64))
335
-
370
+
336
371
  tensors.append(tensor)
337
-
372
+
338
373
  # Process targets
339
374
  label_tensors = []
340
375
  for target_name in self.target_columns:
341
376
  column = get_column_data(data, target_name)
342
377
  if column is None:
343
378
  continue
344
-
379
+
345
380
  label_tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
346
-
381
+
347
382
  if label_tensor.dim() == 1:
348
383
  label_tensor = label_tensor.view(-1, 1)
349
384
  elif label_tensor.dim() == 2:
350
385
  if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
351
386
  label_tensor = label_tensor.t()
352
-
387
+
353
388
  label_tensors.append(label_tensor)
354
-
389
+
355
390
  # Combine target tensors
356
391
  if label_tensors:
357
392
  if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
358
393
  y_tensor = label_tensors[0]
359
394
  else:
360
395
  y_tensor = torch.cat(label_tensors, dim=1)
361
-
396
+
362
397
  if y_tensor.shape[1] == 1:
363
398
  y_tensor = y_tensor.squeeze(1)
364
-
399
+
365
400
  tensors.append(y_tensor)
366
-
401
+
367
402
  dataset = TensorDataset(*tensors)
368
403
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
369
-
370
- def _create_from_path(self,
371
- path: str,
372
- batch_size: int,
373
- shuffle: bool,
374
- load_full: bool,
375
- chunk_size: int) -> DataLoader:
404
+
405
+ def _create_from_path(
406
+ self,
407
+ path: str,
408
+ batch_size: int,
409
+ shuffle: bool,
410
+ load_full: bool,
411
+ chunk_size: int,
412
+ ) -> DataLoader:
376
413
  """
377
414
  Create DataLoader from a file path, supporting CSV and Parquet formats, with options for full loading or streaming.
378
415
  """
379
416
 
380
417
  path_obj = Path(path)
381
-
418
+
382
419
  # Determine if it's a file or directory
383
420
  if path_obj.is_file():
384
421
  file_paths = [str(path_obj)]
@@ -387,41 +424,46 @@ class RecDataLoader:
387
424
  # Find all CSV and Parquet files in directory
388
425
  csv_files = glob.glob(os.path.join(path, "*.csv"))
389
426
  parquet_files = glob.glob(os.path.join(path, "*.parquet"))
390
-
427
+
391
428
  if csv_files and parquet_files:
392
- raise ValueError("Directory contains both CSV and Parquet files. Please use a single format.")
393
-
429
+ raise ValueError(
430
+ "Directory contains both CSV and Parquet files. Please use a single format."
431
+ )
432
+
394
433
  file_paths = csv_files if csv_files else parquet_files
395
-
434
+
396
435
  if not file_paths:
397
436
  raise ValueError(f"No CSV or Parquet files found in directory: {path}")
398
-
399
- file_type = 'csv' if csv_files else 'parquet'
437
+
438
+ file_type = "csv" if csv_files else "parquet"
400
439
  file_paths.sort() # Sort for consistent ordering
401
440
  else:
402
441
  raise ValueError(f"Invalid path: {path}")
403
-
442
+
404
443
  # Load full data into memory or use streaming
405
444
  if load_full:
406
445
  dfs = []
407
446
  for file_path in file_paths:
408
- if file_type == 'csv':
447
+ if file_type == "csv":
409
448
  df = pd.read_csv(file_path)
410
449
  else: # parquet
411
450
  df = pd.read_parquet(file_path)
412
451
  dfs.append(df)
413
-
452
+
414
453
  combined_df = pd.concat(dfs, ignore_index=True)
415
454
  return self._create_from_memory(combined_df, batch_size, shuffle)
416
455
  else:
417
- return self._load_files_streaming(file_paths, file_type, batch_size, chunk_size)
418
-
419
-
420
- def _load_files_streaming(self,
421
- file_paths: list[str],
422
- file_type: Literal['csv', 'parquet'],
423
- batch_size: int,
424
- chunk_size: int) -> DataLoader:
456
+ return self._load_files_streaming(
457
+ file_paths, file_type, batch_size, chunk_size
458
+ )
459
+
460
+ def _load_files_streaming(
461
+ self,
462
+ file_paths: list[str],
463
+ file_type: Literal["csv", "parquet"],
464
+ batch_size: int,
465
+ chunk_size: int,
466
+ ) -> DataLoader:
425
467
  # Create FileDataset for streaming
426
468
  dataset = FileDataset(
427
469
  file_paths=file_paths,
@@ -431,17 +473,16 @@ class RecDataLoader:
431
473
  target_columns=self.target_columns,
432
474
  chunk_size=chunk_size,
433
475
  file_type=file_type,
434
- processor=self.processor
476
+ processor=self.processor,
435
477
  )
436
-
478
+
437
479
  return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
438
-
439
- def _get_file_type(self, file_path: str) -> Literal['csv', 'parquet']:
480
+
481
+ def _get_file_type(self, file_path: str) -> Literal["csv", "parquet"]:
440
482
  ext = os.path.splitext(file_path)[1].lower()
441
- if ext == '.csv':
442
- return 'csv'
443
- elif ext == '.parquet':
444
- return 'parquet'
483
+ if ext == ".csv":
484
+ return "csv"
485
+ elif ext == ".parquet":
486
+ return "parquet"
445
487
  else:
446
488
  raise ValueError(f"Unsupported file type: {ext}")
447
-