nextrec 0.1.11__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +1 -2
- nextrec/basic/callback.py +1 -2
- nextrec/basic/features.py +39 -8
- nextrec/basic/layers.py +3 -4
- nextrec/basic/loggers.py +15 -10
- nextrec/basic/metrics.py +1 -2
- nextrec/basic/model.py +160 -125
- nextrec/basic/session.py +150 -0
- nextrec/data/__init__.py +13 -2
- nextrec/data/data_utils.py +74 -22
- nextrec/data/dataloader.py +513 -0
- nextrec/data/preprocessor.py +494 -134
- nextrec/loss/__init__.py +31 -24
- nextrec/loss/listwise.py +164 -0
- nextrec/loss/loss_utils.py +133 -106
- nextrec/loss/pairwise.py +105 -0
- nextrec/loss/pointwise.py +198 -0
- nextrec/models/match/dssm.py +26 -17
- nextrec/models/match/dssm_v2.py +20 -2
- nextrec/models/match/mind.py +18 -3
- nextrec/models/match/sdm.py +17 -2
- nextrec/models/match/youtube_dnn.py +23 -10
- nextrec/models/multi_task/esmm.py +8 -8
- nextrec/models/multi_task/mmoe.py +8 -8
- nextrec/models/multi_task/ple.py +8 -8
- nextrec/models/multi_task/share_bottom.py +8 -8
- nextrec/models/ranking/__init__.py +8 -0
- nextrec/models/ranking/afm.py +5 -4
- nextrec/models/ranking/autoint.py +6 -4
- nextrec/models/ranking/dcn.py +6 -4
- nextrec/models/ranking/deepfm.py +5 -4
- nextrec/models/ranking/dien.py +6 -4
- nextrec/models/ranking/din.py +6 -4
- nextrec/models/ranking/fibinet.py +6 -4
- nextrec/models/ranking/fm.py +6 -4
- nextrec/models/ranking/masknet.py +6 -4
- nextrec/models/ranking/pnn.py +6 -4
- nextrec/models/ranking/widedeep.py +6 -4
- nextrec/models/ranking/xdeepfm.py +6 -4
- nextrec/utils/__init__.py +7 -11
- nextrec/utils/embedding.py +2 -4
- nextrec/utils/initializer.py +4 -5
- nextrec/utils/optimizer.py +7 -8
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/METADATA +3 -3
- nextrec-0.2.2.dist-info/RECORD +53 -0
- nextrec/basic/dataloader.py +0 -447
- nextrec/loss/match_losses.py +0 -294
- nextrec/utils/common.py +0 -14
- nextrec-0.1.11.dist-info/RECORD +0 -51
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
|
|
2
|
+
nextrec/__version__.py,sha256=m6kyaNpwBcP1XYcqrelX2oS3PJuOnElOcRdBa9pEb8c,22
|
|
3
|
+
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
+
nextrec/basic/activation.py,sha256=9EfYmwE0brTSKwx_0FIGQ_rybFBT9n_G-UWA7NAhMsI,2804
|
|
5
|
+
nextrec/basic/callback.py,sha256=qkq3k8rP0g4BW2C3FSCdVt_CyCcJwJ-rUXjhT2p4LP8,1035
|
|
6
|
+
nextrec/basic/features.py,sha256=qNrGm74R6K7dw-kJdA2Sbp_Tjb_oDOo1a2JCfyOaYtw,3957
|
|
7
|
+
nextrec/basic/layers.py,sha256=mDNApSlPkmPSnIPj3BDHfDEjviLybWuSGrh61Zog2uk,38290
|
|
8
|
+
nextrec/basic/loggers.py,sha256=x8lzyyK-uqBN5XGOm1Cb33dmfc2bl114n6QeFTtE54k,3752
|
|
9
|
+
nextrec/basic/metrics.py,sha256=w8tGe2tTbBNz9A1TNZF3jSpxcNC6QvFP5I0lWRd0Nw4,20398
|
|
10
|
+
nextrec/basic/model.py,sha256=r3yWw2RMk-_Ap9hXKzyp7-TKpSaPtqRzpABaJkCrT9M,66336
|
|
11
|
+
nextrec/basic/session.py,sha256=2kogEjgKAN1_ygelbwoqOs187BAcUnDTqXG1w_Pgb9I,4791
|
|
12
|
+
nextrec/data/__init__.py,sha256=SOD64AkPykwKgm0CT89aDcHeiiDQlOOAkLDsAjYiqAM,814
|
|
13
|
+
nextrec/data/data_utils.py,sha256=vGZ378YM_JQXO9npRB7JqojJx1ovjbJCWI-7lQJkicA,6298
|
|
14
|
+
nextrec/data/dataloader.py,sha256=T0i2f5KMohd-hkgTPBNCPy9xaziRrUwYd01oI8GIEI8,20450
|
|
15
|
+
nextrec/data/preprocessor.py,sha256=LzuBeBHcp6nnm7oaHDV-SZNk92ev8xjCj4r-thakRMw,42288
|
|
16
|
+
nextrec/loss/__init__.py,sha256=t-wkqxcu5wdYlrb67-CxX9aOGom0CpMJK8Fe8KGDSEE,857
|
|
17
|
+
nextrec/loss/listwise.py,sha256=LcYIPf6PGRtjV_AoWaAyp3rse904S2MghE5t032I07I,5628
|
|
18
|
+
nextrec/loss/loss_utils.py,sha256=LnTkpMTS2bhbq4Lsjf3AUn1uBaOg1TaH5VO2R8hwARc,5324
|
|
19
|
+
nextrec/loss/pairwise.py,sha256=RuQuTE-EkLaHQvT9m0CTAXxneTnVQLF1Pi9wblEClI8,3289
|
|
20
|
+
nextrec/loss/pointwise.py,sha256=6QveizdohzQTxAoBKTVSoCBpp-fy3JC8vCjImXa7jL0,7157
|
|
21
|
+
nextrec/models/generative/hstu.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
|
+
nextrec/models/match/__init__.py,sha256=ASZB5abqKPhDbk8NErNNNa0DHuWpsVxvUtyEn5XMx6Y,215
|
|
24
|
+
nextrec/models/match/dssm.py,sha256=e0hUqNLJVwTRVz4F4EiO8KLOOprKRBDtI4ID6Y1Tc60,8232
|
|
25
|
+
nextrec/models/match/dssm_v2.py,sha256=ywtqTy3YN9ke_7kzcDp7Fhtldw9RJz6yfewxALJb6Z0,7189
|
|
26
|
+
nextrec/models/match/mind.py,sha256=XSUDlZ-V95JXHHBDUl5sz99SaVuQKDvf3TArVjwUexs,9417
|
|
27
|
+
nextrec/models/match/sdm.py,sha256=96yfMQ6arP6JRhAkDTGEjlBiTteznMykrDV_3jqvvVk,10920
|
|
28
|
+
nextrec/models/match/youtube_dnn.py,sha256=pnrz9LYu65Fj4neOriFF45B5k2-yYiiREtQICxxYXZ0,7546
|
|
29
|
+
nextrec/models/multi_task/esmm.py,sha256=E9B6TlpnPUeyldTofyFg4B7SKByyxbiW2fUGHLOryO4,4883
|
|
30
|
+
nextrec/models/multi_task/mmoe.py,sha256=zhQr43Vfz7Kgi6B9pKPmaenp_38a_D7w4VvlpwCyF6Y,6165
|
|
31
|
+
nextrec/models/multi_task/ple.py,sha256=otP6oLgzrJhwkLFItzNE-AtIPouObDkafRvWzTCxfNo,11335
|
|
32
|
+
nextrec/models/multi_task/share_bottom.py,sha256=LL5HBVlvvBzHV2fLBRQMGIwpqmlxILTgU4c51XyTCo4,4517
|
|
33
|
+
nextrec/models/ranking/__init__.py,sha256=-qe34zQEVwmxeTPGYCa6gbql9quT8DwB7-buHfA7Iig,428
|
|
34
|
+
nextrec/models/ranking/afm.py,sha256=r9m1nEnc0m5d4pMtOxRMqOaXaBNCEkjJBFB-5wSHeFA,4540
|
|
35
|
+
nextrec/models/ranking/autoint.py,sha256=GYzRynjn6Csq4b3qYIFWxLQ4Yl57_OQBeF2IY0Zhr9Q,5654
|
|
36
|
+
nextrec/models/ranking/dcn.py,sha256=dUV5GbHypBGc9vVozk6aGYfIXq23c0deX-HFnIhZueg,4208
|
|
37
|
+
nextrec/models/ranking/deepfm.py,sha256=y28yJxF__TZR3O1G2ufKZVtBRLgCgmlXWqvPgLzwm3U,3510
|
|
38
|
+
nextrec/models/ranking/dien.py,sha256=E6s9TDwQfGSwtzzh8hG2F5gwgVxzVZPcptYvHLNzOLA,8475
|
|
39
|
+
nextrec/models/ranking/din.py,sha256=j5tkT5k91CbsMlMr5vJOySrcY2_rFGxmEgJJ0McW7-Q,7196
|
|
40
|
+
nextrec/models/ranking/fibinet.py,sha256=X6CbQbritvq5jql_Tvs4bn_tRla2zpWPplftZv8k6f0,4853
|
|
41
|
+
nextrec/models/ranking/fm.py,sha256=3Qx_Fgowegr6UPQtEeTmHtOrbWzkvqH94ZTjOqRLu-E,2961
|
|
42
|
+
nextrec/models/ranking/masknet.py,sha256=Bu0mZl2vKqcGnqCuUjPHjPRd1f-cDTeVwFj8Y_6v3C8,4639
|
|
43
|
+
nextrec/models/ranking/pnn.py,sha256=5RxIKdxD0XcGq-b_QDdwGRwk6b_5BQjyMvCw3Ibv2Kk,4957
|
|
44
|
+
nextrec/models/ranking/widedeep.py,sha256=b6ctElaZPv5WSYDA4piYUBo3je0eJpWpWECwcuWavM4,3716
|
|
45
|
+
nextrec/models/ranking/xdeepfm.py,sha256=I00J5tfE4tPluqeW-qrNtE4V_9fC7-rgFvA0Fxqka7o,4274
|
|
46
|
+
nextrec/utils/__init__.py,sha256=6x3OZbqks2gtgJd00y_-Y8QiAT42x5t14ARHQ-ULQDo,350
|
|
47
|
+
nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
|
|
48
|
+
nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
|
|
49
|
+
nextrec/utils/optimizer.py,sha256=85ifoy2IQgjPHOqLqr1ho7XBGE_0ry1yEB9efS6C2lM,2446
|
|
50
|
+
nextrec-0.2.2.dist-info/METADATA,sha256=lAhtUMV17TjgvjOuKZph_it7O02Ld1gKA0mc0k44oNQ,11425
|
|
51
|
+
nextrec-0.2.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
52
|
+
nextrec-0.2.2.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
53
|
+
nextrec-0.2.2.dist-info/RECORD,,
|
nextrec/basic/dataloader.py
DELETED
|
@@ -1,447 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Dataloader definitions
|
|
3
|
-
|
|
4
|
-
Date: create on 27/10/2025
|
|
5
|
-
Author:
|
|
6
|
-
Yang Zhou,zyaztec@gmail.com
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import os
|
|
10
|
-
import glob
|
|
11
|
-
import torch
|
|
12
|
-
import logging
|
|
13
|
-
import numpy as np
|
|
14
|
-
import pandas as pd
|
|
15
|
-
import tqdm
|
|
16
|
-
|
|
17
|
-
from pathlib import Path
|
|
18
|
-
from typing import Iterator, Literal, Union, Optional
|
|
19
|
-
from torch.utils.data import DataLoader, TensorDataset, IterableDataset
|
|
20
|
-
|
|
21
|
-
from nextrec.data.preprocessor import DataProcessor
|
|
22
|
-
from nextrec.data import get_column_data, collate_fn
|
|
23
|
-
|
|
24
|
-
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
25
|
-
from nextrec.basic.loggers import colorize
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class FileDataset(IterableDataset):
|
|
29
|
-
"""
|
|
30
|
-
Iterable dataset for reading multiple files in batches.
|
|
31
|
-
Supports CSV and Parquet files with chunk-based reading.
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
def __init__(self,
|
|
35
|
-
file_paths: list[str], # file paths to read, containing CSV or Parquet files
|
|
36
|
-
dense_features: list[DenseFeature], # dense feature definitions
|
|
37
|
-
sparse_features: list[SparseFeature], # sparse feature definitions
|
|
38
|
-
sequence_features: list[SequenceFeature], # sequence feature definitions
|
|
39
|
-
target_columns: list[str], # target column names
|
|
40
|
-
chunk_size: int = 10000,
|
|
41
|
-
file_type: Literal['csv', 'parquet'] = 'csv',
|
|
42
|
-
processor: Optional['DataProcessor'] = None): # optional DataProcessor for transformation
|
|
43
|
-
|
|
44
|
-
self.file_paths = file_paths
|
|
45
|
-
self.dense_features = dense_features
|
|
46
|
-
self.sparse_features = sparse_features
|
|
47
|
-
self.sequence_features = sequence_features
|
|
48
|
-
self.target_columns = target_columns
|
|
49
|
-
self.chunk_size = chunk_size
|
|
50
|
-
self.file_type = file_type
|
|
51
|
-
self.processor = processor
|
|
52
|
-
|
|
53
|
-
self.all_features = dense_features + sparse_features + sequence_features
|
|
54
|
-
self.feature_names = [f.name for f in self.all_features]
|
|
55
|
-
self.current_file_index = 0
|
|
56
|
-
self.total_files = len(file_paths)
|
|
57
|
-
|
|
58
|
-
def __iter__(self) -> Iterator[tuple]:
|
|
59
|
-
self.current_file_index = 0
|
|
60
|
-
self._file_pbar = None
|
|
61
|
-
|
|
62
|
-
# Create progress bar for file processing when multiple files
|
|
63
|
-
if self.total_files > 1:
|
|
64
|
-
self._file_pbar = tqdm.tqdm(
|
|
65
|
-
total=self.total_files,
|
|
66
|
-
desc="Files",
|
|
67
|
-
unit="file",
|
|
68
|
-
position=0,
|
|
69
|
-
leave=True,
|
|
70
|
-
bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
for file_path in self.file_paths:
|
|
74
|
-
self.current_file_index += 1
|
|
75
|
-
|
|
76
|
-
# Update file progress bar
|
|
77
|
-
if self._file_pbar is not None:
|
|
78
|
-
self._file_pbar.update(1)
|
|
79
|
-
elif self.total_files == 1:
|
|
80
|
-
# For single file, log the file name
|
|
81
|
-
file_name = os.path.basename(file_path)
|
|
82
|
-
logging.info(colorize(f"Processing file: {file_name}", color="cyan"))
|
|
83
|
-
|
|
84
|
-
if self.file_type == 'csv':
|
|
85
|
-
yield from self._read_csv_chunks(file_path)
|
|
86
|
-
elif self.file_type == 'parquet':
|
|
87
|
-
yield from self._read_parquet_chunks(file_path)
|
|
88
|
-
|
|
89
|
-
# Close file progress bar
|
|
90
|
-
if self._file_pbar is not None:
|
|
91
|
-
self._file_pbar.close()
|
|
92
|
-
|
|
93
|
-
def _read_csv_chunks(self, file_path: str) -> Iterator[tuple]:
|
|
94
|
-
chunk_iterator = pd.read_csv(file_path, chunksize=self.chunk_size)
|
|
95
|
-
|
|
96
|
-
for chunk in chunk_iterator:
|
|
97
|
-
tensors = self._dataframe_to_tensors(chunk)
|
|
98
|
-
if tensors:
|
|
99
|
-
yield tensors
|
|
100
|
-
|
|
101
|
-
def _read_parquet_chunks(self, file_path: str) -> Iterator[tuple]:
|
|
102
|
-
"""
|
|
103
|
-
Read parquet file in chunks to reduce memory footprint.
|
|
104
|
-
Uses pyarrow's batch reading for true streaming.
|
|
105
|
-
"""
|
|
106
|
-
import pyarrow.parquet as pq
|
|
107
|
-
parquet_file = pq.ParquetFile(file_path)
|
|
108
|
-
for batch in parquet_file.iter_batches(batch_size=self.chunk_size):
|
|
109
|
-
chunk = batch.to_pandas()
|
|
110
|
-
tensors = self._dataframe_to_tensors(chunk)
|
|
111
|
-
if tensors:
|
|
112
|
-
yield tensors
|
|
113
|
-
del chunk
|
|
114
|
-
|
|
115
|
-
def _dataframe_to_tensors(self, df: pd.DataFrame) -> tuple | None:
|
|
116
|
-
if self.processor is not None:
|
|
117
|
-
if not self.processor.is_fitted:
|
|
118
|
-
raise ValueError("DataProcessor must be fitted before using in streaming mode")
|
|
119
|
-
transformed_data = self.processor.transform(df, return_dict=True)
|
|
120
|
-
else:
|
|
121
|
-
transformed_data = df
|
|
122
|
-
|
|
123
|
-
tensors = []
|
|
124
|
-
|
|
125
|
-
# Process features
|
|
126
|
-
for feature in self.all_features:
|
|
127
|
-
if self.processor is not None:
|
|
128
|
-
column_data = transformed_data.get(feature.name)
|
|
129
|
-
if column_data is None:
|
|
130
|
-
continue
|
|
131
|
-
else:
|
|
132
|
-
# Get data from original dataframe
|
|
133
|
-
if feature.name not in df.columns:
|
|
134
|
-
logging.warning(colorize(f"Feature column '{feature.name}' not found in DataFrame", "yellow"))
|
|
135
|
-
continue
|
|
136
|
-
column_data = df[feature.name].values
|
|
137
|
-
|
|
138
|
-
# Handle sequence features: convert to 2D array of shape (batch_size, seq_length)
|
|
139
|
-
if isinstance(feature, SequenceFeature):
|
|
140
|
-
if isinstance(column_data, np.ndarray) and column_data.dtype == object:
|
|
141
|
-
try:
|
|
142
|
-
column_data = np.stack([np.asarray(seq, dtype=np.int64) for seq in column_data]) # type: ignore
|
|
143
|
-
except (ValueError, TypeError) as e:
|
|
144
|
-
# Fallback: handle variable-length sequences by padding
|
|
145
|
-
sequences = []
|
|
146
|
-
max_len = feature.max_len if hasattr(feature, 'max_len') else 0
|
|
147
|
-
for seq in column_data:
|
|
148
|
-
if isinstance(seq, (list, tuple, np.ndarray)):
|
|
149
|
-
seq_arr = np.asarray(seq, dtype=np.int64)
|
|
150
|
-
else:
|
|
151
|
-
seq_arr = np.array([], dtype=np.int64)
|
|
152
|
-
sequences.append(seq_arr)
|
|
153
|
-
|
|
154
|
-
# Pad sequences to same length
|
|
155
|
-
if max_len == 0:
|
|
156
|
-
max_len = max(len(seq) for seq in sequences) if sequences else 1
|
|
157
|
-
|
|
158
|
-
padded = []
|
|
159
|
-
for seq in sequences:
|
|
160
|
-
if len(seq) > max_len:
|
|
161
|
-
padded.append(seq[:max_len])
|
|
162
|
-
else:
|
|
163
|
-
pad_width = max_len - len(seq)
|
|
164
|
-
padded.append(np.pad(seq, (0, pad_width), constant_values=0))
|
|
165
|
-
column_data = np.stack(padded)
|
|
166
|
-
else:
|
|
167
|
-
column_data = np.asarray(column_data, dtype=np.int64)
|
|
168
|
-
tensor = torch.from_numpy(column_data)
|
|
169
|
-
elif isinstance(feature, DenseFeature):
|
|
170
|
-
tensor = torch.from_numpy(np.asarray(column_data, dtype=np.float32))
|
|
171
|
-
else: # SparseFeature
|
|
172
|
-
tensor = torch.from_numpy(np.asarray(column_data, dtype=np.int64))
|
|
173
|
-
|
|
174
|
-
tensors.append(tensor)
|
|
175
|
-
|
|
176
|
-
# Process targets
|
|
177
|
-
target_tensors = []
|
|
178
|
-
for target_name in self.target_columns:
|
|
179
|
-
if self.processor is not None:
|
|
180
|
-
target_data = transformed_data.get(target_name)
|
|
181
|
-
if target_data is None:
|
|
182
|
-
continue
|
|
183
|
-
else:
|
|
184
|
-
if target_name not in df.columns:
|
|
185
|
-
continue
|
|
186
|
-
target_data = df[target_name].values
|
|
187
|
-
|
|
188
|
-
target_tensor = torch.from_numpy(np.asarray(target_data, dtype=np.float32))
|
|
189
|
-
|
|
190
|
-
if target_tensor.dim() == 1:
|
|
191
|
-
target_tensor = target_tensor.view(-1, 1)
|
|
192
|
-
|
|
193
|
-
target_tensors.append(target_tensor)
|
|
194
|
-
|
|
195
|
-
# Combine target tensors
|
|
196
|
-
if target_tensors:
|
|
197
|
-
if len(target_tensors) == 1 and target_tensors[0].shape[1] > 1:
|
|
198
|
-
y_tensor = target_tensors[0]
|
|
199
|
-
else:
|
|
200
|
-
y_tensor = torch.cat(target_tensors, dim=1)
|
|
201
|
-
|
|
202
|
-
if y_tensor.shape[1] == 1:
|
|
203
|
-
y_tensor = y_tensor.squeeze(1)
|
|
204
|
-
|
|
205
|
-
tensors.append(y_tensor)
|
|
206
|
-
|
|
207
|
-
if not tensors:
|
|
208
|
-
return None
|
|
209
|
-
|
|
210
|
-
return tuple(tensors)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class RecDataLoader:
|
|
214
|
-
"""
|
|
215
|
-
Custom DataLoader for recommendation models.
|
|
216
|
-
Supports multiple input formats: dict, DataFrame, CSV files, Parquet files, and directories.
|
|
217
|
-
Optionally supports DataProcessor for on-the-fly data transformation.
|
|
218
|
-
|
|
219
|
-
Examples:
|
|
220
|
-
>>> # 创建RecDataLoader
|
|
221
|
-
>>> dataloader = RecDataLoader(
|
|
222
|
-
>>> dense_features=dense_features,
|
|
223
|
-
>>> sparse_features=sparse_features,
|
|
224
|
-
>>> sequence_features=sequence_features,
|
|
225
|
-
>>> target_columns=target_columns,
|
|
226
|
-
>>> processor=processor
|
|
227
|
-
>>> )
|
|
228
|
-
"""
|
|
229
|
-
|
|
230
|
-
def __init__(self,
|
|
231
|
-
dense_features: list[DenseFeature] | None = None,
|
|
232
|
-
sparse_features: list[SparseFeature] | None = None,
|
|
233
|
-
sequence_features: list[SequenceFeature] | None = None,
|
|
234
|
-
target: list[str] | None | str = None,
|
|
235
|
-
processor: Optional['DataProcessor'] = None):
|
|
236
|
-
|
|
237
|
-
self.dense_features = dense_features if dense_features else []
|
|
238
|
-
self.sparse_features = sparse_features if sparse_features else []
|
|
239
|
-
self.sequence_features = sequence_features if sequence_features else []
|
|
240
|
-
if isinstance(target, str):
|
|
241
|
-
self.target_columns = [target]
|
|
242
|
-
elif isinstance(target, list):
|
|
243
|
-
self.target_columns = target
|
|
244
|
-
else:
|
|
245
|
-
self.target_columns = []
|
|
246
|
-
self.processor = processor
|
|
247
|
-
|
|
248
|
-
self.all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
249
|
-
|
|
250
|
-
def create_dataloader(self,
|
|
251
|
-
data: Union[dict, pd.DataFrame, str, DataLoader],
|
|
252
|
-
batch_size: int = 32,
|
|
253
|
-
shuffle: bool = True,
|
|
254
|
-
load_full: bool = True,
|
|
255
|
-
chunk_size: int = 10000) -> DataLoader:
|
|
256
|
-
"""
|
|
257
|
-
Create DataLoader from various data sources.
|
|
258
|
-
"""
|
|
259
|
-
if isinstance(data, DataLoader):
|
|
260
|
-
return data
|
|
261
|
-
|
|
262
|
-
if isinstance(data, (str, os.PathLike)):
|
|
263
|
-
return self._create_from_path(data, batch_size, shuffle, load_full, chunk_size)
|
|
264
|
-
|
|
265
|
-
if isinstance(data, (dict, pd.DataFrame)):
|
|
266
|
-
return self._create_from_memory(data, batch_size, shuffle)
|
|
267
|
-
|
|
268
|
-
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
269
|
-
|
|
270
|
-
def _create_from_memory(self,
|
|
271
|
-
data: Union[dict, pd.DataFrame],
|
|
272
|
-
batch_size: int,
|
|
273
|
-
shuffle: bool) -> DataLoader:
|
|
274
|
-
|
|
275
|
-
if self.processor is not None:
|
|
276
|
-
if not self.processor.is_fitted:
|
|
277
|
-
raise ValueError("DataProcessor must be fitted before using in RecDataLoader")
|
|
278
|
-
data = self.processor.transform(data, return_dict=True)
|
|
279
|
-
|
|
280
|
-
tensors = []
|
|
281
|
-
|
|
282
|
-
# Process features
|
|
283
|
-
for feature in self.all_features:
|
|
284
|
-
column = get_column_data(data, feature.name)
|
|
285
|
-
if column is None:
|
|
286
|
-
raise KeyError(f"Feature {feature.name} not found in provided data.")
|
|
287
|
-
|
|
288
|
-
if isinstance(feature, SequenceFeature):
|
|
289
|
-
if isinstance(column, pd.Series):
|
|
290
|
-
column = column.values
|
|
291
|
-
|
|
292
|
-
# Handle different input formats for sequence features
|
|
293
|
-
if isinstance(column, np.ndarray):
|
|
294
|
-
# Check if elements are actually sequences (not just object dtype scalars)
|
|
295
|
-
if column.dtype == object and len(column) > 0 and isinstance(column[0], (list, tuple, np.ndarray)):
|
|
296
|
-
# Each element is a sequence (array/list), stack them into 2D array
|
|
297
|
-
try:
|
|
298
|
-
column = np.stack([np.asarray(seq, dtype=np.int64) for seq in column]) # type: ignore
|
|
299
|
-
except (ValueError, TypeError) as e:
|
|
300
|
-
# Fallback: handle variable-length sequences by padding
|
|
301
|
-
sequences = []
|
|
302
|
-
max_len = feature.max_len if hasattr(feature, 'max_len') else 0
|
|
303
|
-
for seq in column:
|
|
304
|
-
if isinstance(seq, (list, tuple, np.ndarray)):
|
|
305
|
-
seq_arr = np.asarray(seq, dtype=np.int64)
|
|
306
|
-
else:
|
|
307
|
-
seq_arr = np.array([], dtype=np.int64)
|
|
308
|
-
sequences.append(seq_arr)
|
|
309
|
-
|
|
310
|
-
# Pad sequences to same length
|
|
311
|
-
if max_len == 0:
|
|
312
|
-
max_len = max(len(seq) for seq in sequences) if sequences else 1
|
|
313
|
-
|
|
314
|
-
padded = []
|
|
315
|
-
for seq in sequences:
|
|
316
|
-
if len(seq) > max_len:
|
|
317
|
-
padded.append(seq[:max_len])
|
|
318
|
-
else:
|
|
319
|
-
pad_width = max_len - len(seq)
|
|
320
|
-
padded.append(np.pad(seq, (0, pad_width), constant_values=0))
|
|
321
|
-
column = np.stack(padded)
|
|
322
|
-
elif column.ndim == 1:
|
|
323
|
-
# 1D array, need to reshape or handle appropriately
|
|
324
|
-
# Assuming each element should be treated as a single-item sequence
|
|
325
|
-
column = column.reshape(-1, 1)
|
|
326
|
-
# else: already a 2D array
|
|
327
|
-
|
|
328
|
-
column = np.asarray(column, dtype=np.int64)
|
|
329
|
-
tensor = torch.from_numpy(column)
|
|
330
|
-
|
|
331
|
-
elif isinstance(feature, DenseFeature):
|
|
332
|
-
tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
|
|
333
|
-
else: # SparseFeature
|
|
334
|
-
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64))
|
|
335
|
-
|
|
336
|
-
tensors.append(tensor)
|
|
337
|
-
|
|
338
|
-
# Process targets
|
|
339
|
-
label_tensors = []
|
|
340
|
-
for target_name in self.target_columns:
|
|
341
|
-
column = get_column_data(data, target_name)
|
|
342
|
-
if column is None:
|
|
343
|
-
continue
|
|
344
|
-
|
|
345
|
-
label_tensor = torch.from_numpy(np.asarray(column, dtype=np.float32))
|
|
346
|
-
|
|
347
|
-
if label_tensor.dim() == 1:
|
|
348
|
-
label_tensor = label_tensor.view(-1, 1)
|
|
349
|
-
elif label_tensor.dim() == 2:
|
|
350
|
-
if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
|
|
351
|
-
label_tensor = label_tensor.t()
|
|
352
|
-
|
|
353
|
-
label_tensors.append(label_tensor)
|
|
354
|
-
|
|
355
|
-
# Combine target tensors
|
|
356
|
-
if label_tensors:
|
|
357
|
-
if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
|
|
358
|
-
y_tensor = label_tensors[0]
|
|
359
|
-
else:
|
|
360
|
-
y_tensor = torch.cat(label_tensors, dim=1)
|
|
361
|
-
|
|
362
|
-
if y_tensor.shape[1] == 1:
|
|
363
|
-
y_tensor = y_tensor.squeeze(1)
|
|
364
|
-
|
|
365
|
-
tensors.append(y_tensor)
|
|
366
|
-
|
|
367
|
-
dataset = TensorDataset(*tensors)
|
|
368
|
-
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
369
|
-
|
|
370
|
-
def _create_from_path(self,
|
|
371
|
-
path: str,
|
|
372
|
-
batch_size: int,
|
|
373
|
-
shuffle: bool,
|
|
374
|
-
load_full: bool,
|
|
375
|
-
chunk_size: int) -> DataLoader:
|
|
376
|
-
"""
|
|
377
|
-
Create DataLoader from a file path, supporting CSV and Parquet formats, with options for full loading or streaming.
|
|
378
|
-
"""
|
|
379
|
-
|
|
380
|
-
path_obj = Path(path)
|
|
381
|
-
|
|
382
|
-
# Determine if it's a file or directory
|
|
383
|
-
if path_obj.is_file():
|
|
384
|
-
file_paths = [str(path_obj)]
|
|
385
|
-
file_type = self._get_file_type(str(path_obj))
|
|
386
|
-
elif path_obj.is_dir():
|
|
387
|
-
# Find all CSV and Parquet files in directory
|
|
388
|
-
csv_files = glob.glob(os.path.join(path, "*.csv"))
|
|
389
|
-
parquet_files = glob.glob(os.path.join(path, "*.parquet"))
|
|
390
|
-
|
|
391
|
-
if csv_files and parquet_files:
|
|
392
|
-
raise ValueError("Directory contains both CSV and Parquet files. Please use a single format.")
|
|
393
|
-
|
|
394
|
-
file_paths = csv_files if csv_files else parquet_files
|
|
395
|
-
|
|
396
|
-
if not file_paths:
|
|
397
|
-
raise ValueError(f"No CSV or Parquet files found in directory: {path}")
|
|
398
|
-
|
|
399
|
-
file_type = 'csv' if csv_files else 'parquet'
|
|
400
|
-
file_paths.sort() # Sort for consistent ordering
|
|
401
|
-
else:
|
|
402
|
-
raise ValueError(f"Invalid path: {path}")
|
|
403
|
-
|
|
404
|
-
# Load full data into memory or use streaming
|
|
405
|
-
if load_full:
|
|
406
|
-
dfs = []
|
|
407
|
-
for file_path in file_paths:
|
|
408
|
-
if file_type == 'csv':
|
|
409
|
-
df = pd.read_csv(file_path)
|
|
410
|
-
else: # parquet
|
|
411
|
-
df = pd.read_parquet(file_path)
|
|
412
|
-
dfs.append(df)
|
|
413
|
-
|
|
414
|
-
combined_df = pd.concat(dfs, ignore_index=True)
|
|
415
|
-
return self._create_from_memory(combined_df, batch_size, shuffle)
|
|
416
|
-
else:
|
|
417
|
-
return self._load_files_streaming(file_paths, file_type, batch_size, chunk_size)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
def _load_files_streaming(self,
|
|
421
|
-
file_paths: list[str],
|
|
422
|
-
file_type: Literal['csv', 'parquet'],
|
|
423
|
-
batch_size: int,
|
|
424
|
-
chunk_size: int) -> DataLoader:
|
|
425
|
-
# Create FileDataset for streaming
|
|
426
|
-
dataset = FileDataset(
|
|
427
|
-
file_paths=file_paths,
|
|
428
|
-
dense_features=self.dense_features,
|
|
429
|
-
sparse_features=self.sparse_features,
|
|
430
|
-
sequence_features=self.sequence_features,
|
|
431
|
-
target_columns=self.target_columns,
|
|
432
|
-
chunk_size=chunk_size,
|
|
433
|
-
file_type=file_type,
|
|
434
|
-
processor=self.processor
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
|
|
438
|
-
|
|
439
|
-
def _get_file_type(self, file_path: str) -> Literal['csv', 'parquet']:
|
|
440
|
-
ext = os.path.splitext(file_path)[1].lower()
|
|
441
|
-
if ext == '.csv':
|
|
442
|
-
return 'csv'
|
|
443
|
-
elif ext == '.parquet':
|
|
444
|
-
return 'parquet'
|
|
445
|
-
else:
|
|
446
|
-
raise ValueError(f"Unsupported file type: {ext}")
|
|
447
|
-
|