nextrec 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
  45. nextrec-0.1.7.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -15,11 +15,11 @@ import logging
15
15
 
16
16
  from typing import Dict, Union, Optional, Literal, Any
17
17
  from sklearn.preprocessing import (
18
- StandardScaler,
19
- MinMaxScaler,
20
- RobustScaler,
18
+ StandardScaler,
19
+ MinMaxScaler,
20
+ RobustScaler,
21
21
  MaxAbsScaler,
22
- LabelEncoder,
22
+ LabelEncoder
23
23
  )
24
24
 
25
25
 
@@ -28,202 +28,198 @@ from nextrec.basic.loggers import setup_logger, colorize
28
28
 
29
29
  class DataProcessor:
30
30
  """DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
31
-
31
+
32
32
  Examples:
33
33
  >>> processor = DataProcessor()
34
34
  >>> processor.add_numeric_feature('age', scaler='standard')
35
35
  >>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
36
36
  >>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
37
37
  >>> processor.add_target('label', target_type='binary')
38
- >>>
38
+ >>>
39
39
  >>> # Fit and transform data
40
40
  >>> processor.fit(train_df)
41
41
  >>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
42
- >>>
42
+ >>>
43
43
  >>> # Save and load processor
44
44
  >>> processor.save('processor.pkl')
45
45
  >>> loaded_processor = DataProcessor.load('processor.pkl')
46
- >>>
46
+ >>>
47
47
  >>> # Get vocabulary sizes for embedding layers
48
48
  >>> vocab_sizes = processor.get_vocab_sizes()
49
49
  """
50
-
51
50
  def __init__(self):
52
51
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
53
52
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
54
53
  self.sequence_features: Dict[str, Dict[str, Any]] = {}
55
54
  self.target_features: Dict[str, Dict[str, Any]] = {}
56
-
55
+
57
56
  self.is_fitted = False
58
- self._transform_summary_printed = (
59
- False # Track if summary has been printed during transform
60
- )
61
-
57
+ self._transform_summary_printed = False # Track if summary has been printed during transform
58
+
62
59
  self.scalers: Dict[str, Any] = {}
63
60
  self.label_encoders: Dict[str, LabelEncoder] = {}
64
61
  self.target_encoders: Dict[str, Dict[str, int]] = {}
65
-
62
+
66
63
  # Initialize logger if not already initialized
67
64
  self._logger_initialized = False
68
65
  if not logging.getLogger().hasHandlers():
69
66
  setup_logger()
70
67
  self._logger_initialized = True
71
-
68
+
72
69
  def add_numeric_feature(
73
- self,
74
- name: str,
75
- scaler: Optional[
76
- Literal["standard", "minmax", "robust", "maxabs", "log", "none"]
77
- ] = "standard",
78
- fill_na: Optional[float] = None,
70
+ self,
71
+ name: str,
72
+ scaler: Optional[Literal['standard', 'minmax', 'robust', 'maxabs', 'log', 'none']] = 'standard',
73
+ fill_na: Optional[float] = None
79
74
  ):
80
- self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
81
-
75
+ self.numeric_features[name] = {
76
+ 'scaler': scaler,
77
+ 'fill_na': fill_na
78
+ }
79
+
82
80
  def add_sparse_feature(
83
- self,
84
- name: str,
85
- encode_method: Literal["hash", "label"] = "label",
81
+ self,
82
+ name: str,
83
+ encode_method: Literal['hash', 'label'] = 'label',
86
84
  hash_size: Optional[int] = None,
87
- fill_na: str = "<UNK>",
85
+ fill_na: str = '<UNK>'
88
86
  ):
89
- if encode_method == "hash" and hash_size is None:
87
+ if encode_method == 'hash' and hash_size is None:
90
88
  raise ValueError("hash_size must be specified when encode_method='hash'")
91
-
89
+
92
90
  self.sparse_features[name] = {
93
- "encode_method": encode_method,
94
- "hash_size": hash_size,
95
- "fill_na": fill_na,
91
+ 'encode_method': encode_method,
92
+ 'hash_size': hash_size,
93
+ 'fill_na': fill_na
96
94
  }
97
-
95
+
98
96
  def add_sequence_feature(
99
- self,
97
+ self,
100
98
  name: str,
101
- encode_method: Literal["hash", "label"] = "label",
99
+ encode_method: Literal['hash', 'label'] = 'label',
102
100
  hash_size: Optional[int] = None,
103
101
  max_len: Optional[int] = 50,
104
102
  pad_value: int = 0,
105
- truncate: Literal[
106
- "pre", "post"
107
- ] = "pre", # pre: keep last max_len items, post: keep first max_len items
108
- separator: str = ",",
103
+ truncate: Literal['pre', 'post'] = 'pre', # pre: keep last max_len items, post: keep first max_len items
104
+ separator: str = ','
109
105
  ):
110
106
 
111
- if encode_method == "hash" and hash_size is None:
107
+ if encode_method == 'hash' and hash_size is None:
112
108
  raise ValueError("hash_size must be specified when encode_method='hash'")
113
-
109
+
114
110
  self.sequence_features[name] = {
115
- "encode_method": encode_method,
116
- "hash_size": hash_size,
117
- "max_len": max_len,
118
- "pad_value": pad_value,
119
- "truncate": truncate,
120
- "separator": separator,
111
+ 'encode_method': encode_method,
112
+ 'hash_size': hash_size,
113
+ 'max_len': max_len,
114
+ 'pad_value': pad_value,
115
+ 'truncate': truncate,
116
+ 'separator': separator
121
117
  }
122
-
118
+
123
119
  def add_target(
124
- self,
125
- name: str, # example: 'click'
126
- target_type: Literal["binary", "multiclass", "regression"] = "binary",
127
- label_map: Optional[
128
- Dict[str, int]
129
- ] = None, # example: {'click': 1, 'no_click': 0}
120
+ self,
121
+ name: str, # example: 'click'
122
+ target_type: Literal['binary', 'multiclass', 'regression'] = 'binary',
123
+ label_map: Optional[Dict[str, int]] = None # example: {'click': 1, 'no_click': 0}
130
124
  ):
131
125
  self.target_features[name] = {
132
- "target_type": target_type,
133
- "label_map": label_map,
126
+ 'target_type': target_type,
127
+ 'label_map': label_map
134
128
  }
135
-
129
+
136
130
  def _hash_string(self, s: str, hash_size: int) -> int:
137
131
  return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
138
-
132
+
139
133
  def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
140
134
 
141
135
  name = str(data.name)
142
- scaler_type = config["scaler"]
143
- fill_na = config["fill_na"]
144
-
136
+ scaler_type = config['scaler']
137
+ fill_na = config['fill_na']
138
+
145
139
  if data.isna().any():
146
140
  if fill_na is None:
147
141
  # Default use mean value to fill missing values for numeric features
148
142
  fill_na = data.mean()
149
- config["fill_na_value"] = fill_na
150
-
151
- if scaler_type == "standard":
143
+ config['fill_na_value'] = fill_na
144
+
145
+ if scaler_type == 'standard':
152
146
  scaler = StandardScaler()
153
- elif scaler_type == "minmax":
147
+ elif scaler_type == 'minmax':
154
148
  scaler = MinMaxScaler()
155
- elif scaler_type == "robust":
149
+ elif scaler_type == 'robust':
156
150
  scaler = RobustScaler()
157
- elif scaler_type == "maxabs":
151
+ elif scaler_type == 'maxabs':
158
152
  scaler = MaxAbsScaler()
159
- elif scaler_type == "log":
160
- scaler = None
161
- elif scaler_type == "none":
153
+ elif scaler_type == 'log':
154
+ scaler = None
155
+ elif scaler_type == 'none':
162
156
  scaler = None
163
157
  else:
164
158
  raise ValueError(f"Unknown scaler type: {scaler_type}")
165
-
166
- if scaler is not None and scaler_type != "log":
167
- filled_data = data.fillna(config.get("fill_na_value", 0))
159
+
160
+ if scaler is not None and scaler_type != 'log':
161
+ filled_data = data.fillna(config.get('fill_na_value', 0))
168
162
  values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
169
163
  scaler.fit(values)
170
164
  self.scalers[name] = scaler
171
-
165
+
172
166
  def _process_numeric_feature_transform(
173
- self, data: pd.Series, config: Dict[str, Any]
167
+ self,
168
+ data: pd.Series,
169
+ config: Dict[str, Any]
174
170
  ) -> np.ndarray:
175
171
  logger = logging.getLogger()
176
-
172
+
177
173
  name = str(data.name)
178
- scaler_type = config["scaler"]
179
- fill_na_value = config.get("fill_na_value", 0)
174
+ scaler_type = config['scaler']
175
+ fill_na_value = config.get('fill_na_value', 0)
180
176
 
181
177
  filled_data = data.fillna(fill_na_value)
182
178
  values = np.array(filled_data.values, dtype=np.float64)
183
179
 
184
- if scaler_type == "log":
180
+ if scaler_type == 'log':
185
181
  result = np.log1p(np.maximum(values, 0))
186
- elif scaler_type == "none":
182
+ elif scaler_type == 'none':
187
183
  result = values
188
184
  else:
189
185
  scaler = self.scalers.get(name)
190
186
  if scaler is None:
191
- logger.warning(
192
- f"Scaler for {name} not fitted, returning original values"
193
- )
187
+ logger.warning(f"Scaler for {name} not fitted, returning original values")
194
188
  result = values
195
189
  else:
196
190
  result = scaler.transform(values.reshape(-1, 1)).ravel()
197
-
191
+
198
192
  return result
199
-
193
+
200
194
  def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
201
195
 
202
196
  name = str(data.name)
203
- encode_method = config["encode_method"]
204
- fill_na = config["fill_na"] # <UNK>
205
-
197
+ encode_method = config['encode_method']
198
+ fill_na = config['fill_na'] # <UNK>
199
+
206
200
  filled_data = data.fillna(fill_na).astype(str)
207
-
208
- if encode_method == "label":
201
+
202
+ if encode_method == 'label':
209
203
  le = LabelEncoder()
210
204
  le.fit(filled_data)
211
205
  self.label_encoders[name] = le
212
- config["vocab_size"] = len(le.classes_)
213
- elif encode_method == "hash":
214
- config["vocab_size"] = config["hash_size"]
215
-
206
+ config['vocab_size'] = len(le.classes_)
207
+ elif encode_method == 'hash':
208
+ config['vocab_size'] = config['hash_size']
209
+
216
210
  def _process_sparse_feature_transform(
217
- self, data: pd.Series, config: Dict[str, Any]
211
+ self,
212
+ data: pd.Series,
213
+ config: Dict[str, Any]
218
214
  ) -> np.ndarray:
219
-
215
+
220
216
  name = str(data.name)
221
- encode_method = config["encode_method"]
222
- fill_na = config["fill_na"]
223
-
217
+ encode_method = config['encode_method']
218
+ fill_na = config['fill_na']
219
+
224
220
  filled_data = data.fillna(fill_na).astype(str)
225
-
226
- if encode_method == "label":
221
+
222
+ if encode_method == 'label':
227
223
  le = self.label_encoders.get(name)
228
224
  if le is None:
229
225
  raise ValueError(f"LabelEncoder for {name} not fitted")
@@ -236,23 +232,20 @@ class DataProcessor:
236
232
  else:
237
233
  result.append(0)
238
234
  return np.array(result, dtype=np.int64)
239
-
240
- elif encode_method == "hash":
241
- hash_size = config["hash_size"]
242
- return np.array(
243
- [self._hash_string(val, hash_size) for val in filled_data],
244
- dtype=np.int64,
245
- )
246
-
235
+
236
+ elif encode_method == 'hash':
237
+ hash_size = config['hash_size']
238
+ return np.array([self._hash_string(val, hash_size) for val in filled_data], dtype=np.int64)
239
+
247
240
  return np.array([], dtype=np.int64)
248
-
241
+
249
242
  def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
250
243
 
251
244
  name = str(data.name)
252
- encode_method = config["encode_method"]
253
- separator = config["separator"]
254
-
255
- if encode_method == "label":
245
+ encode_method = config['encode_method']
246
+ separator = config['separator']
247
+
248
+ if encode_method == 'label':
256
249
  all_tokens = set()
257
250
  for seq in data:
258
251
  # Skip None, np.nan, and empty strings
@@ -260,9 +253,9 @@ class DataProcessor:
260
253
  continue
261
254
  if isinstance(seq, (float, np.floating)) and np.isnan(seq):
262
255
  continue
263
- if isinstance(seq, str) and seq.strip() == "":
256
+ if isinstance(seq, str) and seq.strip() == '':
264
257
  continue
265
-
258
+
266
259
  if isinstance(seq, str):
267
260
  tokens = seq.split(separator)
268
261
  elif isinstance(seq, (list, tuple)):
@@ -271,39 +264,41 @@ class DataProcessor:
271
264
  tokens = [str(t) for t in seq.tolist()]
272
265
  else:
273
266
  continue
274
-
267
+
275
268
  all_tokens.update(tokens)
276
-
269
+
277
270
  if len(all_tokens) == 0:
278
- all_tokens.add("<PAD>")
279
-
271
+ all_tokens.add('<PAD>')
272
+
280
273
  le = LabelEncoder()
281
274
  le.fit(list(all_tokens))
282
275
  self.label_encoders[name] = le
283
- config["vocab_size"] = len(le.classes_)
284
- elif encode_method == "hash":
285
- config["vocab_size"] = config["hash_size"]
286
-
276
+ config['vocab_size'] = len(le.classes_)
277
+ elif encode_method == 'hash':
278
+ config['vocab_size'] = config['hash_size']
279
+
287
280
  def _process_sequence_feature_transform(
288
- self, data: pd.Series, config: Dict[str, Any]
281
+ self,
282
+ data: pd.Series,
283
+ config: Dict[str, Any]
289
284
  ) -> np.ndarray:
290
285
  name = str(data.name)
291
- encode_method = config["encode_method"]
292
- max_len = config["max_len"]
293
- pad_value = config["pad_value"]
294
- truncate = config["truncate"]
295
- separator = config["separator"]
296
-
286
+ encode_method = config['encode_method']
287
+ max_len = config['max_len']
288
+ pad_value = config['pad_value']
289
+ truncate = config['truncate']
290
+ separator = config['separator']
291
+
297
292
  result = []
298
293
  for seq in data:
299
294
  tokens = []
300
-
295
+
301
296
  if seq is None:
302
297
  tokens = []
303
298
  elif isinstance(seq, (float, np.floating)) and np.isnan(seq):
304
299
  tokens = []
305
300
  elif isinstance(seq, str):
306
- if seq.strip() == "":
301
+ if seq.strip() == '':
307
302
  tokens = []
308
303
  else:
309
304
  tokens = seq.split(separator)
@@ -313,12 +308,12 @@ class DataProcessor:
313
308
  tokens = [str(t) for t in seq.tolist()]
314
309
  else:
315
310
  tokens = []
316
-
317
- if encode_method == "label":
311
+
312
+ if encode_method == 'label':
318
313
  le = self.label_encoders.get(name)
319
314
  if le is None:
320
315
  raise ValueError(f"LabelEncoder for {name} not fitted")
321
-
316
+
322
317
  encoded = []
323
318
  for token in tokens:
324
319
  token_str = str(token).strip()
@@ -327,70 +322,66 @@ class DataProcessor:
327
322
  encoded.append(int(encoded_val[0]))
328
323
  else:
329
324
  encoded.append(0) # UNK
330
- elif encode_method == "hash":
331
- hash_size = config["hash_size"]
332
- encoded = [
333
- self._hash_string(str(token), hash_size)
334
- for token in tokens
335
- if str(token).strip()
336
- ]
325
+ elif encode_method == 'hash':
326
+ hash_size = config['hash_size']
327
+ encoded = [self._hash_string(str(token), hash_size) for token in tokens if str(token).strip()]
337
328
  else:
338
329
  encoded = []
339
-
330
+
340
331
  if len(encoded) > max_len:
341
- if truncate == "pre": # keep last max_len items
332
+ if truncate == 'pre': # keep last max_len items
342
333
  encoded = encoded[-max_len:]
343
- else: # keep first max_len items
334
+ else: # keep first max_len items
344
335
  encoded = encoded[:max_len]
345
336
  elif len(encoded) < max_len:
346
337
  padding = [pad_value] * (max_len - len(encoded))
347
338
  encoded = encoded + padding
348
-
339
+
349
340
  result.append(encoded)
350
-
341
+
351
342
  return np.array(result, dtype=np.int64)
352
-
343
+
353
344
  def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
354
345
  name = str(data.name)
355
- target_type = config["target_type"]
356
- label_map = config["label_map"]
357
-
358
- if target_type in ["binary", "multiclass"]:
346
+ target_type = config['target_type']
347
+ label_map = config['label_map']
348
+
349
+ if target_type in ['binary', 'multiclass']:
359
350
  if label_map is None:
360
351
  unique_values = data.dropna().unique()
361
352
  sorted_values = sorted(unique_values)
362
-
353
+
363
354
  try:
364
355
  int_values = [int(v) for v in sorted_values]
365
356
  if int_values == list(range(len(int_values))):
366
357
  label_map = {str(val): int(val) for val in sorted_values}
367
358
  else:
368
- label_map = {
369
- str(val): idx for idx, val in enumerate(sorted_values)
370
- }
359
+ label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
371
360
  except (ValueError, TypeError):
372
361
  label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
373
-
374
- config["label_map"] = label_map
375
-
362
+
363
+ config['label_map'] = label_map
364
+
376
365
  self.target_encoders[name] = label_map
377
-
366
+
378
367
  def _process_target_transform(
379
- self, data: pd.Series, config: Dict[str, Any]
368
+ self,
369
+ data: pd.Series,
370
+ config: Dict[str, Any]
380
371
  ) -> np.ndarray:
381
372
  logger = logging.getLogger()
382
-
373
+
383
374
  name = str(data.name)
384
- target_type = config["target_type"]
385
-
386
- if target_type == "regression":
375
+ target_type = config['target_type']
376
+
377
+ if target_type == 'regression':
387
378
  values = np.array(data.values, dtype=np.float32)
388
379
  return values
389
380
  else:
390
381
  label_map = self.target_encoders.get(name)
391
382
  if label_map is None:
392
383
  raise ValueError(f"Target encoder for {name} not fitted")
393
-
384
+
394
385
  result = []
395
386
  for val in data:
396
387
  str_val = str(val)
@@ -399,18 +390,16 @@ class DataProcessor:
399
390
  else:
400
391
  logger.warning(f"Unknown target value: {val}, mapping to 0")
401
392
  result.append(0)
402
-
403
- return np.array(
404
- result, dtype=np.int64 if target_type == "multiclass" else np.float32
405
- )
406
-
393
+
394
+ return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
395
+
407
396
  # fit is nothing but registering the statistics from data so that we can transform the data later
408
397
  def fit(self, data: Union[pd.DataFrame, Dict[str, Any]]):
409
398
  logger = logging.getLogger()
410
-
399
+
411
400
  if isinstance(data, dict):
412
401
  data = pd.DataFrame(data)
413
-
402
+
414
403
  logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
415
404
 
416
405
  for name, config in self.numeric_features.items():
@@ -418,13 +407,13 @@ class DataProcessor:
418
407
  logger.warning(f"Numeric feature {name} not found in data")
419
408
  continue
420
409
  self._process_numeric_feature_fit(data[name], config)
421
-
410
+
422
411
  for name, config in self.sparse_features.items():
423
412
  if name not in data.columns:
424
413
  logger.warning(f"Sparse feature {name} not found in data")
425
414
  continue
426
415
  self._process_sparse_feature_fit(data[name], config)
427
-
416
+
428
417
  for name, config in self.sequence_features.items():
429
418
  if name not in data.columns:
430
419
  logger.warning(f"Sequence feature {name} not found in data")
@@ -436,21 +425,22 @@ class DataProcessor:
436
425
  logger.warning(f"Target {name} not found in data")
437
426
  continue
438
427
  self._process_target_fit(data[name], config)
439
-
428
+
440
429
  self.is_fitted = True
441
- logger.info(
442
- colorize("DataProcessor fitted successfully", color="green", bold=True)
443
- )
430
+ logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
444
431
  return self
445
-
432
+
446
433
  def transform(
447
- self, data: Union[pd.DataFrame, Dict[str, Any]], return_dict: bool = True
434
+ self,
435
+ data: Union[pd.DataFrame, Dict[str, Any]],
436
+ return_dict: bool = True
448
437
  ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
449
438
  logger = logging.getLogger()
439
+
450
440
 
451
441
  if not self.is_fitted:
452
442
  raise ValueError("DataProcessor must be fitted before transform")
453
-
443
+
454
444
  # Convert input to dict format for unified processing
455
445
  if isinstance(data, pd.DataFrame):
456
446
  data_dict = {col: data[col] for col in data.columns}
@@ -458,7 +448,7 @@ class DataProcessor:
458
448
  data_dict = data
459
449
  else:
460
450
  raise ValueError(f"Unsupported data type: {type(data)}")
461
-
451
+
462
452
  result_dict = {}
463
453
  for key, value in data_dict.items():
464
454
  if isinstance(value, pd.Series):
@@ -504,7 +494,7 @@ class DataProcessor:
504
494
  series_data = pd.Series(data_dict[name], name=name)
505
495
  processed = self._process_target_transform(series_data, config)
506
496
  result_dict[name] = processed
507
-
497
+
508
498
  if return_dict:
509
499
  return result_dict
510
500
  else:
@@ -515,19 +505,21 @@ class DataProcessor:
515
505
  columns_dict[key] = [list(seq) for seq in value]
516
506
  else:
517
507
  columns_dict[key] = value
518
-
508
+
519
509
  result_df = pd.DataFrame(columns_dict)
520
510
  return result_df
521
-
511
+
522
512
  def fit_transform(
523
- self, data: Union[pd.DataFrame, Dict[str, Any]], return_dict: bool = True
513
+ self,
514
+ data: Union[pd.DataFrame, Dict[str, Any]],
515
+ return_dict: bool = True
524
516
  ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
525
517
  self.fit(data)
526
518
  return self.transform(data, return_dict=return_dict)
527
-
519
+
528
520
  def save(self, filepath: str):
529
521
  logger = logging.getLogger()
530
-
522
+
531
523
  if not self.is_fitted:
532
524
  logger.warning("Saving unfitted DataProcessor")
533
525
 
@@ -535,152 +527,136 @@ class DataProcessor:
535
527
  if dir_path and not os.path.exists(dir_path):
536
528
  os.makedirs(dir_path, exist_ok=True)
537
529
  logger.info(f"Created directory: {dir_path}")
538
-
530
+
539
531
  state = {
540
- "numeric_features": self.numeric_features,
541
- "sparse_features": self.sparse_features,
542
- "sequence_features": self.sequence_features,
543
- "target_features": self.target_features,
544
- "is_fitted": self.is_fitted,
545
- "scalers": self.scalers,
546
- "label_encoders": self.label_encoders,
547
- "target_encoders": self.target_encoders,
532
+ 'numeric_features': self.numeric_features,
533
+ 'sparse_features': self.sparse_features,
534
+ 'sequence_features': self.sequence_features,
535
+ 'target_features': self.target_features,
536
+ 'is_fitted': self.is_fitted,
537
+ 'scalers': self.scalers,
538
+ 'label_encoders': self.label_encoders,
539
+ 'target_encoders': self.target_encoders
548
540
  }
549
-
550
- with open(filepath, "wb") as f:
541
+
542
+ with open(filepath, 'wb') as f:
551
543
  pickle.dump(state, f)
552
-
544
+
553
545
  logger.info(f"DataProcessor saved to {filepath}")
554
-
546
+
555
547
  @classmethod
556
- def load(cls, filepath: str) -> "DataProcessor":
548
+ def load(cls, filepath: str) -> 'DataProcessor':
557
549
  logger = logging.getLogger()
558
-
559
- with open(filepath, "rb") as f:
550
+
551
+ with open(filepath, 'rb') as f:
560
552
  state = pickle.load(f)
561
-
553
+
562
554
  processor = cls()
563
- processor.numeric_features = state["numeric_features"]
564
- processor.sparse_features = state["sparse_features"]
565
- processor.sequence_features = state["sequence_features"]
566
- processor.target_features = state["target_features"]
567
- processor.is_fitted = state["is_fitted"]
568
- processor.scalers = state["scalers"]
569
- processor.label_encoders = state["label_encoders"]
570
- processor.target_encoders = state["target_encoders"]
571
-
555
+ processor.numeric_features = state['numeric_features']
556
+ processor.sparse_features = state['sparse_features']
557
+ processor.sequence_features = state['sequence_features']
558
+ processor.target_features = state['target_features']
559
+ processor.is_fitted = state['is_fitted']
560
+ processor.scalers = state['scalers']
561
+ processor.label_encoders = state['label_encoders']
562
+ processor.target_encoders = state['target_encoders']
563
+
572
564
  logger.info(f"DataProcessor loaded from {filepath}")
573
565
  return processor
574
-
566
+
575
567
  def get_vocab_sizes(self) -> Dict[str, int]:
576
568
  vocab_sizes = {}
577
-
569
+
578
570
  for name, config in self.sparse_features.items():
579
- vocab_sizes[name] = config.get("vocab_size", 0)
580
-
571
+ vocab_sizes[name] = config.get('vocab_size', 0)
572
+
581
573
  for name, config in self.sequence_features.items():
582
- vocab_sizes[name] = config.get("vocab_size", 0)
583
-
574
+ vocab_sizes[name] = config.get('vocab_size', 0)
575
+
584
576
  return vocab_sizes
585
-
577
+
586
578
  def summary(self):
587
579
  """Print a summary of the DataProcessor configuration."""
588
580
  logger = logging.getLogger()
589
-
581
+
590
582
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
591
583
  logger.info(colorize("DataProcessor Summary", color="bright_blue", bold=True))
592
584
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
593
-
585
+
594
586
  logger.info("")
595
587
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
596
588
  logger.info(colorize("-" * 80, color="cyan"))
597
-
589
+
598
590
  if self.numeric_features:
599
591
  logger.info(f"Dense Features ({len(self.numeric_features)}):")
600
-
592
+
601
593
  max_name_len = max(len(name) for name in self.numeric_features.keys())
602
594
  name_width = max(max_name_len, 10) + 2
603
-
604
- logger.info(
605
- f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}"
606
- )
595
+
596
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}")
607
597
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*15} {'-'*10}")
608
598
  for i, (name, config) in enumerate(self.numeric_features.items(), 1):
609
- scaler = config["scaler"]
610
- fill_na = config.get("fill_na_value", config.get("fill_na", "N/A"))
611
- logger.info(
612
- f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}"
613
- )
614
-
599
+ scaler = config['scaler']
600
+ fill_na = config.get('fill_na_value', config.get('fill_na', 'N/A'))
601
+ logger.info(f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}")
602
+
615
603
  if self.sparse_features:
616
604
  logger.info(f"Sparse Features ({len(self.sparse_features)}):")
617
-
605
+
618
606
  max_name_len = max(len(name) for name in self.sparse_features.keys())
619
607
  name_width = max(max_name_len, 10) + 2
620
-
621
- logger.info(
622
- f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}"
623
- )
608
+
609
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}")
624
610
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12}")
625
611
  for i, (name, config) in enumerate(self.sparse_features.items(), 1):
626
- method = config["encode_method"]
627
- vocab_size = config.get("vocab_size", "N/A")
628
- hash_size = config.get("hash_size", "N/A")
629
- logger.info(
630
- f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}"
631
- )
632
-
612
+ method = config['encode_method']
613
+ vocab_size = config.get('vocab_size', 'N/A')
614
+ hash_size = config.get('hash_size', 'N/A')
615
+ logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}")
616
+
633
617
  if self.sequence_features:
634
618
  logger.info(f"Sequence Features ({len(self.sequence_features)}):")
635
-
619
+
636
620
  max_name_len = max(len(name) for name in self.sequence_features.keys())
637
621
  name_width = max(max_name_len, 10) + 2
638
-
639
- logger.info(
640
- f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}"
641
- )
642
- logger.info(
643
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}"
644
- )
622
+
623
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}")
624
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}")
645
625
  for i, (name, config) in enumerate(self.sequence_features.items(), 1):
646
- method = config["encode_method"]
647
- vocab_size = config.get("vocab_size", "N/A")
648
- hash_size = config.get("hash_size", "N/A")
649
- max_len = config.get("max_len", "N/A")
650
- logger.info(
651
- f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}"
652
- )
653
-
626
+ method = config['encode_method']
627
+ vocab_size = config.get('vocab_size', 'N/A')
628
+ hash_size = config.get('hash_size', 'N/A')
629
+ max_len = config.get('max_len', 'N/A')
630
+ logger.info(f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}")
631
+
654
632
  logger.info("")
655
633
  logger.info(colorize("[2] Target Configuration", color="cyan", bold=True))
656
634
  logger.info(colorize("-" * 80, color="cyan"))
657
-
635
+
658
636
  if self.target_features:
659
637
  logger.info(f"Target Features ({len(self.target_features)}):")
660
-
638
+
661
639
  max_name_len = max(len(name) for name in self.target_features.keys())
662
640
  name_width = max(max_name_len, 10) + 2
663
-
641
+
664
642
  logger.info(f" {'#':<4} {'Name':<{name_width}} {'Type':>15}")
665
643
  logger.info(f" {'-'*4} {'-'*name_width} {'-'*15}")
666
644
  for i, (name, config) in enumerate(self.target_features.items(), 1):
667
- target_type = config["target_type"]
645
+ target_type = config['target_type']
668
646
  logger.info(f" {i:<4} {name:<{name_width}} {str(target_type):>15}")
669
647
  else:
670
648
  logger.info("No target features configured")
671
-
649
+
672
650
  logger.info("")
673
651
  logger.info(colorize("[3] Processor Status", color="cyan", bold=True))
674
652
  logger.info(colorize("-" * 80, color="cyan"))
675
653
  logger.info(f"Fitted: {self.is_fitted}")
676
- logger.info(
677
- f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}"
678
- )
654
+ logger.info(f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}")
679
655
  logger.info(f" Dense Features: {len(self.numeric_features)}")
680
656
  logger.info(f" Sparse Features: {len(self.sparse_features)}")
681
657
  logger.info(f" Sequence Features: {len(self.sequence_features)}")
682
658
  logger.info(f"Target Features: {len(self.target_features)}")
683
-
659
+
684
660
  logger.info("")
685
661
  logger.info("")
686
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
662
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))