nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
nextrec/__init__.py ADDED
@@ -0,0 +1,41 @@
1
+ """
2
+ NextRec - A Unified Deep Learning Framework for Recommender Systems
3
+ ===================================================================
4
+
5
+ NextRec provides a comprehensive suite of recommendation models including:
6
+ - Ranking models (CTR prediction)
7
+ - Matching models (retrieval)
8
+ - Multi-task learning models
9
+ - Generative recommendation models
10
+
11
+ Quick Start
12
+ -----------
13
+ >>> from nextrec.basic.features import DenseFeature, SparseFeature
14
+ >>> from nextrec.models.ranking.deepfm import DeepFM
15
+ >>>
16
+ >>> # Define features
17
+ >>> dense_features = [DenseFeature('age')]
18
+ >>> sparse_features = [SparseFeature('category', vocab_size=100, embedding_dim=16)]
19
+ >>>
20
+ >>> # Build model
21
+ >>> model = DeepFM(
22
+ ... dense_features=dense_features,
23
+ ... sparse_features=sparse_features,
24
+ ... targets=['label']
25
+ ... )
26
+ >>>
27
+ >>> # Train model
28
+ >>> model.fit(train_data=df_train, valid_data=df_valid)
29
+ """
30
+
31
+ from nextrec.__version__ import __version__
32
+
33
+ __all__ = [
34
+ '__version__',
35
+ ]
36
+
37
+ # Package metadata
38
+ __author__ = "zerolovesea"
39
+ __email__ = "zyaztec@gmail.com"
40
+ __license__ = "Apache 2.0"
41
+ __url__ = "https://github.com/zerolovesea/NextRec"
nextrec/__version__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.1"
File without changes
@@ -0,0 +1,92 @@
1
+ """
2
+ Activation function definitions
3
+
4
+ Date: create on 27/10/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class Dice(nn.Module):
14
+ """
15
+ Dice activation function from the paper:
16
+ "Deep Interest Network for Click-Through Rate Prediction" (Zhou et al., 2018)
17
+
18
+ Dice(x) = p(x) * x + (1 - p(x)) * alpha * x
19
+ where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
20
+ """
21
+ def __init__(self, emb_size: int, epsilon: float = 1e-9):
22
+ super(Dice, self).__init__()
23
+ self.epsilon = epsilon
24
+ self.alpha = nn.Parameter(torch.zeros(emb_size))
25
+ self.bn = nn.BatchNorm1d(emb_size)
26
+
27
+ def forward(self, x):
28
+ # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
29
+ original_shape = x.shape
30
+
31
+ if x.dim() == 3:
32
+ # For 3D input (batch_size, seq_len, emb_size), reshape to 2D
33
+ batch_size, seq_len, emb_size = x.shape
34
+ x = x.view(-1, emb_size)
35
+
36
+ x_norm = self.bn(x)
37
+ p = torch.sigmoid(x_norm)
38
+ output = p * x + (1 - p) * self.alpha * x
39
+
40
+ if len(original_shape) == 3:
41
+ output = output.view(original_shape)
42
+
43
+ return output
44
+
45
+
46
+ def activation_layer(activation: str, emb_size: int | None = None):
47
+ """Create an activation layer based on the given activation name."""
48
+
49
+ activation = activation.lower()
50
+
51
+ if activation == "dice":
52
+ if emb_size is None:
53
+ raise ValueError("emb_size is required for Dice activation")
54
+ return Dice(emb_size)
55
+ elif activation == "relu":
56
+ return nn.ReLU()
57
+ elif activation == "relu6":
58
+ return nn.ReLU6()
59
+ elif activation == "elu":
60
+ return nn.ELU()
61
+ elif activation == "selu":
62
+ return nn.SELU()
63
+ elif activation == "leaky_relu":
64
+ return nn.LeakyReLU()
65
+ elif activation == "prelu":
66
+ return nn.PReLU()
67
+ elif activation == "gelu":
68
+ return nn.GELU()
69
+ elif activation == "sigmoid":
70
+ return nn.Sigmoid()
71
+ elif activation == "tanh":
72
+ return nn.Tanh()
73
+ elif activation == "softplus":
74
+ return nn.Softplus()
75
+ elif activation == "softsign":
76
+ return nn.Softsign()
77
+ elif activation == "hardswish":
78
+ return nn.Hardswish()
79
+ elif activation == "mish":
80
+ return nn.Mish()
81
+ elif activation in ["silu", "swish"]:
82
+ return nn.SiLU()
83
+ elif activation == "hardsigmoid":
84
+ return nn.Hardsigmoid()
85
+ elif activation == "tanhshrink":
86
+ return nn.Tanhshrink()
87
+ elif activation == "softshrink":
88
+ return nn.Softshrink()
89
+ elif activation in ["none", "linear", "identity"]:
90
+ return nn.Identity()
91
+ else:
92
+ raise ValueError(f"Unsupported activation function: {activation}")
@@ -0,0 +1,35 @@
1
+ """
2
+ EarlyStopper definitions
3
+
4
+ Date: create on 27/10/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ import copy
10
+
11
+ class EarlyStopper(object):
12
+ def __init__(self, patience: int = 20, mode: str = "max"):
13
+ self.patience = patience
14
+ self.trial_counter = 0
15
+ self.best_metrics = 0
16
+ self.best_weights = None
17
+ self.mode = mode
18
+
19
+ def stop_training(self, val_metrics, weights):
20
+ if self.mode == "max":
21
+ if val_metrics > self.best_metrics:
22
+ self.best_metrics = val_metrics
23
+ self.trial_counter = 0
24
+ self.best_weights = copy.deepcopy(weights)
25
+ elif self.mode == "min":
26
+ if val_metrics < self.best_metrics:
27
+ self.best_metrics = val_metrics
28
+ self.trial_counter = 0
29
+ self.best_weights = copy.deepcopy(weights)
30
+ return False
31
+ elif self.trial_counter + 1 < self.patience:
32
+ self.trial_counter += 1
33
+ return False
34
+ else:
35
+ return True
@@ -0,0 +1,447 @@
1
+ """
2
+ Dataloader definitions
3
+
4
+ Date: create on 27/10/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import logging
13
+ import numpy as np
14
+ import pandas as pd
15
+ import tqdm
16
+
17
+ from pathlib import Path
18
+ from typing import Iterator, Literal, Union, Optional
19
+ from torch.utils.data import DataLoader, TensorDataset, IterableDataset
20
+
21
+ from nextrec.data.preprocessor import DataProcessor
22
+ from nextrec.data import get_column_data, collate_fn
23
+
24
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
25
+ from nextrec.basic.loggers import colorize
26
+
27
+
28
+ class FileDataset(IterableDataset):
29
+ """
30
+ Iterable dataset for reading multiple files in batches.
31
+ Supports CSV and Parquet files with chunk-based reading.
32
+ """
33
+
34
+ def __init__(self,
35
+ file_paths: list[str], # file paths to read, containing CSV or Parquet files
36
+ dense_features: list[DenseFeature], # dense feature definitions
37
+ sparse_features: list[SparseFeature], # sparse feature definitions
38
+ sequence_features: list[SequenceFeature], # sequence feature definitions
39
+ target_columns: list[str], # target column names
40
+ chunk_size: int = 10000,
41
+ file_type: Literal['csv', 'parquet'] = 'csv',
42
+ processor: Optional['DataProcessor'] = None): # optional DataProcessor for transformation
43
+
44
+ self.file_paths = file_paths
45
+ self.dense_features = dense_features
46
+ self.sparse_features = sparse_features
47
+ self.sequence_features = sequence_features
48
+ self.target_columns = target_columns
49
+ self.chunk_size = chunk_size
50
+ self.file_type = file_type
51
+ self.processor = processor
52
+
53
+ self.all_features = dense_features + sparse_features + sequence_features
54
+ self.feature_names = [f.name for f in self.all_features]
55
+ self.current_file_index = 0
56
+ self.total_files = len(file_paths)
57
+
58
+ def __iter__(self) -> Iterator[tuple]:
59
+ self.current_file_index = 0
60
+ self._file_pbar = None
61
+
62
+ # Create progress bar for file processing when multiple files
63
+ if self.total_files > 1:
64
+ self._file_pbar = tqdm.tqdm(
65
+ total=self.total_files,
66
+ desc="Files",
67
+ unit="file",
68
+ position=0,
69
+ leave=True,
70
+ bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
71
+ )
72
+
73
+ for file_path in self.file_paths:
74
+ self.current_file_index += 1
75
+
76
+ # Update file progress bar
77
+ if self._file_pbar is not None:
78
+ self._file_pbar.update(1)
79
+ elif self.total_files == 1:
80
+ # For single file, log the file name
81
+ file_name = os.path.basename(file_path)
82
+ logging.info(colorize(f"Processing file: {file_name}", color="cyan"))
83
+
84
+ if self.file_type == 'csv':
85
+ yield from self._read_csv_chunks(file_path)
86
+ elif self.file_type == 'parquet':
87
+ yield from self._read_parquet_chunks(file_path)
88
+
89
+ # Close file progress bar
90
+ if self._file_pbar is not None:
91
+ self._file_pbar.close()
92
+
93
+ def _read_csv_chunks(self, file_path: str) -> Iterator[tuple]:
94
+ chunk_iterator = pd.read_csv(file_path, chunksize=self.chunk_size)
95
+
96
+ for chunk in chunk_iterator:
97
+ tensors = self._dataframe_to_tensors(chunk)
98
+ if tensors:
99
+ yield tensors
100
+
101
+ def _read_parquet_chunks(self, file_path: str) -> Iterator[tuple]:
102
+ """
103
+ Read parquet file in chunks to reduce memory footprint.
104
+ Uses pyarrow's batch reading for true streaming.
105
+ """
106
+ import pyarrow.parquet as pq
107
+ parquet_file = pq.ParquetFile(file_path)
108
+ for batch in parquet_file.iter_batches(batch_size=self.chunk_size):
109
+ chunk = batch.to_pandas()
110
+ tensors = self._dataframe_to_tensors(chunk)
111
+ if tensors:
112
+ yield tensors
113
+ del chunk
114
+
115
+ def _dataframe_to_tensors(self, df: pd.DataFrame) -> tuple | None:
116
+ if self.processor is not None:
117
+ if not self.processor.is_fitted:
118
+ raise ValueError("DataProcessor must be fitted before using in streaming mode")
119
+ transformed_data = self.processor.transform(df, return_dict=True)
120
+ else:
121
+ transformed_data = df
122
+
123
+ tensors = []
124
+
125
+ # Process features
126
+ for feature in self.all_features:
127
+ if self.processor is not None:
128
+ column_data = transformed_data.get(feature.name)
129
+ if column_data is None:
130
+ continue
131
+ else:
132
+ # Get data from original dataframe
133
+ if feature.name not in df.columns:
134
+ logging.warning(colorize(f"Feature column '{feature.name}' not found in DataFrame", "yellow"))
135
+ continue
136
+ column_data = df[feature.name].values
137
+
138
+ # Handle sequence features: convert to 2D array of shape (batch_size, seq_length)
139
+ if isinstance(feature, SequenceFeature):
140
+ if isinstance(column_data, np.ndarray) and column_data.dtype == object:
141
+ try:
142
+ column_data = np.stack([np.asarray(seq, dtype=np.int64) for seq in column_data]) # type: ignore
143
+ except (ValueError, TypeError) as e:
144
+ # Fallback: handle variable-length sequences by padding
145
+ sequences = []
146
+ max_len = feature.max_len if hasattr(feature, 'max_len') else 0
147
+ for seq in column_data:
148
+ if isinstance(seq, (list, tuple, np.ndarray)):
149
+ seq_arr = np.asarray(seq, dtype=np.int64)
150
+ else:
151
+ seq_arr = np.array([], dtype=np.int64)
152
+ sequences.append(seq_arr)
153
+
154
+ # Pad sequences to same length
155
+ if max_len == 0:
156
+ max_len = max(len(seq) for seq in sequences) if sequences else 1
157
+
158
+ padded = []
159
+ for seq in sequences:
160
+ if len(seq) > max_len:
161
+ padded.append(seq[:max_len])
162
+ else:
163
+ pad_width = max_len - len(seq)
164
+ padded.append(np.pad(seq, (0, pad_width), constant_values=0))
165
+ column_data = np.stack(padded)
166
+ else:
167
+ column_data = np.asarray(column_data, dtype=np.int64)
168
+ tensor = torch.from_numpy(column_data)
169
+ elif isinstance(feature, DenseFeature):
170
+ tensor = torch.from_numpy(np.asarray(column_data, dtype=np.float32))
171
+ else: # SparseFeature
172
+ tensor = torch.from_numpy(np.asarray(column_data, dtype=np.int64))
173
+
174
+ tensors.append(tensor)
175
+
176
+ # Process targets
177
+ target_tensors = []
178
+ for target_name in self.target_columns:
179
+ if self.processor is not None:
180
+ target_data = transformed_data.get(target_name)
181
+ if target_data is None:
182
+ continue
183
+ else:
184
+ if target_name not in df.columns:
185
+ continue
186
+ target_data = df[target_name].values
187
+
188
+ target_tensor = torch.from_numpy(np.asarray(target_data, dtype=np.float32))
189
+
190
+ if target_tensor.dim() == 1:
191
+ target_tensor = target_tensor.view(-1, 1)
192
+
193
+ target_tensors.append(target_tensor)
194
+
195
+ # Combine target tensors
196
+ if target_tensors:
197
+ if len(target_tensors) == 1 and target_tensors[0].shape[1] > 1:
198
+ y_tensor = target_tensors[0]
199
+ else:
200
+ y_tensor = torch.cat(target_tensors, dim=1)
201
+
202
+ if y_tensor.shape[1] == 1:
203
+ y_tensor = y_tensor.squeeze(1)
204
+
205
+ tensors.append(y_tensor)
206
+
207
+ if not tensors:
208
+ return None
209
+
210
+ return tuple(tensors)
211
+
212
+
213
+ class RecDataLoader:
214
+ """
215
+ Custom DataLoader for recommendation models.
216
+ Supports multiple input formats: dict, DataFrame, CSV files, Parquet files, and directories.
217
+ Optionally supports DataProcessor for on-the-fly data transformation.
218
+
219
+ Examples:
220
+ >>> # 创建RecDataLoader
221
+ >>> dataloader = RecDataLoader(
222
+ >>> dense_features=dense_features,
223
+ >>> sparse_features=sparse_features,
224
+ >>> sequence_features=sequence_features,
225
+ >>> target_columns=target_columns,
226
+ >>> processor=processor
227
+ >>> )
228
+ """
229
+
230
+ def __init__(self,
231
+ dense_features: list[DenseFeature] | None = None,
232
+ sparse_features: list[SparseFeature] | None = None,
233
+ sequence_features: list[SequenceFeature] | None = None,
234
+ target: list[str] | None | str = None,
235
+ processor: Optional['DataProcessor'] = None):
236
+
237
+ self.dense_features = dense_features if dense_features else []
238
+ self.sparse_features = sparse_features if sparse_features else []
239
+ self.sequence_features = sequence_features if sequence_features else []
240
+ if isinstance(target, str):
241
+ self.target_columns = [target]
242
+ elif isinstance(target, list):
243
+ self.target_columns = target
244
+ else:
245
+ self.target_columns = []
246
+ self.processor = processor
247
+
248
+ self.all_features = self.dense_features + self.sparse_features + self.sequence_features
249
+
250
+ def create_dataloader(self,
251
+ data: Union[dict, pd.DataFrame, str, DataLoader],
252
+ batch_size: int = 32,
253
+ shuffle: bool = True,
254
+ load_full: bool = True,
255
+ chunk_size: int = 10000) -> DataLoader:
256
+ """
257
+ Create DataLoader from various data sources.
258
+ """
259
+ if isinstance(data, DataLoader):
260
+ return data
261
+
262
+ if isinstance(data, (str, os.PathLike)):
263
+ return self._create_from_path(data, batch_size, shuffle, load_full, chunk_size)
264
+
265
+ if isinstance(data, (dict, pd.DataFrame)):
266
+ return self._create_from_memory(data, batch_size, shuffle)
267
+
268
+ raise ValueError(f"Unsupported data type: {type(data)}")
269
+
270
+ def _create_from_memory(self,
271
+ data: Union[dict, pd.DataFrame],
272
+ batch_size: int,
273
+ shuffle: bool) -> DataLoader:
274
+
275
+ if self.processor is not None:
276
+ if not self.processor.is_fitted:
277
+ raise ValueError("DataProcessor must be fitted before using in RecDataLoader")
278
+ data = self.processor.transform(data, return_dict=True)
279
+
280
+ tensors = []
281
+
282
+ # Process features
283
+ for feature in self.all_features:
284
+ column = get_column_data(data, feature.name)
285
+ if column is None:
286
+ raise KeyError(f"Feature {feature.name} not found in provided data.")
287
+
288
+ if isinstance(feature, SequenceFeature):
289
+ if isinstance(column, pd.Series):
290
+ column = column.values
291
+
292
+ # Handle different input formats for sequence features
293
+ if isinstance(column, np.ndarray):
294
+ # Check if elements are actually sequences (not just object dtype scalars)
295
+ if column.dtype == object and len(column) > 0 and isinstance(column[0], (list, tuple, np.ndarray)):
296
+ # Each element is a sequence (array/list), stack them into 2D array
297
+ try:
298
+ column = np.stack([np.asarray(seq, dtype=np.int64) for seq in column]) # type: ignore
299
+ except (ValueError, TypeError) as e:
300
+ # Fallback: handle variable-length sequences by padding
301
+ sequences = []
302
+ max_len = feature.max_len if hasattr(feature, 'max_len') else 0
303
+ for seq in column:
304
+ if isinstance(seq, (list, tuple, np.ndarray)):
305
+ seq_arr = np.asarray(seq, dtype=np.int64)
306
+ else:
307
+ seq_arr = np.array([], dtype=np.int64)
308
+ sequences.append(seq_arr)
309
+
310
+ # Pad sequences to same length
311
+ if max_len == 0:
312
+ max_len = max(len(seq) for seq in sequences) if sequences else 1
313
+
314
+ padded = []
315
+ for seq in sequences:
316
+ if len(seq) > max_len:
317
+ padded.append(seq[:max_len])
318
+ else:
319
+ pad_width = max_len - len(seq)
320
+ padded.append(np.pad(seq, (0, pad_width), constant_values=0))
321
+ column = np.stack(padded)
322
+ elif column.ndim == 1:
323
+ # 1D array, need to reshape or handle appropriately
324
+ # Assuming each element should be treated as a single-item sequence
325
+ column = column.reshape(-1, 1)
326
+ # else: already a 2D array
327
+
328
+ column = np.asarray(column, dtype=np.int64)
329
+ tensor = torch.from_numpy(column)
330
+
331
+ elif isinstance(feature, DenseFeature):
332
+ tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
333
+ else: # SparseFeature
334
+ tensor = torch.from_numpy(np.asarray(column, dtype=np.int64))
335
+
336
+ tensors.append(tensor)
337
+
338
+ # Process targets
339
+ label_tensors = []
340
+ for target_name in self.target_columns:
341
+ column = get_column_data(data, target_name)
342
+ if column is None:
343
+ continue
344
+
345
+ label_tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
346
+
347
+ if label_tensor.dim() == 1:
348
+ label_tensor = label_tensor.view(-1, 1)
349
+ elif label_tensor.dim() == 2:
350
+ if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
351
+ label_tensor = label_tensor.t()
352
+
353
+ label_tensors.append(label_tensor)
354
+
355
+ # Combine target tensors
356
+ if label_tensors:
357
+ if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
358
+ y_tensor = label_tensors[0]
359
+ else:
360
+ y_tensor = torch.cat(label_tensors, dim=1)
361
+
362
+ if y_tensor.shape[1] == 1:
363
+ y_tensor = y_tensor.squeeze(1)
364
+
365
+ tensors.append(y_tensor)
366
+
367
+ dataset = TensorDataset(*tensors)
368
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
369
+
370
+ def _create_from_path(self,
371
+ path: str,
372
+ batch_size: int,
373
+ shuffle: bool,
374
+ load_full: bool,
375
+ chunk_size: int) -> DataLoader:
376
+ """
377
+ Create DataLoader from a file path, supporting CSV and Parquet formats, with options for full loading or streaming.
378
+ """
379
+
380
+ path_obj = Path(path)
381
+
382
+ # Determine if it's a file or directory
383
+ if path_obj.is_file():
384
+ file_paths = [str(path_obj)]
385
+ file_type = self._get_file_type(str(path_obj))
386
+ elif path_obj.is_dir():
387
+ # Find all CSV and Parquet files in directory
388
+ csv_files = glob.glob(os.path.join(path, "*.csv"))
389
+ parquet_files = glob.glob(os.path.join(path, "*.parquet"))
390
+
391
+ if csv_files and parquet_files:
392
+ raise ValueError("Directory contains both CSV and Parquet files. Please use a single format.")
393
+
394
+ file_paths = csv_files if csv_files else parquet_files
395
+
396
+ if not file_paths:
397
+ raise ValueError(f"No CSV or Parquet files found in directory: {path}")
398
+
399
+ file_type = 'csv' if csv_files else 'parquet'
400
+ file_paths.sort() # Sort for consistent ordering
401
+ else:
402
+ raise ValueError(f"Invalid path: {path}")
403
+
404
+ # Load full data into memory or use streaming
405
+ if load_full:
406
+ dfs = []
407
+ for file_path in file_paths:
408
+ if file_type == 'csv':
409
+ df = pd.read_csv(file_path)
410
+ else: # parquet
411
+ df = pd.read_parquet(file_path)
412
+ dfs.append(df)
413
+
414
+ combined_df = pd.concat(dfs, ignore_index=True)
415
+ return self._create_from_memory(combined_df, batch_size, shuffle)
416
+ else:
417
+ return self._load_files_streaming(file_paths, file_type, batch_size, chunk_size)
418
+
419
+
420
+ def _load_files_streaming(self,
421
+ file_paths: list[str],
422
+ file_type: Literal['csv', 'parquet'],
423
+ batch_size: int,
424
+ chunk_size: int) -> DataLoader:
425
+ # Create FileDataset for streaming
426
+ dataset = FileDataset(
427
+ file_paths=file_paths,
428
+ dense_features=self.dense_features,
429
+ sparse_features=self.sparse_features,
430
+ sequence_features=self.sequence_features,
431
+ target_columns=self.target_columns,
432
+ chunk_size=chunk_size,
433
+ file_type=file_type,
434
+ processor=self.processor
435
+ )
436
+
437
+ return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
438
+
439
+ def _get_file_type(self, file_path: str) -> Literal['csv', 'parquet']:
440
+ ext = os.path.splitext(file_path)[1].lower()
441
+ if ext == '.csv':
442
+ return 'csv'
443
+ elif ext == '.parquet':
444
+ return 'parquet'
445
+ else:
446
+ raise ValueError(f"Unsupported file type: {ext}")
447
+