nextrec 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +41 -0
- nextrec/__version__.py +1 -0
- nextrec/basic/__init__.py +0 -0
- nextrec/basic/activation.py +92 -0
- nextrec/basic/callback.py +35 -0
- nextrec/basic/dataloader.py +447 -0
- nextrec/basic/features.py +87 -0
- nextrec/basic/layers.py +985 -0
- nextrec/basic/loggers.py +124 -0
- nextrec/basic/metrics.py +557 -0
- nextrec/basic/model.py +1438 -0
- nextrec/data/__init__.py +27 -0
- nextrec/data/data_utils.py +132 -0
- nextrec/data/preprocessor.py +662 -0
- nextrec/loss/__init__.py +35 -0
- nextrec/loss/loss_utils.py +136 -0
- nextrec/loss/match_losses.py +294 -0
- nextrec/models/generative/hstu.py +0 -0
- nextrec/models/generative/tiger.py +0 -0
- nextrec/models/match/__init__.py +13 -0
- nextrec/models/match/dssm.py +200 -0
- nextrec/models/match/dssm_v2.py +162 -0
- nextrec/models/match/mind.py +210 -0
- nextrec/models/match/sdm.py +253 -0
- nextrec/models/match/youtube_dnn.py +172 -0
- nextrec/models/multi_task/esmm.py +129 -0
- nextrec/models/multi_task/mmoe.py +161 -0
- nextrec/models/multi_task/ple.py +260 -0
- nextrec/models/multi_task/share_bottom.py +126 -0
- nextrec/models/ranking/__init__.py +17 -0
- nextrec/models/ranking/afm.py +118 -0
- nextrec/models/ranking/autoint.py +140 -0
- nextrec/models/ranking/dcn.py +120 -0
- nextrec/models/ranking/deepfm.py +95 -0
- nextrec/models/ranking/dien.py +214 -0
- nextrec/models/ranking/din.py +181 -0
- nextrec/models/ranking/fibinet.py +130 -0
- nextrec/models/ranking/fm.py +87 -0
- nextrec/models/ranking/masknet.py +125 -0
- nextrec/models/ranking/pnn.py +128 -0
- nextrec/models/ranking/widedeep.py +105 -0
- nextrec/models/ranking/xdeepfm.py +117 -0
- nextrec/utils/__init__.py +18 -0
- nextrec/utils/common.py +14 -0
- nextrec/utils/embedding.py +19 -0
- nextrec/utils/initializer.py +47 -0
- nextrec/utils/optimizer.py +75 -0
- nextrec-0.1.1.dist-info/METADATA +302 -0
- nextrec-0.1.1.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/WHEEL +4 -0
- nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,662 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
3
|
+
|
|
4
|
+
Date: create on 13/11/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou, zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pickle
|
|
13
|
+
import hashlib
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
from typing import Dict, Union, Optional, Literal, Any
|
|
17
|
+
from sklearn.preprocessing import (
|
|
18
|
+
StandardScaler,
|
|
19
|
+
MinMaxScaler,
|
|
20
|
+
RobustScaler,
|
|
21
|
+
MaxAbsScaler,
|
|
22
|
+
LabelEncoder
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
from nextrec.basic.loggers import setup_logger, colorize
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DataProcessor:
|
|
30
|
+
"""DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
>>> processor = DataProcessor()
|
|
34
|
+
>>> processor.add_numeric_feature('age', scaler='standard')
|
|
35
|
+
>>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
|
|
36
|
+
>>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
|
|
37
|
+
>>> processor.add_target('label', target_type='binary')
|
|
38
|
+
>>>
|
|
39
|
+
>>> # Fit and transform data
|
|
40
|
+
>>> processor.fit(train_df)
|
|
41
|
+
>>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
|
|
42
|
+
>>>
|
|
43
|
+
>>> # Save and load processor
|
|
44
|
+
>>> processor.save('processor.pkl')
|
|
45
|
+
>>> loaded_processor = DataProcessor.load('processor.pkl')
|
|
46
|
+
>>>
|
|
47
|
+
>>> # Get vocabulary sizes for embedding layers
|
|
48
|
+
>>> vocab_sizes = processor.get_vocab_sizes()
|
|
49
|
+
"""
|
|
50
|
+
def __init__(self):
|
|
51
|
+
self.numeric_features: Dict[str, Dict[str, Any]] = {}
|
|
52
|
+
self.sparse_features: Dict[str, Dict[str, Any]] = {}
|
|
53
|
+
self.sequence_features: Dict[str, Dict[str, Any]] = {}
|
|
54
|
+
self.target_features: Dict[str, Dict[str, Any]] = {}
|
|
55
|
+
|
|
56
|
+
self.is_fitted = False
|
|
57
|
+
self._transform_summary_printed = False # Track if summary has been printed during transform
|
|
58
|
+
|
|
59
|
+
self.scalers: Dict[str, Any] = {}
|
|
60
|
+
self.label_encoders: Dict[str, LabelEncoder] = {}
|
|
61
|
+
self.target_encoders: Dict[str, Dict[str, int]] = {}
|
|
62
|
+
|
|
63
|
+
# Initialize logger if not already initialized
|
|
64
|
+
self._logger_initialized = False
|
|
65
|
+
if not logging.getLogger().hasHandlers():
|
|
66
|
+
setup_logger()
|
|
67
|
+
self._logger_initialized = True
|
|
68
|
+
|
|
69
|
+
def add_numeric_feature(
|
|
70
|
+
self,
|
|
71
|
+
name: str,
|
|
72
|
+
scaler: Optional[Literal['standard', 'minmax', 'robust', 'maxabs', 'log', 'none']] = 'standard',
|
|
73
|
+
fill_na: Optional[float] = None
|
|
74
|
+
):
|
|
75
|
+
self.numeric_features[name] = {
|
|
76
|
+
'scaler': scaler,
|
|
77
|
+
'fill_na': fill_na
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def add_sparse_feature(
|
|
81
|
+
self,
|
|
82
|
+
name: str,
|
|
83
|
+
encode_method: Literal['hash', 'label'] = 'label',
|
|
84
|
+
hash_size: Optional[int] = None,
|
|
85
|
+
fill_na: str = '<UNK>'
|
|
86
|
+
):
|
|
87
|
+
if encode_method == 'hash' and hash_size is None:
|
|
88
|
+
raise ValueError("hash_size must be specified when encode_method='hash'")
|
|
89
|
+
|
|
90
|
+
self.sparse_features[name] = {
|
|
91
|
+
'encode_method': encode_method,
|
|
92
|
+
'hash_size': hash_size,
|
|
93
|
+
'fill_na': fill_na
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
def add_sequence_feature(
|
|
97
|
+
self,
|
|
98
|
+
name: str,
|
|
99
|
+
encode_method: Literal['hash', 'label'] = 'label',
|
|
100
|
+
hash_size: Optional[int] = None,
|
|
101
|
+
max_len: Optional[int] = 50,
|
|
102
|
+
pad_value: int = 0,
|
|
103
|
+
truncate: Literal['pre', 'post'] = 'pre', # pre: keep last max_len items, post: keep first max_len items
|
|
104
|
+
separator: str = ','
|
|
105
|
+
):
|
|
106
|
+
|
|
107
|
+
if encode_method == 'hash' and hash_size is None:
|
|
108
|
+
raise ValueError("hash_size must be specified when encode_method='hash'")
|
|
109
|
+
|
|
110
|
+
self.sequence_features[name] = {
|
|
111
|
+
'encode_method': encode_method,
|
|
112
|
+
'hash_size': hash_size,
|
|
113
|
+
'max_len': max_len,
|
|
114
|
+
'pad_value': pad_value,
|
|
115
|
+
'truncate': truncate,
|
|
116
|
+
'separator': separator
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
def add_target(
|
|
120
|
+
self,
|
|
121
|
+
name: str, # example: 'click'
|
|
122
|
+
target_type: Literal['binary', 'multiclass', 'regression'] = 'binary',
|
|
123
|
+
label_map: Optional[Dict[str, int]] = None # example: {'click': 1, 'no_click': 0}
|
|
124
|
+
):
|
|
125
|
+
self.target_features[name] = {
|
|
126
|
+
'target_type': target_type,
|
|
127
|
+
'label_map': label_map
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def _hash_string(self, s: str, hash_size: int) -> int:
|
|
131
|
+
return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
|
|
132
|
+
|
|
133
|
+
def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
134
|
+
|
|
135
|
+
name = str(data.name)
|
|
136
|
+
scaler_type = config['scaler']
|
|
137
|
+
fill_na = config['fill_na']
|
|
138
|
+
|
|
139
|
+
if data.isna().any():
|
|
140
|
+
if fill_na is None:
|
|
141
|
+
# Default use mean value to fill missing values for numeric features
|
|
142
|
+
fill_na = data.mean()
|
|
143
|
+
config['fill_na_value'] = fill_na
|
|
144
|
+
|
|
145
|
+
if scaler_type == 'standard':
|
|
146
|
+
scaler = StandardScaler()
|
|
147
|
+
elif scaler_type == 'minmax':
|
|
148
|
+
scaler = MinMaxScaler()
|
|
149
|
+
elif scaler_type == 'robust':
|
|
150
|
+
scaler = RobustScaler()
|
|
151
|
+
elif scaler_type == 'maxabs':
|
|
152
|
+
scaler = MaxAbsScaler()
|
|
153
|
+
elif scaler_type == 'log':
|
|
154
|
+
scaler = None
|
|
155
|
+
elif scaler_type == 'none':
|
|
156
|
+
scaler = None
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError(f"Unknown scaler type: {scaler_type}")
|
|
159
|
+
|
|
160
|
+
if scaler is not None and scaler_type != 'log':
|
|
161
|
+
filled_data = data.fillna(config.get('fill_na_value', 0))
|
|
162
|
+
values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
|
|
163
|
+
scaler.fit(values)
|
|
164
|
+
self.scalers[name] = scaler
|
|
165
|
+
|
|
166
|
+
def _process_numeric_feature_transform(
|
|
167
|
+
self,
|
|
168
|
+
data: pd.Series,
|
|
169
|
+
config: Dict[str, Any]
|
|
170
|
+
) -> np.ndarray:
|
|
171
|
+
logger = logging.getLogger()
|
|
172
|
+
|
|
173
|
+
name = str(data.name)
|
|
174
|
+
scaler_type = config['scaler']
|
|
175
|
+
fill_na_value = config.get('fill_na_value', 0)
|
|
176
|
+
|
|
177
|
+
filled_data = data.fillna(fill_na_value)
|
|
178
|
+
values = np.array(filled_data.values, dtype=np.float64)
|
|
179
|
+
|
|
180
|
+
if scaler_type == 'log':
|
|
181
|
+
result = np.log1p(np.maximum(values, 0))
|
|
182
|
+
elif scaler_type == 'none':
|
|
183
|
+
result = values
|
|
184
|
+
else:
|
|
185
|
+
scaler = self.scalers.get(name)
|
|
186
|
+
if scaler is None:
|
|
187
|
+
logger.warning(f"Scaler for {name} not fitted, returning original values")
|
|
188
|
+
result = values
|
|
189
|
+
else:
|
|
190
|
+
result = scaler.transform(values.reshape(-1, 1)).ravel()
|
|
191
|
+
|
|
192
|
+
return result
|
|
193
|
+
|
|
194
|
+
def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
195
|
+
|
|
196
|
+
name = str(data.name)
|
|
197
|
+
encode_method = config['encode_method']
|
|
198
|
+
fill_na = config['fill_na'] # <UNK>
|
|
199
|
+
|
|
200
|
+
filled_data = data.fillna(fill_na).astype(str)
|
|
201
|
+
|
|
202
|
+
if encode_method == 'label':
|
|
203
|
+
le = LabelEncoder()
|
|
204
|
+
le.fit(filled_data)
|
|
205
|
+
self.label_encoders[name] = le
|
|
206
|
+
config['vocab_size'] = len(le.classes_)
|
|
207
|
+
elif encode_method == 'hash':
|
|
208
|
+
config['vocab_size'] = config['hash_size']
|
|
209
|
+
|
|
210
|
+
def _process_sparse_feature_transform(
|
|
211
|
+
self,
|
|
212
|
+
data: pd.Series,
|
|
213
|
+
config: Dict[str, Any]
|
|
214
|
+
) -> np.ndarray:
|
|
215
|
+
|
|
216
|
+
name = str(data.name)
|
|
217
|
+
encode_method = config['encode_method']
|
|
218
|
+
fill_na = config['fill_na']
|
|
219
|
+
|
|
220
|
+
filled_data = data.fillna(fill_na).astype(str)
|
|
221
|
+
|
|
222
|
+
if encode_method == 'label':
|
|
223
|
+
le = self.label_encoders.get(name)
|
|
224
|
+
if le is None:
|
|
225
|
+
raise ValueError(f"LabelEncoder for {name} not fitted")
|
|
226
|
+
|
|
227
|
+
result = []
|
|
228
|
+
for val in filled_data:
|
|
229
|
+
if val in le.classes_:
|
|
230
|
+
encoded = le.transform([val])
|
|
231
|
+
result.append(int(encoded[0]))
|
|
232
|
+
else:
|
|
233
|
+
result.append(0)
|
|
234
|
+
return np.array(result, dtype=np.int64)
|
|
235
|
+
|
|
236
|
+
elif encode_method == 'hash':
|
|
237
|
+
hash_size = config['hash_size']
|
|
238
|
+
return np.array([self._hash_string(val, hash_size) for val in filled_data], dtype=np.int64)
|
|
239
|
+
|
|
240
|
+
return np.array([], dtype=np.int64)
|
|
241
|
+
|
|
242
|
+
def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
243
|
+
|
|
244
|
+
name = str(data.name)
|
|
245
|
+
encode_method = config['encode_method']
|
|
246
|
+
separator = config['separator']
|
|
247
|
+
|
|
248
|
+
if encode_method == 'label':
|
|
249
|
+
all_tokens = set()
|
|
250
|
+
for seq in data:
|
|
251
|
+
# Skip None, np.nan, and empty strings
|
|
252
|
+
if seq is None:
|
|
253
|
+
continue
|
|
254
|
+
if isinstance(seq, (float, np.floating)) and np.isnan(seq):
|
|
255
|
+
continue
|
|
256
|
+
if isinstance(seq, str) and seq.strip() == '':
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
if isinstance(seq, str):
|
|
260
|
+
tokens = seq.split(separator)
|
|
261
|
+
elif isinstance(seq, (list, tuple)):
|
|
262
|
+
tokens = [str(t) for t in seq]
|
|
263
|
+
elif isinstance(seq, np.ndarray):
|
|
264
|
+
tokens = [str(t) for t in seq.tolist()]
|
|
265
|
+
else:
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
all_tokens.update(tokens)
|
|
269
|
+
|
|
270
|
+
if len(all_tokens) == 0:
|
|
271
|
+
all_tokens.add('<PAD>')
|
|
272
|
+
|
|
273
|
+
le = LabelEncoder()
|
|
274
|
+
le.fit(list(all_tokens))
|
|
275
|
+
self.label_encoders[name] = le
|
|
276
|
+
config['vocab_size'] = len(le.classes_)
|
|
277
|
+
elif encode_method == 'hash':
|
|
278
|
+
config['vocab_size'] = config['hash_size']
|
|
279
|
+
|
|
280
|
+
def _process_sequence_feature_transform(
|
|
281
|
+
self,
|
|
282
|
+
data: pd.Series,
|
|
283
|
+
config: Dict[str, Any]
|
|
284
|
+
) -> np.ndarray:
|
|
285
|
+
name = str(data.name)
|
|
286
|
+
encode_method = config['encode_method']
|
|
287
|
+
max_len = config['max_len']
|
|
288
|
+
pad_value = config['pad_value']
|
|
289
|
+
truncate = config['truncate']
|
|
290
|
+
separator = config['separator']
|
|
291
|
+
|
|
292
|
+
result = []
|
|
293
|
+
for seq in data:
|
|
294
|
+
tokens = []
|
|
295
|
+
|
|
296
|
+
if seq is None:
|
|
297
|
+
tokens = []
|
|
298
|
+
elif isinstance(seq, (float, np.floating)) and np.isnan(seq):
|
|
299
|
+
tokens = []
|
|
300
|
+
elif isinstance(seq, str):
|
|
301
|
+
if seq.strip() == '':
|
|
302
|
+
tokens = []
|
|
303
|
+
else:
|
|
304
|
+
tokens = seq.split(separator)
|
|
305
|
+
elif isinstance(seq, (list, tuple)):
|
|
306
|
+
tokens = [str(t) for t in seq]
|
|
307
|
+
elif isinstance(seq, np.ndarray):
|
|
308
|
+
tokens = [str(t) for t in seq.tolist()]
|
|
309
|
+
else:
|
|
310
|
+
tokens = []
|
|
311
|
+
|
|
312
|
+
if encode_method == 'label':
|
|
313
|
+
le = self.label_encoders.get(name)
|
|
314
|
+
if le is None:
|
|
315
|
+
raise ValueError(f"LabelEncoder for {name} not fitted")
|
|
316
|
+
|
|
317
|
+
encoded = []
|
|
318
|
+
for token in tokens:
|
|
319
|
+
token_str = str(token).strip()
|
|
320
|
+
if token_str and token_str in le.classes_:
|
|
321
|
+
encoded_val = le.transform([token_str])
|
|
322
|
+
encoded.append(int(encoded_val[0]))
|
|
323
|
+
else:
|
|
324
|
+
encoded.append(0) # UNK
|
|
325
|
+
elif encode_method == 'hash':
|
|
326
|
+
hash_size = config['hash_size']
|
|
327
|
+
encoded = [self._hash_string(str(token), hash_size) for token in tokens if str(token).strip()]
|
|
328
|
+
else:
|
|
329
|
+
encoded = []
|
|
330
|
+
|
|
331
|
+
if len(encoded) > max_len:
|
|
332
|
+
if truncate == 'pre': # keep last max_len items
|
|
333
|
+
encoded = encoded[-max_len:]
|
|
334
|
+
else: # keep first max_len items
|
|
335
|
+
encoded = encoded[:max_len]
|
|
336
|
+
elif len(encoded) < max_len:
|
|
337
|
+
padding = [pad_value] * (max_len - len(encoded))
|
|
338
|
+
encoded = encoded + padding
|
|
339
|
+
|
|
340
|
+
result.append(encoded)
|
|
341
|
+
|
|
342
|
+
return np.array(result, dtype=np.int64)
|
|
343
|
+
|
|
344
|
+
def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
345
|
+
name = str(data.name)
|
|
346
|
+
target_type = config['target_type']
|
|
347
|
+
label_map = config['label_map']
|
|
348
|
+
|
|
349
|
+
if target_type in ['binary', 'multiclass']:
|
|
350
|
+
if label_map is None:
|
|
351
|
+
unique_values = data.dropna().unique()
|
|
352
|
+
sorted_values = sorted(unique_values)
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
int_values = [int(v) for v in sorted_values]
|
|
356
|
+
if int_values == list(range(len(int_values))):
|
|
357
|
+
label_map = {str(val): int(val) for val in sorted_values}
|
|
358
|
+
else:
|
|
359
|
+
label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
|
|
360
|
+
except (ValueError, TypeError):
|
|
361
|
+
label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
|
|
362
|
+
|
|
363
|
+
config['label_map'] = label_map
|
|
364
|
+
|
|
365
|
+
self.target_encoders[name] = label_map
|
|
366
|
+
|
|
367
|
+
def _process_target_transform(
|
|
368
|
+
self,
|
|
369
|
+
data: pd.Series,
|
|
370
|
+
config: Dict[str, Any]
|
|
371
|
+
) -> np.ndarray:
|
|
372
|
+
logger = logging.getLogger()
|
|
373
|
+
|
|
374
|
+
name = str(data.name)
|
|
375
|
+
target_type = config['target_type']
|
|
376
|
+
|
|
377
|
+
if target_type == 'regression':
|
|
378
|
+
values = np.array(data.values, dtype=np.float32)
|
|
379
|
+
return values
|
|
380
|
+
else:
|
|
381
|
+
label_map = self.target_encoders.get(name)
|
|
382
|
+
if label_map is None:
|
|
383
|
+
raise ValueError(f"Target encoder for {name} not fitted")
|
|
384
|
+
|
|
385
|
+
result = []
|
|
386
|
+
for val in data:
|
|
387
|
+
str_val = str(val)
|
|
388
|
+
if str_val in label_map:
|
|
389
|
+
result.append(label_map[str_val])
|
|
390
|
+
else:
|
|
391
|
+
logger.warning(f"Unknown target value: {val}, mapping to 0")
|
|
392
|
+
result.append(0)
|
|
393
|
+
|
|
394
|
+
return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
|
|
395
|
+
|
|
396
|
+
# fit is nothing but registering the statistics from data so that we can transform the data later
|
|
397
|
+
def fit(self, data: Union[pd.DataFrame, Dict[str, Any]]):
|
|
398
|
+
logger = logging.getLogger()
|
|
399
|
+
|
|
400
|
+
if isinstance(data, dict):
|
|
401
|
+
data = pd.DataFrame(data)
|
|
402
|
+
|
|
403
|
+
logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
|
|
404
|
+
|
|
405
|
+
for name, config in self.numeric_features.items():
|
|
406
|
+
if name not in data.columns:
|
|
407
|
+
logger.warning(f"Numeric feature {name} not found in data")
|
|
408
|
+
continue
|
|
409
|
+
self._process_numeric_feature_fit(data[name], config)
|
|
410
|
+
|
|
411
|
+
for name, config in self.sparse_features.items():
|
|
412
|
+
if name not in data.columns:
|
|
413
|
+
logger.warning(f"Sparse feature {name} not found in data")
|
|
414
|
+
continue
|
|
415
|
+
self._process_sparse_feature_fit(data[name], config)
|
|
416
|
+
|
|
417
|
+
for name, config in self.sequence_features.items():
|
|
418
|
+
if name not in data.columns:
|
|
419
|
+
logger.warning(f"Sequence feature {name} not found in data")
|
|
420
|
+
continue
|
|
421
|
+
self._process_sequence_feature_fit(data[name], config)
|
|
422
|
+
|
|
423
|
+
for name, config in self.target_features.items():
|
|
424
|
+
if name not in data.columns:
|
|
425
|
+
logger.warning(f"Target {name} not found in data")
|
|
426
|
+
continue
|
|
427
|
+
self._process_target_fit(data[name], config)
|
|
428
|
+
|
|
429
|
+
self.is_fitted = True
|
|
430
|
+
logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
|
|
431
|
+
return self
|
|
432
|
+
|
|
433
|
+
def transform(
|
|
434
|
+
self,
|
|
435
|
+
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
436
|
+
return_dict: bool = True
|
|
437
|
+
) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
|
|
438
|
+
logger = logging.getLogger()
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
if not self.is_fitted:
|
|
442
|
+
raise ValueError("DataProcessor must be fitted before transform")
|
|
443
|
+
|
|
444
|
+
# Convert input to dict format for unified processing
|
|
445
|
+
if isinstance(data, pd.DataFrame):
|
|
446
|
+
data_dict = {col: data[col] for col in data.columns}
|
|
447
|
+
elif isinstance(data, dict):
|
|
448
|
+
data_dict = data
|
|
449
|
+
else:
|
|
450
|
+
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
451
|
+
|
|
452
|
+
result_dict = {}
|
|
453
|
+
for key, value in data_dict.items():
|
|
454
|
+
if isinstance(value, pd.Series):
|
|
455
|
+
result_dict[key] = value.values
|
|
456
|
+
elif isinstance(value, np.ndarray):
|
|
457
|
+
result_dict[key] = value
|
|
458
|
+
else:
|
|
459
|
+
result_dict[key] = np.array(value)
|
|
460
|
+
|
|
461
|
+
# process numeric features
|
|
462
|
+
for name, config in self.numeric_features.items():
|
|
463
|
+
if name not in data_dict:
|
|
464
|
+
logger.warning(f"Numeric feature {name} not found in data")
|
|
465
|
+
continue
|
|
466
|
+
# Convert to Series for processing
|
|
467
|
+
series_data = pd.Series(data_dict[name], name=name)
|
|
468
|
+
processed = self._process_numeric_feature_transform(series_data, config)
|
|
469
|
+
result_dict[name] = processed
|
|
470
|
+
|
|
471
|
+
# process sparse features
|
|
472
|
+
for name, config in self.sparse_features.items():
|
|
473
|
+
if name not in data_dict:
|
|
474
|
+
logger.warning(f"Sparse feature {name} not found in data")
|
|
475
|
+
continue
|
|
476
|
+
series_data = pd.Series(data_dict[name], name=name)
|
|
477
|
+
processed = self._process_sparse_feature_transform(series_data, config)
|
|
478
|
+
result_dict[name] = processed
|
|
479
|
+
|
|
480
|
+
# process sequence features
|
|
481
|
+
for name, config in self.sequence_features.items():
|
|
482
|
+
if name not in data_dict:
|
|
483
|
+
logger.warning(f"Sequence feature {name} not found in data")
|
|
484
|
+
continue
|
|
485
|
+
series_data = pd.Series(data_dict[name], name=name)
|
|
486
|
+
processed = self._process_sequence_feature_transform(series_data, config)
|
|
487
|
+
result_dict[name] = processed
|
|
488
|
+
|
|
489
|
+
# process target features
|
|
490
|
+
for name, config in self.target_features.items():
|
|
491
|
+
if name not in data_dict:
|
|
492
|
+
logger.warning(f"Target {name} not found in data")
|
|
493
|
+
continue
|
|
494
|
+
series_data = pd.Series(data_dict[name], name=name)
|
|
495
|
+
processed = self._process_target_transform(series_data, config)
|
|
496
|
+
result_dict[name] = processed
|
|
497
|
+
|
|
498
|
+
if return_dict:
|
|
499
|
+
return result_dict
|
|
500
|
+
else:
|
|
501
|
+
# Convert all arrays to Series/lists at once to avoid fragmentation
|
|
502
|
+
columns_dict = {}
|
|
503
|
+
for key, value in result_dict.items():
|
|
504
|
+
if key in self.sequence_features:
|
|
505
|
+
columns_dict[key] = [list(seq) for seq in value]
|
|
506
|
+
else:
|
|
507
|
+
columns_dict[key] = value
|
|
508
|
+
|
|
509
|
+
result_df = pd.DataFrame(columns_dict)
|
|
510
|
+
return result_df
|
|
511
|
+
|
|
512
|
+
def fit_transform(
|
|
513
|
+
self,
|
|
514
|
+
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
515
|
+
return_dict: bool = True
|
|
516
|
+
) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
|
|
517
|
+
self.fit(data)
|
|
518
|
+
return self.transform(data, return_dict=return_dict)
|
|
519
|
+
|
|
520
|
+
def save(self, filepath: str):
|
|
521
|
+
logger = logging.getLogger()
|
|
522
|
+
|
|
523
|
+
if not self.is_fitted:
|
|
524
|
+
logger.warning("Saving unfitted DataProcessor")
|
|
525
|
+
|
|
526
|
+
dir_path = os.path.dirname(filepath)
|
|
527
|
+
if dir_path and not os.path.exists(dir_path):
|
|
528
|
+
os.makedirs(dir_path, exist_ok=True)
|
|
529
|
+
logger.info(f"Created directory: {dir_path}")
|
|
530
|
+
|
|
531
|
+
state = {
|
|
532
|
+
'numeric_features': self.numeric_features,
|
|
533
|
+
'sparse_features': self.sparse_features,
|
|
534
|
+
'sequence_features': self.sequence_features,
|
|
535
|
+
'target_features': self.target_features,
|
|
536
|
+
'is_fitted': self.is_fitted,
|
|
537
|
+
'scalers': self.scalers,
|
|
538
|
+
'label_encoders': self.label_encoders,
|
|
539
|
+
'target_encoders': self.target_encoders
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
with open(filepath, 'wb') as f:
|
|
543
|
+
pickle.dump(state, f)
|
|
544
|
+
|
|
545
|
+
logger.info(f"DataProcessor saved to {filepath}")
|
|
546
|
+
|
|
547
|
+
@classmethod
|
|
548
|
+
def load(cls, filepath: str) -> 'DataProcessor':
|
|
549
|
+
logger = logging.getLogger()
|
|
550
|
+
|
|
551
|
+
with open(filepath, 'rb') as f:
|
|
552
|
+
state = pickle.load(f)
|
|
553
|
+
|
|
554
|
+
processor = cls()
|
|
555
|
+
processor.numeric_features = state['numeric_features']
|
|
556
|
+
processor.sparse_features = state['sparse_features']
|
|
557
|
+
processor.sequence_features = state['sequence_features']
|
|
558
|
+
processor.target_features = state['target_features']
|
|
559
|
+
processor.is_fitted = state['is_fitted']
|
|
560
|
+
processor.scalers = state['scalers']
|
|
561
|
+
processor.label_encoders = state['label_encoders']
|
|
562
|
+
processor.target_encoders = state['target_encoders']
|
|
563
|
+
|
|
564
|
+
logger.info(f"DataProcessor loaded from {filepath}")
|
|
565
|
+
return processor
|
|
566
|
+
|
|
567
|
+
def get_vocab_sizes(self) -> Dict[str, int]:
|
|
568
|
+
vocab_sizes = {}
|
|
569
|
+
|
|
570
|
+
for name, config in self.sparse_features.items():
|
|
571
|
+
vocab_sizes[name] = config.get('vocab_size', 0)
|
|
572
|
+
|
|
573
|
+
for name, config in self.sequence_features.items():
|
|
574
|
+
vocab_sizes[name] = config.get('vocab_size', 0)
|
|
575
|
+
|
|
576
|
+
return vocab_sizes
|
|
577
|
+
|
|
578
|
+
def summary(self):
|
|
579
|
+
"""Print a summary of the DataProcessor configuration."""
|
|
580
|
+
logger = logging.getLogger()
|
|
581
|
+
|
|
582
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
583
|
+
logger.info(colorize("DataProcessor Summary", color="bright_blue", bold=True))
|
|
584
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
585
|
+
|
|
586
|
+
logger.info("")
|
|
587
|
+
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
588
|
+
logger.info(colorize("-" * 80, color="cyan"))
|
|
589
|
+
|
|
590
|
+
if self.numeric_features:
|
|
591
|
+
logger.info(f"Dense Features ({len(self.numeric_features)}):")
|
|
592
|
+
|
|
593
|
+
max_name_len = max(len(name) for name in self.numeric_features.keys())
|
|
594
|
+
name_width = max(max_name_len, 10) + 2
|
|
595
|
+
|
|
596
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}")
|
|
597
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*15} {'-'*10}")
|
|
598
|
+
for i, (name, config) in enumerate(self.numeric_features.items(), 1):
|
|
599
|
+
scaler = config['scaler']
|
|
600
|
+
fill_na = config.get('fill_na_value', config.get('fill_na', 'N/A'))
|
|
601
|
+
logger.info(f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}")
|
|
602
|
+
|
|
603
|
+
if self.sparse_features:
|
|
604
|
+
logger.info(f"Sparse Features ({len(self.sparse_features)}):")
|
|
605
|
+
|
|
606
|
+
max_name_len = max(len(name) for name in self.sparse_features.keys())
|
|
607
|
+
name_width = max(max_name_len, 10) + 2
|
|
608
|
+
|
|
609
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}")
|
|
610
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12}")
|
|
611
|
+
for i, (name, config) in enumerate(self.sparse_features.items(), 1):
|
|
612
|
+
method = config['encode_method']
|
|
613
|
+
vocab_size = config.get('vocab_size', 'N/A')
|
|
614
|
+
hash_size = config.get('hash_size', 'N/A')
|
|
615
|
+
logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}")
|
|
616
|
+
|
|
617
|
+
if self.sequence_features:
|
|
618
|
+
logger.info(f"Sequence Features ({len(self.sequence_features)}):")
|
|
619
|
+
|
|
620
|
+
max_name_len = max(len(name) for name in self.sequence_features.keys())
|
|
621
|
+
name_width = max(max_name_len, 10) + 2
|
|
622
|
+
|
|
623
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}")
|
|
624
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}")
|
|
625
|
+
for i, (name, config) in enumerate(self.sequence_features.items(), 1):
|
|
626
|
+
method = config['encode_method']
|
|
627
|
+
vocab_size = config.get('vocab_size', 'N/A')
|
|
628
|
+
hash_size = config.get('hash_size', 'N/A')
|
|
629
|
+
max_len = config.get('max_len', 'N/A')
|
|
630
|
+
logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}")
|
|
631
|
+
|
|
632
|
+
logger.info("")
|
|
633
|
+
logger.info(colorize("[2] Target Configuration", color="cyan", bold=True))
|
|
634
|
+
logger.info(colorize("-" * 80, color="cyan"))
|
|
635
|
+
|
|
636
|
+
if self.target_features:
|
|
637
|
+
logger.info(f"Target Features ({len(self.target_features)}):")
|
|
638
|
+
|
|
639
|
+
max_name_len = max(len(name) for name in self.target_features.keys())
|
|
640
|
+
name_width = max(max_name_len, 10) + 2
|
|
641
|
+
|
|
642
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Type':>15}")
|
|
643
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*15}")
|
|
644
|
+
for i, (name, config) in enumerate(self.target_features.items(), 1):
|
|
645
|
+
target_type = config['target_type']
|
|
646
|
+
logger.info(f" {i:<4} {name:<{name_width}} {str(target_type):>15}")
|
|
647
|
+
else:
|
|
648
|
+
logger.info("No target features configured")
|
|
649
|
+
|
|
650
|
+
logger.info("")
|
|
651
|
+
logger.info(colorize("[3] Processor Status", color="cyan", bold=True))
|
|
652
|
+
logger.info(colorize("-" * 80, color="cyan"))
|
|
653
|
+
logger.info(f"Fitted: {self.is_fitted}")
|
|
654
|
+
logger.info(f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}")
|
|
655
|
+
logger.info(f" Dense Features: {len(self.numeric_features)}")
|
|
656
|
+
logger.info(f" Sparse Features: {len(self.sparse_features)}")
|
|
657
|
+
logger.info(f" Sequence Features: {len(self.sequence_features)}")
|
|
658
|
+
logger.info(f"Target Features: {len(self.target_features)}")
|
|
659
|
+
|
|
660
|
+
logger.info("")
|
|
661
|
+
logger.info("")
|
|
662
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
nextrec/loss/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from nextrec.loss.match_losses import (
|
|
2
|
+
BPRLoss,
|
|
3
|
+
HingeLoss,
|
|
4
|
+
TripletLoss,
|
|
5
|
+
SampledSoftmaxLoss,
|
|
6
|
+
CosineContrastiveLoss,
|
|
7
|
+
InfoNCELoss,
|
|
8
|
+
ListNetLoss,
|
|
9
|
+
ListMLELoss,
|
|
10
|
+
ApproxNDCGLoss,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from nextrec.loss.loss_utils import (
|
|
14
|
+
get_loss_fn,
|
|
15
|
+
validate_training_mode,
|
|
16
|
+
VALID_TASK_TYPES,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
# Match losses
|
|
21
|
+
'BPRLoss',
|
|
22
|
+
'HingeLoss',
|
|
23
|
+
'TripletLoss',
|
|
24
|
+
'SampledSoftmaxLoss',
|
|
25
|
+
'CosineContrastiveLoss',
|
|
26
|
+
'InfoNCELoss',
|
|
27
|
+
# Listwise losses
|
|
28
|
+
'ListNetLoss',
|
|
29
|
+
'ListMLELoss',
|
|
30
|
+
'ApproxNDCGLoss',
|
|
31
|
+
# Utilities
|
|
32
|
+
'get_loss_fn',
|
|
33
|
+
'validate_training_mode',
|
|
34
|
+
'VALID_TASK_TYPES',
|
|
35
|
+
]
|