nextrec 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/features.py +10 -23
  3. nextrec/basic/layers.py +18 -61
  4. nextrec/basic/loggers.py +71 -8
  5. nextrec/basic/metrics.py +55 -33
  6. nextrec/basic/model.py +287 -397
  7. nextrec/data/__init__.py +2 -2
  8. nextrec/data/data_utils.py +80 -4
  9. nextrec/data/dataloader.py +38 -59
  10. nextrec/data/preprocessor.py +38 -73
  11. nextrec/models/generative/hstu.py +1 -1
  12. nextrec/models/match/dssm.py +2 -2
  13. nextrec/models/match/dssm_v2.py +2 -2
  14. nextrec/models/match/mind.py +2 -2
  15. nextrec/models/match/sdm.py +2 -2
  16. nextrec/models/match/youtube_dnn.py +2 -2
  17. nextrec/models/multi_task/esmm.py +1 -1
  18. nextrec/models/multi_task/mmoe.py +1 -1
  19. nextrec/models/multi_task/ple.py +1 -1
  20. nextrec/models/multi_task/poso.py +1 -1
  21. nextrec/models/multi_task/share_bottom.py +1 -1
  22. nextrec/models/ranking/afm.py +1 -1
  23. nextrec/models/ranking/autoint.py +1 -1
  24. nextrec/models/ranking/dcn.py +1 -1
  25. nextrec/models/ranking/deepfm.py +1 -1
  26. nextrec/models/ranking/dien.py +1 -1
  27. nextrec/models/ranking/din.py +1 -1
  28. nextrec/models/ranking/fibinet.py +1 -1
  29. nextrec/models/ranking/fm.py +1 -1
  30. nextrec/models/ranking/masknet.py +2 -2
  31. nextrec/models/ranking/pnn.py +1 -1
  32. nextrec/models/ranking/widedeep.py +1 -1
  33. nextrec/models/ranking/xdeepfm.py +1 -1
  34. nextrec/utils/__init__.py +2 -1
  35. nextrec/utils/common.py +21 -2
  36. {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/METADATA +3 -3
  37. nextrec-0.3.4.dist-info/RECORD +57 -0
  38. nextrec-0.3.2.dist-info/RECORD +0 -57
  39. {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/WHEEL +0 -0
  40. {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
+ Checkpoint: edit on 02/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
  from __future__ import annotations
@@ -32,31 +33,11 @@ from nextrec.data.data_utils import (
32
33
  default_output_dir,
33
34
  )
34
35
  from nextrec.basic.session import resolve_save_path
35
- from nextrec.basic.features import FeatureSpecMixin
36
+ from nextrec.basic.features import FeatureSet
36
37
  from nextrec.__version__ import __version__
37
38
 
38
39
 
39
- class DataProcessor(FeatureSpecMixin):
40
- """DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
41
-
42
- Examples:
43
- >>> processor = DataProcessor()
44
- >>> processor.add_numeric_feature('age', scaler='standard')
45
- >>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
46
- >>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
47
- >>> processor.add_target('label', target_type='binary')
48
- >>>
49
- >>> # Fit and transform data
50
- >>> processor.fit(train_df)
51
- >>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
52
- >>>
53
- >>> # Save and load processor
54
- >>> processor.save('processor.pkl')
55
- >>> loaded_processor = DataProcessor.load('processor.pkl')
56
- >>>
57
- >>> # Get vocabulary sizes for embedding layers
58
- >>> vocab_sizes = processor.get_vocab_sizes()
59
- """
40
+ class DataProcessor(FeatureSet):
60
41
  def __init__(self):
61
42
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
62
43
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
@@ -70,7 +51,7 @@ class DataProcessor(FeatureSpecMixin):
70
51
  self.scalers: Dict[str, Any] = {}
71
52
  self.label_encoders: Dict[str, LabelEncoder] = {}
72
53
  self.target_encoders: Dict[str, Dict[str, int]] = {}
73
- self._set_target_id_config([], [])
54
+ self.set_target_id([], [])
74
55
 
75
56
  def add_numeric_feature(
76
57
  self,
@@ -129,12 +110,12 @@ class DataProcessor(FeatureSpecMixin):
129
110
  'target_type': target_type,
130
111
  'label_map': label_map
131
112
  }
132
- self._set_target_id_config(list(self.target_features.keys()), [])
113
+ self.set_target_id(list(self.target_features.keys()), [])
133
114
 
134
- def _hash_string(self, s: str, hash_size: int) -> int:
115
+ def hash_string(self, s: str, hash_size: int) -> int:
135
116
  return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
136
117
 
137
- def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
118
+ def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
138
119
  name = str(data.name)
139
120
  scaler_type = config['scaler']
140
121
  fill_na = config['fill_na']
@@ -163,7 +144,7 @@ class DataProcessor(FeatureSpecMixin):
163
144
  scaler.fit(values)
164
145
  self.scalers[name] = scaler
165
146
 
166
- def _process_numeric_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
147
+ def process_numeric_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
167
148
  logger = logging.getLogger()
168
149
  name = str(data.name)
169
150
  scaler_type = config['scaler']
@@ -183,7 +164,7 @@ class DataProcessor(FeatureSpecMixin):
183
164
  result = scaler.transform(values.reshape(-1, 1)).ravel()
184
165
  return result
185
166
 
186
- def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
167
+ def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
187
168
  name = str(data.name)
188
169
  encode_method = config['encode_method']
189
170
  fill_na = config['fill_na'] # <UNK>
@@ -196,7 +177,7 @@ class DataProcessor(FeatureSpecMixin):
196
177
  elif encode_method == 'hash':
197
178
  config['vocab_size'] = config['hash_size']
198
179
 
199
- def _process_sparse_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
180
+ def process_sparse_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
200
181
  name = str(data.name)
201
182
  encode_method = config['encode_method']
202
183
  fill_na = config['fill_na']
@@ -214,11 +195,11 @@ class DataProcessor(FeatureSpecMixin):
214
195
  return encoded.to_numpy()
215
196
  if encode_method == 'hash':
216
197
  hash_size = config['hash_size']
217
- hash_fn = self._hash_string
198
+ hash_fn = self.hash_string
218
199
  return np.fromiter((hash_fn(v, hash_size) for v in sparse_series.to_numpy()), dtype=np.int64, count=sparse_series.size,)
219
200
  return np.array([], dtype=np.int64)
220
201
 
221
- def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
202
+ def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
222
203
  name = str(data.name)
223
204
  encode_method = config['encode_method']
224
205
  separator = config['separator']
@@ -251,7 +232,7 @@ class DataProcessor(FeatureSpecMixin):
251
232
  elif encode_method == 'hash':
252
233
  config['vocab_size'] = config['hash_size']
253
234
 
254
- def _process_sequence_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
235
+ def process_sequence_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
255
236
  """Optimized sequence transform with preallocation and cached vocab map."""
256
237
  name = str(data.name)
257
238
  encode_method = config['encode_method']
@@ -275,7 +256,7 @@ class DataProcessor(FeatureSpecMixin):
275
256
  config['_class_to_idx'] = class_to_idx
276
257
  else:
277
258
  class_to_idx = None # type: ignore
278
- hash_fn = self._hash_string
259
+ hash_fn = self.hash_string
279
260
  hash_size = config.get('hash_size')
280
261
  for i, seq in enumerate(arr):
281
262
  # normalize sequence to a list of strings
@@ -300,11 +281,7 @@ class DataProcessor(FeatureSpecMixin):
300
281
  elif encode_method == 'hash':
301
282
  if hash_size is None:
302
283
  raise ValueError("hash_size must be set for hash encoding")
303
- encoded = [
304
- hash_fn(str(token), hash_size)
305
- for token in tokens
306
- if str(token).strip()
307
- ]
284
+ encoded = [hash_fn(str(token), hash_size) for token in tokens if str(token).strip()]
308
285
  else:
309
286
  encoded = []
310
287
  if not encoded:
@@ -314,7 +291,7 @@ class DataProcessor(FeatureSpecMixin):
314
291
  output[i, : len(encoded)] = encoded
315
292
  return output
316
293
 
317
- def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
294
+ def process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
318
295
  name = str(data.name)
319
296
  target_type = config['target_type']
320
297
  label_map = config.get('label_map')
@@ -333,7 +310,7 @@ class DataProcessor(FeatureSpecMixin):
333
310
  config['label_map'] = label_map
334
311
  self.target_encoders[name] = label_map
335
312
 
336
- def _process_target_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
313
+ def process_target_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
337
314
  logger = logging.getLogger()
338
315
  name = str(data.name)
339
316
  target_type = config.get('target_type')
@@ -354,13 +331,13 @@ class DataProcessor(FeatureSpecMixin):
354
331
  result.append(0)
355
332
  return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
356
333
 
357
- def _load_dataframe_from_path(self, path: str) -> pd.DataFrame:
334
+ def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
358
335
  """Load all data from a file or directory path into a single DataFrame."""
359
336
  file_paths, file_type = resolve_file_paths(path)
360
337
  frames = load_dataframes(file_paths, file_type)
361
338
  return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
362
339
 
363
- def _extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
340
+ def extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
364
341
  """Extract sequence tokens from a single value."""
365
342
  if value is None:
366
343
  return []
@@ -373,7 +350,7 @@ class DataProcessor(FeatureSpecMixin):
373
350
  return [str(v) for v in value]
374
351
  return [str(value)]
375
352
 
376
- def _fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
353
+ def fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
377
354
  """Fit processor statistics by streaming files to reduce memory usage."""
378
355
  logger = logging.getLogger()
379
356
  logger.info(colorize("Fitting DataProcessor (streaming path mode)...", color="cyan", bold=True))
@@ -432,7 +409,7 @@ class DataProcessor(FeatureSpecMixin):
432
409
  series = chunk[name]
433
410
  tokens = []
434
411
  for val in series:
435
- tokens.extend(self._extract_sequence_tokens(val, separator))
412
+ tokens.extend(self.extract_sequence_tokens(val, separator))
436
413
  seq_vocab[name].update(tokens)
437
414
 
438
415
  # target features
@@ -547,7 +524,7 @@ class DataProcessor(FeatureSpecMixin):
547
524
  logger.info(colorize("DataProcessor fitted successfully (streaming path mode)", color="green", bold=True))
548
525
  return self
549
526
 
550
- def _transform_in_memory(
527
+ def transform_in_memory(
551
528
  self,
552
529
  data: Union[pd.DataFrame, Dict[str, Any]],
553
530
  return_dict: bool,
@@ -580,7 +557,7 @@ class DataProcessor(FeatureSpecMixin):
580
557
  continue
581
558
  # Convert to Series for processing
582
559
  series_data = pd.Series(data_dict[name], name=name)
583
- processed = self._process_numeric_feature_transform(series_data, config)
560
+ processed = self.process_numeric_feature_transform(series_data, config)
584
561
  result_dict[name] = processed
585
562
 
586
563
  # process sparse features
@@ -589,7 +566,7 @@ class DataProcessor(FeatureSpecMixin):
589
566
  logger.warning(f"Sparse feature {name} not found in data")
590
567
  continue
591
568
  series_data = pd.Series(data_dict[name], name=name)
592
- processed = self._process_sparse_feature_transform(series_data, config)
569
+ processed = self.process_sparse_feature_transform(series_data, config)
593
570
  result_dict[name] = processed
594
571
 
595
572
  # process sequence features
@@ -598,7 +575,7 @@ class DataProcessor(FeatureSpecMixin):
598
575
  logger.warning(f"Sequence feature {name} not found in data")
599
576
  continue
600
577
  series_data = pd.Series(data_dict[name], name=name)
601
- processed = self._process_sequence_feature_transform(series_data, config)
578
+ processed = self.process_sequence_feature_transform(series_data, config)
602
579
  result_dict[name] = processed
603
580
 
604
581
  # process target features
@@ -607,10 +584,10 @@ class DataProcessor(FeatureSpecMixin):
607
584
  logger.warning(f"Target {name} not found in data")
608
585
  continue
609
586
  series_data = pd.Series(data_dict[name], name=name)
610
- processed = self._process_target_transform(series_data, config)
587
+ processed = self.process_target_transform(series_data, config)
611
588
  result_dict[name] = processed
612
589
 
613
- def _dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
590
+ def dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
614
591
  # Convert all arrays to Series/lists at once to avoid fragmentation
615
592
  columns_dict = {}
616
593
  for key, value in result.items():
@@ -628,7 +605,7 @@ class DataProcessor(FeatureSpecMixin):
628
605
  effective_format = save_format or "parquet"
629
606
  result_df = None
630
607
  if (not return_dict) or persist:
631
- result_df = _dict_to_dataframe(result_dict)
608
+ result_df = dict_to_dataframe(result_dict)
632
609
  if persist:
633
610
  if output_path is None:
634
611
  raise ValueError("output_path must be provided when persisting transformed data.")
@@ -648,7 +625,7 @@ class DataProcessor(FeatureSpecMixin):
648
625
  assert result_df is not None, "DataFrame is None after transform"
649
626
  return result_df
650
627
 
651
- def _transform_path(
628
+ def transform_path(
652
629
  self,
653
630
  input_path: str,
654
631
  output_path: Optional[str],
@@ -668,13 +645,7 @@ class DataProcessor(FeatureSpecMixin):
668
645
  saved_paths = []
669
646
  for file_path in tqdm.tqdm(file_paths, desc="Transforming files", unit="file"):
670
647
  df = read_table(file_path, file_type)
671
- transformed_df = self._transform_in_memory(
672
- df,
673
- return_dict=False,
674
- persist=False,
675
- save_format=None,
676
- output_path=None,
677
- )
648
+ transformed_df = self.transform_in_memory(df, return_dict=False, persist=False, save_format=None, output_path=None)
678
649
  assert isinstance(transformed_df, pd.DataFrame), "Expected DataFrame when return_dict=False"
679
650
  source_path = Path(file_path)
680
651
  target_file = output_root / f"{source_path.stem}.{target_format}"
@@ -694,9 +665,9 @@ class DataProcessor(FeatureSpecMixin):
694
665
  uses_robust = any(cfg.get("scaler") == "robust" for cfg in self.numeric_features.values())
695
666
  if uses_robust:
696
667
  logger.warning("Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited.")
697
- data = self._load_dataframe_from_path(path_str)
668
+ data = self.load_dataframe_from_path(path_str)
698
669
  else:
699
- return self._fit_from_path(path_str, chunk_size)
670
+ return self.fit_from_path(path_str, chunk_size)
700
671
  if isinstance(data, dict):
701
672
  data = pd.DataFrame(data)
702
673
  logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
@@ -704,22 +675,22 @@ class DataProcessor(FeatureSpecMixin):
704
675
  if name not in data.columns:
705
676
  logger.warning(f"Numeric feature {name} not found in data")
706
677
  continue
707
- self._process_numeric_feature_fit(data[name], config)
678
+ self.process_numeric_feature_fit(data[name], config)
708
679
  for name, config in self.sparse_features.items():
709
680
  if name not in data.columns:
710
681
  logger.warning(f"Sparse feature {name} not found in data")
711
682
  continue
712
- self._process_sparse_feature_fit(data[name], config)
683
+ self.process_sparse_feature_fit(data[name], config)
713
684
  for name, config in self.sequence_features.items():
714
685
  if name not in data.columns:
715
686
  logger.warning(f"Sequence feature {name} not found in data")
716
687
  continue
717
- self._process_sequence_feature_fit(data[name], config)
688
+ self.process_sequence_feature_fit(data[name], config)
718
689
  for name, config in self.target_features.items():
719
690
  if name not in data.columns:
720
691
  logger.warning(f"Target {name} not found in data")
721
692
  continue
722
- self._process_target_fit(data[name], config)
693
+ self.process_target_fit(data[name], config)
723
694
  self.is_fitted = True
724
695
  return self
725
696
 
@@ -735,14 +706,8 @@ class DataProcessor(FeatureSpecMixin):
735
706
  if isinstance(data, (str, os.PathLike)):
736
707
  if return_dict:
737
708
  raise ValueError("Path transform writes files only; set return_dict=False when passing a path.")
738
- return self._transform_path(str(data), output_path, save_format)
739
- return self._transform_in_memory(
740
- data=data,
741
- return_dict=return_dict,
742
- persist=output_path is not None,
743
- save_format=save_format,
744
- output_path=output_path,
745
- )
709
+ return self.transform_path(str(data), output_path, save_format)
710
+ return self.transform_in_memory(data=data, return_dict=return_dict, persist=output_path is not None, save_format=save_format, output_path=output_path)
746
711
 
747
712
  def fit_transform(
748
713
  self,
@@ -344,7 +344,7 @@ class HSTU(BaseModel):
344
344
  loss_params.setdefault("ignore_index", self.ignore_index)
345
345
 
346
346
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
347
- self._register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
347
+ self.register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
348
348
 
349
349
  def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
350
350
  """
@@ -143,11 +143,11 @@ class DSSM(BaseMatchModel):
143
143
  activation=dnn_activation
144
144
  )
145
145
 
146
- self._register_regularization_weights(
146
+ self.register_regularization_weights(
147
147
  embedding_attr='user_embedding',
148
148
  include_modules=['user_dnn']
149
149
  )
150
- self._register_regularization_weights(
150
+ self.register_regularization_weights(
151
151
  embedding_attr='item_embedding',
152
152
  include_modules=['item_dnn']
153
153
  )
@@ -134,11 +134,11 @@ class DSSM_v2(BaseMatchModel):
134
134
  activation=dnn_activation
135
135
  )
136
136
 
137
- self._register_regularization_weights(
137
+ self.register_regularization_weights(
138
138
  embedding_attr='user_embedding',
139
139
  include_modules=['user_dnn']
140
140
  )
141
- self._register_regularization_weights(
141
+ self.register_regularization_weights(
142
142
  embedding_attr='item_embedding',
143
143
  include_modules=['item_dnn']
144
144
  )
@@ -258,11 +258,11 @@ class MIND(BaseMatchModel):
258
258
  else:
259
259
  self.item_dnn = None
260
260
 
261
- self._register_regularization_weights(
261
+ self.register_regularization_weights(
262
262
  embedding_attr='user_embedding',
263
263
  include_modules=['capsule_network']
264
264
  )
265
- self._register_regularization_weights(
265
+ self.register_regularization_weights(
266
266
  embedding_attr='item_embedding',
267
267
  include_modules=['item_dnn'] if self.item_dnn else []
268
268
  )
@@ -176,11 +176,11 @@ class SDM(BaseMatchModel):
176
176
  else:
177
177
  self.item_dnn = None
178
178
 
179
- self._register_regularization_weights(
179
+ self.register_regularization_weights(
180
180
  embedding_attr='user_embedding',
181
181
  include_modules=['rnn', 'user_dnn']
182
182
  )
183
- self._register_regularization_weights(
183
+ self.register_regularization_weights(
184
184
  embedding_attr='item_embedding',
185
185
  include_modules=['item_dnn'] if self.item_dnn else []
186
186
  )
@@ -140,11 +140,11 @@ class YoutubeDNN(BaseMatchModel):
140
140
  activation=dnn_activation
141
141
  )
142
142
 
143
- self._register_regularization_weights(
143
+ self.register_regularization_weights(
144
144
  embedding_attr='user_embedding',
145
145
  include_modules=['user_dnn']
146
146
  )
147
- self._register_regularization_weights(
147
+ self.register_regularization_weights(
148
148
  embedding_attr='item_embedding',
149
149
  include_modules=['item_dnn']
150
150
  )
@@ -128,7 +128,7 @@ class ESMM(BaseModel):
128
128
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
129
129
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
130
130
  # Register regularization weights
131
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
131
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
132
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
133
133
 
134
134
  def forward(self, x):
@@ -146,7 +146,7 @@ class MMOE(BaseModel):
146
146
  self.towers.append(tower)
147
147
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
148
148
  # Register regularization weights
149
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
149
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
150
150
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
151
151
 
152
152
  def forward(self, x):
@@ -249,7 +249,7 @@ class PLE(BaseModel):
249
249
  self.towers.append(tower)
250
250
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
251
251
  # Register regularization weights
252
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
252
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
253
253
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
254
254
 
255
255
  def forward(self, x):
@@ -389,7 +389,7 @@ class POSO(BaseModel):
389
389
  self.tower_heads = None
390
390
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks,)
391
391
  include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
392
- self._register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
392
+ self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
393
393
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
394
394
 
395
395
  def forward(self, x):
@@ -122,7 +122,7 @@ class ShareBottom(BaseModel):
122
122
  self.towers.append(tower)
123
123
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
124
124
  # Register regularization weights
125
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
125
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
126
126
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
127
127
 
128
128
  def forward(self, x):
@@ -81,7 +81,7 @@ class AFM(BaseModel):
81
81
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
82
82
 
83
83
  # Register regularization weights
84
- self._register_regularization_weights(
84
+ self.register_regularization_weights(
85
85
  embedding_attr='embedding',
86
86
  include_modules=['linear', 'attention_linear', 'attention_p', 'output_projection']
87
87
  )
@@ -150,7 +150,7 @@ class AutoInt(BaseModel):
150
150
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
151
151
 
152
152
  # Register regularization weights
153
- self._register_regularization_weights(
153
+ self.register_regularization_weights(
154
154
  embedding_attr='embedding',
155
155
  include_modules=['projection_layers', 'attention_layers', 'fc']
156
156
  )
@@ -109,7 +109,7 @@ class DCN(BaseModel):
109
109
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
110
110
 
111
111
  # Register regularization weights
112
- self._register_regularization_weights(
112
+ self.register_regularization_weights(
113
113
  embedding_attr='embedding',
114
114
  include_modules=['cross_network', 'mlp', 'final_layer']
115
115
  )
@@ -107,7 +107,7 @@ class DeepFM(BaseModel):
107
107
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
108
108
 
109
109
  # Register regularization weights
110
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
110
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
111
111
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
112
112
 
113
113
  def forward(self, x):
@@ -237,7 +237,7 @@ class DIEN(BaseModel):
237
237
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
238
238
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
239
239
  # Register regularization weights
240
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj'])
240
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj'])
241
241
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
242
242
 
243
243
  def forward(self, x):
@@ -108,7 +108,7 @@ class DIN(BaseModel):
108
108
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
109
109
 
110
110
  # Register regularization weights
111
- self._register_regularization_weights(
111
+ self.register_regularization_weights(
112
112
  embedding_attr='embedding',
113
113
  include_modules=['attention', 'mlp', 'candidate_attention_proj']
114
114
  )
@@ -104,7 +104,7 @@ class FiBiNET(BaseModel):
104
104
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
105
105
 
106
106
  # Register regularization weights
107
- self._register_regularization_weights(
107
+ self.register_regularization_weights(
108
108
  embedding_attr='embedding',
109
109
  include_modules=['linear', 'senet', 'bilinear_standard', 'bilinear_senet', 'mlp']
110
110
  )
@@ -69,7 +69,7 @@ class FM(BaseModel):
69
69
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
70
70
 
71
71
  # Register regularization weights
72
- self._register_regularization_weights(
72
+ self.register_regularization_weights(
73
73
  embedding_attr='embedding',
74
74
  include_modules=['linear']
75
75
  )
@@ -234,10 +234,10 @@ class MaskNet(BaseModel):
234
234
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
235
235
 
236
236
  if self.model_type == "serial":
237
- self._register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
237
+ self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
238
238
  # serial
239
239
  else:
240
- self._register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
240
+ self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
241
241
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
242
242
 
243
243
  def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
@@ -91,7 +91,7 @@ class PNN(BaseModel):
91
91
  modules = ['mlp']
92
92
  if self.product_type == "outer":
93
93
  modules.append('kernel')
94
- self._register_regularization_weights(
94
+ self.register_regularization_weights(
95
95
  embedding_attr='embedding',
96
96
  include_modules=modules
97
97
  )
@@ -111,7 +111,7 @@ class WideDeep(BaseModel):
111
111
  self.mlp = MLP(input_dim=input_dim, **mlp_params)
112
112
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
113
113
  # Register regularization weights
114
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
114
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
115
115
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
116
116
 
117
117
  def forward(self, x):
@@ -121,7 +121,7 @@ class xDeepFM(BaseModel):
121
121
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
122
122
 
123
123
  # Register regularization weights
124
- self._register_regularization_weights(
124
+ self.register_regularization_weights(
125
125
  embedding_attr='embedding',
126
126
  include_modules=['linear', 'cin', 'mlp']
127
127
  )
nextrec/utils/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from .optimizer import get_optimizer, get_scheduler
2
2
  from .initializer import get_initializer
3
3
  from .embedding import get_auto_embedding_dim
4
- from .common import resolve_device
4
+ from .common import resolve_device, to_tensor
5
5
  from . import optimizer, initializer, embedding, common
6
6
 
7
7
  __all__ = [
@@ -10,6 +10,7 @@ __all__ = [
10
10
  'get_initializer',
11
11
  'get_auto_embedding_dim',
12
12
  'resolve_device',
13
+ 'to_tensor',
13
14
  'optimizer',
14
15
  'initializer',
15
16
  'embedding',