nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,662 @@
1
+ """
2
+ DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
+
4
+ Date: create on 13/11/2025
5
+ Author:
6
+ Yang Zhou, zyaztec@gmail.com
7
+ """
8
+
9
+ import os
10
+ import pandas as pd
11
+ import numpy as np
12
+ import pickle
13
+ import hashlib
14
+ import logging
15
+
16
+ from typing import Dict, Union, Optional, Literal, Any
17
+ from sklearn.preprocessing import (
18
+ StandardScaler,
19
+ MinMaxScaler,
20
+ RobustScaler,
21
+ MaxAbsScaler,
22
+ LabelEncoder
23
+ )
24
+
25
+
26
+ from nextrec.basic.loggers import setup_logger, colorize
27
+
28
+
29
+ class DataProcessor:
30
+ """DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
31
+
32
+ Examples:
33
+ >>> processor = DataProcessor()
34
+ >>> processor.add_numeric_feature('age', scaler='standard')
35
+ >>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
36
+ >>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
37
+ >>> processor.add_target('label', target_type='binary')
38
+ >>>
39
+ >>> # Fit and transform data
40
+ >>> processor.fit(train_df)
41
+ >>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
42
+ >>>
43
+ >>> # Save and load processor
44
+ >>> processor.save('processor.pkl')
45
+ >>> loaded_processor = DataProcessor.load('processor.pkl')
46
+ >>>
47
+ >>> # Get vocabulary sizes for embedding layers
48
+ >>> vocab_sizes = processor.get_vocab_sizes()
49
+ """
50
+ def __init__(self):
51
+ self.numeric_features: Dict[str, Dict[str, Any]] = {}
52
+ self.sparse_features: Dict[str, Dict[str, Any]] = {}
53
+ self.sequence_features: Dict[str, Dict[str, Any]] = {}
54
+ self.target_features: Dict[str, Dict[str, Any]] = {}
55
+
56
+ self.is_fitted = False
57
+ self._transform_summary_printed = False # Track if summary has been printed during transform
58
+
59
+ self.scalers: Dict[str, Any] = {}
60
+ self.label_encoders: Dict[str, LabelEncoder] = {}
61
+ self.target_encoders: Dict[str, Dict[str, int]] = {}
62
+
63
+ # Initialize logger if not already initialized
64
+ self._logger_initialized = False
65
+ if not logging.getLogger().hasHandlers():
66
+ setup_logger()
67
+ self._logger_initialized = True
68
+
69
+ def add_numeric_feature(
70
+ self,
71
+ name: str,
72
+ scaler: Optional[Literal['standard', 'minmax', 'robust', 'maxabs', 'log', 'none']] = 'standard',
73
+ fill_na: Optional[float] = None
74
+ ):
75
+ self.numeric_features[name] = {
76
+ 'scaler': scaler,
77
+ 'fill_na': fill_na
78
+ }
79
+
80
+ def add_sparse_feature(
81
+ self,
82
+ name: str,
83
+ encode_method: Literal['hash', 'label'] = 'label',
84
+ hash_size: Optional[int] = None,
85
+ fill_na: str = '<UNK>'
86
+ ):
87
+ if encode_method == 'hash' and hash_size is None:
88
+ raise ValueError("hash_size must be specified when encode_method='hash'")
89
+
90
+ self.sparse_features[name] = {
91
+ 'encode_method': encode_method,
92
+ 'hash_size': hash_size,
93
+ 'fill_na': fill_na
94
+ }
95
+
96
+ def add_sequence_feature(
97
+ self,
98
+ name: str,
99
+ encode_method: Literal['hash', 'label'] = 'label',
100
+ hash_size: Optional[int] = None,
101
+ max_len: Optional[int] = 50,
102
+ pad_value: int = 0,
103
+ truncate: Literal['pre', 'post'] = 'pre', # pre: keep last max_len items, post: keep first max_len items
104
+ separator: str = ','
105
+ ):
106
+
107
+ if encode_method == 'hash' and hash_size is None:
108
+ raise ValueError("hash_size must be specified when encode_method='hash'")
109
+
110
+ self.sequence_features[name] = {
111
+ 'encode_method': encode_method,
112
+ 'hash_size': hash_size,
113
+ 'max_len': max_len,
114
+ 'pad_value': pad_value,
115
+ 'truncate': truncate,
116
+ 'separator': separator
117
+ }
118
+
119
+ def add_target(
120
+ self,
121
+ name: str, # example: 'click'
122
+ target_type: Literal['binary', 'multiclass', 'regression'] = 'binary',
123
+ label_map: Optional[Dict[str, int]] = None # example: {'click': 1, 'no_click': 0}
124
+ ):
125
+ self.target_features[name] = {
126
+ 'target_type': target_type,
127
+ 'label_map': label_map
128
+ }
129
+
130
+ def _hash_string(self, s: str, hash_size: int) -> int:
131
+ return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
132
+
133
+ def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
134
+
135
+ name = str(data.name)
136
+ scaler_type = config['scaler']
137
+ fill_na = config['fill_na']
138
+
139
+ if data.isna().any():
140
+ if fill_na is None:
141
+ # Default use mean value to fill missing values for numeric features
142
+ fill_na = data.mean()
143
+ config['fill_na_value'] = fill_na
144
+
145
+ if scaler_type == 'standard':
146
+ scaler = StandardScaler()
147
+ elif scaler_type == 'minmax':
148
+ scaler = MinMaxScaler()
149
+ elif scaler_type == 'robust':
150
+ scaler = RobustScaler()
151
+ elif scaler_type == 'maxabs':
152
+ scaler = MaxAbsScaler()
153
+ elif scaler_type == 'log':
154
+ scaler = None
155
+ elif scaler_type == 'none':
156
+ scaler = None
157
+ else:
158
+ raise ValueError(f"Unknown scaler type: {scaler_type}")
159
+
160
+ if scaler is not None and scaler_type != 'log':
161
+ filled_data = data.fillna(config.get('fill_na_value', 0))
162
+ values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
163
+ scaler.fit(values)
164
+ self.scalers[name] = scaler
165
+
166
+ def _process_numeric_feature_transform(
167
+ self,
168
+ data: pd.Series,
169
+ config: Dict[str, Any]
170
+ ) -> np.ndarray:
171
+ logger = logging.getLogger()
172
+
173
+ name = str(data.name)
174
+ scaler_type = config['scaler']
175
+ fill_na_value = config.get('fill_na_value', 0)
176
+
177
+ filled_data = data.fillna(fill_na_value)
178
+ values = np.array(filled_data.values, dtype=np.float64)
179
+
180
+ if scaler_type == 'log':
181
+ result = np.log1p(np.maximum(values, 0))
182
+ elif scaler_type == 'none':
183
+ result = values
184
+ else:
185
+ scaler = self.scalers.get(name)
186
+ if scaler is None:
187
+ logger.warning(f"Scaler for {name} not fitted, returning original values")
188
+ result = values
189
+ else:
190
+ result = scaler.transform(values.reshape(-1, 1)).ravel()
191
+
192
+ return result
193
+
194
+ def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
195
+
196
+ name = str(data.name)
197
+ encode_method = config['encode_method']
198
+ fill_na = config['fill_na'] # <UNK>
199
+
200
+ filled_data = data.fillna(fill_na).astype(str)
201
+
202
+ if encode_method == 'label':
203
+ le = LabelEncoder()
204
+ le.fit(filled_data)
205
+ self.label_encoders[name] = le
206
+ config['vocab_size'] = len(le.classes_)
207
+ elif encode_method == 'hash':
208
+ config['vocab_size'] = config['hash_size']
209
+
210
+ def _process_sparse_feature_transform(
211
+ self,
212
+ data: pd.Series,
213
+ config: Dict[str, Any]
214
+ ) -> np.ndarray:
215
+
216
+ name = str(data.name)
217
+ encode_method = config['encode_method']
218
+ fill_na = config['fill_na']
219
+
220
+ filled_data = data.fillna(fill_na).astype(str)
221
+
222
+ if encode_method == 'label':
223
+ le = self.label_encoders.get(name)
224
+ if le is None:
225
+ raise ValueError(f"LabelEncoder for {name} not fitted")
226
+
227
+ result = []
228
+ for val in filled_data:
229
+ if val in le.classes_:
230
+ encoded = le.transform([val])
231
+ result.append(int(encoded[0]))
232
+ else:
233
+ result.append(0)
234
+ return np.array(result, dtype=np.int64)
235
+
236
+ elif encode_method == 'hash':
237
+ hash_size = config['hash_size']
238
+ return np.array([self._hash_string(val, hash_size) for val in filled_data], dtype=np.int64)
239
+
240
+ return np.array([], dtype=np.int64)
241
+
242
+ def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
243
+
244
+ name = str(data.name)
245
+ encode_method = config['encode_method']
246
+ separator = config['separator']
247
+
248
+ if encode_method == 'label':
249
+ all_tokens = set()
250
+ for seq in data:
251
+ # Skip None, np.nan, and empty strings
252
+ if seq is None:
253
+ continue
254
+ if isinstance(seq, (float, np.floating)) and np.isnan(seq):
255
+ continue
256
+ if isinstance(seq, str) and seq.strip() == '':
257
+ continue
258
+
259
+ if isinstance(seq, str):
260
+ tokens = seq.split(separator)
261
+ elif isinstance(seq, (list, tuple)):
262
+ tokens = [str(t) for t in seq]
263
+ elif isinstance(seq, np.ndarray):
264
+ tokens = [str(t) for t in seq.tolist()]
265
+ else:
266
+ continue
267
+
268
+ all_tokens.update(tokens)
269
+
270
+ if len(all_tokens) == 0:
271
+ all_tokens.add('<PAD>')
272
+
273
+ le = LabelEncoder()
274
+ le.fit(list(all_tokens))
275
+ self.label_encoders[name] = le
276
+ config['vocab_size'] = len(le.classes_)
277
+ elif encode_method == 'hash':
278
+ config['vocab_size'] = config['hash_size']
279
+
280
+ def _process_sequence_feature_transform(
281
+ self,
282
+ data: pd.Series,
283
+ config: Dict[str, Any]
284
+ ) -> np.ndarray:
285
+ name = str(data.name)
286
+ encode_method = config['encode_method']
287
+ max_len = config['max_len']
288
+ pad_value = config['pad_value']
289
+ truncate = config['truncate']
290
+ separator = config['separator']
291
+
292
+ result = []
293
+ for seq in data:
294
+ tokens = []
295
+
296
+ if seq is None:
297
+ tokens = []
298
+ elif isinstance(seq, (float, np.floating)) and np.isnan(seq):
299
+ tokens = []
300
+ elif isinstance(seq, str):
301
+ if seq.strip() == '':
302
+ tokens = []
303
+ else:
304
+ tokens = seq.split(separator)
305
+ elif isinstance(seq, (list, tuple)):
306
+ tokens = [str(t) for t in seq]
307
+ elif isinstance(seq, np.ndarray):
308
+ tokens = [str(t) for t in seq.tolist()]
309
+ else:
310
+ tokens = []
311
+
312
+ if encode_method == 'label':
313
+ le = self.label_encoders.get(name)
314
+ if le is None:
315
+ raise ValueError(f"LabelEncoder for {name} not fitted")
316
+
317
+ encoded = []
318
+ for token in tokens:
319
+ token_str = str(token).strip()
320
+ if token_str and token_str in le.classes_:
321
+ encoded_val = le.transform([token_str])
322
+ encoded.append(int(encoded_val[0]))
323
+ else:
324
+ encoded.append(0) # UNK
325
+ elif encode_method == 'hash':
326
+ hash_size = config['hash_size']
327
+ encoded = [self._hash_string(str(token), hash_size) for token in tokens if str(token).strip()]
328
+ else:
329
+ encoded = []
330
+
331
+ if len(encoded) > max_len:
332
+ if truncate == 'pre': # keep last max_len items
333
+ encoded = encoded[-max_len:]
334
+ else: # keep first max_len items
335
+ encoded = encoded[:max_len]
336
+ elif len(encoded) < max_len:
337
+ padding = [pad_value] * (max_len - len(encoded))
338
+ encoded = encoded + padding
339
+
340
+ result.append(encoded)
341
+
342
+ return np.array(result, dtype=np.int64)
343
+
344
+ def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
345
+ name = str(data.name)
346
+ target_type = config['target_type']
347
+ label_map = config['label_map']
348
+
349
+ if target_type in ['binary', 'multiclass']:
350
+ if label_map is None:
351
+ unique_values = data.dropna().unique()
352
+ sorted_values = sorted(unique_values)
353
+
354
+ try:
355
+ int_values = [int(v) for v in sorted_values]
356
+ if int_values == list(range(len(int_values))):
357
+ label_map = {str(val): int(val) for val in sorted_values}
358
+ else:
359
+ label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
360
+ except (ValueError, TypeError):
361
+ label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
362
+
363
+ config['label_map'] = label_map
364
+
365
+ self.target_encoders[name] = label_map
366
+
367
+ def _process_target_transform(
368
+ self,
369
+ data: pd.Series,
370
+ config: Dict[str, Any]
371
+ ) -> np.ndarray:
372
+ logger = logging.getLogger()
373
+
374
+ name = str(data.name)
375
+ target_type = config['target_type']
376
+
377
+ if target_type == 'regression':
378
+ values = np.array(data.values, dtype=np.float32)
379
+ return values
380
+ else:
381
+ label_map = self.target_encoders.get(name)
382
+ if label_map is None:
383
+ raise ValueError(f"Target encoder for {name} not fitted")
384
+
385
+ result = []
386
+ for val in data:
387
+ str_val = str(val)
388
+ if str_val in label_map:
389
+ result.append(label_map[str_val])
390
+ else:
391
+ logger.warning(f"Unknown target value: {val}, mapping to 0")
392
+ result.append(0)
393
+
394
+ return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
395
+
396
+ # fit is nothing but registering the statistics from data so that we can transform the data later
397
+ def fit(self, data: Union[pd.DataFrame, Dict[str, Any]]):
398
+ logger = logging.getLogger()
399
+
400
+ if isinstance(data, dict):
401
+ data = pd.DataFrame(data)
402
+
403
+ logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
404
+
405
+ for name, config in self.numeric_features.items():
406
+ if name not in data.columns:
407
+ logger.warning(f"Numeric feature {name} not found in data")
408
+ continue
409
+ self._process_numeric_feature_fit(data[name], config)
410
+
411
+ for name, config in self.sparse_features.items():
412
+ if name not in data.columns:
413
+ logger.warning(f"Sparse feature {name} not found in data")
414
+ continue
415
+ self._process_sparse_feature_fit(data[name], config)
416
+
417
+ for name, config in self.sequence_features.items():
418
+ if name not in data.columns:
419
+ logger.warning(f"Sequence feature {name} not found in data")
420
+ continue
421
+ self._process_sequence_feature_fit(data[name], config)
422
+
423
+ for name, config in self.target_features.items():
424
+ if name not in data.columns:
425
+ logger.warning(f"Target {name} not found in data")
426
+ continue
427
+ self._process_target_fit(data[name], config)
428
+
429
+ self.is_fitted = True
430
+ logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
431
+ return self
432
+
433
+ def transform(
434
+ self,
435
+ data: Union[pd.DataFrame, Dict[str, Any]],
436
+ return_dict: bool = True
437
+ ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
438
+ logger = logging.getLogger()
439
+
440
+
441
+ if not self.is_fitted:
442
+ raise ValueError("DataProcessor must be fitted before transform")
443
+
444
+ # Convert input to dict format for unified processing
445
+ if isinstance(data, pd.DataFrame):
446
+ data_dict = {col: data[col] for col in data.columns}
447
+ elif isinstance(data, dict):
448
+ data_dict = data
449
+ else:
450
+ raise ValueError(f"Unsupported data type: {type(data)}")
451
+
452
+ result_dict = {}
453
+ for key, value in data_dict.items():
454
+ if isinstance(value, pd.Series):
455
+ result_dict[key] = value.values
456
+ elif isinstance(value, np.ndarray):
457
+ result_dict[key] = value
458
+ else:
459
+ result_dict[key] = np.array(value)
460
+
461
+ # process numeric features
462
+ for name, config in self.numeric_features.items():
463
+ if name not in data_dict:
464
+ logger.warning(f"Numeric feature {name} not found in data")
465
+ continue
466
+ # Convert to Series for processing
467
+ series_data = pd.Series(data_dict[name], name=name)
468
+ processed = self._process_numeric_feature_transform(series_data, config)
469
+ result_dict[name] = processed
470
+
471
+ # process sparse features
472
+ for name, config in self.sparse_features.items():
473
+ if name not in data_dict:
474
+ logger.warning(f"Sparse feature {name} not found in data")
475
+ continue
476
+ series_data = pd.Series(data_dict[name], name=name)
477
+ processed = self._process_sparse_feature_transform(series_data, config)
478
+ result_dict[name] = processed
479
+
480
+ # process sequence features
481
+ for name, config in self.sequence_features.items():
482
+ if name not in data_dict:
483
+ logger.warning(f"Sequence feature {name} not found in data")
484
+ continue
485
+ series_data = pd.Series(data_dict[name], name=name)
486
+ processed = self._process_sequence_feature_transform(series_data, config)
487
+ result_dict[name] = processed
488
+
489
+ # process target features
490
+ for name, config in self.target_features.items():
491
+ if name not in data_dict:
492
+ logger.warning(f"Target {name} not found in data")
493
+ continue
494
+ series_data = pd.Series(data_dict[name], name=name)
495
+ processed = self._process_target_transform(series_data, config)
496
+ result_dict[name] = processed
497
+
498
+ if return_dict:
499
+ return result_dict
500
+ else:
501
+ # Convert all arrays to Series/lists at once to avoid fragmentation
502
+ columns_dict = {}
503
+ for key, value in result_dict.items():
504
+ if key in self.sequence_features:
505
+ columns_dict[key] = [list(seq) for seq in value]
506
+ else:
507
+ columns_dict[key] = value
508
+
509
+ result_df = pd.DataFrame(columns_dict)
510
+ return result_df
511
+
512
+ def fit_transform(
513
+ self,
514
+ data: Union[pd.DataFrame, Dict[str, Any]],
515
+ return_dict: bool = True
516
+ ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
517
+ self.fit(data)
518
+ return self.transform(data, return_dict=return_dict)
519
+
520
+ def save(self, filepath: str):
521
+ logger = logging.getLogger()
522
+
523
+ if not self.is_fitted:
524
+ logger.warning("Saving unfitted DataProcessor")
525
+
526
+ dir_path = os.path.dirname(filepath)
527
+ if dir_path and not os.path.exists(dir_path):
528
+ os.makedirs(dir_path, exist_ok=True)
529
+ logger.info(f"Created directory: {dir_path}")
530
+
531
+ state = {
532
+ 'numeric_features': self.numeric_features,
533
+ 'sparse_features': self.sparse_features,
534
+ 'sequence_features': self.sequence_features,
535
+ 'target_features': self.target_features,
536
+ 'is_fitted': self.is_fitted,
537
+ 'scalers': self.scalers,
538
+ 'label_encoders': self.label_encoders,
539
+ 'target_encoders': self.target_encoders
540
+ }
541
+
542
+ with open(filepath, 'wb') as f:
543
+ pickle.dump(state, f)
544
+
545
+ logger.info(f"DataProcessor saved to {filepath}")
546
+
547
+ @classmethod
548
+ def load(cls, filepath: str) -> 'DataProcessor':
549
+ logger = logging.getLogger()
550
+
551
+ with open(filepath, 'rb') as f:
552
+ state = pickle.load(f)
553
+
554
+ processor = cls()
555
+ processor.numeric_features = state['numeric_features']
556
+ processor.sparse_features = state['sparse_features']
557
+ processor.sequence_features = state['sequence_features']
558
+ processor.target_features = state['target_features']
559
+ processor.is_fitted = state['is_fitted']
560
+ processor.scalers = state['scalers']
561
+ processor.label_encoders = state['label_encoders']
562
+ processor.target_encoders = state['target_encoders']
563
+
564
+ logger.info(f"DataProcessor loaded from {filepath}")
565
+ return processor
566
+
567
+ def get_vocab_sizes(self) -> Dict[str, int]:
568
+ vocab_sizes = {}
569
+
570
+ for name, config in self.sparse_features.items():
571
+ vocab_sizes[name] = config.get('vocab_size', 0)
572
+
573
+ for name, config in self.sequence_features.items():
574
+ vocab_sizes[name] = config.get('vocab_size', 0)
575
+
576
+ return vocab_sizes
577
+
578
+ def summary(self):
579
+ """Print a summary of the DataProcessor configuration."""
580
+ logger = logging.getLogger()
581
+
582
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))
583
+ logger.info(colorize("DataProcessor Summary", color="bright_blue", bold=True))
584
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))
585
+
586
+ logger.info("")
587
+ logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
588
+ logger.info(colorize("-" * 80, color="cyan"))
589
+
590
+ if self.numeric_features:
591
+ logger.info(f"Dense Features ({len(self.numeric_features)}):")
592
+
593
+ max_name_len = max(len(name) for name in self.numeric_features.keys())
594
+ name_width = max(max_name_len, 10) + 2
595
+
596
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}")
597
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*15} {'-'*10}")
598
+ for i, (name, config) in enumerate(self.numeric_features.items(), 1):
599
+ scaler = config['scaler']
600
+ fill_na = config.get('fill_na_value', config.get('fill_na', 'N/A'))
601
+ logger.info(f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}")
602
+
603
+ if self.sparse_features:
604
+ logger.info(f"Sparse Features ({len(self.sparse_features)}):")
605
+
606
+ max_name_len = max(len(name) for name in self.sparse_features.keys())
607
+ name_width = max(max_name_len, 10) + 2
608
+
609
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}")
610
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12}")
611
+ for i, (name, config) in enumerate(self.sparse_features.items(), 1):
612
+ method = config['encode_method']
613
+ vocab_size = config.get('vocab_size', 'N/A')
614
+ hash_size = config.get('hash_size', 'N/A')
615
+ logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}")
616
+
617
+ if self.sequence_features:
618
+ logger.info(f"Sequence Features ({len(self.sequence_features)}):")
619
+
620
+ max_name_len = max(len(name) for name in self.sequence_features.keys())
621
+ name_width = max(max_name_len, 10) + 2
622
+
623
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}")
624
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}")
625
+ for i, (name, config) in enumerate(self.sequence_features.items(), 1):
626
+ method = config['encode_method']
627
+ vocab_size = config.get('vocab_size', 'N/A')
628
+ hash_size = config.get('hash_size', 'N/A')
629
+ max_len = config.get('max_len', 'N/A')
630
+ logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}")
631
+
632
+ logger.info("")
633
+ logger.info(colorize("[2] Target Configuration", color="cyan", bold=True))
634
+ logger.info(colorize("-" * 80, color="cyan"))
635
+
636
+ if self.target_features:
637
+ logger.info(f"Target Features ({len(self.target_features)}):")
638
+
639
+ max_name_len = max(len(name) for name in self.target_features.keys())
640
+ name_width = max(max_name_len, 10) + 2
641
+
642
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Type':>15}")
643
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*15}")
644
+ for i, (name, config) in enumerate(self.target_features.items(), 1):
645
+ target_type = config['target_type']
646
+ logger.info(f" {i:<4} {name:<{name_width}} {str(target_type):>15}")
647
+ else:
648
+ logger.info("No target features configured")
649
+
650
+ logger.info("")
651
+ logger.info(colorize("[3] Processor Status", color="cyan", bold=True))
652
+ logger.info(colorize("-" * 80, color="cyan"))
653
+ logger.info(f"Fitted: {self.is_fitted}")
654
+ logger.info(f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}")
655
+ logger.info(f" Dense Features: {len(self.numeric_features)}")
656
+ logger.info(f" Sparse Features: {len(self.sparse_features)}")
657
+ logger.info(f" Sequence Features: {len(self.sequence_features)}")
658
+ logger.info(f"Target Features: {len(self.target_features)}")
659
+
660
+ logger.info("")
661
+ logger.info("")
662
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))
@@ -0,0 +1,35 @@
1
+ from nextrec.loss.match_losses import (
2
+ BPRLoss,
3
+ HingeLoss,
4
+ TripletLoss,
5
+ SampledSoftmaxLoss,
6
+ CosineContrastiveLoss,
7
+ InfoNCELoss,
8
+ ListNetLoss,
9
+ ListMLELoss,
10
+ ApproxNDCGLoss,
11
+ )
12
+
13
+ from nextrec.loss.loss_utils import (
14
+ get_loss_fn,
15
+ validate_training_mode,
16
+ VALID_TASK_TYPES,
17
+ )
18
+
19
+ __all__ = [
20
+ # Match losses
21
+ 'BPRLoss',
22
+ 'HingeLoss',
23
+ 'TripletLoss',
24
+ 'SampledSoftmaxLoss',
25
+ 'CosineContrastiveLoss',
26
+ 'InfoNCELoss',
27
+ # Listwise losses
28
+ 'ListNetLoss',
29
+ 'ListMLELoss',
30
+ 'ApproxNDCGLoss',
31
+ # Utilities
32
+ 'get_loss_fn',
33
+ 'validate_training_mode',
34
+ 'VALID_TASK_TYPES',
35
+ ]