nextrec 0.4.30__py3-none-any.whl → 0.4.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/model.py +48 -4
  3. nextrec/cli.py +18 -10
  4. nextrec/data/batch_utils.py +2 -2
  5. nextrec/data/preprocessor.py +53 -1
  6. nextrec/models/multi_task/[pre]aitm.py +3 -3
  7. nextrec/models/multi_task/[pre]snr_trans.py +3 -3
  8. nextrec/models/multi_task/[pre]star.py +3 -3
  9. nextrec/models/multi_task/apg.py +3 -3
  10. nextrec/models/multi_task/cross_stitch.py +3 -3
  11. nextrec/models/multi_task/escm.py +3 -3
  12. nextrec/models/multi_task/esmm.py +3 -3
  13. nextrec/models/multi_task/hmoe.py +3 -3
  14. nextrec/models/multi_task/mmoe.py +3 -3
  15. nextrec/models/multi_task/pepnet.py +4 -4
  16. nextrec/models/multi_task/ple.py +3 -3
  17. nextrec/models/multi_task/poso.py +3 -3
  18. nextrec/models/multi_task/share_bottom.py +3 -3
  19. nextrec/models/ranking/afm.py +3 -2
  20. nextrec/models/ranking/autoint.py +3 -2
  21. nextrec/models/ranking/dcn.py +3 -2
  22. nextrec/models/ranking/dcn_v2.py +3 -2
  23. nextrec/models/ranking/deepfm.py +3 -2
  24. nextrec/models/ranking/dien.py +3 -2
  25. nextrec/models/ranking/din.py +3 -2
  26. nextrec/models/ranking/eulernet.py +3 -2
  27. nextrec/models/ranking/ffm.py +3 -2
  28. nextrec/models/ranking/fibinet.py +3 -2
  29. nextrec/models/ranking/fm.py +3 -2
  30. nextrec/models/ranking/lr.py +3 -2
  31. nextrec/models/ranking/masknet.py +3 -2
  32. nextrec/models/ranking/pnn.py +3 -2
  33. nextrec/models/ranking/widedeep.py +3 -2
  34. nextrec/models/ranking/xdeepfm.py +3 -2
  35. nextrec/models/tree_base/__init__.py +15 -0
  36. nextrec/models/tree_base/base.py +693 -0
  37. nextrec/models/tree_base/catboost.py +97 -0
  38. nextrec/models/tree_base/lightgbm.py +69 -0
  39. nextrec/models/tree_base/xgboost.py +61 -0
  40. nextrec/utils/config.py +1 -0
  41. nextrec/utils/types.py +2 -0
  42. {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/METADATA +5 -5
  43. {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/RECORD +46 -41
  44. {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/licenses/LICENSE +1 -1
  45. {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/WHEEL +0 -0
  46. {nextrec-0.4.30.dist-info → nextrec-0.4.32.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,693 @@
1
+ """
2
+ Tree-based model base for NextRec.
3
+
4
+ This module provides a lightweight adapter to plug tree models (xgboost/lightgbm/catboost)
5
+ into the NextRec training/prediction pipeline.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import os
12
+ import pickle
13
+ from pathlib import Path
14
+ from typing import Any, Iterable, Literal, overload
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+ from nextrec.basic.features import (
20
+ DenseFeature,
21
+ FeatureSet,
22
+ SequenceFeature,
23
+ SparseFeature,
24
+ )
25
+ from nextrec.basic.loggers import colorize, format_kv, setup_logger
26
+ from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
27
+ from nextrec.basic.session import create_session, get_save_path
28
+ from nextrec.data.dataloader import RecDataLoader
29
+ from nextrec.data.data_processing import get_column_data
30
+ from nextrec.utils.console import display_metrics_table
31
+ from nextrec.utils.data import FILE_FORMAT_CONFIG, check_streaming_support
32
+ from nextrec.utils.feature import to_list
33
+ from nextrec.utils.torch_utils import to_numpy
34
+
35
+
36
+ class TreeBaseModel(FeatureSet):
37
+ model_file_suffix = "bin"
38
+
39
+ @property
40
+ def model_name(self) -> str:
41
+ return self.__class__.__name__.lower()
42
+
43
+ @property
44
+ def default_task(self) -> str:
45
+ return "binary"
46
+
47
+ def __init__(
48
+ self,
49
+ dense_features: list[DenseFeature] | None = None,
50
+ sparse_features: list[SparseFeature] | None = None,
51
+ sequence_features: list[SequenceFeature] | None = None,
52
+ target: list[str] | str | None = None,
53
+ id_columns: list[str] | str | None = None,
54
+ task: str | list[str] | None = None,
55
+ device: str = "cpu",
56
+ session_id: str | None = None,
57
+ model_params: dict[str, Any] | None = None,
58
+ sequence_pooling: str = "mean",
59
+ **kwargs: Any,
60
+ ):
61
+ self.device = device
62
+ self.model_params = dict(model_params or {}) # tree model parameters
63
+ if kwargs:
64
+ self.model_params.update(kwargs)
65
+ self.sequence_pooling = sequence_pooling
66
+
67
+ self.set_all_features(
68
+ dense_features, sparse_features, sequence_features, target, id_columns
69
+ )
70
+ self.task = task or self.default_task
71
+
72
+ self.session_id = session_id
73
+ self.session = create_session(session_id)
74
+ self.session_path = self.session.root
75
+ self.checkpoint_path = os.path.join(
76
+ self.session_path,
77
+ f"{self.model_name.upper()}_checkpoint.{self.model_file_suffix}",
78
+ )
79
+ self.best_path = os.path.join(
80
+ self.session_path,
81
+ f"{self.model_name.upper()}_best.{self.model_file_suffix}",
82
+ )
83
+ self.features_config_path = os.path.join(
84
+ self.session_path, "features_config.pkl"
85
+ )
86
+
87
+ self.model: Any | None = None
88
+ self._cat_feature_indices: list[int] = []
89
+
90
+ def assert_task(self) -> None:
91
+ if self.target_columns and len(self.target_columns) > 1:
92
+ raise ValueError(
93
+ f"[{self.model_name}-init Error] tree models only support a single target column."
94
+ )
95
+ if isinstance(self.task, list) and len(self.task) > 1:
96
+ raise ValueError(
97
+ f"[{self.model_name}-init Error] tree models only support a single task type."
98
+ )
99
+
100
+ def pool_sequence(self, arr: np.ndarray, feature: SequenceFeature) -> np.ndarray:
101
+ if arr.ndim == 1:
102
+ return arr.reshape(-1, 1)
103
+ padding_value = feature.padding_idx
104
+ mask = arr != padding_value
105
+ if self.sequence_pooling == "sum":
106
+ pooled = (arr * mask).sum(axis=1)
107
+ elif self.sequence_pooling == "max":
108
+ masked = np.where(mask, arr, -np.inf)
109
+ pooled = np.max(masked, axis=1)
110
+ pooled = np.where(np.isfinite(pooled), pooled, 0.0)
111
+ elif self.sequence_pooling == "last":
112
+ idx = np.where(mask, np.arange(arr.shape[1]), -1)
113
+ last_idx = idx.max(axis=1)
114
+ pooled = np.array(
115
+ [arr[row, col] if col >= 0 else 0.0 for row, col in enumerate(last_idx)]
116
+ )
117
+ else:
118
+ counts = np.maximum(mask.sum(axis=1), 1)
119
+ pooled = (arr * mask).sum(axis=1) / counts
120
+ return pooled.reshape(-1, 1).astype(np.float32)
121
+
122
+ def features_to_matrix(self, features: dict[str, Any]) -> np.ndarray:
123
+ columns: list[np.ndarray] = []
124
+ cat_indices: list[int] = []
125
+ feature_offset = 0
126
+ for feat in self.all_features:
127
+ if feat.name not in features:
128
+ raise KeyError(
129
+ f"[{self.model_name}-data Error] Missing feature '{feat.name}'."
130
+ )
131
+ arr = to_numpy(features[feat.name])
132
+ if isinstance(feat, SequenceFeature):
133
+ arr = self.pool_sequence(arr, feat)
134
+ if arr.ndim == 1:
135
+ arr = arr.reshape(-1, 1)
136
+ if isinstance(feat, SparseFeature):
137
+ for col_idx in range(arr.shape[1]):
138
+ cat_indices.append(feature_offset + col_idx)
139
+ feature_offset += arr.shape[1]
140
+ columns.append(arr.astype(np.float32))
141
+ if columns:
142
+ self._cat_feature_indices = cat_indices
143
+ return np.concatenate(columns, axis=1)
144
+ return np.empty((0, 0), dtype=np.float32)
145
+
146
+ def extract_labels(self, labels: dict[str, Any] | None) -> np.ndarray | None:
147
+ if labels is None:
148
+ return None
149
+ if self.target_columns:
150
+ target = self.target_columns[0]
151
+ if target not in labels:
152
+ return None
153
+ return to_numpy(labels[target]).reshape(-1)
154
+ first_key = next(iter(labels.keys()), None)
155
+ if first_key is None:
156
+ return None
157
+ return to_numpy(labels[first_key]).reshape(-1)
158
+
159
+ def extract_ids(
160
+ self, ids: dict[str, Any] | None, id_column: str | None
161
+ ) -> np.ndarray | None:
162
+ if ids is None or id_column is None:
163
+ return None
164
+ if id_column not in ids:
165
+ return None
166
+ return np.asarray(ids[id_column]).reshape(-1)
167
+
168
+ def collect_from_dataloader(
169
+ self,
170
+ data_loader: Iterable,
171
+ require_labels: bool,
172
+ include_ids: bool,
173
+ id_column: str | None,
174
+ ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
175
+ feature_chunks: list[np.ndarray] = []
176
+ label_chunks: list[np.ndarray] = []
177
+ id_chunks: list[np.ndarray] = []
178
+ for batch in data_loader:
179
+ if not isinstance(batch, dict) or "features" not in batch:
180
+ raise TypeError(
181
+ f"[{self.model_name}-data Error] Expected batches with a 'features' dict."
182
+ )
183
+ features = batch.get("features", {})
184
+ labels = batch.get("labels")
185
+ ids = batch.get("ids")
186
+ X_batch = self.features_to_matrix(features)
187
+ feature_chunks.append(X_batch)
188
+ y_batch = self.extract_labels(labels)
189
+ if require_labels and y_batch is None:
190
+ raise ValueError(
191
+ f"[{self.model_name}-data Error] Labels are required but missing."
192
+ )
193
+ if y_batch is not None:
194
+ label_chunks.append(y_batch)
195
+ if include_ids and id_column:
196
+ id_batch = self.extract_ids(ids, id_column)
197
+ if id_batch is not None:
198
+ id_chunks.append(id_batch)
199
+ X_all = (
200
+ np.concatenate(feature_chunks, axis=0)
201
+ if feature_chunks
202
+ else np.empty((0, 0))
203
+ )
204
+ y_all = np.concatenate(label_chunks, axis=0) if label_chunks else None
205
+ ids_all = np.concatenate(id_chunks, axis=0) if id_chunks else None
206
+ return X_all, y_all, ids_all
207
+
208
+ def collect_from_table(
209
+ self,
210
+ data: dict | pd.DataFrame,
211
+ require_labels: bool,
212
+ include_ids: bool,
213
+ id_column: str | None,
214
+ ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
215
+ features: dict[str, Any] = {}
216
+ for feat in self.all_features:
217
+ column = get_column_data(data, feat.name)
218
+ if column is None:
219
+ raise KeyError(
220
+ f"[{self.model_name}-data Error] Missing feature '{feat.name}'."
221
+ )
222
+ features[feat.name] = column
223
+ X_all = self.features_to_matrix(features)
224
+ y_all = None
225
+ if require_labels:
226
+ label_payload: dict[str, Any] = {}
227
+ for name in self.target_columns:
228
+ column = get_column_data(data, name)
229
+ if column is not None:
230
+ label_payload[name] = column
231
+ y_all = self.extract_labels(label_payload or None)
232
+ if y_all is None:
233
+ raise ValueError(
234
+ f"[{self.model_name}-data Error] Labels are required but missing."
235
+ )
236
+ ids_all = None
237
+ if include_ids and id_column:
238
+ id_col = get_column_data(data, id_column)
239
+ if id_col is not None:
240
+ ids_all = np.asarray(id_col).reshape(-1)
241
+ return X_all, y_all, ids_all
242
+
243
+ def prepare_arrays(
244
+ self,
245
+ data: Any,
246
+ require_labels: bool,
247
+ include_ids: bool,
248
+ id_column: str | None,
249
+ ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
250
+ if isinstance(data, (str, os.PathLike)):
251
+ raise TypeError(
252
+ f"[{self.model_name}-data Error] File paths are not supported here. "
253
+ "Use RecDataLoader to create a DataLoader for training."
254
+ )
255
+ if isinstance(data, (pd.DataFrame, dict)):
256
+ return self.collect_from_table(data, require_labels, include_ids, id_column)
257
+ if isinstance(data, Iterable) and hasattr(data, "__iter__"):
258
+ return self.collect_from_dataloader(
259
+ data, require_labels, include_ids, id_column
260
+ )
261
+ raise TypeError(
262
+ f"[{self.model_name}-data Error] Unsupported data type: {type(data)}"
263
+ )
264
+
265
+ def build_estimator(self, model_params: dict[str, Any], epochs: int | None):
266
+ raise NotImplementedError
267
+
268
+ def fit_estimator(
269
+ self,
270
+ model: Any,
271
+ X_train: np.ndarray,
272
+ y_train: np.ndarray,
273
+ X_valid: np.ndarray | None,
274
+ y_valid: np.ndarray | None,
275
+ cat_features: list[int],
276
+ **kwargs: Any,
277
+ ) -> Any:
278
+ raise NotImplementedError
279
+
280
+ def predict_scores(self, model: Any, X: np.ndarray) -> np.ndarray:
281
+ if hasattr(model, "predict_proba"):
282
+ proba = model.predict_proba(X)
283
+ if isinstance(proba, list):
284
+ proba = np.asarray(proba)
285
+ if proba.ndim == 2 and proba.shape[1] > 1:
286
+ return proba[:, 1]
287
+ return np.asarray(model.predict(X)).reshape(-1)
288
+
289
+ def compile(self, optimizer=None, loss=None, loss_params=None, **kwargs) -> None:
290
+ del optimizer, loss, loss_params, kwargs # not used for tree models
291
+
292
+ def fit(
293
+ self,
294
+ train_data: Any,
295
+ valid_data: Any | None = None,
296
+ metrics: list[str] | dict[str, list[str]] | None = None,
297
+ epochs: int = 1,
298
+ batch_size: int = 512,
299
+ shuffle: bool = True,
300
+ num_workers: int = 0,
301
+ user_id_column: str | None = None,
302
+ ignore_label: int | float | None = None,
303
+ **kwargs: Any,
304
+ ) -> None:
305
+ del batch_size, shuffle, num_workers # not used for tree models
306
+ self.assert_task()
307
+ if train_data is None:
308
+ raise ValueError(f"[{self.model_name}-fit Error] train_data is required.")
309
+
310
+ setup_logger(session_id=self.session_path)
311
+ logging.info("")
312
+ logging.info(colorize("[Tree Model]", color="bright_blue", bold=True))
313
+ logging.info(colorize("-" * 80, color="bright_blue"))
314
+ logging.info(format_kv("Model", self.__class__.__name__))
315
+ logging.info(format_kv("Session ID", self.session_id))
316
+ logging.info(format_kv("Device", self.device))
317
+
318
+ target_names = self.target_columns or ["label"]
319
+ metrics_list, task_specific_metrics, _ = configure_metrics(
320
+ self.task, metrics, target_names
321
+ )
322
+ need_user_id = check_user_id(metrics_list, task_specific_metrics)
323
+ id_column = user_id_column or (self.id_columns[0] if self.id_columns else None)
324
+ include_ids = need_user_id and id_column is not None
325
+
326
+ X_train, y_train, train_ids = self.prepare_arrays(
327
+ train_data,
328
+ require_labels=True,
329
+ include_ids=include_ids,
330
+ id_column=id_column,
331
+ )
332
+ X_valid = y_valid = valid_ids = None
333
+ if valid_data is not None:
334
+ X_valid, y_valid, valid_ids = self.prepare_arrays(
335
+ valid_data,
336
+ require_labels=True,
337
+ include_ids=include_ids,
338
+ id_column=id_column,
339
+ )
340
+
341
+ logging.info("")
342
+ logging.info(colorize("[Features]", color="bright_blue", bold=True))
343
+ logging.info(colorize("-" * 80, color="bright_blue"))
344
+ logging.info(format_kv("Dense features", len(self.dense_features)))
345
+ logging.info(format_kv("Sparse features", len(self.sparse_features)))
346
+ logging.info(format_kv("Sequence features", len(self.sequence_features)))
347
+ logging.info(format_kv("Targets", len(target_names)))
348
+ logging.info(format_kv("Train rows", X_train.shape[0]))
349
+ if X_valid is not None:
350
+ logging.info(format_kv("Valid rows", X_valid.shape[0]))
351
+
352
+ model = self.build_estimator(dict(self.model_params), epochs)
353
+ self.model = self.fit_estimator(
354
+ model,
355
+ X_train,
356
+ y_train,
357
+ X_valid,
358
+ y_valid,
359
+ self._cat_feature_indices,
360
+ **kwargs,
361
+ )
362
+
363
+ if metrics_list and y_valid is not None and X_valid is not None:
364
+ y_pred = self.predict_scores(self.model, X_valid)
365
+ metrics_dict = evaluate_metrics(
366
+ y_valid,
367
+ y_pred,
368
+ metrics_list,
369
+ self.task,
370
+ target_names,
371
+ task_specific_metrics=task_specific_metrics,
372
+ user_ids=valid_ids,
373
+ ignore_label=ignore_label,
374
+ )
375
+ display_metrics_table(
376
+ epoch=1,
377
+ epochs=1,
378
+ split="valid",
379
+ loss=None,
380
+ metrics=metrics_dict,
381
+ target_names=target_names,
382
+ base_metrics=metrics_list,
383
+ )
384
+
385
+ self.save_model()
386
+
387
+ @overload
388
+ def predict(
389
+ self,
390
+ data: Any,
391
+ batch_size: int = 512,
392
+ save_path: None = None,
393
+ save_format: str = "csv",
394
+ include_ids: bool | None = None,
395
+ id_columns: str | list[str] | None = None,
396
+ return_dataframe: Literal[True] = True,
397
+ stream_chunk_size: int = 10000,
398
+ num_workers: int = 0,
399
+ ) -> pd.DataFrame: ...
400
+
401
+ @overload
402
+ def predict(
403
+ self,
404
+ data: Any,
405
+ batch_size: int = 512,
406
+ save_path: None = None,
407
+ save_format: str = "csv",
408
+ include_ids: bool | None = None,
409
+ id_columns: str | list[str] | None = None,
410
+ return_dataframe: Literal[False] = False,
411
+ stream_chunk_size: int = 10000,
412
+ num_workers: int = 0,
413
+ ) -> np.ndarray: ...
414
+
415
+ @overload
416
+ def predict(
417
+ self,
418
+ data: Any,
419
+ batch_size: int = 512,
420
+ *,
421
+ save_path: str | os.PathLike,
422
+ save_format: str = "csv",
423
+ include_ids: bool | None = None,
424
+ id_columns: str | list[str] | None = None,
425
+ return_dataframe: Literal[True] = True,
426
+ stream_chunk_size: int = 10000,
427
+ num_workers: int = 0,
428
+ ) -> pd.DataFrame: ...
429
+
430
+ @overload
431
+ def predict(
432
+ self,
433
+ data: Any,
434
+ batch_size: int = 512,
435
+ *,
436
+ save_path: str | os.PathLike,
437
+ save_format: str = "csv",
438
+ include_ids: bool | None = None,
439
+ id_columns: str | list[str] | None = None,
440
+ return_dataframe: Literal[False] = False,
441
+ stream_chunk_size: int = 10000,
442
+ num_workers: int = 0,
443
+ ) -> Path: ...
444
+
445
+ def predict(
446
+ self,
447
+ data: Any,
448
+ batch_size: int = 512,
449
+ save_path: str | os.PathLike | None = None,
450
+ save_format: str = "csv",
451
+ include_ids: bool | None = None,
452
+ id_columns: str | list[str] | None = None,
453
+ return_dataframe: bool = True,
454
+ stream_chunk_size: int = 10000,
455
+ num_workers: int = 0,
456
+ ) -> pd.DataFrame | np.ndarray | Path | None:
457
+ del batch_size, num_workers # not used for tree models
458
+
459
+ if self.model is None:
460
+ raise ValueError(f"[{self.model_name}-predict Error] Model is not loaded.")
461
+
462
+ predict_id_columns = to_list(id_columns) or self.id_columns
463
+ if include_ids is None:
464
+ include_ids = bool(predict_id_columns)
465
+ include_ids = include_ids and bool(predict_id_columns)
466
+
467
+ if save_path is not None and not return_dataframe:
468
+ return self.predict_streaming(
469
+ data=data,
470
+ save_path=save_path,
471
+ save_format=save_format,
472
+ include_ids=include_ids,
473
+ stream_chunk_size=stream_chunk_size,
474
+ id_columns=predict_id_columns,
475
+ )
476
+
477
+ if isinstance(data, (str, os.PathLike)):
478
+ rec_loader = RecDataLoader(
479
+ dense_features=self.dense_features,
480
+ sparse_features=self.sparse_features,
481
+ sequence_features=self.sequence_features,
482
+ target=None,
483
+ id_columns=predict_id_columns,
484
+ )
485
+ data = rec_loader.create_dataloader(
486
+ data=data,
487
+ batch_size=stream_chunk_size,
488
+ shuffle=False,
489
+ streaming=True,
490
+ chunk_size=stream_chunk_size,
491
+ )
492
+
493
+ X_all, _, ids_all = self.prepare_arrays(
494
+ data,
495
+ require_labels=False,
496
+ include_ids=include_ids,
497
+ id_column=predict_id_columns[0] if predict_id_columns else None,
498
+ )
499
+ y_pred = self.predict_scores(self.model, X_all)
500
+ y_pred = y_pred.reshape(-1, 1)
501
+
502
+ pred_columns = self.target_columns or ["pred"]
503
+ pred_df = pd.DataFrame(y_pred, columns=pred_columns[:1])
504
+ if include_ids and ids_all is not None:
505
+ id_df = pd.DataFrame({predict_id_columns[0]: ids_all})
506
+ output = pd.concat([id_df, pred_df], axis=1)
507
+ else:
508
+ output = pred_df if return_dataframe else y_pred
509
+
510
+ if save_path is not None:
511
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
512
+ target_path = get_save_path(
513
+ path=save_path,
514
+ default_dir=self.session.predictions_dir,
515
+ default_name="predictions",
516
+ suffix=suffix,
517
+ add_timestamp=True if save_path is None else False,
518
+ )
519
+ if isinstance(output, pd.DataFrame):
520
+ df_to_save = output
521
+ else:
522
+ df_to_save = pd.DataFrame(y_pred, columns=pred_columns[:1])
523
+ if include_ids and ids_all is not None and predict_id_columns:
524
+ id_df = pd.DataFrame({predict_id_columns[0]: ids_all})
525
+ df_to_save = pd.concat([id_df, df_to_save], axis=1)
526
+ if save_format == "csv":
527
+ df_to_save.to_csv(target_path, index=False)
528
+ elif save_format == "parquet":
529
+ df_to_save.to_parquet(target_path, index=False)
530
+ elif save_format == "feather":
531
+ df_to_save.to_feather(target_path)
532
+ elif save_format == "excel":
533
+ df_to_save.to_excel(target_path, index=False)
534
+ elif save_format == "hdf5":
535
+ df_to_save.to_hdf(target_path, key="predictions", mode="w")
536
+ else:
537
+ raise ValueError(f"Unsupported save format: {save_format}")
538
+ logging.info(f"Predictions saved to: {target_path}")
539
+ return output
540
+
541
+ def predict_streaming(
542
+ self,
543
+ data: Any,
544
+ save_path: str | os.PathLike,
545
+ save_format: str,
546
+ include_ids: bool,
547
+ stream_chunk_size: int,
548
+ id_columns: list[str] | None,
549
+ ) -> Path:
550
+ if isinstance(data, (str, os.PathLike)):
551
+ rec_loader = RecDataLoader(
552
+ dense_features=self.dense_features,
553
+ sparse_features=self.sparse_features,
554
+ sequence_features=self.sequence_features,
555
+ target=None,
556
+ id_columns=id_columns,
557
+ )
558
+ data_loader = rec_loader.create_dataloader(
559
+ data=data,
560
+ batch_size=stream_chunk_size,
561
+ shuffle=False,
562
+ streaming=True,
563
+ chunk_size=stream_chunk_size,
564
+ )
565
+ else:
566
+ data_loader = data
567
+
568
+ if not check_streaming_support(save_format):
569
+ logging.warning(
570
+ f"[{self.model_name}-predict Warning] Format '{save_format}' does not support streaming writes."
571
+ )
572
+
573
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
574
+ target_path = get_save_path(
575
+ path=save_path,
576
+ default_dir=self.session.predictions_dir,
577
+ default_name="predictions",
578
+ suffix=suffix,
579
+ add_timestamp=True if save_path is None else False,
580
+ )
581
+
582
+ header_written = False
583
+ parquet_writer = None
584
+ collected_frames: list[pd.DataFrame] = []
585
+ id_column = id_columns[0] if id_columns else None
586
+ for batch in data_loader:
587
+ X_batch = self.features_to_matrix(batch.get("features", {}))
588
+ y_pred = self.predict_scores(self.model, X_batch).reshape(-1, 1)
589
+ pred_df = pd.DataFrame(y_pred, columns=self.target_columns or ["pred"])
590
+ if include_ids and id_column:
591
+ ids = self.extract_ids(batch.get("ids"), id_column)
592
+ if ids is not None:
593
+ pred_df.insert(0, id_column, ids)
594
+ if save_format == "csv":
595
+ pred_df.to_csv(
596
+ target_path, mode="a", header=not header_written, index=False
597
+ )
598
+ elif save_format == "parquet":
599
+ try:
600
+ import pyarrow as pa
601
+ import pyarrow.parquet as pq
602
+ except ImportError as exc: # pragma: no cover
603
+ raise ImportError(
604
+ f"[{self.model_name}-predict Error] Parquet streaming save requires pyarrow."
605
+ ) from exc
606
+ table = pa.Table.from_pandas(pred_df, preserve_index=False)
607
+ if parquet_writer is None:
608
+ parquet_writer = pq.ParquetWriter(target_path, table.schema)
609
+ parquet_writer.write_table(table)
610
+ else:
611
+ collected_frames.append(pred_df)
612
+ header_written = True
613
+ if parquet_writer is not None:
614
+ parquet_writer.close()
615
+ if collected_frames:
616
+ combined_df = pd.concat(collected_frames, ignore_index=True)
617
+ if save_format == "feather":
618
+ combined_df.to_feather(target_path)
619
+ elif save_format == "excel":
620
+ combined_df.to_excel(target_path, index=False)
621
+ elif save_format == "hdf5":
622
+ combined_df.to_hdf(target_path, key="predictions", mode="w")
623
+ else:
624
+ raise ValueError(f"Unsupported save format: {save_format}")
625
+ return target_path
626
+
627
+ def save_model(self, save_path: str | os.PathLike | None = None) -> Path:
628
+ if self.model is None:
629
+ raise ValueError(f"[{self.model_name}-save Error] Model is not fitted.")
630
+ target_path = get_save_path(
631
+ path=save_path,
632
+ default_dir=self.session_path,
633
+ default_name=self.model_name.upper(),
634
+ suffix=self.model_file_suffix,
635
+ add_timestamp=True if save_path is None else False,
636
+ )
637
+ self.save_model_file(self.model, target_path)
638
+ with open(self.features_config_path, "wb") as handle:
639
+ pickle.dump(
640
+ {
641
+ "all_features": self.all_features,
642
+ "target": self.target_columns,
643
+ "id_columns": self.id_columns,
644
+ },
645
+ handle,
646
+ )
647
+ return target_path
648
+
649
+ def save_model_file(self, model: Any, path: Path) -> None:
650
+ raise NotImplementedError
651
+
652
+ def load_model(
653
+ self,
654
+ save_path: str | os.PathLike,
655
+ map_location: str | None = None,
656
+ verbose: bool = True,
657
+ ) -> None:
658
+ del map_location
659
+ model_path = Path(save_path)
660
+ if model_path.is_dir():
661
+ candidates = sorted(model_path.glob(f"*.{self.model_file_suffix}"))
662
+ if not candidates:
663
+ raise FileNotFoundError(
664
+ f"[{self.model_name}-load Error] No model file found in {model_path}"
665
+ )
666
+ model_path = candidates[-1]
667
+ if not model_path.exists():
668
+ raise FileNotFoundError(
669
+ f"[{self.model_name}-load Error] Model file does not exist: {model_path}"
670
+ )
671
+ self.model = self.load_model_file(model_path)
672
+ config_path = model_path.parent / "features_config.pkl"
673
+ if config_path.exists():
674
+ with open(config_path, "rb") as handle:
675
+ cfg = pickle.load(handle)
676
+ all_features = cfg.get("all_features", [])
677
+ dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
678
+ sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
679
+ sequence_features = [
680
+ f for f in all_features if isinstance(f, SequenceFeature)
681
+ ]
682
+ self.set_all_features(
683
+ dense_features=dense_features,
684
+ sparse_features=sparse_features,
685
+ sequence_features=sequence_features,
686
+ target=cfg.get("target"),
687
+ id_columns=cfg.get("id_columns"),
688
+ )
689
+ if verbose:
690
+ logging.info(f"Model loaded from: {model_path}")
691
+
692
+ def load_model_file(self, path: Path) -> Any:
693
+ raise NotImplementedError