nextrec 0.4.30__py3-none-any.whl → 0.4.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +48 -4
- nextrec/cli.py +18 -10
- nextrec/data/batch_utils.py +2 -2
- nextrec/data/preprocessor.py +53 -1
- nextrec/models/multi_task/[pre]aitm.py +3 -3
- nextrec/models/multi_task/[pre]snr_trans.py +3 -3
- nextrec/models/multi_task/[pre]star.py +3 -3
- nextrec/models/multi_task/apg.py +3 -3
- nextrec/models/multi_task/cross_stitch.py +3 -3
- nextrec/models/multi_task/escm.py +3 -3
- nextrec/models/multi_task/esmm.py +3 -3
- nextrec/models/multi_task/hmoe.py +3 -3
- nextrec/models/multi_task/mmoe.py +3 -3
- nextrec/models/multi_task/pepnet.py +4 -4
- nextrec/models/multi_task/ple.py +3 -3
- nextrec/models/multi_task/poso.py +3 -3
- nextrec/models/multi_task/share_bottom.py +3 -3
- nextrec/models/ranking/afm.py +3 -2
- nextrec/models/ranking/autoint.py +3 -2
- nextrec/models/ranking/dcn.py +3 -2
- nextrec/models/ranking/dcn_v2.py +3 -2
- nextrec/models/ranking/deepfm.py +3 -2
- nextrec/models/ranking/dien.py +3 -2
- nextrec/models/ranking/din.py +3 -2
- nextrec/models/ranking/eulernet.py +3 -2
- nextrec/models/ranking/ffm.py +3 -2
- nextrec/models/ranking/fibinet.py +3 -2
- nextrec/models/ranking/fm.py +3 -2
- nextrec/models/ranking/lr.py +3 -2
- nextrec/models/ranking/masknet.py +3 -2
- nextrec/models/ranking/pnn.py +3 -2
- nextrec/models/ranking/widedeep.py +3 -2
- nextrec/models/ranking/xdeepfm.py +3 -2
- nextrec/models/tree_base/__init__.py +15 -0
- nextrec/models/tree_base/base.py +693 -0
- nextrec/models/tree_base/catboost.py +97 -0
- nextrec/models/tree_base/lightgbm.py +69 -0
- nextrec/models/tree_base/xgboost.py +61 -0
- nextrec/utils/config.py +1 -0
- nextrec/utils/types.py +2 -0
- {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/METADATA +5 -5
- {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/RECORD +46 -41
- {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/licenses/LICENSE +1 -1
- {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/WHEEL +0 -0
- {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,693 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tree-based model base for NextRec.
|
|
3
|
+
|
|
4
|
+
This module provides a lightweight adapter to plug tree models (xgboost/lightgbm/catboost)
|
|
5
|
+
into the NextRec training/prediction pipeline.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import pickle
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Iterable, Literal, overload
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
from nextrec.basic.features import (
|
|
20
|
+
DenseFeature,
|
|
21
|
+
FeatureSet,
|
|
22
|
+
SequenceFeature,
|
|
23
|
+
SparseFeature,
|
|
24
|
+
)
|
|
25
|
+
from nextrec.basic.loggers import colorize, format_kv, setup_logger
|
|
26
|
+
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
27
|
+
from nextrec.basic.session import create_session, get_save_path
|
|
28
|
+
from nextrec.data.dataloader import RecDataLoader
|
|
29
|
+
from nextrec.data.data_processing import get_column_data
|
|
30
|
+
from nextrec.utils.console import display_metrics_table
|
|
31
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG, check_streaming_support
|
|
32
|
+
from nextrec.utils.feature import to_list
|
|
33
|
+
from nextrec.utils.torch_utils import to_numpy
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TreeBaseModel(FeatureSet):
|
|
37
|
+
model_file_suffix = "bin"
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def model_name(self) -> str:
|
|
41
|
+
return self.__class__.__name__.lower()
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def default_task(self) -> str:
|
|
45
|
+
return "binary"
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
dense_features: list[DenseFeature] | None = None,
|
|
50
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
51
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
52
|
+
target: list[str] | str | None = None,
|
|
53
|
+
id_columns: list[str] | str | None = None,
|
|
54
|
+
task: str | list[str] | None = None,
|
|
55
|
+
device: str = "cpu",
|
|
56
|
+
session_id: str | None = None,
|
|
57
|
+
model_params: dict[str, Any] | None = None,
|
|
58
|
+
sequence_pooling: str = "mean",
|
|
59
|
+
**kwargs: Any,
|
|
60
|
+
):
|
|
61
|
+
self.device = device
|
|
62
|
+
self.model_params = dict(model_params or {}) # tree model parameters
|
|
63
|
+
if kwargs:
|
|
64
|
+
self.model_params.update(kwargs)
|
|
65
|
+
self.sequence_pooling = sequence_pooling
|
|
66
|
+
|
|
67
|
+
self.set_all_features(
|
|
68
|
+
dense_features, sparse_features, sequence_features, target, id_columns
|
|
69
|
+
)
|
|
70
|
+
self.task = task or self.default_task
|
|
71
|
+
|
|
72
|
+
self.session_id = session_id
|
|
73
|
+
self.session = create_session(session_id)
|
|
74
|
+
self.session_path = self.session.root
|
|
75
|
+
self.checkpoint_path = os.path.join(
|
|
76
|
+
self.session_path,
|
|
77
|
+
f"{self.model_name.upper()}_checkpoint.{self.model_file_suffix}",
|
|
78
|
+
)
|
|
79
|
+
self.best_path = os.path.join(
|
|
80
|
+
self.session_path,
|
|
81
|
+
f"{self.model_name.upper()}_best.{self.model_file_suffix}",
|
|
82
|
+
)
|
|
83
|
+
self.features_config_path = os.path.join(
|
|
84
|
+
self.session_path, "features_config.pkl"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.model: Any | None = None
|
|
88
|
+
self._cat_feature_indices: list[int] = []
|
|
89
|
+
|
|
90
|
+
def assert_task(self) -> None:
|
|
91
|
+
if self.target_columns and len(self.target_columns) > 1:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"[{self.model_name}-init Error] tree models only support a single target column."
|
|
94
|
+
)
|
|
95
|
+
if isinstance(self.task, list) and len(self.task) > 1:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"[{self.model_name}-init Error] tree models only support a single task type."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def pool_sequence(self, arr: np.ndarray, feature: SequenceFeature) -> np.ndarray:
|
|
101
|
+
if arr.ndim == 1:
|
|
102
|
+
return arr.reshape(-1, 1)
|
|
103
|
+
padding_value = feature.padding_idx
|
|
104
|
+
mask = arr != padding_value
|
|
105
|
+
if self.sequence_pooling == "sum":
|
|
106
|
+
pooled = (arr * mask).sum(axis=1)
|
|
107
|
+
elif self.sequence_pooling == "max":
|
|
108
|
+
masked = np.where(mask, arr, -np.inf)
|
|
109
|
+
pooled = np.max(masked, axis=1)
|
|
110
|
+
pooled = np.where(np.isfinite(pooled), pooled, 0.0)
|
|
111
|
+
elif self.sequence_pooling == "last":
|
|
112
|
+
idx = np.where(mask, np.arange(arr.shape[1]), -1)
|
|
113
|
+
last_idx = idx.max(axis=1)
|
|
114
|
+
pooled = np.array(
|
|
115
|
+
[arr[row, col] if col >= 0 else 0.0 for row, col in enumerate(last_idx)]
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
counts = np.maximum(mask.sum(axis=1), 1)
|
|
119
|
+
pooled = (arr * mask).sum(axis=1) / counts
|
|
120
|
+
return pooled.reshape(-1, 1).astype(np.float32)
|
|
121
|
+
|
|
122
|
+
def features_to_matrix(self, features: dict[str, Any]) -> np.ndarray:
|
|
123
|
+
columns: list[np.ndarray] = []
|
|
124
|
+
cat_indices: list[int] = []
|
|
125
|
+
feature_offset = 0
|
|
126
|
+
for feat in self.all_features:
|
|
127
|
+
if feat.name not in features:
|
|
128
|
+
raise KeyError(
|
|
129
|
+
f"[{self.model_name}-data Error] Missing feature '{feat.name}'."
|
|
130
|
+
)
|
|
131
|
+
arr = to_numpy(features[feat.name])
|
|
132
|
+
if isinstance(feat, SequenceFeature):
|
|
133
|
+
arr = self.pool_sequence(arr, feat)
|
|
134
|
+
if arr.ndim == 1:
|
|
135
|
+
arr = arr.reshape(-1, 1)
|
|
136
|
+
if isinstance(feat, SparseFeature):
|
|
137
|
+
for col_idx in range(arr.shape[1]):
|
|
138
|
+
cat_indices.append(feature_offset + col_idx)
|
|
139
|
+
feature_offset += arr.shape[1]
|
|
140
|
+
columns.append(arr.astype(np.float32))
|
|
141
|
+
if columns:
|
|
142
|
+
self._cat_feature_indices = cat_indices
|
|
143
|
+
return np.concatenate(columns, axis=1)
|
|
144
|
+
return np.empty((0, 0), dtype=np.float32)
|
|
145
|
+
|
|
146
|
+
def extract_labels(self, labels: dict[str, Any] | None) -> np.ndarray | None:
|
|
147
|
+
if labels is None:
|
|
148
|
+
return None
|
|
149
|
+
if self.target_columns:
|
|
150
|
+
target = self.target_columns[0]
|
|
151
|
+
if target not in labels:
|
|
152
|
+
return None
|
|
153
|
+
return to_numpy(labels[target]).reshape(-1)
|
|
154
|
+
first_key = next(iter(labels.keys()), None)
|
|
155
|
+
if first_key is None:
|
|
156
|
+
return None
|
|
157
|
+
return to_numpy(labels[first_key]).reshape(-1)
|
|
158
|
+
|
|
159
|
+
def extract_ids(
|
|
160
|
+
self, ids: dict[str, Any] | None, id_column: str | None
|
|
161
|
+
) -> np.ndarray | None:
|
|
162
|
+
if ids is None or id_column is None:
|
|
163
|
+
return None
|
|
164
|
+
if id_column not in ids:
|
|
165
|
+
return None
|
|
166
|
+
return np.asarray(ids[id_column]).reshape(-1)
|
|
167
|
+
|
|
168
|
+
def collect_from_dataloader(
|
|
169
|
+
self,
|
|
170
|
+
data_loader: Iterable,
|
|
171
|
+
require_labels: bool,
|
|
172
|
+
include_ids: bool,
|
|
173
|
+
id_column: str | None,
|
|
174
|
+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
|
|
175
|
+
feature_chunks: list[np.ndarray] = []
|
|
176
|
+
label_chunks: list[np.ndarray] = []
|
|
177
|
+
id_chunks: list[np.ndarray] = []
|
|
178
|
+
for batch in data_loader:
|
|
179
|
+
if not isinstance(batch, dict) or "features" not in batch:
|
|
180
|
+
raise TypeError(
|
|
181
|
+
f"[{self.model_name}-data Error] Expected batches with a 'features' dict."
|
|
182
|
+
)
|
|
183
|
+
features = batch.get("features", {})
|
|
184
|
+
labels = batch.get("labels")
|
|
185
|
+
ids = batch.get("ids")
|
|
186
|
+
X_batch = self.features_to_matrix(features)
|
|
187
|
+
feature_chunks.append(X_batch)
|
|
188
|
+
y_batch = self.extract_labels(labels)
|
|
189
|
+
if require_labels and y_batch is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"[{self.model_name}-data Error] Labels are required but missing."
|
|
192
|
+
)
|
|
193
|
+
if y_batch is not None:
|
|
194
|
+
label_chunks.append(y_batch)
|
|
195
|
+
if include_ids and id_column:
|
|
196
|
+
id_batch = self.extract_ids(ids, id_column)
|
|
197
|
+
if id_batch is not None:
|
|
198
|
+
id_chunks.append(id_batch)
|
|
199
|
+
X_all = (
|
|
200
|
+
np.concatenate(feature_chunks, axis=0)
|
|
201
|
+
if feature_chunks
|
|
202
|
+
else np.empty((0, 0))
|
|
203
|
+
)
|
|
204
|
+
y_all = np.concatenate(label_chunks, axis=0) if label_chunks else None
|
|
205
|
+
ids_all = np.concatenate(id_chunks, axis=0) if id_chunks else None
|
|
206
|
+
return X_all, y_all, ids_all
|
|
207
|
+
|
|
208
|
+
def collect_from_table(
|
|
209
|
+
self,
|
|
210
|
+
data: dict | pd.DataFrame,
|
|
211
|
+
require_labels: bool,
|
|
212
|
+
include_ids: bool,
|
|
213
|
+
id_column: str | None,
|
|
214
|
+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
|
|
215
|
+
features: dict[str, Any] = {}
|
|
216
|
+
for feat in self.all_features:
|
|
217
|
+
column = get_column_data(data, feat.name)
|
|
218
|
+
if column is None:
|
|
219
|
+
raise KeyError(
|
|
220
|
+
f"[{self.model_name}-data Error] Missing feature '{feat.name}'."
|
|
221
|
+
)
|
|
222
|
+
features[feat.name] = column
|
|
223
|
+
X_all = self.features_to_matrix(features)
|
|
224
|
+
y_all = None
|
|
225
|
+
if require_labels:
|
|
226
|
+
label_payload: dict[str, Any] = {}
|
|
227
|
+
for name in self.target_columns:
|
|
228
|
+
column = get_column_data(data, name)
|
|
229
|
+
if column is not None:
|
|
230
|
+
label_payload[name] = column
|
|
231
|
+
y_all = self.extract_labels(label_payload or None)
|
|
232
|
+
if y_all is None:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"[{self.model_name}-data Error] Labels are required but missing."
|
|
235
|
+
)
|
|
236
|
+
ids_all = None
|
|
237
|
+
if include_ids and id_column:
|
|
238
|
+
id_col = get_column_data(data, id_column)
|
|
239
|
+
if id_col is not None:
|
|
240
|
+
ids_all = np.asarray(id_col).reshape(-1)
|
|
241
|
+
return X_all, y_all, ids_all
|
|
242
|
+
|
|
243
|
+
def prepare_arrays(
|
|
244
|
+
self,
|
|
245
|
+
data: Any,
|
|
246
|
+
require_labels: bool,
|
|
247
|
+
include_ids: bool,
|
|
248
|
+
id_column: str | None,
|
|
249
|
+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
|
|
250
|
+
if isinstance(data, (str, os.PathLike)):
|
|
251
|
+
raise TypeError(
|
|
252
|
+
f"[{self.model_name}-data Error] File paths are not supported here. "
|
|
253
|
+
"Use RecDataLoader to create a DataLoader for training."
|
|
254
|
+
)
|
|
255
|
+
if isinstance(data, (pd.DataFrame, dict)):
|
|
256
|
+
return self.collect_from_table(data, require_labels, include_ids, id_column)
|
|
257
|
+
if isinstance(data, Iterable) and hasattr(data, "__iter__"):
|
|
258
|
+
return self.collect_from_dataloader(
|
|
259
|
+
data, require_labels, include_ids, id_column
|
|
260
|
+
)
|
|
261
|
+
raise TypeError(
|
|
262
|
+
f"[{self.model_name}-data Error] Unsupported data type: {type(data)}"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def build_estimator(self, model_params: dict[str, Any], epochs: int | None):
|
|
266
|
+
raise NotImplementedError
|
|
267
|
+
|
|
268
|
+
def fit_estimator(
|
|
269
|
+
self,
|
|
270
|
+
model: Any,
|
|
271
|
+
X_train: np.ndarray,
|
|
272
|
+
y_train: np.ndarray,
|
|
273
|
+
X_valid: np.ndarray | None,
|
|
274
|
+
y_valid: np.ndarray | None,
|
|
275
|
+
cat_features: list[int],
|
|
276
|
+
**kwargs: Any,
|
|
277
|
+
) -> Any:
|
|
278
|
+
raise NotImplementedError
|
|
279
|
+
|
|
280
|
+
def predict_scores(self, model: Any, X: np.ndarray) -> np.ndarray:
|
|
281
|
+
if hasattr(model, "predict_proba"):
|
|
282
|
+
proba = model.predict_proba(X)
|
|
283
|
+
if isinstance(proba, list):
|
|
284
|
+
proba = np.asarray(proba)
|
|
285
|
+
if proba.ndim == 2 and proba.shape[1] > 1:
|
|
286
|
+
return proba[:, 1]
|
|
287
|
+
return np.asarray(model.predict(X)).reshape(-1)
|
|
288
|
+
|
|
289
|
+
def compile(self, optimizer=None, loss=None, loss_params=None, **kwargs) -> None:
|
|
290
|
+
del optimizer, loss, loss_params, kwargs # not used for tree models
|
|
291
|
+
|
|
292
|
+
def fit(
|
|
293
|
+
self,
|
|
294
|
+
train_data: Any,
|
|
295
|
+
valid_data: Any | None = None,
|
|
296
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
297
|
+
epochs: int = 1,
|
|
298
|
+
batch_size: int = 512,
|
|
299
|
+
shuffle: bool = True,
|
|
300
|
+
num_workers: int = 0,
|
|
301
|
+
user_id_column: str | None = None,
|
|
302
|
+
ignore_label: int | float | None = None,
|
|
303
|
+
**kwargs: Any,
|
|
304
|
+
) -> None:
|
|
305
|
+
del batch_size, shuffle, num_workers # not used for tree models
|
|
306
|
+
self.assert_task()
|
|
307
|
+
if train_data is None:
|
|
308
|
+
raise ValueError(f"[{self.model_name}-fit Error] train_data is required.")
|
|
309
|
+
|
|
310
|
+
setup_logger(session_id=self.session_path)
|
|
311
|
+
logging.info("")
|
|
312
|
+
logging.info(colorize("[Tree Model]", color="bright_blue", bold=True))
|
|
313
|
+
logging.info(colorize("-" * 80, color="bright_blue"))
|
|
314
|
+
logging.info(format_kv("Model", self.__class__.__name__))
|
|
315
|
+
logging.info(format_kv("Session ID", self.session_id))
|
|
316
|
+
logging.info(format_kv("Device", self.device))
|
|
317
|
+
|
|
318
|
+
target_names = self.target_columns or ["label"]
|
|
319
|
+
metrics_list, task_specific_metrics, _ = configure_metrics(
|
|
320
|
+
self.task, metrics, target_names
|
|
321
|
+
)
|
|
322
|
+
need_user_id = check_user_id(metrics_list, task_specific_metrics)
|
|
323
|
+
id_column = user_id_column or (self.id_columns[0] if self.id_columns else None)
|
|
324
|
+
include_ids = need_user_id and id_column is not None
|
|
325
|
+
|
|
326
|
+
X_train, y_train, train_ids = self.prepare_arrays(
|
|
327
|
+
train_data,
|
|
328
|
+
require_labels=True,
|
|
329
|
+
include_ids=include_ids,
|
|
330
|
+
id_column=id_column,
|
|
331
|
+
)
|
|
332
|
+
X_valid = y_valid = valid_ids = None
|
|
333
|
+
if valid_data is not None:
|
|
334
|
+
X_valid, y_valid, valid_ids = self.prepare_arrays(
|
|
335
|
+
valid_data,
|
|
336
|
+
require_labels=True,
|
|
337
|
+
include_ids=include_ids,
|
|
338
|
+
id_column=id_column,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
logging.info("")
|
|
342
|
+
logging.info(colorize("[Features]", color="bright_blue", bold=True))
|
|
343
|
+
logging.info(colorize("-" * 80, color="bright_blue"))
|
|
344
|
+
logging.info(format_kv("Dense features", len(self.dense_features)))
|
|
345
|
+
logging.info(format_kv("Sparse features", len(self.sparse_features)))
|
|
346
|
+
logging.info(format_kv("Sequence features", len(self.sequence_features)))
|
|
347
|
+
logging.info(format_kv("Targets", len(target_names)))
|
|
348
|
+
logging.info(format_kv("Train rows", X_train.shape[0]))
|
|
349
|
+
if X_valid is not None:
|
|
350
|
+
logging.info(format_kv("Valid rows", X_valid.shape[0]))
|
|
351
|
+
|
|
352
|
+
model = self.build_estimator(dict(self.model_params), epochs)
|
|
353
|
+
self.model = self.fit_estimator(
|
|
354
|
+
model,
|
|
355
|
+
X_train,
|
|
356
|
+
y_train,
|
|
357
|
+
X_valid,
|
|
358
|
+
y_valid,
|
|
359
|
+
self._cat_feature_indices,
|
|
360
|
+
**kwargs,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if metrics_list and y_valid is not None and X_valid is not None:
|
|
364
|
+
y_pred = self.predict_scores(self.model, X_valid)
|
|
365
|
+
metrics_dict = evaluate_metrics(
|
|
366
|
+
y_valid,
|
|
367
|
+
y_pred,
|
|
368
|
+
metrics_list,
|
|
369
|
+
self.task,
|
|
370
|
+
target_names,
|
|
371
|
+
task_specific_metrics=task_specific_metrics,
|
|
372
|
+
user_ids=valid_ids,
|
|
373
|
+
ignore_label=ignore_label,
|
|
374
|
+
)
|
|
375
|
+
display_metrics_table(
|
|
376
|
+
epoch=1,
|
|
377
|
+
epochs=1,
|
|
378
|
+
split="valid",
|
|
379
|
+
loss=None,
|
|
380
|
+
metrics=metrics_dict,
|
|
381
|
+
target_names=target_names,
|
|
382
|
+
base_metrics=metrics_list,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
self.save_model()
|
|
386
|
+
|
|
387
|
+
@overload
|
|
388
|
+
def predict(
|
|
389
|
+
self,
|
|
390
|
+
data: Any,
|
|
391
|
+
batch_size: int = 512,
|
|
392
|
+
save_path: None = None,
|
|
393
|
+
save_format: str = "csv",
|
|
394
|
+
include_ids: bool | None = None,
|
|
395
|
+
id_columns: str | list[str] | None = None,
|
|
396
|
+
return_dataframe: Literal[True] = True,
|
|
397
|
+
stream_chunk_size: int = 10000,
|
|
398
|
+
num_workers: int = 0,
|
|
399
|
+
) -> pd.DataFrame: ...
|
|
400
|
+
|
|
401
|
+
@overload
|
|
402
|
+
def predict(
|
|
403
|
+
self,
|
|
404
|
+
data: Any,
|
|
405
|
+
batch_size: int = 512,
|
|
406
|
+
save_path: None = None,
|
|
407
|
+
save_format: str = "csv",
|
|
408
|
+
include_ids: bool | None = None,
|
|
409
|
+
id_columns: str | list[str] | None = None,
|
|
410
|
+
return_dataframe: Literal[False] = False,
|
|
411
|
+
stream_chunk_size: int = 10000,
|
|
412
|
+
num_workers: int = 0,
|
|
413
|
+
) -> np.ndarray: ...
|
|
414
|
+
|
|
415
|
+
@overload
|
|
416
|
+
def predict(
|
|
417
|
+
self,
|
|
418
|
+
data: Any,
|
|
419
|
+
batch_size: int = 512,
|
|
420
|
+
*,
|
|
421
|
+
save_path: str | os.PathLike,
|
|
422
|
+
save_format: str = "csv",
|
|
423
|
+
include_ids: bool | None = None,
|
|
424
|
+
id_columns: str | list[str] | None = None,
|
|
425
|
+
return_dataframe: Literal[True] = True,
|
|
426
|
+
stream_chunk_size: int = 10000,
|
|
427
|
+
num_workers: int = 0,
|
|
428
|
+
) -> pd.DataFrame: ...
|
|
429
|
+
|
|
430
|
+
@overload
|
|
431
|
+
def predict(
|
|
432
|
+
self,
|
|
433
|
+
data: Any,
|
|
434
|
+
batch_size: int = 512,
|
|
435
|
+
*,
|
|
436
|
+
save_path: str | os.PathLike,
|
|
437
|
+
save_format: str = "csv",
|
|
438
|
+
include_ids: bool | None = None,
|
|
439
|
+
id_columns: str | list[str] | None = None,
|
|
440
|
+
return_dataframe: Literal[False] = False,
|
|
441
|
+
stream_chunk_size: int = 10000,
|
|
442
|
+
num_workers: int = 0,
|
|
443
|
+
) -> Path: ...
|
|
444
|
+
|
|
445
|
+
def predict(
|
|
446
|
+
self,
|
|
447
|
+
data: Any,
|
|
448
|
+
batch_size: int = 512,
|
|
449
|
+
save_path: str | os.PathLike | None = None,
|
|
450
|
+
save_format: str = "csv",
|
|
451
|
+
include_ids: bool | None = None,
|
|
452
|
+
id_columns: str | list[str] | None = None,
|
|
453
|
+
return_dataframe: bool = True,
|
|
454
|
+
stream_chunk_size: int = 10000,
|
|
455
|
+
num_workers: int = 0,
|
|
456
|
+
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
457
|
+
del batch_size, num_workers # not used for tree models
|
|
458
|
+
|
|
459
|
+
if self.model is None:
|
|
460
|
+
raise ValueError(f"[{self.model_name}-predict Error] Model is not loaded.")
|
|
461
|
+
|
|
462
|
+
predict_id_columns = to_list(id_columns) or self.id_columns
|
|
463
|
+
if include_ids is None:
|
|
464
|
+
include_ids = bool(predict_id_columns)
|
|
465
|
+
include_ids = include_ids and bool(predict_id_columns)
|
|
466
|
+
|
|
467
|
+
if save_path is not None and not return_dataframe:
|
|
468
|
+
return self.predict_streaming(
|
|
469
|
+
data=data,
|
|
470
|
+
save_path=save_path,
|
|
471
|
+
save_format=save_format,
|
|
472
|
+
include_ids=include_ids,
|
|
473
|
+
stream_chunk_size=stream_chunk_size,
|
|
474
|
+
id_columns=predict_id_columns,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
if isinstance(data, (str, os.PathLike)):
|
|
478
|
+
rec_loader = RecDataLoader(
|
|
479
|
+
dense_features=self.dense_features,
|
|
480
|
+
sparse_features=self.sparse_features,
|
|
481
|
+
sequence_features=self.sequence_features,
|
|
482
|
+
target=None,
|
|
483
|
+
id_columns=predict_id_columns,
|
|
484
|
+
)
|
|
485
|
+
data = rec_loader.create_dataloader(
|
|
486
|
+
data=data,
|
|
487
|
+
batch_size=stream_chunk_size,
|
|
488
|
+
shuffle=False,
|
|
489
|
+
streaming=True,
|
|
490
|
+
chunk_size=stream_chunk_size,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
X_all, _, ids_all = self.prepare_arrays(
|
|
494
|
+
data,
|
|
495
|
+
require_labels=False,
|
|
496
|
+
include_ids=include_ids,
|
|
497
|
+
id_column=predict_id_columns[0] if predict_id_columns else None,
|
|
498
|
+
)
|
|
499
|
+
y_pred = self.predict_scores(self.model, X_all)
|
|
500
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
501
|
+
|
|
502
|
+
pred_columns = self.target_columns or ["pred"]
|
|
503
|
+
pred_df = pd.DataFrame(y_pred, columns=pred_columns[:1])
|
|
504
|
+
if include_ids and ids_all is not None:
|
|
505
|
+
id_df = pd.DataFrame({predict_id_columns[0]: ids_all})
|
|
506
|
+
output = pd.concat([id_df, pred_df], axis=1)
|
|
507
|
+
else:
|
|
508
|
+
output = pred_df if return_dataframe else y_pred
|
|
509
|
+
|
|
510
|
+
if save_path is not None:
|
|
511
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
512
|
+
target_path = get_save_path(
|
|
513
|
+
path=save_path,
|
|
514
|
+
default_dir=self.session.predictions_dir,
|
|
515
|
+
default_name="predictions",
|
|
516
|
+
suffix=suffix,
|
|
517
|
+
add_timestamp=True if save_path is None else False,
|
|
518
|
+
)
|
|
519
|
+
if isinstance(output, pd.DataFrame):
|
|
520
|
+
df_to_save = output
|
|
521
|
+
else:
|
|
522
|
+
df_to_save = pd.DataFrame(y_pred, columns=pred_columns[:1])
|
|
523
|
+
if include_ids and ids_all is not None and predict_id_columns:
|
|
524
|
+
id_df = pd.DataFrame({predict_id_columns[0]: ids_all})
|
|
525
|
+
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
526
|
+
if save_format == "csv":
|
|
527
|
+
df_to_save.to_csv(target_path, index=False)
|
|
528
|
+
elif save_format == "parquet":
|
|
529
|
+
df_to_save.to_parquet(target_path, index=False)
|
|
530
|
+
elif save_format == "feather":
|
|
531
|
+
df_to_save.to_feather(target_path)
|
|
532
|
+
elif save_format == "excel":
|
|
533
|
+
df_to_save.to_excel(target_path, index=False)
|
|
534
|
+
elif save_format == "hdf5":
|
|
535
|
+
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
536
|
+
else:
|
|
537
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
538
|
+
logging.info(f"Predictions saved to: {target_path}")
|
|
539
|
+
return output
|
|
540
|
+
|
|
541
|
+
def predict_streaming(
|
|
542
|
+
self,
|
|
543
|
+
data: Any,
|
|
544
|
+
save_path: str | os.PathLike,
|
|
545
|
+
save_format: str,
|
|
546
|
+
include_ids: bool,
|
|
547
|
+
stream_chunk_size: int,
|
|
548
|
+
id_columns: list[str] | None,
|
|
549
|
+
) -> Path:
|
|
550
|
+
if isinstance(data, (str, os.PathLike)):
|
|
551
|
+
rec_loader = RecDataLoader(
|
|
552
|
+
dense_features=self.dense_features,
|
|
553
|
+
sparse_features=self.sparse_features,
|
|
554
|
+
sequence_features=self.sequence_features,
|
|
555
|
+
target=None,
|
|
556
|
+
id_columns=id_columns,
|
|
557
|
+
)
|
|
558
|
+
data_loader = rec_loader.create_dataloader(
|
|
559
|
+
data=data,
|
|
560
|
+
batch_size=stream_chunk_size,
|
|
561
|
+
shuffle=False,
|
|
562
|
+
streaming=True,
|
|
563
|
+
chunk_size=stream_chunk_size,
|
|
564
|
+
)
|
|
565
|
+
else:
|
|
566
|
+
data_loader = data
|
|
567
|
+
|
|
568
|
+
if not check_streaming_support(save_format):
|
|
569
|
+
logging.warning(
|
|
570
|
+
f"[{self.model_name}-predict Warning] Format '{save_format}' does not support streaming writes."
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
574
|
+
target_path = get_save_path(
|
|
575
|
+
path=save_path,
|
|
576
|
+
default_dir=self.session.predictions_dir,
|
|
577
|
+
default_name="predictions",
|
|
578
|
+
suffix=suffix,
|
|
579
|
+
add_timestamp=True if save_path is None else False,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
header_written = False
|
|
583
|
+
parquet_writer = None
|
|
584
|
+
collected_frames: list[pd.DataFrame] = []
|
|
585
|
+
id_column = id_columns[0] if id_columns else None
|
|
586
|
+
for batch in data_loader:
|
|
587
|
+
X_batch = self.features_to_matrix(batch.get("features", {}))
|
|
588
|
+
y_pred = self.predict_scores(self.model, X_batch).reshape(-1, 1)
|
|
589
|
+
pred_df = pd.DataFrame(y_pred, columns=self.target_columns or ["pred"])
|
|
590
|
+
if include_ids and id_column:
|
|
591
|
+
ids = self.extract_ids(batch.get("ids"), id_column)
|
|
592
|
+
if ids is not None:
|
|
593
|
+
pred_df.insert(0, id_column, ids)
|
|
594
|
+
if save_format == "csv":
|
|
595
|
+
pred_df.to_csv(
|
|
596
|
+
target_path, mode="a", header=not header_written, index=False
|
|
597
|
+
)
|
|
598
|
+
elif save_format == "parquet":
|
|
599
|
+
try:
|
|
600
|
+
import pyarrow as pa
|
|
601
|
+
import pyarrow.parquet as pq
|
|
602
|
+
except ImportError as exc: # pragma: no cover
|
|
603
|
+
raise ImportError(
|
|
604
|
+
f"[{self.model_name}-predict Error] Parquet streaming save requires pyarrow."
|
|
605
|
+
) from exc
|
|
606
|
+
table = pa.Table.from_pandas(pred_df, preserve_index=False)
|
|
607
|
+
if parquet_writer is None:
|
|
608
|
+
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
609
|
+
parquet_writer.write_table(table)
|
|
610
|
+
else:
|
|
611
|
+
collected_frames.append(pred_df)
|
|
612
|
+
header_written = True
|
|
613
|
+
if parquet_writer is not None:
|
|
614
|
+
parquet_writer.close()
|
|
615
|
+
if collected_frames:
|
|
616
|
+
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
617
|
+
if save_format == "feather":
|
|
618
|
+
combined_df.to_feather(target_path)
|
|
619
|
+
elif save_format == "excel":
|
|
620
|
+
combined_df.to_excel(target_path, index=False)
|
|
621
|
+
elif save_format == "hdf5":
|
|
622
|
+
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
623
|
+
else:
|
|
624
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
625
|
+
return target_path
|
|
626
|
+
|
|
627
|
+
def save_model(self, save_path: str | os.PathLike | None = None) -> Path:
|
|
628
|
+
if self.model is None:
|
|
629
|
+
raise ValueError(f"[{self.model_name}-save Error] Model is not fitted.")
|
|
630
|
+
target_path = get_save_path(
|
|
631
|
+
path=save_path,
|
|
632
|
+
default_dir=self.session_path,
|
|
633
|
+
default_name=self.model_name.upper(),
|
|
634
|
+
suffix=self.model_file_suffix,
|
|
635
|
+
add_timestamp=True if save_path is None else False,
|
|
636
|
+
)
|
|
637
|
+
self.save_model_file(self.model, target_path)
|
|
638
|
+
with open(self.features_config_path, "wb") as handle:
|
|
639
|
+
pickle.dump(
|
|
640
|
+
{
|
|
641
|
+
"all_features": self.all_features,
|
|
642
|
+
"target": self.target_columns,
|
|
643
|
+
"id_columns": self.id_columns,
|
|
644
|
+
},
|
|
645
|
+
handle,
|
|
646
|
+
)
|
|
647
|
+
return target_path
|
|
648
|
+
|
|
649
|
+
def save_model_file(self, model: Any, path: Path) -> None:
|
|
650
|
+
raise NotImplementedError
|
|
651
|
+
|
|
652
|
+
def load_model(
|
|
653
|
+
self,
|
|
654
|
+
save_path: str | os.PathLike,
|
|
655
|
+
map_location: str | None = None,
|
|
656
|
+
verbose: bool = True,
|
|
657
|
+
) -> None:
|
|
658
|
+
del map_location
|
|
659
|
+
model_path = Path(save_path)
|
|
660
|
+
if model_path.is_dir():
|
|
661
|
+
candidates = sorted(model_path.glob(f"*.{self.model_file_suffix}"))
|
|
662
|
+
if not candidates:
|
|
663
|
+
raise FileNotFoundError(
|
|
664
|
+
f"[{self.model_name}-load Error] No model file found in {model_path}"
|
|
665
|
+
)
|
|
666
|
+
model_path = candidates[-1]
|
|
667
|
+
if not model_path.exists():
|
|
668
|
+
raise FileNotFoundError(
|
|
669
|
+
f"[{self.model_name}-load Error] Model file does not exist: {model_path}"
|
|
670
|
+
)
|
|
671
|
+
self.model = self.load_model_file(model_path)
|
|
672
|
+
config_path = model_path.parent / "features_config.pkl"
|
|
673
|
+
if config_path.exists():
|
|
674
|
+
with open(config_path, "rb") as handle:
|
|
675
|
+
cfg = pickle.load(handle)
|
|
676
|
+
all_features = cfg.get("all_features", [])
|
|
677
|
+
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
678
|
+
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
679
|
+
sequence_features = [
|
|
680
|
+
f for f in all_features if isinstance(f, SequenceFeature)
|
|
681
|
+
]
|
|
682
|
+
self.set_all_features(
|
|
683
|
+
dense_features=dense_features,
|
|
684
|
+
sparse_features=sparse_features,
|
|
685
|
+
sequence_features=sequence_features,
|
|
686
|
+
target=cfg.get("target"),
|
|
687
|
+
id_columns=cfg.get("id_columns"),
|
|
688
|
+
)
|
|
689
|
+
if verbose:
|
|
690
|
+
logging.info(f"Model loaded from: {model_path}")
|
|
691
|
+
|
|
692
|
+
def load_model_file(self, path: Path) -> Any:
|
|
693
|
+
raise NotImplementedError
|