nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -9
- nextrec/basic/callback.py +1 -0
- nextrec/basic/dataloader.py +168 -127
- nextrec/basic/features.py +24 -27
- nextrec/basic/layers.py +328 -159
- nextrec/basic/loggers.py +50 -37
- nextrec/basic/metrics.py +255 -147
- nextrec/basic/model.py +817 -462
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +16 -12
- nextrec/data/preprocessor.py +276 -252
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +30 -22
- nextrec/loss/match_losses.py +116 -83
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +70 -61
- nextrec/models/match/dssm_v2.py +61 -51
- nextrec/models/match/mind.py +89 -71
- nextrec/models/match/sdm.py +93 -81
- nextrec/models/match/youtube_dnn.py +62 -53
- nextrec/models/multi_task/esmm.py +49 -43
- nextrec/models/multi_task/mmoe.py +65 -56
- nextrec/models/multi_task/ple.py +92 -65
- nextrec/models/multi_task/share_bottom.py +48 -42
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +39 -30
- nextrec/models/ranking/autoint.py +70 -57
- nextrec/models/ranking/dcn.py +43 -35
- nextrec/models/ranking/deepfm.py +34 -28
- nextrec/models/ranking/dien.py +115 -79
- nextrec/models/ranking/din.py +84 -60
- nextrec/models/ranking/fibinet.py +51 -35
- nextrec/models/ranking/fm.py +28 -26
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +30 -31
- nextrec/models/ranking/widedeep.py +36 -31
- nextrec/models/ranking/xdeepfm.py +46 -39
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +23 -15
- nextrec/utils/optimizer.py +14 -10
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
- nextrec-0.1.2.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/RECORD +0 -51
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/data/preprocessor.py
CHANGED
|
@@ -15,11 +15,11 @@ import logging
|
|
|
15
15
|
|
|
16
16
|
from typing import Dict, Union, Optional, Literal, Any
|
|
17
17
|
from sklearn.preprocessing import (
|
|
18
|
-
StandardScaler,
|
|
19
|
-
MinMaxScaler,
|
|
20
|
-
RobustScaler,
|
|
18
|
+
StandardScaler,
|
|
19
|
+
MinMaxScaler,
|
|
20
|
+
RobustScaler,
|
|
21
21
|
MaxAbsScaler,
|
|
22
|
-
LabelEncoder
|
|
22
|
+
LabelEncoder,
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
|
|
@@ -28,198 +28,202 @@ from nextrec.basic.loggers import setup_logger, colorize
|
|
|
28
28
|
|
|
29
29
|
class DataProcessor:
|
|
30
30
|
"""DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Examples:
|
|
33
33
|
>>> processor = DataProcessor()
|
|
34
34
|
>>> processor.add_numeric_feature('age', scaler='standard')
|
|
35
35
|
>>> processor.add_sparse_feature('user_id', encode_method='hash', hash_size=10000)
|
|
36
36
|
>>> processor.add_sequence_feature('item_history', encode_method='label', max_len=50, pad_value=0)
|
|
37
37
|
>>> processor.add_target('label', target_type='binary')
|
|
38
|
-
>>>
|
|
38
|
+
>>>
|
|
39
39
|
>>> # Fit and transform data
|
|
40
40
|
>>> processor.fit(train_df)
|
|
41
41
|
>>> processed_data = processor.transform(test_df) # Returns dict of numpy arrays
|
|
42
|
-
>>>
|
|
42
|
+
>>>
|
|
43
43
|
>>> # Save and load processor
|
|
44
44
|
>>> processor.save('processor.pkl')
|
|
45
45
|
>>> loaded_processor = DataProcessor.load('processor.pkl')
|
|
46
|
-
>>>
|
|
46
|
+
>>>
|
|
47
47
|
>>> # Get vocabulary sizes for embedding layers
|
|
48
48
|
>>> vocab_sizes = processor.get_vocab_sizes()
|
|
49
49
|
"""
|
|
50
|
+
|
|
50
51
|
def __init__(self):
|
|
51
52
|
self.numeric_features: Dict[str, Dict[str, Any]] = {}
|
|
52
53
|
self.sparse_features: Dict[str, Dict[str, Any]] = {}
|
|
53
54
|
self.sequence_features: Dict[str, Dict[str, Any]] = {}
|
|
54
55
|
self.target_features: Dict[str, Dict[str, Any]] = {}
|
|
55
|
-
|
|
56
|
+
|
|
56
57
|
self.is_fitted = False
|
|
57
|
-
self._transform_summary_printed =
|
|
58
|
-
|
|
58
|
+
self._transform_summary_printed = (
|
|
59
|
+
False # Track if summary has been printed during transform
|
|
60
|
+
)
|
|
61
|
+
|
|
59
62
|
self.scalers: Dict[str, Any] = {}
|
|
60
63
|
self.label_encoders: Dict[str, LabelEncoder] = {}
|
|
61
64
|
self.target_encoders: Dict[str, Dict[str, int]] = {}
|
|
62
|
-
|
|
65
|
+
|
|
63
66
|
# Initialize logger if not already initialized
|
|
64
67
|
self._logger_initialized = False
|
|
65
68
|
if not logging.getLogger().hasHandlers():
|
|
66
69
|
setup_logger()
|
|
67
70
|
self._logger_initialized = True
|
|
68
|
-
|
|
71
|
+
|
|
69
72
|
def add_numeric_feature(
|
|
70
|
-
self,
|
|
71
|
-
name: str,
|
|
72
|
-
scaler: Optional[
|
|
73
|
-
|
|
73
|
+
self,
|
|
74
|
+
name: str,
|
|
75
|
+
scaler: Optional[
|
|
76
|
+
Literal["standard", "minmax", "robust", "maxabs", "log", "none"]
|
|
77
|
+
] = "standard",
|
|
78
|
+
fill_na: Optional[float] = None,
|
|
74
79
|
):
|
|
75
|
-
self.numeric_features[name] = {
|
|
76
|
-
|
|
77
|
-
'fill_na': fill_na
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
+
self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
|
|
81
|
+
|
|
80
82
|
def add_sparse_feature(
|
|
81
|
-
self,
|
|
82
|
-
name: str,
|
|
83
|
-
encode_method: Literal[
|
|
83
|
+
self,
|
|
84
|
+
name: str,
|
|
85
|
+
encode_method: Literal["hash", "label"] = "label",
|
|
84
86
|
hash_size: Optional[int] = None,
|
|
85
|
-
fill_na: str =
|
|
87
|
+
fill_na: str = "<UNK>",
|
|
86
88
|
):
|
|
87
|
-
if encode_method ==
|
|
89
|
+
if encode_method == "hash" and hash_size is None:
|
|
88
90
|
raise ValueError("hash_size must be specified when encode_method='hash'")
|
|
89
|
-
|
|
91
|
+
|
|
90
92
|
self.sparse_features[name] = {
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
93
|
+
"encode_method": encode_method,
|
|
94
|
+
"hash_size": hash_size,
|
|
95
|
+
"fill_na": fill_na,
|
|
94
96
|
}
|
|
95
|
-
|
|
97
|
+
|
|
96
98
|
def add_sequence_feature(
|
|
97
|
-
self,
|
|
99
|
+
self,
|
|
98
100
|
name: str,
|
|
99
|
-
encode_method: Literal[
|
|
101
|
+
encode_method: Literal["hash", "label"] = "label",
|
|
100
102
|
hash_size: Optional[int] = None,
|
|
101
103
|
max_len: Optional[int] = 50,
|
|
102
104
|
pad_value: int = 0,
|
|
103
|
-
truncate: Literal[
|
|
104
|
-
|
|
105
|
+
truncate: Literal[
|
|
106
|
+
"pre", "post"
|
|
107
|
+
] = "pre", # pre: keep last max_len items, post: keep first max_len items
|
|
108
|
+
separator: str = ",",
|
|
105
109
|
):
|
|
106
110
|
|
|
107
|
-
if encode_method ==
|
|
111
|
+
if encode_method == "hash" and hash_size is None:
|
|
108
112
|
raise ValueError("hash_size must be specified when encode_method='hash'")
|
|
109
|
-
|
|
113
|
+
|
|
110
114
|
self.sequence_features[name] = {
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
115
|
+
"encode_method": encode_method,
|
|
116
|
+
"hash_size": hash_size,
|
|
117
|
+
"max_len": max_len,
|
|
118
|
+
"pad_value": pad_value,
|
|
119
|
+
"truncate": truncate,
|
|
120
|
+
"separator": separator,
|
|
117
121
|
}
|
|
118
|
-
|
|
122
|
+
|
|
119
123
|
def add_target(
|
|
120
|
-
self,
|
|
121
|
-
name: str,
|
|
122
|
-
target_type: Literal[
|
|
123
|
-
label_map: Optional[
|
|
124
|
+
self,
|
|
125
|
+
name: str, # example: 'click'
|
|
126
|
+
target_type: Literal["binary", "multiclass", "regression"] = "binary",
|
|
127
|
+
label_map: Optional[
|
|
128
|
+
Dict[str, int]
|
|
129
|
+
] = None, # example: {'click': 1, 'no_click': 0}
|
|
124
130
|
):
|
|
125
131
|
self.target_features[name] = {
|
|
126
|
-
|
|
127
|
-
|
|
132
|
+
"target_type": target_type,
|
|
133
|
+
"label_map": label_map,
|
|
128
134
|
}
|
|
129
|
-
|
|
135
|
+
|
|
130
136
|
def _hash_string(self, s: str, hash_size: int) -> int:
|
|
131
137
|
return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
|
|
132
|
-
|
|
138
|
+
|
|
133
139
|
def _process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
134
140
|
|
|
135
141
|
name = str(data.name)
|
|
136
|
-
scaler_type = config[
|
|
137
|
-
fill_na = config[
|
|
138
|
-
|
|
142
|
+
scaler_type = config["scaler"]
|
|
143
|
+
fill_na = config["fill_na"]
|
|
144
|
+
|
|
139
145
|
if data.isna().any():
|
|
140
146
|
if fill_na is None:
|
|
141
147
|
# Default use mean value to fill missing values for numeric features
|
|
142
148
|
fill_na = data.mean()
|
|
143
|
-
config[
|
|
144
|
-
|
|
145
|
-
if scaler_type ==
|
|
149
|
+
config["fill_na_value"] = fill_na
|
|
150
|
+
|
|
151
|
+
if scaler_type == "standard":
|
|
146
152
|
scaler = StandardScaler()
|
|
147
|
-
elif scaler_type ==
|
|
153
|
+
elif scaler_type == "minmax":
|
|
148
154
|
scaler = MinMaxScaler()
|
|
149
|
-
elif scaler_type ==
|
|
155
|
+
elif scaler_type == "robust":
|
|
150
156
|
scaler = RobustScaler()
|
|
151
|
-
elif scaler_type ==
|
|
157
|
+
elif scaler_type == "maxabs":
|
|
152
158
|
scaler = MaxAbsScaler()
|
|
153
|
-
elif scaler_type ==
|
|
154
|
-
scaler = None
|
|
155
|
-
elif scaler_type ==
|
|
159
|
+
elif scaler_type == "log":
|
|
160
|
+
scaler = None
|
|
161
|
+
elif scaler_type == "none":
|
|
156
162
|
scaler = None
|
|
157
163
|
else:
|
|
158
164
|
raise ValueError(f"Unknown scaler type: {scaler_type}")
|
|
159
|
-
|
|
160
|
-
if scaler is not None and scaler_type !=
|
|
161
|
-
filled_data = data.fillna(config.get(
|
|
165
|
+
|
|
166
|
+
if scaler is not None and scaler_type != "log":
|
|
167
|
+
filled_data = data.fillna(config.get("fill_na_value", 0))
|
|
162
168
|
values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
|
|
163
169
|
scaler.fit(values)
|
|
164
170
|
self.scalers[name] = scaler
|
|
165
|
-
|
|
171
|
+
|
|
166
172
|
def _process_numeric_feature_transform(
|
|
167
|
-
self,
|
|
168
|
-
data: pd.Series,
|
|
169
|
-
config: Dict[str, Any]
|
|
173
|
+
self, data: pd.Series, config: Dict[str, Any]
|
|
170
174
|
) -> np.ndarray:
|
|
171
175
|
logger = logging.getLogger()
|
|
172
|
-
|
|
176
|
+
|
|
173
177
|
name = str(data.name)
|
|
174
|
-
scaler_type = config[
|
|
175
|
-
fill_na_value = config.get(
|
|
178
|
+
scaler_type = config["scaler"]
|
|
179
|
+
fill_na_value = config.get("fill_na_value", 0)
|
|
176
180
|
|
|
177
181
|
filled_data = data.fillna(fill_na_value)
|
|
178
182
|
values = np.array(filled_data.values, dtype=np.float64)
|
|
179
183
|
|
|
180
|
-
if scaler_type ==
|
|
184
|
+
if scaler_type == "log":
|
|
181
185
|
result = np.log1p(np.maximum(values, 0))
|
|
182
|
-
elif scaler_type ==
|
|
186
|
+
elif scaler_type == "none":
|
|
183
187
|
result = values
|
|
184
188
|
else:
|
|
185
189
|
scaler = self.scalers.get(name)
|
|
186
190
|
if scaler is None:
|
|
187
|
-
logger.warning(
|
|
191
|
+
logger.warning(
|
|
192
|
+
f"Scaler for {name} not fitted, returning original values"
|
|
193
|
+
)
|
|
188
194
|
result = values
|
|
189
195
|
else:
|
|
190
196
|
result = scaler.transform(values.reshape(-1, 1)).ravel()
|
|
191
|
-
|
|
197
|
+
|
|
192
198
|
return result
|
|
193
|
-
|
|
199
|
+
|
|
194
200
|
def _process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
195
201
|
|
|
196
202
|
name = str(data.name)
|
|
197
|
-
encode_method = config[
|
|
198
|
-
fill_na = config[
|
|
199
|
-
|
|
203
|
+
encode_method = config["encode_method"]
|
|
204
|
+
fill_na = config["fill_na"] # <UNK>
|
|
205
|
+
|
|
200
206
|
filled_data = data.fillna(fill_na).astype(str)
|
|
201
|
-
|
|
202
|
-
if encode_method ==
|
|
207
|
+
|
|
208
|
+
if encode_method == "label":
|
|
203
209
|
le = LabelEncoder()
|
|
204
210
|
le.fit(filled_data)
|
|
205
211
|
self.label_encoders[name] = le
|
|
206
|
-
config[
|
|
207
|
-
elif encode_method ==
|
|
208
|
-
config[
|
|
209
|
-
|
|
212
|
+
config["vocab_size"] = len(le.classes_)
|
|
213
|
+
elif encode_method == "hash":
|
|
214
|
+
config["vocab_size"] = config["hash_size"]
|
|
215
|
+
|
|
210
216
|
def _process_sparse_feature_transform(
|
|
211
|
-
self,
|
|
212
|
-
data: pd.Series,
|
|
213
|
-
config: Dict[str, Any]
|
|
217
|
+
self, data: pd.Series, config: Dict[str, Any]
|
|
214
218
|
) -> np.ndarray:
|
|
215
|
-
|
|
219
|
+
|
|
216
220
|
name = str(data.name)
|
|
217
|
-
encode_method = config[
|
|
218
|
-
fill_na = config[
|
|
219
|
-
|
|
221
|
+
encode_method = config["encode_method"]
|
|
222
|
+
fill_na = config["fill_na"]
|
|
223
|
+
|
|
220
224
|
filled_data = data.fillna(fill_na).astype(str)
|
|
221
|
-
|
|
222
|
-
if encode_method ==
|
|
225
|
+
|
|
226
|
+
if encode_method == "label":
|
|
223
227
|
le = self.label_encoders.get(name)
|
|
224
228
|
if le is None:
|
|
225
229
|
raise ValueError(f"LabelEncoder for {name} not fitted")
|
|
@@ -232,20 +236,23 @@ class DataProcessor:
|
|
|
232
236
|
else:
|
|
233
237
|
result.append(0)
|
|
234
238
|
return np.array(result, dtype=np.int64)
|
|
235
|
-
|
|
236
|
-
elif encode_method ==
|
|
237
|
-
hash_size = config[
|
|
238
|
-
return np.array(
|
|
239
|
-
|
|
239
|
+
|
|
240
|
+
elif encode_method == "hash":
|
|
241
|
+
hash_size = config["hash_size"]
|
|
242
|
+
return np.array(
|
|
243
|
+
[self._hash_string(val, hash_size) for val in filled_data],
|
|
244
|
+
dtype=np.int64,
|
|
245
|
+
)
|
|
246
|
+
|
|
240
247
|
return np.array([], dtype=np.int64)
|
|
241
|
-
|
|
248
|
+
|
|
242
249
|
def _process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
243
250
|
|
|
244
251
|
name = str(data.name)
|
|
245
|
-
encode_method = config[
|
|
246
|
-
separator = config[
|
|
247
|
-
|
|
248
|
-
if encode_method ==
|
|
252
|
+
encode_method = config["encode_method"]
|
|
253
|
+
separator = config["separator"]
|
|
254
|
+
|
|
255
|
+
if encode_method == "label":
|
|
249
256
|
all_tokens = set()
|
|
250
257
|
for seq in data:
|
|
251
258
|
# Skip None, np.nan, and empty strings
|
|
@@ -253,9 +260,9 @@ class DataProcessor:
|
|
|
253
260
|
continue
|
|
254
261
|
if isinstance(seq, (float, np.floating)) and np.isnan(seq):
|
|
255
262
|
continue
|
|
256
|
-
if isinstance(seq, str) and seq.strip() ==
|
|
263
|
+
if isinstance(seq, str) and seq.strip() == "":
|
|
257
264
|
continue
|
|
258
|
-
|
|
265
|
+
|
|
259
266
|
if isinstance(seq, str):
|
|
260
267
|
tokens = seq.split(separator)
|
|
261
268
|
elif isinstance(seq, (list, tuple)):
|
|
@@ -264,41 +271,39 @@ class DataProcessor:
|
|
|
264
271
|
tokens = [str(t) for t in seq.tolist()]
|
|
265
272
|
else:
|
|
266
273
|
continue
|
|
267
|
-
|
|
274
|
+
|
|
268
275
|
all_tokens.update(tokens)
|
|
269
|
-
|
|
276
|
+
|
|
270
277
|
if len(all_tokens) == 0:
|
|
271
|
-
all_tokens.add(
|
|
272
|
-
|
|
278
|
+
all_tokens.add("<PAD>")
|
|
279
|
+
|
|
273
280
|
le = LabelEncoder()
|
|
274
281
|
le.fit(list(all_tokens))
|
|
275
282
|
self.label_encoders[name] = le
|
|
276
|
-
config[
|
|
277
|
-
elif encode_method ==
|
|
278
|
-
config[
|
|
279
|
-
|
|
283
|
+
config["vocab_size"] = len(le.classes_)
|
|
284
|
+
elif encode_method == "hash":
|
|
285
|
+
config["vocab_size"] = config["hash_size"]
|
|
286
|
+
|
|
280
287
|
def _process_sequence_feature_transform(
|
|
281
|
-
self,
|
|
282
|
-
data: pd.Series,
|
|
283
|
-
config: Dict[str, Any]
|
|
288
|
+
self, data: pd.Series, config: Dict[str, Any]
|
|
284
289
|
) -> np.ndarray:
|
|
285
290
|
name = str(data.name)
|
|
286
|
-
encode_method = config[
|
|
287
|
-
max_len = config[
|
|
288
|
-
pad_value = config[
|
|
289
|
-
truncate = config[
|
|
290
|
-
separator = config[
|
|
291
|
-
|
|
291
|
+
encode_method = config["encode_method"]
|
|
292
|
+
max_len = config["max_len"]
|
|
293
|
+
pad_value = config["pad_value"]
|
|
294
|
+
truncate = config["truncate"]
|
|
295
|
+
separator = config["separator"]
|
|
296
|
+
|
|
292
297
|
result = []
|
|
293
298
|
for seq in data:
|
|
294
299
|
tokens = []
|
|
295
|
-
|
|
300
|
+
|
|
296
301
|
if seq is None:
|
|
297
302
|
tokens = []
|
|
298
303
|
elif isinstance(seq, (float, np.floating)) and np.isnan(seq):
|
|
299
304
|
tokens = []
|
|
300
305
|
elif isinstance(seq, str):
|
|
301
|
-
if seq.strip() ==
|
|
306
|
+
if seq.strip() == "":
|
|
302
307
|
tokens = []
|
|
303
308
|
else:
|
|
304
309
|
tokens = seq.split(separator)
|
|
@@ -308,12 +313,12 @@ class DataProcessor:
|
|
|
308
313
|
tokens = [str(t) for t in seq.tolist()]
|
|
309
314
|
else:
|
|
310
315
|
tokens = []
|
|
311
|
-
|
|
312
|
-
if encode_method ==
|
|
316
|
+
|
|
317
|
+
if encode_method == "label":
|
|
313
318
|
le = self.label_encoders.get(name)
|
|
314
319
|
if le is None:
|
|
315
320
|
raise ValueError(f"LabelEncoder for {name} not fitted")
|
|
316
|
-
|
|
321
|
+
|
|
317
322
|
encoded = []
|
|
318
323
|
for token in tokens:
|
|
319
324
|
token_str = str(token).strip()
|
|
@@ -322,66 +327,70 @@ class DataProcessor:
|
|
|
322
327
|
encoded.append(int(encoded_val[0]))
|
|
323
328
|
else:
|
|
324
329
|
encoded.append(0) # UNK
|
|
325
|
-
elif encode_method ==
|
|
326
|
-
hash_size = config[
|
|
327
|
-
encoded = [
|
|
330
|
+
elif encode_method == "hash":
|
|
331
|
+
hash_size = config["hash_size"]
|
|
332
|
+
encoded = [
|
|
333
|
+
self._hash_string(str(token), hash_size)
|
|
334
|
+
for token in tokens
|
|
335
|
+
if str(token).strip()
|
|
336
|
+
]
|
|
328
337
|
else:
|
|
329
338
|
encoded = []
|
|
330
|
-
|
|
339
|
+
|
|
331
340
|
if len(encoded) > max_len:
|
|
332
|
-
if truncate ==
|
|
341
|
+
if truncate == "pre": # keep last max_len items
|
|
333
342
|
encoded = encoded[-max_len:]
|
|
334
|
-
else:
|
|
343
|
+
else: # keep first max_len items
|
|
335
344
|
encoded = encoded[:max_len]
|
|
336
345
|
elif len(encoded) < max_len:
|
|
337
346
|
padding = [pad_value] * (max_len - len(encoded))
|
|
338
347
|
encoded = encoded + padding
|
|
339
|
-
|
|
348
|
+
|
|
340
349
|
result.append(encoded)
|
|
341
|
-
|
|
350
|
+
|
|
342
351
|
return np.array(result, dtype=np.int64)
|
|
343
|
-
|
|
352
|
+
|
|
344
353
|
def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
|
|
345
354
|
name = str(data.name)
|
|
346
|
-
target_type = config[
|
|
347
|
-
label_map = config[
|
|
348
|
-
|
|
349
|
-
if target_type in [
|
|
355
|
+
target_type = config["target_type"]
|
|
356
|
+
label_map = config["label_map"]
|
|
357
|
+
|
|
358
|
+
if target_type in ["binary", "multiclass"]:
|
|
350
359
|
if label_map is None:
|
|
351
360
|
unique_values = data.dropna().unique()
|
|
352
361
|
sorted_values = sorted(unique_values)
|
|
353
|
-
|
|
362
|
+
|
|
354
363
|
try:
|
|
355
364
|
int_values = [int(v) for v in sorted_values]
|
|
356
365
|
if int_values == list(range(len(int_values))):
|
|
357
366
|
label_map = {str(val): int(val) for val in sorted_values}
|
|
358
367
|
else:
|
|
359
|
-
label_map = {
|
|
368
|
+
label_map = {
|
|
369
|
+
str(val): idx for idx, val in enumerate(sorted_values)
|
|
370
|
+
}
|
|
360
371
|
except (ValueError, TypeError):
|
|
361
372
|
label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
|
|
362
|
-
|
|
363
|
-
config[
|
|
364
|
-
|
|
373
|
+
|
|
374
|
+
config["label_map"] = label_map
|
|
375
|
+
|
|
365
376
|
self.target_encoders[name] = label_map
|
|
366
|
-
|
|
377
|
+
|
|
367
378
|
def _process_target_transform(
|
|
368
|
-
self,
|
|
369
|
-
data: pd.Series,
|
|
370
|
-
config: Dict[str, Any]
|
|
379
|
+
self, data: pd.Series, config: Dict[str, Any]
|
|
371
380
|
) -> np.ndarray:
|
|
372
381
|
logger = logging.getLogger()
|
|
373
|
-
|
|
382
|
+
|
|
374
383
|
name = str(data.name)
|
|
375
|
-
target_type = config[
|
|
376
|
-
|
|
377
|
-
if target_type ==
|
|
384
|
+
target_type = config["target_type"]
|
|
385
|
+
|
|
386
|
+
if target_type == "regression":
|
|
378
387
|
values = np.array(data.values, dtype=np.float32)
|
|
379
388
|
return values
|
|
380
389
|
else:
|
|
381
390
|
label_map = self.target_encoders.get(name)
|
|
382
391
|
if label_map is None:
|
|
383
392
|
raise ValueError(f"Target encoder for {name} not fitted")
|
|
384
|
-
|
|
393
|
+
|
|
385
394
|
result = []
|
|
386
395
|
for val in data:
|
|
387
396
|
str_val = str(val)
|
|
@@ -390,16 +399,18 @@ class DataProcessor:
|
|
|
390
399
|
else:
|
|
391
400
|
logger.warning(f"Unknown target value: {val}, mapping to 0")
|
|
392
401
|
result.append(0)
|
|
393
|
-
|
|
394
|
-
return np.array(
|
|
395
|
-
|
|
402
|
+
|
|
403
|
+
return np.array(
|
|
404
|
+
result, dtype=np.int64 if target_type == "multiclass" else np.float32
|
|
405
|
+
)
|
|
406
|
+
|
|
396
407
|
# fit is nothing but registering the statistics from data so that we can transform the data later
|
|
397
408
|
def fit(self, data: Union[pd.DataFrame, Dict[str, Any]]):
|
|
398
409
|
logger = logging.getLogger()
|
|
399
|
-
|
|
410
|
+
|
|
400
411
|
if isinstance(data, dict):
|
|
401
412
|
data = pd.DataFrame(data)
|
|
402
|
-
|
|
413
|
+
|
|
403
414
|
logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
|
|
404
415
|
|
|
405
416
|
for name, config in self.numeric_features.items():
|
|
@@ -407,13 +418,13 @@ class DataProcessor:
|
|
|
407
418
|
logger.warning(f"Numeric feature {name} not found in data")
|
|
408
419
|
continue
|
|
409
420
|
self._process_numeric_feature_fit(data[name], config)
|
|
410
|
-
|
|
421
|
+
|
|
411
422
|
for name, config in self.sparse_features.items():
|
|
412
423
|
if name not in data.columns:
|
|
413
424
|
logger.warning(f"Sparse feature {name} not found in data")
|
|
414
425
|
continue
|
|
415
426
|
self._process_sparse_feature_fit(data[name], config)
|
|
416
|
-
|
|
427
|
+
|
|
417
428
|
for name, config in self.sequence_features.items():
|
|
418
429
|
if name not in data.columns:
|
|
419
430
|
logger.warning(f"Sequence feature {name} not found in data")
|
|
@@ -425,22 +436,21 @@ class DataProcessor:
|
|
|
425
436
|
logger.warning(f"Target {name} not found in data")
|
|
426
437
|
continue
|
|
427
438
|
self._process_target_fit(data[name], config)
|
|
428
|
-
|
|
439
|
+
|
|
429
440
|
self.is_fitted = True
|
|
430
|
-
logger.info(
|
|
441
|
+
logger.info(
|
|
442
|
+
colorize("DataProcessor fitted successfully", color="green", bold=True)
|
|
443
|
+
)
|
|
431
444
|
return self
|
|
432
|
-
|
|
445
|
+
|
|
433
446
|
def transform(
|
|
434
|
-
self,
|
|
435
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
436
|
-
return_dict: bool = True
|
|
447
|
+
self, data: Union[pd.DataFrame, Dict[str, Any]], return_dict: bool = True
|
|
437
448
|
) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
|
|
438
449
|
logger = logging.getLogger()
|
|
439
|
-
|
|
440
450
|
|
|
441
451
|
if not self.is_fitted:
|
|
442
452
|
raise ValueError("DataProcessor must be fitted before transform")
|
|
443
|
-
|
|
453
|
+
|
|
444
454
|
# Convert input to dict format for unified processing
|
|
445
455
|
if isinstance(data, pd.DataFrame):
|
|
446
456
|
data_dict = {col: data[col] for col in data.columns}
|
|
@@ -448,7 +458,7 @@ class DataProcessor:
|
|
|
448
458
|
data_dict = data
|
|
449
459
|
else:
|
|
450
460
|
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
451
|
-
|
|
461
|
+
|
|
452
462
|
result_dict = {}
|
|
453
463
|
for key, value in data_dict.items():
|
|
454
464
|
if isinstance(value, pd.Series):
|
|
@@ -494,7 +504,7 @@ class DataProcessor:
|
|
|
494
504
|
series_data = pd.Series(data_dict[name], name=name)
|
|
495
505
|
processed = self._process_target_transform(series_data, config)
|
|
496
506
|
result_dict[name] = processed
|
|
497
|
-
|
|
507
|
+
|
|
498
508
|
if return_dict:
|
|
499
509
|
return result_dict
|
|
500
510
|
else:
|
|
@@ -505,21 +515,19 @@ class DataProcessor:
|
|
|
505
515
|
columns_dict[key] = [list(seq) for seq in value]
|
|
506
516
|
else:
|
|
507
517
|
columns_dict[key] = value
|
|
508
|
-
|
|
518
|
+
|
|
509
519
|
result_df = pd.DataFrame(columns_dict)
|
|
510
520
|
return result_df
|
|
511
|
-
|
|
521
|
+
|
|
512
522
|
def fit_transform(
|
|
513
|
-
self,
|
|
514
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
515
|
-
return_dict: bool = True
|
|
523
|
+
self, data: Union[pd.DataFrame, Dict[str, Any]], return_dict: bool = True
|
|
516
524
|
) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
|
|
517
525
|
self.fit(data)
|
|
518
526
|
return self.transform(data, return_dict=return_dict)
|
|
519
|
-
|
|
527
|
+
|
|
520
528
|
def save(self, filepath: str):
|
|
521
529
|
logger = logging.getLogger()
|
|
522
|
-
|
|
530
|
+
|
|
523
531
|
if not self.is_fitted:
|
|
524
532
|
logger.warning("Saving unfitted DataProcessor")
|
|
525
533
|
|
|
@@ -527,136 +535,152 @@ class DataProcessor:
|
|
|
527
535
|
if dir_path and not os.path.exists(dir_path):
|
|
528
536
|
os.makedirs(dir_path, exist_ok=True)
|
|
529
537
|
logger.info(f"Created directory: {dir_path}")
|
|
530
|
-
|
|
538
|
+
|
|
531
539
|
state = {
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
+
"numeric_features": self.numeric_features,
|
|
541
|
+
"sparse_features": self.sparse_features,
|
|
542
|
+
"sequence_features": self.sequence_features,
|
|
543
|
+
"target_features": self.target_features,
|
|
544
|
+
"is_fitted": self.is_fitted,
|
|
545
|
+
"scalers": self.scalers,
|
|
546
|
+
"label_encoders": self.label_encoders,
|
|
547
|
+
"target_encoders": self.target_encoders,
|
|
540
548
|
}
|
|
541
|
-
|
|
542
|
-
with open(filepath,
|
|
549
|
+
|
|
550
|
+
with open(filepath, "wb") as f:
|
|
543
551
|
pickle.dump(state, f)
|
|
544
|
-
|
|
552
|
+
|
|
545
553
|
logger.info(f"DataProcessor saved to {filepath}")
|
|
546
|
-
|
|
554
|
+
|
|
547
555
|
@classmethod
|
|
548
|
-
def load(cls, filepath: str) ->
|
|
556
|
+
def load(cls, filepath: str) -> "DataProcessor":
|
|
549
557
|
logger = logging.getLogger()
|
|
550
|
-
|
|
551
|
-
with open(filepath,
|
|
558
|
+
|
|
559
|
+
with open(filepath, "rb") as f:
|
|
552
560
|
state = pickle.load(f)
|
|
553
|
-
|
|
561
|
+
|
|
554
562
|
processor = cls()
|
|
555
|
-
processor.numeric_features = state[
|
|
556
|
-
processor.sparse_features = state[
|
|
557
|
-
processor.sequence_features = state[
|
|
558
|
-
processor.target_features = state[
|
|
559
|
-
processor.is_fitted = state[
|
|
560
|
-
processor.scalers = state[
|
|
561
|
-
processor.label_encoders = state[
|
|
562
|
-
processor.target_encoders = state[
|
|
563
|
-
|
|
563
|
+
processor.numeric_features = state["numeric_features"]
|
|
564
|
+
processor.sparse_features = state["sparse_features"]
|
|
565
|
+
processor.sequence_features = state["sequence_features"]
|
|
566
|
+
processor.target_features = state["target_features"]
|
|
567
|
+
processor.is_fitted = state["is_fitted"]
|
|
568
|
+
processor.scalers = state["scalers"]
|
|
569
|
+
processor.label_encoders = state["label_encoders"]
|
|
570
|
+
processor.target_encoders = state["target_encoders"]
|
|
571
|
+
|
|
564
572
|
logger.info(f"DataProcessor loaded from {filepath}")
|
|
565
573
|
return processor
|
|
566
|
-
|
|
574
|
+
|
|
567
575
|
def get_vocab_sizes(self) -> Dict[str, int]:
|
|
568
576
|
vocab_sizes = {}
|
|
569
|
-
|
|
577
|
+
|
|
570
578
|
for name, config in self.sparse_features.items():
|
|
571
|
-
vocab_sizes[name] = config.get(
|
|
572
|
-
|
|
579
|
+
vocab_sizes[name] = config.get("vocab_size", 0)
|
|
580
|
+
|
|
573
581
|
for name, config in self.sequence_features.items():
|
|
574
|
-
vocab_sizes[name] = config.get(
|
|
575
|
-
|
|
582
|
+
vocab_sizes[name] = config.get("vocab_size", 0)
|
|
583
|
+
|
|
576
584
|
return vocab_sizes
|
|
577
|
-
|
|
585
|
+
|
|
578
586
|
def summary(self):
|
|
579
587
|
"""Print a summary of the DataProcessor configuration."""
|
|
580
588
|
logger = logging.getLogger()
|
|
581
|
-
|
|
589
|
+
|
|
582
590
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
583
591
|
logger.info(colorize("DataProcessor Summary", color="bright_blue", bold=True))
|
|
584
592
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
585
|
-
|
|
593
|
+
|
|
586
594
|
logger.info("")
|
|
587
595
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
588
596
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
589
|
-
|
|
597
|
+
|
|
590
598
|
if self.numeric_features:
|
|
591
599
|
logger.info(f"Dense Features ({len(self.numeric_features)}):")
|
|
592
|
-
|
|
600
|
+
|
|
593
601
|
max_name_len = max(len(name) for name in self.numeric_features.keys())
|
|
594
602
|
name_width = max(max_name_len, 10) + 2
|
|
595
|
-
|
|
596
|
-
logger.info(
|
|
603
|
+
|
|
604
|
+
logger.info(
|
|
605
|
+
f" {'#':<4} {'Name':<{name_width}} {'Scaler':>15} {'Fill NA':>10}"
|
|
606
|
+
)
|
|
597
607
|
logger.info(f" {'-'*4} {'-'*name_width} {'-'*15} {'-'*10}")
|
|
598
608
|
for i, (name, config) in enumerate(self.numeric_features.items(), 1):
|
|
599
|
-
scaler = config[
|
|
600
|
-
fill_na = config.get(
|
|
601
|
-
logger.info(
|
|
602
|
-
|
|
609
|
+
scaler = config["scaler"]
|
|
610
|
+
fill_na = config.get("fill_na_value", config.get("fill_na", "N/A"))
|
|
611
|
+
logger.info(
|
|
612
|
+
f" {i:<4} {name:<{name_width}} {str(scaler):>15} {str(fill_na):>10}"
|
|
613
|
+
)
|
|
614
|
+
|
|
603
615
|
if self.sparse_features:
|
|
604
616
|
logger.info(f"Sparse Features ({len(self.sparse_features)}):")
|
|
605
|
-
|
|
617
|
+
|
|
606
618
|
max_name_len = max(len(name) for name in self.sparse_features.keys())
|
|
607
619
|
name_width = max(max_name_len, 10) + 2
|
|
608
|
-
|
|
609
|
-
logger.info(
|
|
620
|
+
|
|
621
|
+
logger.info(
|
|
622
|
+
f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12}"
|
|
623
|
+
)
|
|
610
624
|
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12}")
|
|
611
625
|
for i, (name, config) in enumerate(self.sparse_features.items(), 1):
|
|
612
|
-
method = config[
|
|
613
|
-
vocab_size = config.get(
|
|
614
|
-
hash_size = config.get(
|
|
615
|
-
logger.info(
|
|
616
|
-
|
|
626
|
+
method = config["encode_method"]
|
|
627
|
+
vocab_size = config.get("vocab_size", "N/A")
|
|
628
|
+
hash_size = config.get("hash_size", "N/A")
|
|
629
|
+
logger.info(
|
|
630
|
+
f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12}"
|
|
631
|
+
)
|
|
632
|
+
|
|
617
633
|
if self.sequence_features:
|
|
618
634
|
logger.info(f"Sequence Features ({len(self.sequence_features)}):")
|
|
619
|
-
|
|
635
|
+
|
|
620
636
|
max_name_len = max(len(name) for name in self.sequence_features.keys())
|
|
621
637
|
name_width = max(max_name_len, 10) + 2
|
|
622
|
-
|
|
623
|
-
logger.info(
|
|
624
|
-
|
|
638
|
+
|
|
639
|
+
logger.info(
|
|
640
|
+
f" {'#':<4} {'Name':<{name_width}} {'Method':>12} {'Vocab Size':>12} {'Hash Size':>12} {'Max Len':>10}"
|
|
641
|
+
)
|
|
642
|
+
logger.info(
|
|
643
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*12} {'-'*12} {'-'*10}"
|
|
644
|
+
)
|
|
625
645
|
for i, (name, config) in enumerate(self.sequence_features.items(), 1):
|
|
626
|
-
method = config[
|
|
627
|
-
vocab_size = config.get(
|
|
628
|
-
hash_size = config.get(
|
|
629
|
-
max_len = config.get(
|
|
630
|
-
logger.info(
|
|
631
|
-
|
|
646
|
+
method = config["encode_method"]
|
|
647
|
+
vocab_size = config.get("vocab_size", "N/A")
|
|
648
|
+
hash_size = config.get("hash_size", "N/A")
|
|
649
|
+
max_len = config.get("max_len", "N/A")
|
|
650
|
+
logger.info(
|
|
651
|
+
f" {i:<4} {name:<{name_width}} {str(method):>12} {str(vocab_size):>12} {str(hash_size):>12} {str(max_len):>10}"
|
|
652
|
+
)
|
|
653
|
+
|
|
632
654
|
logger.info("")
|
|
633
655
|
logger.info(colorize("[2] Target Configuration", color="cyan", bold=True))
|
|
634
656
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
635
|
-
|
|
657
|
+
|
|
636
658
|
if self.target_features:
|
|
637
659
|
logger.info(f"Target Features ({len(self.target_features)}):")
|
|
638
|
-
|
|
660
|
+
|
|
639
661
|
max_name_len = max(len(name) for name in self.target_features.keys())
|
|
640
662
|
name_width = max(max_name_len, 10) + 2
|
|
641
|
-
|
|
663
|
+
|
|
642
664
|
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Type':>15}")
|
|
643
665
|
logger.info(f" {'-'*4} {'-'*name_width} {'-'*15}")
|
|
644
666
|
for i, (name, config) in enumerate(self.target_features.items(), 1):
|
|
645
|
-
target_type = config[
|
|
667
|
+
target_type = config["target_type"]
|
|
646
668
|
logger.info(f" {i:<4} {name:<{name_width}} {str(target_type):>15}")
|
|
647
669
|
else:
|
|
648
670
|
logger.info("No target features configured")
|
|
649
|
-
|
|
671
|
+
|
|
650
672
|
logger.info("")
|
|
651
673
|
logger.info(colorize("[3] Processor Status", color="cyan", bold=True))
|
|
652
674
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
653
675
|
logger.info(f"Fitted: {self.is_fitted}")
|
|
654
|
-
logger.info(
|
|
676
|
+
logger.info(
|
|
677
|
+
f"Total Features: {len(self.numeric_features) + len(self.sparse_features) + len(self.sequence_features)}"
|
|
678
|
+
)
|
|
655
679
|
logger.info(f" Dense Features: {len(self.numeric_features)}")
|
|
656
680
|
logger.info(f" Sparse Features: {len(self.sparse_features)}")
|
|
657
681
|
logger.info(f" Sequence Features: {len(self.sequence_features)}")
|
|
658
682
|
logger.info(f"Target Features: {len(self.target_features)}")
|
|
659
|
-
|
|
683
|
+
|
|
660
684
|
logger.info("")
|
|
661
685
|
logger.info("")
|
|
662
|
-
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
686
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|