nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@ Date: create on 13/11/2025
5
5
  Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
+
8
9
  from __future__ import annotations
9
10
  import os
10
11
  import pickle
@@ -16,13 +17,25 @@ import pandas as pd
16
17
  import tqdm
17
18
  from pathlib import Path
18
19
  from typing import Dict, Union, Optional, Literal, Any
19
- from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, MaxAbsScaler, LabelEncoder
20
+ from sklearn.preprocessing import (
21
+ StandardScaler,
22
+ MinMaxScaler,
23
+ RobustScaler,
24
+ MaxAbsScaler,
25
+ LabelEncoder,
26
+ )
20
27
 
21
28
 
22
29
  from nextrec.basic.features import FeatureSet
23
30
  from nextrec.basic.loggers import colorize
24
31
  from nextrec.basic.session import resolve_save_path
25
- from nextrec.utils.file import resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
32
+ from nextrec.utils.file import (
33
+ resolve_file_paths,
34
+ iter_file_chunks,
35
+ read_table,
36
+ load_dataframes,
37
+ default_output_dir,
38
+ )
26
39
 
27
40
  from nextrec.__version__ import __version__
28
41
 
@@ -36,164 +49,179 @@ class DataProcessor(FeatureSet):
36
49
  self.version = __version__
37
50
 
38
51
  self.is_fitted = False
39
- self._transform_summary_printed = False # Track if summary has been printed during transform
40
-
52
+ self._transform_summary_printed = (
53
+ False # Track if summary has been printed during transform
54
+ )
55
+
41
56
  self.scalers: Dict[str, Any] = {}
42
57
  self.label_encoders: Dict[str, LabelEncoder] = {}
43
58
  self.target_encoders: Dict[str, Dict[str, int]] = {}
44
59
  self.set_target_id([], [])
45
60
 
46
61
  def add_numeric_feature(
47
- self,
48
- name: str,
49
- scaler: Optional[Literal['standard', 'minmax', 'robust', 'maxabs', 'log', 'none']] = 'standard',
50
- fill_na: Optional[float] = None
62
+ self,
63
+ name: str,
64
+ scaler: Optional[
65
+ Literal["standard", "minmax", "robust", "maxabs", "log", "none"]
66
+ ] = "standard",
67
+ fill_na: Optional[float] = None,
51
68
  ):
52
- self.numeric_features[name] = {
53
- 'scaler': scaler,
54
- 'fill_na': fill_na
55
- }
56
-
69
+ self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
70
+
57
71
  def add_sparse_feature(
58
- self,
59
- name: str,
60
- encode_method: Literal['hash', 'label'] = 'label',
72
+ self,
73
+ name: str,
74
+ encode_method: Literal["hash", "label"] = "label",
61
75
  hash_size: Optional[int] = None,
62
- fill_na: str = '<UNK>'
76
+ fill_na: str = "<UNK>",
63
77
  ):
64
- if encode_method == 'hash' and hash_size is None:
78
+ if encode_method == "hash" and hash_size is None:
65
79
  raise ValueError("hash_size must be specified when encode_method='hash'")
66
80
  self.sparse_features[name] = {
67
- 'encode_method': encode_method,
68
- 'hash_size': hash_size,
69
- 'fill_na': fill_na
81
+ "encode_method": encode_method,
82
+ "hash_size": hash_size,
83
+ "fill_na": fill_na,
70
84
  }
71
-
85
+
72
86
  def add_sequence_feature(
73
- self,
87
+ self,
74
88
  name: str,
75
- encode_method: Literal['hash', 'label'] = 'label',
89
+ encode_method: Literal["hash", "label"] = "label",
76
90
  hash_size: Optional[int] = None,
77
91
  max_len: Optional[int] = 50,
78
92
  pad_value: int = 0,
79
- truncate: Literal['pre', 'post'] = 'pre', # pre: keep last max_len items, post: keep first max_len items
80
- separator: str = ','
93
+ truncate: Literal[
94
+ "pre", "post"
95
+ ] = "pre", # pre: keep last max_len items, post: keep first max_len items
96
+ separator: str = ",",
81
97
  ):
82
- if encode_method == 'hash' and hash_size is None:
98
+ if encode_method == "hash" and hash_size is None:
83
99
  raise ValueError("hash_size must be specified when encode_method='hash'")
84
100
  self.sequence_features[name] = {
85
- 'encode_method': encode_method,
86
- 'hash_size': hash_size,
87
- 'max_len': max_len,
88
- 'pad_value': pad_value,
89
- 'truncate': truncate,
90
- 'separator': separator
101
+ "encode_method": encode_method,
102
+ "hash_size": hash_size,
103
+ "max_len": max_len,
104
+ "pad_value": pad_value,
105
+ "truncate": truncate,
106
+ "separator": separator,
91
107
  }
92
-
108
+
93
109
  def add_target(
94
- self,
95
- name: str, # example: 'click'
96
- target_type: Literal['binary', 'multiclass', 'regression'] = 'binary',
97
- label_map: Optional[Dict[str, int]] = None # example: {'click': 1, 'no_click': 0}
110
+ self,
111
+ name: str, # example: 'click'
112
+ target_type: Literal["binary", "multiclass", "regression"] = "binary",
113
+ label_map: Optional[
114
+ Dict[str, int]
115
+ ] = None, # example: {'click': 1, 'no_click': 0}
98
116
  ):
99
117
  self.target_features[name] = {
100
- 'target_type': target_type,
101
- 'label_map': label_map
118
+ "target_type": target_type,
119
+ "label_map": label_map,
102
120
  }
103
121
  self.set_target_id(list(self.target_features.keys()), [])
104
-
122
+
105
123
  def hash_string(self, s: str, hash_size: int) -> int:
106
124
  return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
107
-
125
+
108
126
  def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
109
127
  name = str(data.name)
110
- scaler_type = config['scaler']
111
- fill_na = config['fill_na']
128
+ scaler_type = config["scaler"]
129
+ fill_na = config["fill_na"]
112
130
  if data.isna().any():
113
131
  if fill_na is None:
114
132
  # Default use mean value to fill missing values for numeric features
115
133
  fill_na = data.mean()
116
- config['fill_na_value'] = fill_na
117
- if scaler_type == 'standard':
134
+ config["fill_na_value"] = fill_na
135
+ if scaler_type == "standard":
118
136
  scaler = StandardScaler()
119
- elif scaler_type == 'minmax':
137
+ elif scaler_type == "minmax":
120
138
  scaler = MinMaxScaler()
121
- elif scaler_type == 'robust':
139
+ elif scaler_type == "robust":
122
140
  scaler = RobustScaler()
123
- elif scaler_type == 'maxabs':
141
+ elif scaler_type == "maxabs":
124
142
  scaler = MaxAbsScaler()
125
- elif scaler_type == 'log':
126
- scaler = None
127
- elif scaler_type == 'none':
143
+ elif scaler_type == "log":
144
+ scaler = None
145
+ elif scaler_type == "none":
128
146
  scaler = None
129
147
  else:
130
148
  raise ValueError(f"Unknown scaler type: {scaler_type}")
131
- if scaler is not None and scaler_type != 'log':
132
- filled_data = data.fillna(config.get('fill_na_value', 0))
149
+ if scaler is not None and scaler_type != "log":
150
+ filled_data = data.fillna(config.get("fill_na_value", 0))
133
151
  values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
134
152
  scaler.fit(values)
135
153
  self.scalers[name] = scaler
136
-
137
- def process_numeric_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
154
+
155
+ def process_numeric_feature_transform(
156
+ self, data: pd.Series, config: Dict[str, Any]
157
+ ) -> np.ndarray:
138
158
  logger = logging.getLogger()
139
159
  name = str(data.name)
140
- scaler_type = config['scaler']
141
- fill_na_value = config.get('fill_na_value', 0)
160
+ scaler_type = config["scaler"]
161
+ fill_na_value = config.get("fill_na_value", 0)
142
162
  filled_data = data.fillna(fill_na_value)
143
163
  values = np.array(filled_data.values, dtype=np.float64)
144
- if scaler_type == 'log':
164
+ if scaler_type == "log":
145
165
  result = np.log1p(np.maximum(values, 0))
146
- elif scaler_type == 'none':
166
+ elif scaler_type == "none":
147
167
  result = values
148
168
  else:
149
169
  scaler = self.scalers.get(name)
150
170
  if scaler is None:
151
- logger.warning(f"Scaler for {name} not fitted, returning original values")
171
+ logger.warning(
172
+ f"Scaler for {name} not fitted, returning original values"
173
+ )
152
174
  result = values
153
175
  else:
154
176
  result = scaler.transform(values.reshape(-1, 1)).ravel()
155
177
  return result
156
-
178
+
157
179
  def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
158
180
  name = str(data.name)
159
- encode_method = config['encode_method']
160
- fill_na = config['fill_na'] # <UNK>
181
+ encode_method = config["encode_method"]
182
+ fill_na = config["fill_na"] # <UNK>
161
183
  filled_data = data.fillna(fill_na).astype(str)
162
- if encode_method == 'label':
184
+ if encode_method == "label":
163
185
  le = LabelEncoder()
164
186
  le.fit(filled_data)
165
187
  self.label_encoders[name] = le
166
- config['vocab_size'] = len(le.classes_)
167
- elif encode_method == 'hash':
168
- config['vocab_size'] = config['hash_size']
169
-
170
- def process_sparse_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
188
+ config["vocab_size"] = len(le.classes_)
189
+ elif encode_method == "hash":
190
+ config["vocab_size"] = config["hash_size"]
191
+
192
+ def process_sparse_feature_transform(
193
+ self, data: pd.Series, config: Dict[str, Any]
194
+ ) -> np.ndarray:
171
195
  name = str(data.name)
172
- encode_method = config['encode_method']
173
- fill_na = config['fill_na']
196
+ encode_method = config["encode_method"]
197
+ fill_na = config["fill_na"]
174
198
  sparse_series = pd.Series(data, name=name).fillna(fill_na).astype(str)
175
- if encode_method == 'label':
199
+ if encode_method == "label":
176
200
  le = self.label_encoders.get(name)
177
201
  if le is None:
178
202
  raise ValueError(f"LabelEncoder for {name} not fitted")
179
- class_to_idx = config.get('_class_to_idx')
203
+ class_to_idx = config.get("_class_to_idx")
180
204
  if class_to_idx is None:
181
205
  class_to_idx = {cls: idx for idx, cls in enumerate(le.classes_)}
182
- config['_class_to_idx'] = class_to_idx
206
+ config["_class_to_idx"] = class_to_idx
183
207
  encoded = sparse_series.map(class_to_idx)
184
208
  encoded = encoded.fillna(0).astype(np.int64)
185
209
  return encoded.to_numpy()
186
- if encode_method == 'hash':
187
- hash_size = config['hash_size']
210
+ if encode_method == "hash":
211
+ hash_size = config["hash_size"]
188
212
  hash_fn = self.hash_string
189
- return np.fromiter((hash_fn(v, hash_size) for v in sparse_series.to_numpy()), dtype=np.int64, count=sparse_series.size,)
213
+ return np.fromiter(
214
+ (hash_fn(v, hash_size) for v in sparse_series.to_numpy()),
215
+ dtype=np.int64,
216
+ count=sparse_series.size,
217
+ )
190
218
  return np.array([], dtype=np.int64)
191
-
219
+
192
220
  def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
193
221
  name = str(data.name)
194
- encode_method = config['encode_method']
195
- separator = config['separator']
196
- if encode_method == 'label':
222
+ encode_method = config["encode_method"]
223
+ separator = config["separator"]
224
+ if encode_method == "label":
197
225
  all_tokens = set()
198
226
  for seq in data:
199
227
  # Skip None, np.nan, and empty strings
@@ -201,9 +229,9 @@ class DataProcessor(FeatureSet):
201
229
  continue
202
230
  if isinstance(seq, (float, np.floating)) and np.isnan(seq):
203
231
  continue
204
- if isinstance(seq, str) and seq.strip() == '':
232
+ if isinstance(seq, str) and seq.strip() == "":
205
233
  continue
206
-
234
+
207
235
  if isinstance(seq, str):
208
236
  tokens = seq.split(separator)
209
237
  elif isinstance(seq, (list, tuple)):
@@ -214,40 +242,42 @@ class DataProcessor(FeatureSet):
214
242
  continue
215
243
  all_tokens.update(tokens)
216
244
  if len(all_tokens) == 0:
217
- all_tokens.add('<PAD>')
245
+ all_tokens.add("<PAD>")
218
246
  le = LabelEncoder()
219
247
  le.fit(list(all_tokens))
220
248
  self.label_encoders[name] = le
221
- config['vocab_size'] = len(le.classes_)
222
- elif encode_method == 'hash':
223
- config['vocab_size'] = config['hash_size']
224
-
225
- def process_sequence_feature_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
249
+ config["vocab_size"] = len(le.classes_)
250
+ elif encode_method == "hash":
251
+ config["vocab_size"] = config["hash_size"]
252
+
253
+ def process_sequence_feature_transform(
254
+ self, data: pd.Series, config: Dict[str, Any]
255
+ ) -> np.ndarray:
226
256
  """Optimized sequence transform with preallocation and cached vocab map."""
227
257
  name = str(data.name)
228
- encode_method = config['encode_method']
229
- max_len = config['max_len']
230
- pad_value = config['pad_value']
231
- truncate = config['truncate']
232
- separator = config['separator']
258
+ encode_method = config["encode_method"]
259
+ max_len = config["max_len"]
260
+ pad_value = config["pad_value"]
261
+ truncate = config["truncate"]
262
+ separator = config["separator"]
233
263
  arr = np.asarray(data, dtype=object)
234
264
  n = arr.shape[0]
235
265
  output = np.full((n, max_len), pad_value, dtype=np.int64)
236
266
  # Shared helpers cached locally for speed and cross-platform consistency
237
267
  split_fn = str.split
238
268
  is_nan = np.isnan
239
- if encode_method == 'label':
269
+ if encode_method == "label":
240
270
  le = self.label_encoders.get(name)
241
271
  if le is None:
242
272
  raise ValueError(f"LabelEncoder for {name} not fitted")
243
- class_to_idx = config.get('_class_to_idx')
273
+ class_to_idx = config.get("_class_to_idx")
244
274
  if class_to_idx is None:
245
275
  class_to_idx = {cls: idx for idx, cls in enumerate(le.classes_)}
246
- config['_class_to_idx'] = class_to_idx
276
+ config["_class_to_idx"] = class_to_idx
247
277
  else:
248
278
  class_to_idx = None # type: ignore
249
279
  hash_fn = self.hash_string
250
- hash_size = config.get('hash_size')
280
+ hash_size = config.get("hash_size")
251
281
  for i, seq in enumerate(arr):
252
282
  # normalize sequence to a list of strings
253
283
  tokens = []
@@ -262,30 +292,34 @@ class DataProcessor(FeatureSet):
262
292
  tokens = [str(t) for t in seq]
263
293
  else:
264
294
  tokens = []
265
- if encode_method == 'label':
295
+ if encode_method == "label":
266
296
  encoded = [
267
297
  class_to_idx.get(token.strip(), 0) # type: ignore[union-attr]
268
298
  for token in tokens
269
- if token is not None and token != ''
299
+ if token is not None and token != ""
270
300
  ]
271
- elif encode_method == 'hash':
301
+ elif encode_method == "hash":
272
302
  if hash_size is None:
273
303
  raise ValueError("hash_size must be set for hash encoding")
274
- encoded = [hash_fn(str(token), hash_size) for token in tokens if str(token).strip()]
304
+ encoded = [
305
+ hash_fn(str(token), hash_size)
306
+ for token in tokens
307
+ if str(token).strip()
308
+ ]
275
309
  else:
276
310
  encoded = []
277
311
  if not encoded:
278
312
  continue
279
313
  if len(encoded) > max_len:
280
- encoded = encoded[-max_len:] if truncate == 'pre' else encoded[:max_len]
314
+ encoded = encoded[-max_len:] if truncate == "pre" else encoded[:max_len]
281
315
  output[i, : len(encoded)] = encoded
282
316
  return output
283
-
317
+
284
318
  def process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
285
319
  name = str(data.name)
286
- target_type = config['target_type']
287
- label_map = config.get('label_map')
288
- if target_type in ['binary', 'multiclass']:
320
+ target_type = config["target_type"]
321
+ label_map = config.get("label_map")
322
+ if target_type in ["binary", "multiclass"]:
289
323
  if label_map is None:
290
324
  unique_values = data.dropna().unique()
291
325
  sorted_values = sorted(unique_values)
@@ -294,23 +328,27 @@ class DataProcessor(FeatureSet):
294
328
  if int_values == list(range(len(int_values))):
295
329
  label_map = {str(val): int(val) for val in sorted_values}
296
330
  else:
297
- label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
331
+ label_map = {
332
+ str(val): idx for idx, val in enumerate(sorted_values)
333
+ }
298
334
  except (ValueError, TypeError):
299
335
  label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
300
- config['label_map'] = label_map
336
+ config["label_map"] = label_map
301
337
  self.target_encoders[name] = label_map
302
-
303
- def process_target_transform(self, data: pd.Series, config: Dict[str, Any]) -> np.ndarray:
338
+
339
+ def process_target_transform(
340
+ self, data: pd.Series, config: Dict[str, Any]
341
+ ) -> np.ndarray:
304
342
  logger = logging.getLogger()
305
343
  name = str(data.name)
306
- target_type = config.get('target_type')
307
- if target_type == 'regression':
344
+ target_type = config.get("target_type")
345
+ if target_type == "regression":
308
346
  values = np.array(data.values, dtype=np.float32)
309
347
  return values
310
348
  else:
311
349
  label_map = self.target_encoders.get(name)
312
350
  if label_map is None:
313
- raise ValueError(f"Target encoder for {name} not fitted")
351
+ raise ValueError(f"Target encoder for {name} not fitted")
314
352
  result = []
315
353
  for val in data:
316
354
  str_val = str(val)
@@ -319,8 +357,10 @@ class DataProcessor(FeatureSet):
319
357
  else:
320
358
  logger.warning(f"Unknown target value: {val}, mapping to 0")
321
359
  result.append(0)
322
- return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
323
-
360
+ return np.array(
361
+ result, dtype=np.int64 if target_type == "multiclass" else np.float32
362
+ )
363
+
324
364
  def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
325
365
  """Load all data from a file or directory path into a single DataFrame."""
326
366
  file_paths, file_type = resolve_file_paths(path)
@@ -340,10 +380,16 @@ class DataProcessor(FeatureSet):
340
380
  return [str(v) for v in value]
341
381
  return [str(value)]
342
382
 
343
- def fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
383
+ def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
344
384
  """Fit processor statistics by streaming files to reduce memory usage."""
345
385
  logger = logging.getLogger()
346
- logger.info(colorize("Fitting DataProcessor (streaming path mode)...", color="cyan", bold=True))
386
+ logger.info(
387
+ colorize(
388
+ "Fitting DataProcessor (streaming path mode)...",
389
+ color="cyan",
390
+ bold=True,
391
+ )
392
+ )
347
393
  file_paths, file_type = resolve_file_paths(path)
348
394
 
349
395
  numeric_acc: Dict[str, Dict[str, float]] = {}
@@ -356,9 +402,15 @@ class DataProcessor(FeatureSet):
356
402
  "max": -np.inf,
357
403
  "max_abs": 0.0,
358
404
  }
359
- sparse_vocab: Dict[str, set[str]] = {name: set() for name in self.sparse_features.keys()}
360
- seq_vocab: Dict[str, set[str]] = {name: set() for name in self.sequence_features.keys()}
361
- target_values: Dict[str, set[Any]] = {name: set() for name in self.target_features.keys()}
405
+ sparse_vocab: Dict[str, set[str]] = {
406
+ name: set() for name in self.sparse_features.keys()
407
+ }
408
+ seq_vocab: Dict[str, set[str]] = {
409
+ name: set() for name in self.sequence_features.keys()
410
+ }
411
+ target_values: Dict[str, set[Any]] = {
412
+ name: set() for name in self.target_features.keys()
413
+ }
362
414
  missing_features = set()
363
415
  for file_path in file_paths:
364
416
  for chunk in iter_file_chunks(file_path, file_type, chunk_size):
@@ -410,12 +462,16 @@ class DataProcessor(FeatureSet):
410
462
  vals = chunk[name].dropna().tolist()
411
463
  target_values[name].update(vals)
412
464
  if missing_features:
413
- logger.warning(f"The following configured features were not found in provided files: {sorted(missing_features)}")
465
+ logger.warning(
466
+ f"The following configured features were not found in provided files: {sorted(missing_features)}"
467
+ )
414
468
  # finalize numeric scalers
415
469
  for name, config in self.numeric_features.items():
416
470
  acc = numeric_acc[name]
417
471
  if acc["count"] == 0:
418
- logger.warning(f"Numeric feature {name} has no valid values in provided files")
472
+ logger.warning(
473
+ f"Numeric feature {name} has no valid values in provided files"
474
+ )
419
475
  continue
420
476
  mean_val = acc["sum"] / acc["count"]
421
477
  if config["fill_na"] is not None:
@@ -428,7 +484,9 @@ class DataProcessor(FeatureSet):
428
484
  scaler = StandardScaler()
429
485
  scaler.mean_ = np.array([mean_val], dtype=np.float64)
430
486
  scaler.var_ = np.array([var], dtype=np.float64)
431
- scaler.scale_ = np.array([np.sqrt(var) if var > 0 else 1.0], dtype=np.float64)
487
+ scaler.scale_ = np.array(
488
+ [np.sqrt(var) if var > 0 else 1.0], dtype=np.float64
489
+ )
432
490
  scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
433
491
  self.scalers[name] = scaler
434
492
  elif scaler_type == "minmax":
@@ -503,15 +561,25 @@ class DataProcessor(FeatureSet):
503
561
  if int_values == list(range(len(int_values))):
504
562
  label_map = {str(val): int(val) for val in sorted_values}
505
563
  else:
506
- label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
564
+ label_map = {
565
+ str(val): idx for idx, val in enumerate(sorted_values)
566
+ }
507
567
  except (ValueError, TypeError):
508
- label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
568
+ label_map = {
569
+ str(val): idx for idx, val in enumerate(sorted_values)
570
+ }
509
571
  config["label_map"] = label_map
510
572
 
511
573
  self.target_encoders[name] = label_map
512
574
 
513
575
  self.is_fitted = True
514
- logger.info(colorize("DataProcessor fitted successfully (streaming path mode)", color="green", bold=True))
576
+ logger.info(
577
+ colorize(
578
+ "DataProcessor fitted successfully (streaming path mode)",
579
+ color="green",
580
+ bold=True,
581
+ )
582
+ )
515
583
  return self
516
584
 
517
585
  def transform_in_memory(
@@ -522,7 +590,7 @@ class DataProcessor(FeatureSet):
522
590
  save_format: Optional[Literal["csv", "parquet"]],
523
591
  output_path: Optional[str],
524
592
  ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
525
- logger = logging.getLogger()
593
+ logger = logging.getLogger()
526
594
  # Convert input to dict format for unified processing
527
595
  if isinstance(data, pd.DataFrame):
528
596
  data_dict = {col: data[col] for col in data.columns}
@@ -530,7 +598,7 @@ class DataProcessor(FeatureSet):
530
598
  data_dict = data
531
599
  else:
532
600
  raise ValueError(f"Unsupported data type: {type(data)}")
533
-
601
+
534
602
  result_dict = {}
535
603
  for key, value in data_dict.items():
536
604
  if isinstance(value, pd.Series):
@@ -587,7 +655,7 @@ class DataProcessor(FeatureSet):
587
655
  else:
588
656
  columns_dict[key] = value
589
657
  return pd.DataFrame(columns_dict)
590
-
658
+
591
659
  if save_format not in [None, "csv", "parquet"]:
592
660
  raise ValueError("save_format must be either 'csv', 'parquet', or None")
593
661
  effective_format = save_format
@@ -598,7 +666,9 @@ class DataProcessor(FeatureSet):
598
666
  result_df = dict_to_dataframe(result_dict)
599
667
  if persist:
600
668
  if output_path is None:
601
- raise ValueError("output_path must be provided when persisting transformed data.")
669
+ raise ValueError(
670
+ "output_path must be provided when persisting transformed data."
671
+ )
602
672
  output_dir = Path(output_path)
603
673
  if output_dir.suffix:
604
674
  output_dir = output_dir.parent
@@ -609,7 +679,11 @@ class DataProcessor(FeatureSet):
609
679
  result_df.to_parquet(save_path, index=False)
610
680
  else:
611
681
  result_df.to_csv(save_path, index=False)
612
- logger.info(colorize(f"Transformed data saved to: {save_path.resolve()}", color="green"))
682
+ logger.info(
683
+ colorize(
684
+ f"Transformed data saved to: {save_path.resolve()}", color="green"
685
+ )
686
+ )
613
687
  if return_dict:
614
688
  return result_dict
615
689
  assert result_df is not None, "DataFrame is None after transform"
@@ -627,7 +701,9 @@ class DataProcessor(FeatureSet):
627
701
  target_format = save_format or file_type
628
702
  if target_format not in ["csv", "parquet"]:
629
703
  raise ValueError("save_format must be either 'csv' or 'parquet'")
630
- base_output_dir = Path(output_path) if output_path else default_output_dir(input_path)
704
+ base_output_dir = (
705
+ Path(output_path) if output_path else default_output_dir(input_path)
706
+ )
631
707
  if base_output_dir.suffix:
632
708
  base_output_dir = base_output_dir.parent
633
709
  output_root = base_output_dir / "transformed_data"
@@ -635,8 +711,12 @@ class DataProcessor(FeatureSet):
635
711
  saved_paths = []
636
712
  for file_path in tqdm.tqdm(file_paths, desc="Transforming files", unit="file"):
637
713
  df = read_table(file_path, file_type)
638
- transformed_df = self.transform_in_memory(df, return_dict=False, persist=False, save_format=None, output_path=None)
639
- assert isinstance(transformed_df, pd.DataFrame), "Expected DataFrame when return_dict=False"
714
+ transformed_df = self.transform_in_memory(
715
+ df, return_dict=False, persist=False, save_format=None, output_path=None
716
+ )
717
+ assert isinstance(
718
+ transformed_df, pd.DataFrame
719
+ ), "Expected DataFrame when return_dict=False"
640
720
  source_path = Path(file_path)
641
721
  target_file = output_root / f"{source_path.stem}.{target_format}"
642
722
  if target_format == "csv":
@@ -644,17 +724,30 @@ class DataProcessor(FeatureSet):
644
724
  else:
645
725
  transformed_df.to_parquet(target_file, index=False)
646
726
  saved_paths.append(str(target_file.resolve()))
647
- logger.info(colorize(f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}", color="green",))
727
+ logger.info(
728
+ colorize(
729
+ f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}",
730
+ color="green",
731
+ )
732
+ )
648
733
  return saved_paths
649
734
 
650
735
  # fit is nothing but registering the statistics from data so that we can transform the data later
651
- def fit(self, data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],chunk_size: int = 200000,):
736
+ def fit(
737
+ self,
738
+ data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
739
+ chunk_size: int = 200000,
740
+ ):
652
741
  logger = logging.getLogger()
653
742
  if isinstance(data, (str, os.PathLike)):
654
743
  path_str = str(data)
655
- uses_robust = any(cfg.get("scaler") == "robust" for cfg in self.numeric_features.values())
744
+ uses_robust = any(
745
+ cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
746
+ )
656
747
  if uses_robust:
657
- logger.warning("Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited.")
748
+ logger.warning(
749
+ "Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited."
750
+ )
658
751
  data = self.load_dataframe_from_path(path_str)
659
752
  else:
660
753
  return self.fit_from_path(path_str, chunk_size)
@@ -683,9 +776,9 @@ class DataProcessor(FeatureSet):
683
776
  self.process_target_fit(data[name], config)
684
777
  self.is_fitted = True
685
778
  return self
686
-
779
+
687
780
  def transform(
688
- self,
781
+ self,
689
782
  data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
690
783
  return_dict: bool = True,
691
784
  save_format: Optional[Literal["csv", "parquet"]] = None,
@@ -695,12 +788,20 @@ class DataProcessor(FeatureSet):
695
788
  raise ValueError("DataProcessor must be fitted before transform")
696
789
  if isinstance(data, (str, os.PathLike)):
697
790
  if return_dict:
698
- raise ValueError("Path transform writes files only; set return_dict=False when passing a path.")
791
+ raise ValueError(
792
+ "Path transform writes files only; set return_dict=False when passing a path."
793
+ )
699
794
  return self.transform_path(str(data), output_path, save_format)
700
- return self.transform_in_memory(data=data, return_dict=return_dict, persist=output_path is not None, save_format=save_format, output_path=output_path)
701
-
795
+ return self.transform_in_memory(
796
+ data=data,
797
+ return_dict=return_dict,
798
+ persist=output_path is not None,
799
+ save_format=save_format,
800
+ output_path=output_path,
801
+ )
802
+
702
803
  def fit_transform(
703
- self,
804
+ self,
704
805
  data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
705
806
  return_dict: bool = True,
706
807
  save_format: Optional[Literal["csv", "parquet"]] = None,
@@ -726,7 +827,7 @@ class DataProcessor(FeatureSet):
726
827
  default_dir=Path(os.getcwd()),
727
828
  default_name="fitted_processor",
728
829
  suffix=".pkl",
729
- add_timestamp=False
830
+ add_timestamp=False,
730
831
  )
731
832
  state = {
732
833
  "numeric_features": self.numeric_features,
@@ -741,117 +842,137 @@ class DataProcessor(FeatureSet):
741
842
  }
742
843
  with open(target_path, "wb") as f:
743
844
  pickle.dump(state, f)
744
- logger.info(f"DataProcessor saved to: {target_path}, NextRec version: {self.version}")
745
-
845
+ logger.info(
846
+ f"DataProcessor saved to: {target_path}, NextRec version: {self.version}"
847
+ )
848
+
746
849
  @classmethod
747
- def load(cls, load_path: str | Path) -> 'DataProcessor':
850
+ def load(cls, load_path: str | Path) -> "DataProcessor":
748
851
  logger = logging.getLogger()
749
852
  load_path = Path(load_path)
750
- with open(load_path, 'rb') as f:
853
+ with open(load_path, "rb") as f:
751
854
  state = pickle.load(f)
752
855
  processor = cls()
753
- processor.numeric_features = state.get('numeric_features', {})
754
- processor.sparse_features = state.get('sparse_features', {})
755
- processor.sequence_features = state.get('sequence_features', {})
756
- processor.target_features = state.get('target_features', {})
757
- processor.is_fitted = state.get('is_fitted', False)
758
- processor.scalers = state.get('scalers', {})
759
- processor.label_encoders = state.get('label_encoders', {})
760
- processor.target_encoders = state.get('target_encoders', {})
856
+ processor.numeric_features = state.get("numeric_features", {})
857
+ processor.sparse_features = state.get("sparse_features", {})
858
+ processor.sequence_features = state.get("sequence_features", {})
859
+ processor.target_features = state.get("target_features", {})
860
+ processor.is_fitted = state.get("is_fitted", False)
861
+ processor.scalers = state.get("scalers", {})
862
+ processor.label_encoders = state.get("label_encoders", {})
863
+ processor.target_encoders = state.get("target_encoders", {})
761
864
  processor.version = state.get("processor_version", "unknown")
762
- logger.info(f"DataProcessor loaded from {load_path}, NextRec version: {processor.version}")
865
+ logger.info(
866
+ f"DataProcessor loaded from {load_path}, NextRec version: {processor.version}"
867
+ )
763
868
  return processor
764
-
869
+
765
870
  def get_vocab_sizes(self) -> Dict[str, int]:
766
871
  vocab_sizes = {}
767
872
  for name, config in self.sparse_features.items():
768
- vocab_sizes[name] = config.get('vocab_size', 0)
873
+ vocab_sizes[name] = config.get("vocab_size", 0)
769
874
  for name, config in self.sequence_features.items():
770
- vocab_sizes[name] = config.get('vocab_size', 0)
875
+ vocab_sizes[name] = config.get("vocab_size", 0)
771
876
  return vocab_sizes
772
-
877
+
773
878
  def summary(self):
774
879
  """Print a summary of the DataProcessor configuration."""
775
880
  logger = logging.getLogger()
776
-
881
+
777
882
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
778
883
  logger.info(colorize("DataProcessor Summary", color="bright_blue", bold=True))
779
884
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
780
-
885
+
781
886
  logger.info("")
782
887
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
783
888
  logger.info(colorize("-" * 80, color="cyan"))
784
-
889
+
785
890
  if self.numeric_features:
786
891
  logger.info(f"Dense Features ({len(self.numeric_features)}):")
787
-
892
+
788
893
  max_name_len = max(len(name) for name in self.numeric_features.keys())
789
894
  name_width = max(max_name_len, 10) + 2
790
-
791
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}")
895
+
896
+ logger.info(
897
+ f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}"
898
+ )
792
899
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*15} {'-'*10}")
793
900
  for i, (name, config) in enumerate(self.numeric_features.items(), 1):
794
- scaler = config['scaler']
795
- fill_na = config.get('fill_na_value', config.get('fill_na', 'N/A'))
796
- logger.info(f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}")
797
-
901
+ scaler = config["scaler"]
902
+ fill_na = config.get("fill_na_value", config.get("fill_na", "N/A"))
903
+ logger.info(
904
+ f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}"
905
+ )
906
+
798
907
  if self.sparse_features:
799
908
  logger.info(f"Sparse Features ({len(self.sparse_features)}):")
800
-
909
+
801
910
  max_name_len = max(len(name) for name in self.sparse_features.keys())
802
911
  name_width = max(max_name_len, 10) + 2
803
-
804
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}")
912
+
913
+ logger.info(
914
+ f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}"
915
+ )
805
916
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12}")
806
917
  for i, (name, config) in enumerate(self.sparse_features.items(), 1):
807
- method = config['encode_method']
808
- vocab_size = config.get('vocab_size', 'N/A')
809
- hash_size = config.get('hash_size', 'N/A')
810
- logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}")
811
-
918
+ method = config["encode_method"]
919
+ vocab_size = config.get("vocab_size", "N/A")
920
+ hash_size = config.get("hash_size", "N/A")
921
+ logger.info(
922
+ f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}"
923
+ )
924
+
812
925
  if self.sequence_features:
813
926
  logger.info(f"Sequence Features ({len(self.sequence_features)}):")
814
-
927
+
815
928
  max_name_len = max(len(name) for name in self.sequence_features.keys())
816
929
  name_width = max(max_name_len, 10) + 2
817
-
818
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}")
819
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}")
930
+
931
+ logger.info(
932
+ f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}"
933
+ )
934
+ logger.info(
935
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}"
936
+ )
820
937
  for i, (name, config) in enumerate(self.sequence_features.items(), 1):
821
- method = config['encode_method']
822
- vocab_size = config.get('vocab_size', 'N/A')
823
- hash_size = config.get('hash_size', 'N/A')
824
- max_len = config.get('max_len', 'N/A')
825
- logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}")
826
-
938
+ method = config["encode_method"]
939
+ vocab_size = config.get("vocab_size", "N/A")
940
+ hash_size = config.get("hash_size", "N/A")
941
+ max_len = config.get("max_len", "N/A")
942
+ logger.info(
943
+ f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}"
944
+ )
945
+
827
946
  logger.info("")
828
947
  logger.info(colorize("[2] Target Configuration", color="cyan", bold=True))
829
948
  logger.info(colorize("-" * 80, color="cyan"))
830
-
949
+
831
950
  if self.target_features:
832
951
  logger.info(f"Target Features ({len(self.target_features)}):")
833
-
952
+
834
953
  max_name_len = max(len(name) for name in self.target_features.keys())
835
954
  name_width = max(max_name_len, 10) + 2
836
-
955
+
837
956
  logger.info(f" {'#':<4} {'Name':<{name_width}} {'Type':>15}")
838
957
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*15}")
839
958
  for i, (name, config) in enumerate(self.target_features.items(), 1):
840
- target_type = config['target_type']
959
+ target_type = config["target_type"]
841
960
  logger.info(f" {i:<4} {name:<{name_width}} {str(target_type):>15}")
842
961
  else:
843
962
  logger.info("No target features configured")
844
-
963
+
845
964
  logger.info("")
846
965
  logger.info(colorize("[3] Processor Status", color="cyan", bold=True))
847
966
  logger.info(colorize("-" * 80, color="cyan"))
848
967
  logger.info(f"Fitted: {self.is_fitted}")
849
- logger.info(f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}")
968
+ logger.info(
969
+ f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}"
970
+ )
850
971
  logger.info(f" Dense Features: {len(self.numeric_features)}")
851
972
  logger.info(f" Sparse Features: {len(self.sparse_features)}")
852
973
  logger.info(f" Sequence Features: {len(self.sequence_features)}")
853
974
  logger.info(f"Target Features: {len(self.target_features)}")
854
-
975
+
855
976
  logger.info("")
856
977
  logger.info("")
857
978
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))