nextrec 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
  45. nextrec-0.1.7.dist-info/RECORD +51 -0
  46. nextrec-0.1.3.dist-info/RECORD +0 -51
  47. {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.3.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/__init__.py CHANGED
@@ -12,18 +12,18 @@ Quick Start
12
12
  -----------
13
13
  >>> from nextrec.basic.features import DenseFeature, SparseFeature
14
14
  >>> from nextrec.models.ranking.deepfm import DeepFM
15
- >>>
15
+ >>>
16
16
  >>> # Define features
17
17
  >>> dense_features = [DenseFeature('age')]
18
18
  >>> sparse_features = [SparseFeature('category', vocab_size=100, embedding_dim=16)]
19
- >>>
19
+ >>>
20
20
  >>> # Build model
21
21
  >>> model = DeepFM(
22
22
  ... dense_features=dense_features,
23
23
  ... sparse_features=sparse_features,
24
24
  ... targets=['label']
25
25
  ... )
26
- >>>
26
+ >>>
27
27
  >>> # Train model
28
28
  >>> model.fit(train_data=df_train, valid_data=df_valid)
29
29
  """
@@ -31,7 +31,7 @@ Quick Start
31
31
  from nextrec.__version__ import __version__
32
32
 
33
33
  __all__ = [
34
- "__version__",
34
+ '__version__',
35
35
  ]
36
36
 
37
37
  # Package metadata
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.3"
1
+ __version__ = "0.1.7"
@@ -14,41 +14,40 @@ class Dice(nn.Module):
14
14
  """
15
15
  Dice activation function from the paper:
16
16
  "Deep Interest Network for Click-Through Rate Prediction" (Zhou et al., 2018)
17
-
17
+
18
18
  Dice(x) = p(x) * x + (1 - p(x)) * alpha * x
19
19
  where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
20
20
  """
21
-
22
21
  def __init__(self, emb_size: int, epsilon: float = 1e-9):
23
22
  super(Dice, self).__init__()
24
23
  self.epsilon = epsilon
25
24
  self.alpha = nn.Parameter(torch.zeros(emb_size))
26
25
  self.bn = nn.BatchNorm1d(emb_size)
27
-
26
+
28
27
  def forward(self, x):
29
28
  # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
30
29
  original_shape = x.shape
31
-
30
+
32
31
  if x.dim() == 3:
33
32
  # For 3D input (batch_size, seq_len, emb_size), reshape to 2D
34
33
  batch_size, seq_len, emb_size = x.shape
35
34
  x = x.view(-1, emb_size)
36
-
35
+
37
36
  x_norm = self.bn(x)
38
37
  p = torch.sigmoid(x_norm)
39
38
  output = p * x + (1 - p) * self.alpha * x
40
-
39
+
41
40
  if len(original_shape) == 3:
42
41
  output = output.view(original_shape)
43
-
42
+
44
43
  return output
45
44
 
46
45
 
47
46
  def activation_layer(activation: str, emb_size: int | None = None):
48
47
  """Create an activation layer based on the given activation name."""
49
-
48
+
50
49
  activation = activation.lower()
51
-
50
+
52
51
  if activation == "dice":
53
52
  if emb_size is None:
54
53
  raise ValueError("emb_size is required for Dice activation")
@@ -90,4 +89,4 @@ def activation_layer(activation: str, emb_size: int | None = None):
90
89
  elif activation in ["none", "linear", "identity"]:
91
90
  return nn.Identity()
92
91
  else:
93
- raise ValueError(f"Unsupported activation function: {activation}")
92
+ raise ValueError(f"Unsupported activation function: {activation}")
nextrec/basic/callback.py CHANGED
@@ -8,7 +8,6 @@ Author:
8
8
 
9
9
  import copy
10
10
 
11
-
12
11
  class EarlyStopper(object):
13
12
  def __init__(self, patience: int = 20, mode: str = "max"):
14
13
  self.patience = patience
@@ -30,18 +30,16 @@ class FileDataset(IterableDataset):
30
30
  Iterable dataset for reading multiple files in batches.
31
31
  Supports CSV and Parquet files with chunk-based reading.
32
32
  """
33
-
34
- def __init__(
35
- self,
36
- file_paths: list[str], # file paths to read, containing CSV or Parquet files
37
- dense_features: list[DenseFeature], # dense feature definitions
38
- sparse_features: list[SparseFeature], # sparse feature definitions
39
- sequence_features: list[SequenceFeature], # sequence feature definitions
40
- target_columns: list[str], # target column names
41
- chunk_size: int = 10000,
42
- file_type: Literal["csv", "parquet"] = "csv",
43
- processor: Optional["DataProcessor"] = None,
44
- ): # optional DataProcessor for transformation
33
+
34
+ def __init__(self,
35
+ file_paths: list[str], # file paths to read, containing CSV or Parquet files
36
+ dense_features: list[DenseFeature], # dense feature definitions
37
+ sparse_features: list[SparseFeature], # sparse feature definitions
38
+ sequence_features: list[SequenceFeature], # sequence feature definitions
39
+ target_columns: list[str], # target column names
40
+ chunk_size: int = 10000,
41
+ file_type: Literal['csv', 'parquet'] = 'csv',
42
+ processor: Optional['DataProcessor'] = None): # optional DataProcessor for transformation
45
43
 
46
44
  self.file_paths = file_paths
47
45
  self.dense_features = dense_features
@@ -51,30 +49,30 @@ class FileDataset(IterableDataset):
51
49
  self.chunk_size = chunk_size
52
50
  self.file_type = file_type
53
51
  self.processor = processor
54
-
52
+
55
53
  self.all_features = dense_features + sparse_features + sequence_features
56
54
  self.feature_names = [f.name for f in self.all_features]
57
55
  self.current_file_index = 0
58
56
  self.total_files = len(file_paths)
59
-
57
+
60
58
  def __iter__(self) -> Iterator[tuple]:
61
59
  self.current_file_index = 0
62
60
  self._file_pbar = None
63
-
61
+
64
62
  # Create progress bar for file processing when multiple files
65
63
  if self.total_files > 1:
66
64
  self._file_pbar = tqdm.tqdm(
67
- total=self.total_files,
68
- desc="Files",
65
+ total=self.total_files,
66
+ desc="Files",
69
67
  unit="file",
70
68
  position=0,
71
69
  leave=True,
72
- bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
70
+ bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
73
71
  )
74
-
72
+
75
73
  for file_path in self.file_paths:
76
74
  self.current_file_index += 1
77
-
75
+
78
76
  # Update file progress bar
79
77
  if self._file_pbar is not None:
80
78
  self._file_pbar.update(1)
@@ -82,51 +80,48 @@ class FileDataset(IterableDataset):
82
80
  # For single file, log the file name
83
81
  file_name = os.path.basename(file_path)
84
82
  logging.info(colorize(f"Processing file: {file_name}", color="cyan"))
85
-
86
- if self.file_type == "csv":
83
+
84
+ if self.file_type == 'csv':
87
85
  yield from self._read_csv_chunks(file_path)
88
- elif self.file_type == "parquet":
86
+ elif self.file_type == 'parquet':
89
87
  yield from self._read_parquet_chunks(file_path)
90
-
88
+
91
89
  # Close file progress bar
92
90
  if self._file_pbar is not None:
93
91
  self._file_pbar.close()
94
-
92
+
95
93
  def _read_csv_chunks(self, file_path: str) -> Iterator[tuple]:
96
94
  chunk_iterator = pd.read_csv(file_path, chunksize=self.chunk_size)
97
-
95
+
98
96
  for chunk in chunk_iterator:
99
97
  tensors = self._dataframe_to_tensors(chunk)
100
98
  if tensors:
101
99
  yield tensors
102
-
100
+
103
101
  def _read_parquet_chunks(self, file_path: str) -> Iterator[tuple]:
104
102
  """
105
103
  Read parquet file in chunks to reduce memory footprint.
106
104
  Uses pyarrow's batch reading for true streaming.
107
105
  """
108
106
  import pyarrow.parquet as pq
109
-
110
107
  parquet_file = pq.ParquetFile(file_path)
111
108
  for batch in parquet_file.iter_batches(batch_size=self.chunk_size):
112
- chunk = batch.to_pandas()
109
+ chunk = batch.to_pandas()
113
110
  tensors = self._dataframe_to_tensors(chunk)
114
111
  if tensors:
115
112
  yield tensors
116
113
  del chunk
117
-
114
+
118
115
  def _dataframe_to_tensors(self, df: pd.DataFrame) -> tuple | None:
119
116
  if self.processor is not None:
120
117
  if not self.processor.is_fitted:
121
- raise ValueError(
122
- "DataProcessor must be fitted before using in streaming mode"
123
- )
118
+ raise ValueError("DataProcessor must be fitted before using in streaming mode")
124
119
  transformed_data = self.processor.transform(df, return_dict=True)
125
120
  else:
126
121
  transformed_data = df
127
-
122
+
128
123
  tensors = []
129
-
124
+
130
125
  # Process features
131
126
  for feature in self.all_features:
132
127
  if self.processor is not None:
@@ -136,15 +131,10 @@ class FileDataset(IterableDataset):
136
131
  else:
137
132
  # Get data from original dataframe
138
133
  if feature.name not in df.columns:
139
- logging.warning(
140
- colorize(
141
- f"Feature column '{feature.name}' not found in DataFrame",
142
- "yellow",
143
- )
144
- )
134
+ logging.warning(colorize(f"Feature column '{feature.name}' not found in DataFrame", "yellow"))
145
135
  continue
146
136
  column_data = df[feature.name].values
147
-
137
+
148
138
  # Handle sequence features: convert to 2D array of shape (batch_size, seq_length)
149
139
  if isinstance(feature, SequenceFeature):
150
140
  if isinstance(column_data, np.ndarray) and column_data.dtype == object:
@@ -153,29 +143,25 @@ class FileDataset(IterableDataset):
153
143
  except (ValueError, TypeError) as e:
154
144
  # Fallback: handle variable-length sequences by padding
155
145
  sequences = []
156
- max_len = feature.max_len if hasattr(feature, "max_len") else 0
146
+ max_len = feature.max_len if hasattr(feature, 'max_len') else 0
157
147
  for seq in column_data:
158
148
  if isinstance(seq, (list, tuple, np.ndarray)):
159
149
  seq_arr = np.asarray(seq, dtype=np.int64)
160
150
  else:
161
151
  seq_arr = np.array([], dtype=np.int64)
162
152
  sequences.append(seq_arr)
163
-
153
+
164
154
  # Pad sequences to same length
165
155
  if max_len == 0:
166
- max_len = (
167
- max(len(seq) for seq in sequences) if sequences else 1
168
- )
169
-
156
+ max_len = max(len(seq) for seq in sequences) if sequences else 1
157
+
170
158
  padded = []
171
159
  for seq in sequences:
172
160
  if len(seq) > max_len:
173
161
  padded.append(seq[:max_len])
174
162
  else:
175
163
  pad_width = max_len - len(seq)
176
- padded.append(
177
- np.pad(seq, (0, pad_width), constant_values=0)
178
- )
164
+ padded.append(np.pad(seq, (0, pad_width), constant_values=0))
179
165
  column_data = np.stack(padded)
180
166
  else:
181
167
  column_data = np.asarray(column_data, dtype=np.int64)
@@ -184,43 +170,43 @@ class FileDataset(IterableDataset):
184
170
  tensor = torch.from_numpy(np.asarray(column_data, dtype=np.float32))
185
171
  else: # SparseFeature
186
172
  tensor = torch.from_numpy(np.asarray(column_data, dtype=np.int64))
187
-
173
+
188
174
  tensors.append(tensor)
189
-
175
+
190
176
  # Process targets
191
177
  target_tensors = []
192
178
  for target_name in self.target_columns:
193
179
  if self.processor is not None:
194
180
  target_data = transformed_data.get(target_name)
195
- if target_data is None:
181
+ if target_data is None:
196
182
  continue
197
183
  else:
198
184
  if target_name not in df.columns:
199
185
  continue
200
186
  target_data = df[target_name].values
201
-
187
+
202
188
  target_tensor = torch.from_numpy(np.asarray(target_data, dtype=np.float32))
203
-
189
+
204
190
  if target_tensor.dim() == 1:
205
191
  target_tensor = target_tensor.view(-1, 1)
206
-
192
+
207
193
  target_tensors.append(target_tensor)
208
-
194
+
209
195
  # Combine target tensors
210
196
  if target_tensors:
211
197
  if len(target_tensors) == 1 and target_tensors[0].shape[1] > 1:
212
198
  y_tensor = target_tensors[0]
213
199
  else:
214
200
  y_tensor = torch.cat(target_tensors, dim=1)
215
-
201
+
216
202
  if y_tensor.shape[1] == 1:
217
203
  y_tensor = y_tensor.squeeze(1)
218
-
204
+
219
205
  tensors.append(y_tensor)
220
-
206
+
221
207
  if not tensors:
222
208
  return None
223
-
209
+
224
210
  return tuple(tensors)
225
211
 
226
212
 
@@ -240,15 +226,13 @@ class RecDataLoader:
240
226
  >>> processor=processor
241
227
  >>> )
242
228
  """
243
-
244
- def __init__(
245
- self,
246
- dense_features: list[DenseFeature] | None = None,
247
- sparse_features: list[SparseFeature] | None = None,
248
- sequence_features: list[SequenceFeature] | None = None,
249
- target: list[str] | None | str = None,
250
- processor: Optional["DataProcessor"] = None,
251
- ):
229
+
230
+ def __init__(self,
231
+ dense_features: list[DenseFeature] | None = None,
232
+ sparse_features: list[SparseFeature] | None = None,
233
+ sequence_features: list[SequenceFeature] | None = None,
234
+ target: list[str] | None | str = None,
235
+ processor: Optional['DataProcessor'] = None):
252
236
 
253
237
  self.dense_features = dense_features if dense_features else []
254
238
  self.sparse_features = sparse_features if sparse_features else []
@@ -260,48 +244,41 @@ class RecDataLoader:
260
244
  else:
261
245
  self.target_columns = []
262
246
  self.processor = processor
263
-
264
- self.all_features = (
265
- self.dense_features + self.sparse_features + self.sequence_features
266
- )
267
-
268
- def create_dataloader(
269
- self,
270
- data: Union[dict, pd.DataFrame, str, DataLoader],
271
- batch_size: int = 32,
272
- shuffle: bool = True,
273
- load_full: bool = True,
274
- chunk_size: int = 10000,
275
- ) -> DataLoader:
247
+
248
+ self.all_features = self.dense_features + self.sparse_features + self.sequence_features
249
+
250
+ def create_dataloader(self,
251
+ data: Union[dict, pd.DataFrame, str, DataLoader],
252
+ batch_size: int = 32,
253
+ shuffle: bool = True,
254
+ load_full: bool = True,
255
+ chunk_size: int = 10000) -> DataLoader:
276
256
  """
277
257
  Create DataLoader from various data sources.
278
258
  """
279
259
  if isinstance(data, DataLoader):
280
260
  return data
281
-
261
+
282
262
  if isinstance(data, (str, os.PathLike)):
283
- return self._create_from_path(
284
- data, batch_size, shuffle, load_full, chunk_size
285
- )
286
-
263
+ return self._create_from_path(data, batch_size, shuffle, load_full, chunk_size)
264
+
287
265
  if isinstance(data, (dict, pd.DataFrame)):
288
266
  return self._create_from_memory(data, batch_size, shuffle)
289
267
 
290
268
  raise ValueError(f"Unsupported data type: {type(data)}")
291
-
292
- def _create_from_memory(
293
- self, data: Union[dict, pd.DataFrame], batch_size: int, shuffle: bool
294
- ) -> DataLoader:
269
+
270
+ def _create_from_memory(self,
271
+ data: Union[dict, pd.DataFrame],
272
+ batch_size: int,
273
+ shuffle: bool) -> DataLoader:
295
274
 
296
275
  if self.processor is not None:
297
276
  if not self.processor.is_fitted:
298
- raise ValueError(
299
- "DataProcessor must be fitted before using in RecDataLoader"
300
- )
277
+ raise ValueError("DataProcessor must be fitted before using in RecDataLoader")
301
278
  data = self.processor.transform(data, return_dict=True)
302
-
279
+
303
280
  tensors = []
304
-
281
+
305
282
  # Process features
306
283
  for feature in self.all_features:
307
284
  column = get_column_data(data, feature.name)
@@ -311,111 +288,97 @@ class RecDataLoader:
311
288
  if isinstance(feature, SequenceFeature):
312
289
  if isinstance(column, pd.Series):
313
290
  column = column.values
314
-
291
+
315
292
  # Handle different input formats for sequence features
316
293
  if isinstance(column, np.ndarray):
317
294
  # Check if elements are actually sequences (not just object dtype scalars)
318
- if (
319
- column.dtype == object
320
- and len(column) > 0
321
- and isinstance(column[0], (list, tuple, np.ndarray))
322
- ):
295
+ if column.dtype == object and len(column) > 0 and isinstance(column[0], (list, tuple, np.ndarray)):
323
296
  # Each element is a sequence (array/list), stack them into 2D array
324
297
  try:
325
298
  column = np.stack([np.asarray(seq, dtype=np.int64) for seq in column]) # type: ignore
326
299
  except (ValueError, TypeError) as e:
327
300
  # Fallback: handle variable-length sequences by padding
328
301
  sequences = []
329
- max_len = (
330
- feature.max_len if hasattr(feature, "max_len") else 0
331
- )
302
+ max_len = feature.max_len if hasattr(feature, 'max_len') else 0
332
303
  for seq in column:
333
304
  if isinstance(seq, (list, tuple, np.ndarray)):
334
305
  seq_arr = np.asarray(seq, dtype=np.int64)
335
306
  else:
336
307
  seq_arr = np.array([], dtype=np.int64)
337
308
  sequences.append(seq_arr)
338
-
309
+
339
310
  # Pad sequences to same length
340
311
  if max_len == 0:
341
- max_len = (
342
- max(len(seq) for seq in sequences)
343
- if sequences
344
- else 1
345
- )
346
-
312
+ max_len = max(len(seq) for seq in sequences) if sequences else 1
313
+
347
314
  padded = []
348
315
  for seq in sequences:
349
316
  if len(seq) > max_len:
350
317
  padded.append(seq[:max_len])
351
318
  else:
352
319
  pad_width = max_len - len(seq)
353
- padded.append(
354
- np.pad(seq, (0, pad_width), constant_values=0)
355
- )
320
+ padded.append(np.pad(seq, (0, pad_width), constant_values=0))
356
321
  column = np.stack(padded)
357
322
  elif column.ndim == 1:
358
323
  # 1D array, need to reshape or handle appropriately
359
324
  # Assuming each element should be treated as a single-item sequence
360
325
  column = column.reshape(-1, 1)
361
326
  # else: already a 2D array
362
-
327
+
363
328
  column = np.asarray(column, dtype=np.int64)
364
329
  tensor = torch.from_numpy(column)
365
-
330
+
366
331
  elif isinstance(feature, DenseFeature):
367
332
  tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
368
333
  else: # SparseFeature
369
334
  tensor = torch.from_numpy(np.asarray(column, dtype=np.int64))
370
-
335
+
371
336
  tensors.append(tensor)
372
-
337
+
373
338
  # Process targets
374
339
  label_tensors = []
375
340
  for target_name in self.target_columns:
376
341
  column = get_column_data(data, target_name)
377
342
  if column is None:
378
343
  continue
379
-
344
+
380
345
  label_tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
381
-
346
+
382
347
  if label_tensor.dim() == 1:
383
348
  label_tensor = label_tensor.view(-1, 1)
384
349
  elif label_tensor.dim() == 2:
385
350
  if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
386
351
  label_tensor = label_tensor.t()
387
-
352
+
388
353
  label_tensors.append(label_tensor)
389
-
354
+
390
355
  # Combine target tensors
391
356
  if label_tensors:
392
357
  if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
393
358
  y_tensor = label_tensors[0]
394
359
  else:
395
360
  y_tensor = torch.cat(label_tensors, dim=1)
396
-
361
+
397
362
  if y_tensor.shape[1] == 1:
398
363
  y_tensor = y_tensor.squeeze(1)
399
-
364
+
400
365
  tensors.append(y_tensor)
401
-
366
+
402
367
  dataset = TensorDataset(*tensors)
403
368
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
404
-
405
- def _create_from_path(
406
- self,
407
- path: str,
408
- batch_size: int,
409
- shuffle: bool,
410
- load_full: bool,
411
- chunk_size: int,
412
- ) -> DataLoader:
369
+
370
+ def _create_from_path(self,
371
+ path: str,
372
+ batch_size: int,
373
+ shuffle: bool,
374
+ load_full: bool,
375
+ chunk_size: int) -> DataLoader:
413
376
  """
414
377
  Create DataLoader from a file path, supporting CSV and Parquet formats, with options for full loading or streaming.
415
378
  """
416
379
 
417
380
  path_obj = Path(path)
418
-
381
+
419
382
  # Determine if it's a file or directory
420
383
  if path_obj.is_file():
421
384
  file_paths = [str(path_obj)]
@@ -424,46 +387,41 @@ class RecDataLoader:
424
387
  # Find all CSV and Parquet files in directory
425
388
  csv_files = glob.glob(os.path.join(path, "*.csv"))
426
389
  parquet_files = glob.glob(os.path.join(path, "*.parquet"))
427
-
390
+
428
391
  if csv_files and parquet_files:
429
- raise ValueError(
430
- "Directory contains both CSV and Parquet files. Please use a single format."
431
- )
432
-
392
+ raise ValueError("Directory contains both CSV and Parquet files. Please use a single format.")
393
+
433
394
  file_paths = csv_files if csv_files else parquet_files
434
-
395
+
435
396
  if not file_paths:
436
397
  raise ValueError(f"No CSV or Parquet files found in directory: {path}")
437
-
438
- file_type = "csv" if csv_files else "parquet"
398
+
399
+ file_type = 'csv' if csv_files else 'parquet'
439
400
  file_paths.sort() # Sort for consistent ordering
440
401
  else:
441
402
  raise ValueError(f"Invalid path: {path}")
442
-
403
+
443
404
  # Load full data into memory or use streaming
444
405
  if load_full:
445
406
  dfs = []
446
407
  for file_path in file_paths:
447
- if file_type == "csv":
408
+ if file_type == 'csv':
448
409
  df = pd.read_csv(file_path)
449
410
  else: # parquet
450
411
  df = pd.read_parquet(file_path)
451
412
  dfs.append(df)
452
-
413
+
453
414
  combined_df = pd.concat(dfs, ignore_index=True)
454
415
  return self._create_from_memory(combined_df, batch_size, shuffle)
455
416
  else:
456
- return self._load_files_streaming(
457
- file_paths, file_type, batch_size, chunk_size
458
- )
459
-
460
- def _load_files_streaming(
461
- self,
462
- file_paths: list[str],
463
- file_type: Literal["csv", "parquet"],
464
- batch_size: int,
465
- chunk_size: int,
466
- ) -> DataLoader:
417
+ return self._load_files_streaming(file_paths, file_type, batch_size, chunk_size)
418
+
419
+
420
+ def _load_files_streaming(self,
421
+ file_paths: list[str],
422
+ file_type: Literal['csv', 'parquet'],
423
+ batch_size: int,
424
+ chunk_size: int) -> DataLoader:
467
425
  # Create FileDataset for streaming
468
426
  dataset = FileDataset(
469
427
  file_paths=file_paths,
@@ -473,16 +431,17 @@ class RecDataLoader:
473
431
  target_columns=self.target_columns,
474
432
  chunk_size=chunk_size,
475
433
  file_type=file_type,
476
- processor=self.processor,
434
+ processor=self.processor
477
435
  )
478
-
436
+
479
437
  return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
480
-
481
- def _get_file_type(self, file_path: str) -> Literal["csv", "parquet"]:
438
+
439
+ def _get_file_type(self, file_path: str) -> Literal['csv', 'parquet']:
482
440
  ext = os.path.splitext(file_path)[1].lower()
483
- if ext == ".csv":
484
- return "csv"
485
- elif ext == ".parquet":
486
- return "parquet"
441
+ if ext == '.csv':
442
+ return 'csv'
443
+ elif ext == '.parquet':
444
+ return 'parquet'
487
445
  else:
488
446
  raise ValueError(f"Unsupported file type: {ext}")
447
+