nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -15,11 +15,11 @@ import logging
15
15
 
16
16
  from typing import Dict, Union, Optional, Literal, Any
17
17
  from sklearn.preprocessing import (
18
- StandardScaler,
19
- MinMaxScaler,
20
- RobustScaler,
18
+ StandardScaler,
19
+ MinMaxScaler,
20
+ RobustScaler,
21
21
  MaxAbsScaler,
22
- LabelEncoder
22
+ LabelEncoder,
23
23
  )
24
24
 
25
25
 
@@ -28,198 +28,202 @@ from nextrec.basic.loggers import setup_logger, colorize
28
28
 
29
29
  class DataProcessor:
30
30
  """DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
31
-
31
+
32
32
  Examples:
33
33
  >>> processor = DataProcessor()
34
34
  >>> processor.add_numeric_feature('age', scaler='standard')
35
35
  >>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
36
36
  >>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
37
37
  >>> processor.add_target('label', target_type='binary')
38
- >>>
38
+ >>>
39
39
  >>> # Fit and transform data
40
40
  >>> processor.fit(train_df)
41
41
  >>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
42
- >>>
42
+ >>>
43
43
  >>> # Save and load processor
44
44
  >>> processor.save('processor.pkl')
45
45
  >>> loaded_processor = DataProcessor.load('processor.pkl')
46
- >>>
46
+ >>>
47
47
  >>> # Get vocabulary sizes for embedding layers
48
48
  >>> vocab_sizes = processor.get_vocab_sizes()
49
49
  """
50
+
50
51
  def __init__(self):
51
52
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
52
53
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
53
54
  self.sequence_features: Dict[str, Dict[str, Any]] = {}
54
55
  self.target_features: Dict[str, Dict[str, Any]] = {}
55
-
56
+
56
57
  self.is_fitted = False
57
- self._transform_summary_printed = False # Track if summary has been printed during transform
58
-
58
+ self._transform_summary_printed = (
59
+ False # Track if summary has been printed during transform
60
+ )
61
+
59
62
  self.scalers: Dict[str, Any] = {}
60
63
  self.label_encoders: Dict[str, LabelEncoder] = {}
61
64
  self.target_encoders: Dict[str, Dict[str, int]] = {}
62
-
65
+
63
66
  # Initialize logger if not already initialized
64
67
  self._logger_initialized = False
65
68
  if not logging.getLogger().hasHandlers():
66
69
  setup_logger()
67
70
  self._logger_initialized = True
68
-
71
+
69
72
  def add_numeric_feature(
70
- self,
71
- name: str,
72
- scaler: Optional[Literal['standard', 'minmax', 'robust', 'maxabs', 'log', 'none']] = 'standard',
73
- fill_na: Optional[float] = None
73
+ self,
74
+ name: str,
75
+ scaler: Optional[
76
+ Literal["standard", "minmax", "robust", "maxabs", "log", "none"]
77
+ ] = "standard",
78
+ fill_na: Optional[float] = None,
74
79
  ):
75
- self.numeric_features[name] = {
76
- 'scaler': scaler,
77
- 'fill_na': fill_na
78
- }
79
-
80
+ self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
81
+
80
82
  def add_sparse_feature(
81
- self,
82
- name: str,
83
- encode_method: Literal['hash', 'label'] = 'label',
83
+ self,
84
+ name: str,
85
+ encode_method: Literal["hash", "label"] = "label",
84
86
  hash_size: Optional[int] = None,
85
- fill_na: str = '<UNK>'
87
+ fill_na: str = "<UNK>",
86
88
  ):
87
- if encode_method == 'hash' and hash_size is None:
89
+ if encode_method == "hash" and hash_size is None:
88
90
  raise ValueError("hash_size must be specified when encode_method='hash'")
89
-
91
+
90
92
  self.sparse_features[name] = {
91
- 'encode_method': encode_method,
92
- 'hash_size': hash_size,
93
- 'fill_na': fill_na
93
+ "encode_method": encode_method,
94
+ "hash_size": hash_size,
95
+ "fill_na": fill_na,
94
96
  }
95
-
97
+
96
98
  def add_sequence_feature(
97
- self,
99
+ self,
98
100
  name: str,
99
- encode_method: Literal['hash', 'label'] = 'label',
101
+ encode_method: Literal["hash", "label"] = "label",
100
102
  hash_size: Optional[int] = None,
101
103
  max_len: Optional[int] = 50,
102
104
  pad_value: int = 0,
103
- truncate: Literal['pre', 'post'] = 'pre', # pre: keep last max_len items, post: keep first max_len items
104
- separator: str = ','
105
+ truncate: Literal[
106
+ "pre", "post"
107
+ ] = "pre", # pre: keep last max_len items, post: keep first max_len items
108
+ separator: str = ",",
105
109
  ):
106
110
 
107
- if encode_method == 'hash' and hash_size is None:
111
+ if encode_method == "hash" and hash_size is None:
108
112
  raise ValueError("hash_size must be specified when encode_method='hash'")
109
-
113
+
110
114
  self.sequence_features[name] = {
111
- 'encode_method': encode_method,
112
- 'hash_size': hash_size,
113
- 'max_len': max_len,
114
- 'pad_value': pad_value,
115
- 'truncate': truncate,
116
- 'separator': separator
115
+ "encode_method": encode_method,
116
+ "hash_size": hash_size,
117
+ "max_len": max_len,
118
+ "pad_value": pad_value,
119
+ "truncate": truncate,
120
+ "separator": separator,
117
121
  }
118
-
122
+
119
123
  def add_target(
120
- self,
121
- name: str, # example: 'click'
122
- target_type: Literal['binary', 'multiclass', 'regression'] = 'binary',
123
- label_map: Optional[Dict[str, int]] = None # example: {'click': 1, 'no_click': 0}
124
+ self,
125
+ name: str, # example: 'click'
126
+ target_type: Literal["binary", "multiclass", "regression"] = "binary",
127
+ label_map: Optional[
128
+ Dict[str, int]
129
+ ] = None, # example: {'click': 1, 'no_click': 0}
124
130
  ):
125
131
  self.target_features[name] = {
126
- 'target_type': target_type,
127
- 'label_map': label_map
132
+ "target_type": target_type,
133
+ "label_map": label_map,
128
134
  }
129
-
135
+
130
136
  def _hash_string(self, s: str, hash_size: int) -> int:
131
137
  return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
132
-
138
+
133
139
  def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
134
140
 
135
141
  name = str(data.name)
136
- scaler_type = config['scaler']
137
- fill_na = config['fill_na']
138
-
142
+ scaler_type = config["scaler"]
143
+ fill_na = config["fill_na"]
144
+
139
145
  if data.isna().any():
140
146
  if fill_na is None:
141
147
  # Default use mean value to fill missing values for numeric features
142
148
  fill_na = data.mean()
143
- config['fill_na_value'] = fill_na
144
-
145
- if scaler_type == 'standard':
149
+ config["fill_na_value"] = fill_na
150
+
151
+ if scaler_type == "standard":
146
152
  scaler = StandardScaler()
147
- elif scaler_type == 'minmax':
153
+ elif scaler_type == "minmax":
148
154
  scaler = MinMaxScaler()
149
- elif scaler_type == 'robust':
155
+ elif scaler_type == "robust":
150
156
  scaler = RobustScaler()
151
- elif scaler_type == 'maxabs':
157
+ elif scaler_type == "maxabs":
152
158
  scaler = MaxAbsScaler()
153
- elif scaler_type == 'log':
154
- scaler = None
155
- elif scaler_type == 'none':
159
+ elif scaler_type == "log":
160
+ scaler = None
161
+ elif scaler_type == "none":
156
162
  scaler = None
157
163
  else:
158
164
  raise ValueError(f"Unknown scaler type: {scaler_type}")
159
-
160
- if scaler is not None and scaler_type != 'log':
161
- filled_data = data.fillna(config.get('fill_na_value', 0))
165
+
166
+ if scaler is not None and scaler_type != "log":
167
+ filled_data = data.fillna(config.get("fill_na_value", 0))
162
168
  values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
163
169
  scaler.fit(values)
164
170
  self.scalers[name] = scaler
165
-
171
+
166
172
  def _process_numeric_feature_transform(
167
- self,
168
- data: pd.Series,
169
- config: Dict[str, Any]
173
+ self, data: pd.Series, config: Dict[str, Any]
170
174
  ) -> np.ndarray:
171
175
  logger = logging.getLogger()
172
-
176
+
173
177
  name = str(data.name)
174
- scaler_type = config['scaler']
175
- fill_na_value = config.get('fill_na_value', 0)
178
+ scaler_type = config["scaler"]
179
+ fill_na_value = config.get("fill_na_value", 0)
176
180
 
177
181
  filled_data = data.fillna(fill_na_value)
178
182
  values = np.array(filled_data.values, dtype=np.float64)
179
183
 
180
- if scaler_type == 'log':
184
+ if scaler_type == "log":
181
185
  result = np.log1p(np.maximum(values, 0))
182
- elif scaler_type == 'none':
186
+ elif scaler_type == "none":
183
187
  result = values
184
188
  else:
185
189
  scaler = self.scalers.get(name)
186
190
  if scaler is None:
187
- logger.warning(f"Scaler for {name} not fitted, returning original values")
191
+ logger.warning(
192
+ f"Scaler for {name} not fitted, returning original values"
193
+ )
188
194
  result = values
189
195
  else:
190
196
  result = scaler.transform(values.reshape(-1, 1)).ravel()
191
-
197
+
192
198
  return result
193
-
199
+
194
200
  def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
195
201
 
196
202
  name = str(data.name)
197
- encode_method = config['encode_method']
198
- fill_na = config['fill_na'] # <UNK>
199
-
203
+ encode_method = config["encode_method"]
204
+ fill_na = config["fill_na"] # <UNK>
205
+
200
206
  filled_data = data.fillna(fill_na).astype(str)
201
-
202
- if encode_method == 'label':
207
+
208
+ if encode_method == "label":
203
209
  le = LabelEncoder()
204
210
  le.fit(filled_data)
205
211
  self.label_encoders[name] = le
206
- config['vocab_size'] = len(le.classes_)
207
- elif encode_method == 'hash':
208
- config['vocab_size'] = config['hash_size']
209
-
212
+ config["vocab_size"] = len(le.classes_)
213
+ elif encode_method == "hash":
214
+ config["vocab_size"] = config["hash_size"]
215
+
210
216
  def _process_sparse_feature_transform(
211
- self,
212
- data: pd.Series,
213
- config: Dict[str, Any]
217
+ self, data: pd.Series, config: Dict[str, Any]
214
218
  ) -> np.ndarray:
215
-
219
+
216
220
  name = str(data.name)
217
- encode_method = config['encode_method']
218
- fill_na = config['fill_na']
219
-
221
+ encode_method = config["encode_method"]
222
+ fill_na = config["fill_na"]
223
+
220
224
  filled_data = data.fillna(fill_na).astype(str)
221
-
222
- if encode_method == 'label':
225
+
226
+ if encode_method == "label":
223
227
  le = self.label_encoders.get(name)
224
228
  if le is None:
225
229
  raise ValueError(f"LabelEncoder for {name} not fitted")
@@ -232,20 +236,23 @@ class DataProcessor:
232
236
  else:
233
237
  result.append(0)
234
238
  return np.array(result, dtype=np.int64)
235
-
236
- elif encode_method == 'hash':
237
- hash_size = config['hash_size']
238
- return np.array([self._hash_string(val, hash_size) for val in filled_data], dtype=np.int64)
239
-
239
+
240
+ elif encode_method == "hash":
241
+ hash_size = config["hash_size"]
242
+ return np.array(
243
+ [self._hash_string(val, hash_size) for val in filled_data],
244
+ dtype=np.int64,
245
+ )
246
+
240
247
  return np.array([], dtype=np.int64)
241
-
248
+
242
249
  def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
243
250
 
244
251
  name = str(data.name)
245
- encode_method = config['encode_method']
246
- separator = config['separator']
247
-
248
- if encode_method == 'label':
252
+ encode_method = config["encode_method"]
253
+ separator = config["separator"]
254
+
255
+ if encode_method == "label":
249
256
  all_tokens = set()
250
257
  for seq in data:
251
258
  # Skip None, np.nan, and empty strings
@@ -253,9 +260,9 @@ class DataProcessor:
253
260
  continue
254
261
  if isinstance(seq, (float, np.floating)) and np.isnan(seq):
255
262
  continue
256
- if isinstance(seq, str) and seq.strip() == '':
263
+ if isinstance(seq, str) and seq.strip() == "":
257
264
  continue
258
-
265
+
259
266
  if isinstance(seq, str):
260
267
  tokens = seq.split(separator)
261
268
  elif isinstance(seq, (list, tuple)):
@@ -264,41 +271,39 @@ class DataProcessor:
264
271
  tokens = [str(t) for t in seq.tolist()]
265
272
  else:
266
273
  continue
267
-
274
+
268
275
  all_tokens.update(tokens)
269
-
276
+
270
277
  if len(all_tokens) == 0:
271
- all_tokens.add('<PAD>')
272
-
278
+ all_tokens.add("<PAD>")
279
+
273
280
  le = LabelEncoder()
274
281
  le.fit(list(all_tokens))
275
282
  self.label_encoders[name] = le
276
- config['vocab_size'] = len(le.classes_)
277
- elif encode_method == 'hash':
278
- config['vocab_size'] = config['hash_size']
279
-
283
+ config["vocab_size"] = len(le.classes_)
284
+ elif encode_method == "hash":
285
+ config["vocab_size"] = config["hash_size"]
286
+
280
287
  def _process_sequence_feature_transform(
281
- self,
282
- data: pd.Series,
283
- config: Dict[str, Any]
288
+ self, data: pd.Series, config: Dict[str, Any]
284
289
  ) -> np.ndarray:
285
290
  name = str(data.name)
286
- encode_method = config['encode_method']
287
- max_len = config['max_len']
288
- pad_value = config['pad_value']
289
- truncate = config['truncate']
290
- separator = config['separator']
291
-
291
+ encode_method = config["encode_method"]
292
+ max_len = config["max_len"]
293
+ pad_value = config["pad_value"]
294
+ truncate = config["truncate"]
295
+ separator = config["separator"]
296
+
292
297
  result = []
293
298
  for seq in data:
294
299
  tokens = []
295
-
300
+
296
301
  if seq is None:
297
302
  tokens = []
298
303
  elif isinstance(seq, (float, np.floating)) and np.isnan(seq):
299
304
  tokens = []
300
305
  elif isinstance(seq, str):
301
- if seq.strip() == '':
306
+ if seq.strip() == "":
302
307
  tokens = []
303
308
  else:
304
309
  tokens = seq.split(separator)
@@ -308,12 +313,12 @@ class DataProcessor:
308
313
  tokens = [str(t) for t in seq.tolist()]
309
314
  else:
310
315
  tokens = []
311
-
312
- if encode_method == 'label':
316
+
317
+ if encode_method == "label":
313
318
  le = self.label_encoders.get(name)
314
319
  if le is None:
315
320
  raise ValueError(f"LabelEncoder for {name} not fitted")
316
-
321
+
317
322
  encoded = []
318
323
  for token in tokens:
319
324
  token_str = str(token).strip()
@@ -322,66 +327,70 @@ class DataProcessor:
322
327
  encoded.append(int(encoded_val[0]))
323
328
  else:
324
329
  encoded.append(0) # UNK
325
- elif encode_method == 'hash':
326
- hash_size = config['hash_size']
327
- encoded = [self._hash_string(str(token), hash_size) for token in tokens if str(token).strip()]
330
+ elif encode_method == "hash":
331
+ hash_size = config["hash_size"]
332
+ encoded = [
333
+ self._hash_string(str(token), hash_size)
334
+ for token in tokens
335
+ if str(token).strip()
336
+ ]
328
337
  else:
329
338
  encoded = []
330
-
339
+
331
340
  if len(encoded) > max_len:
332
- if truncate == 'pre': # keep last max_len items
341
+ if truncate == "pre": # keep last max_len items
333
342
  encoded = encoded[-max_len:]
334
- else: # keep first max_len items
343
+ else: # keep first max_len items
335
344
  encoded = encoded[:max_len]
336
345
  elif len(encoded) < max_len:
337
346
  padding = [pad_value] * (max_len - len(encoded))
338
347
  encoded = encoded + padding
339
-
348
+
340
349
  result.append(encoded)
341
-
350
+
342
351
  return np.array(result, dtype=np.int64)
343
-
352
+
344
353
  def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
345
354
  name = str(data.name)
346
- target_type = config['target_type']
347
- label_map = config['label_map']
348
-
349
- if target_type in ['binary', 'multiclass']:
355
+ target_type = config["target_type"]
356
+ label_map = config["label_map"]
357
+
358
+ if target_type in ["binary", "multiclass"]:
350
359
  if label_map is None:
351
360
  unique_values = data.dropna().unique()
352
361
  sorted_values = sorted(unique_values)
353
-
362
+
354
363
  try:
355
364
  int_values = [int(v) for v in sorted_values]
356
365
  if int_values == list(range(len(int_values))):
357
366
  label_map = {str(val): int(val) for val in sorted_values}
358
367
  else:
359
- label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
368
+ label_map = {
369
+ str(val): idx for idx, val in enumerate(sorted_values)
370
+ }
360
371
  except (ValueError, TypeError):
361
372
  label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
362
-
363
- config['label_map'] = label_map
364
-
373
+
374
+ config["label_map"] = label_map
375
+
365
376
  self.target_encoders[name] = label_map
366
-
377
+
367
378
  def _process_target_transform(
368
- self,
369
- data: pd.Series,
370
- config: Dict[str, Any]
379
+ self, data: pd.Series, config: Dict[str, Any]
371
380
  ) -> np.ndarray:
372
381
  logger = logging.getLogger()
373
-
382
+
374
383
  name = str(data.name)
375
- target_type = config['target_type']
376
-
377
- if target_type == 'regression':
384
+ target_type = config["target_type"]
385
+
386
+ if target_type == "regression":
378
387
  values = np.array(data.values, dtype=np.float32)
379
388
  return values
380
389
  else:
381
390
  label_map = self.target_encoders.get(name)
382
391
  if label_map is None:
383
392
  raise ValueError(f"Target encoder for {name} not fitted")
384
-
393
+
385
394
  result = []
386
395
  for val in data:
387
396
  str_val = str(val)
@@ -390,16 +399,18 @@ class DataProcessor:
390
399
  else:
391
400
  logger.warning(f"Unknown target value: {val}, mapping to 0")
392
401
  result.append(0)
393
-
394
- return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
395
-
402
+
403
+ return np.array(
404
+ result, dtype=np.int64 if target_type == "multiclass" else np.float32
405
+ )
406
+
396
407
  # fit is nothing but registering the statistics from data so that we can transform the data later
397
408
  def fit(self, data: Union[pd.DataFrame, Dict[str, Any]]):
398
409
  logger = logging.getLogger()
399
-
410
+
400
411
  if isinstance(data, dict):
401
412
  data = pd.DataFrame(data)
402
-
413
+
403
414
  logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
404
415
 
405
416
  for name, config in self.numeric_features.items():
@@ -407,13 +418,13 @@ class DataProcessor:
407
418
  logger.warning(f"Numeric feature {name} not found in data")
408
419
  continue
409
420
  self._process_numeric_feature_fit(data[name], config)
410
-
421
+
411
422
  for name, config in self.sparse_features.items():
412
423
  if name not in data.columns:
413
424
  logger.warning(f"Sparse feature {name} not found in data")
414
425
  continue
415
426
  self._process_sparse_feature_fit(data[name], config)
416
-
427
+
417
428
  for name, config in self.sequence_features.items():
418
429
  if name not in data.columns:
419
430
  logger.warning(f"Sequence feature {name} not found in data")
@@ -425,22 +436,21 @@ class DataProcessor:
425
436
  logger.warning(f"Target {name} not found in data")
426
437
  continue
427
438
  self._process_target_fit(data[name], config)
428
-
439
+
429
440
  self.is_fitted = True
430
- logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
441
+ logger.info(
442
+ colorize("DataProcessor fitted successfully", color="green", bold=True)
443
+ )
431
444
  return self
432
-
445
+
433
446
  def transform(
434
- self,
435
- data: Union[pd.DataFrame, Dict[str, Any]],
436
- return_dict: bool = True
447
+ self, data: Union[pd.DataFrame, Dict[str, Any]], return_dict: bool = True
437
448
  ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
438
449
  logger = logging.getLogger()
439
-
440
450
 
441
451
  if not self.is_fitted:
442
452
  raise ValueError("DataProcessor must be fitted before transform")
443
-
453
+
444
454
  # Convert input to dict format for unified processing
445
455
  if isinstance(data, pd.DataFrame):
446
456
  data_dict = {col: data[col] for col in data.columns}
@@ -448,7 +458,7 @@ class DataProcessor:
448
458
  data_dict = data
449
459
  else:
450
460
  raise ValueError(f"Unsupported data type: {type(data)}")
451
-
461
+
452
462
  result_dict = {}
453
463
  for key, value in data_dict.items():
454
464
  if isinstance(value, pd.Series):
@@ -494,7 +504,7 @@ class DataProcessor:
494
504
  series_data = pd.Series(data_dict[name], name=name)
495
505
  processed = self._process_target_transform(series_data, config)
496
506
  result_dict[name] = processed
497
-
507
+
498
508
  if return_dict:
499
509
  return result_dict
500
510
  else:
@@ -505,21 +515,19 @@ class DataProcessor:
505
515
  columns_dict[key] = [list(seq) for seq in value]
506
516
  else:
507
517
  columns_dict[key] = value
508
-
518
+
509
519
  result_df = pd.DataFrame(columns_dict)
510
520
  return result_df
511
-
521
+
512
522
  def fit_transform(
513
- self,
514
- data: Union[pd.DataFrame, Dict[str, Any]],
515
- return_dict: bool = True
523
+ self, data: Union[pd.DataFrame, Dict[str, Any]], return_dict: bool = True
516
524
  ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
517
525
  self.fit(data)
518
526
  return self.transform(data, return_dict=return_dict)
519
-
527
+
520
528
  def save(self, filepath: str):
521
529
  logger = logging.getLogger()
522
-
530
+
523
531
  if not self.is_fitted:
524
532
  logger.warning("Saving unfitted DataProcessor")
525
533
 
@@ -527,136 +535,152 @@ class DataProcessor:
527
535
  if dir_path and not os.path.exists(dir_path):
528
536
  os.makedirs(dir_path, exist_ok=True)
529
537
  logger.info(f"Created directory: {dir_path}")
530
-
538
+
531
539
  state = {
532
- 'numeric_features': self.numeric_features,
533
- 'sparse_features': self.sparse_features,
534
- 'sequence_features': self.sequence_features,
535
- 'target_features': self.target_features,
536
- 'is_fitted': self.is_fitted,
537
- 'scalers': self.scalers,
538
- 'label_encoders': self.label_encoders,
539
- 'target_encoders': self.target_encoders
540
+ "numeric_features": self.numeric_features,
541
+ "sparse_features": self.sparse_features,
542
+ "sequence_features": self.sequence_features,
543
+ "target_features": self.target_features,
544
+ "is_fitted": self.is_fitted,
545
+ "scalers": self.scalers,
546
+ "label_encoders": self.label_encoders,
547
+ "target_encoders": self.target_encoders,
540
548
  }
541
-
542
- with open(filepath, 'wb') as f:
549
+
550
+ with open(filepath, "wb") as f:
543
551
  pickle.dump(state, f)
544
-
552
+
545
553
  logger.info(f"DataProcessor saved to {filepath}")
546
-
554
+
547
555
  @classmethod
548
- def load(cls, filepath: str) -> 'DataProcessor':
556
+ def load(cls, filepath: str) -> "DataProcessor":
549
557
  logger = logging.getLogger()
550
-
551
- with open(filepath, 'rb') as f:
558
+
559
+ with open(filepath, "rb") as f:
552
560
  state = pickle.load(f)
553
-
561
+
554
562
  processor = cls()
555
- processor.numeric_features = state['numeric_features']
556
- processor.sparse_features = state['sparse_features']
557
- processor.sequence_features = state['sequence_features']
558
- processor.target_features = state['target_features']
559
- processor.is_fitted = state['is_fitted']
560
- processor.scalers = state['scalers']
561
- processor.label_encoders = state['label_encoders']
562
- processor.target_encoders = state['target_encoders']
563
-
563
+ processor.numeric_features = state["numeric_features"]
564
+ processor.sparse_features = state["sparse_features"]
565
+ processor.sequence_features = state["sequence_features"]
566
+ processor.target_features = state["target_features"]
567
+ processor.is_fitted = state["is_fitted"]
568
+ processor.scalers = state["scalers"]
569
+ processor.label_encoders = state["label_encoders"]
570
+ processor.target_encoders = state["target_encoders"]
571
+
564
572
  logger.info(f"DataProcessor loaded from {filepath}")
565
573
  return processor
566
-
574
+
567
575
  def get_vocab_sizes(self) -> Dict[str, int]:
568
576
  vocab_sizes = {}
569
-
577
+
570
578
  for name, config in self.sparse_features.items():
571
- vocab_sizes[name] = config.get('vocab_size', 0)
572
-
579
+ vocab_sizes[name] = config.get("vocab_size", 0)
580
+
573
581
  for name, config in self.sequence_features.items():
574
- vocab_sizes[name] = config.get('vocab_size', 0)
575
-
582
+ vocab_sizes[name] = config.get("vocab_size", 0)
583
+
576
584
  return vocab_sizes
577
-
585
+
578
586
  def summary(self):
579
587
  """Print a summary of the DataProcessor configuration."""
580
588
  logger = logging.getLogger()
581
-
589
+
582
590
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
583
591
  logger.info(colorize("DataProcessor Summary", color="bright_blue", bold=True))
584
592
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
585
-
593
+
586
594
  logger.info("")
587
595
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
588
596
  logger.info(colorize("-" * 80, color="cyan"))
589
-
597
+
590
598
  if self.numeric_features:
591
599
  logger.info(f"Dense Features ({len(self.numeric_features)}):")
592
-
600
+
593
601
  max_name_len = max(len(name) for name in self.numeric_features.keys())
594
602
  name_width = max(max_name_len, 10) + 2
595
-
596
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}")
603
+
604
+ logger.info(
605
+ f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}"
606
+ )
597
607
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*15} {'-'*10}")
598
608
  for i, (name, config) in enumerate(self.numeric_features.items(), 1):
599
- scaler = config['scaler']
600
- fill_na = config.get('fill_na_value', config.get('fill_na', 'N/A'))
601
- logger.info(f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}")
602
-
609
+ scaler = config["scaler"]
610
+ fill_na = config.get("fill_na_value", config.get("fill_na", "N/A"))
611
+ logger.info(
612
+ f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}"
613
+ )
614
+
603
615
  if self.sparse_features:
604
616
  logger.info(f"Sparse Features ({len(self.sparse_features)}):")
605
-
617
+
606
618
  max_name_len = max(len(name) for name in self.sparse_features.keys())
607
619
  name_width = max(max_name_len, 10) + 2
608
-
609
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}")
620
+
621
+ logger.info(
622
+ f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}"
623
+ )
610
624
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12}")
611
625
  for i, (name, config) in enumerate(self.sparse_features.items(), 1):
612
- method = config['encode_method']
613
- vocab_size = config.get('vocab_size', 'N/A')
614
- hash_size = config.get('hash_size', 'N/A')
615
- logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}")
616
-
626
+ method = config["encode_method"]
627
+ vocab_size = config.get("vocab_size", "N/A")
628
+ hash_size = config.get("hash_size", "N/A")
629
+ logger.info(
630
+ f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}"
631
+ )
632
+
617
633
  if self.sequence_features:
618
634
  logger.info(f"Sequence Features ({len(self.sequence_features)}):")
619
-
635
+
620
636
  max_name_len = max(len(name) for name in self.sequence_features.keys())
621
637
  name_width = max(max_name_len, 10) + 2
622
-
623
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}")
624
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}")
638
+
639
+ logger.info(
640
+ f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}"
641
+ )
642
+ logger.info(
643
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}"
644
+ )
625
645
  for i, (name, config) in enumerate(self.sequence_features.items(), 1):
626
- method = config['encode_method']
627
- vocab_size = config.get('vocab_size', 'N/A')
628
- hash_size = config.get('hash_size', 'N/A')
629
- max_len = config.get('max_len', 'N/A')
630
- logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}")
631
-
646
+ method = config["encode_method"]
647
+ vocab_size = config.get("vocab_size", "N/A")
648
+ hash_size = config.get("hash_size", "N/A")
649
+ max_len = config.get("max_len", "N/A")
650
+ logger.info(
651
+ f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}"
652
+ )
653
+
632
654
  logger.info("")
633
655
  logger.info(colorize("[2] Target Configuration", color="cyan", bold=True))
634
656
  logger.info(colorize("-" * 80, color="cyan"))
635
-
657
+
636
658
  if self.target_features:
637
659
  logger.info(f"Target Features ({len(self.target_features)}):")
638
-
660
+
639
661
  max_name_len = max(len(name) for name in self.target_features.keys())
640
662
  name_width = max(max_name_len, 10) + 2
641
-
663
+
642
664
  logger.info(f" {'#':<4} {'Name':<{name_width}} {'Type':>15}")
643
665
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*15}")
644
666
  for i, (name, config) in enumerate(self.target_features.items(), 1):
645
- target_type = config['target_type']
667
+ target_type = config["target_type"]
646
668
  logger.info(f" {i:<4} {name:<{name_width}} {str(target_type):>15}")
647
669
  else:
648
670
  logger.info("No target features configured")
649
-
671
+
650
672
  logger.info("")
651
673
  logger.info(colorize("[3] Processor Status", color="cyan", bold=True))
652
674
  logger.info(colorize("-" * 80, color="cyan"))
653
675
  logger.info(f"Fitted: {self.is_fitted}")
654
- logger.info(f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}")
676
+ logger.info(
677
+ f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}"
678
+ )
655
679
  logger.info(f" Dense Features: {len(self.numeric_features)}")
656
680
  logger.info(f" Sparse Features: {len(self.sparse_features)}")
657
681
  logger.info(f" Sequence Features: {len(self.sequence_features)}")
658
682
  logger.info(f"Target Features: {len(self.target_features)}")
659
-
683
+
660
684
  logger.info("")
661
685
  logger.info("")
662
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
686
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))