nextrec 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +30 -15
- nextrec/basic/features.py +1 -0
- nextrec/basic/layers.py +6 -8
- nextrec/basic/loggers.py +14 -7
- nextrec/basic/metrics.py +6 -76
- nextrec/basic/model.py +312 -318
- nextrec/cli.py +5 -10
- nextrec/data/__init__.py +13 -16
- nextrec/data/batch_utils.py +3 -2
- nextrec/data/data_processing.py +10 -2
- nextrec/data/data_utils.py +9 -14
- nextrec/data/dataloader.py +12 -13
- nextrec/data/preprocessor.py +328 -255
- nextrec/loss/__init__.py +1 -5
- nextrec/loss/loss_utils.py +2 -8
- nextrec/models/generative/__init__.py +1 -8
- nextrec/models/generative/hstu.py +6 -4
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +2 -2
- nextrec/models/multi_task/ple.py +2 -2
- nextrec/models/multi_task/poso.py +2 -3
- nextrec/models/multi_task/share_bottom.py +2 -2
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +2 -2
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/ffm.py +0 -0
- nextrec/models/ranking/fibinet.py +5 -5
- nextrec/models/ranking/fm.py +3 -7
- nextrec/models/ranking/lr.py +0 -0
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +2 -2
- nextrec/models/ranking/widedeep.py +2 -2
- nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec/models/representation/__init__.py +9 -0
- nextrec/models/{generative → representation}/rqvae.py +9 -9
- nextrec/models/retrieval/__init__.py +0 -0
- nextrec/models/{match → retrieval}/dssm.py +8 -3
- nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
- nextrec/models/{match → retrieval}/mind.py +4 -3
- nextrec/models/{match → retrieval}/sdm.py +4 -3
- nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
- nextrec/utils/__init__.py +60 -46
- nextrec/utils/config.py +8 -7
- nextrec/utils/console.py +371 -0
- nextrec/utils/{synthetic_data.py → data.py} +102 -15
- nextrec/utils/feature.py +15 -0
- nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/METADATA +6 -6
- nextrec-0.4.9.dist-info/RECORD +70 -0
- nextrec/utils/cli_utils.py +0 -58
- nextrec/utils/device.py +0 -78
- nextrec/utils/distributed.py +0 -141
- nextrec/utils/file.py +0 -92
- nextrec/utils/initializer.py +0 -79
- nextrec/utils/optimizer.py +0 -75
- nextrec/utils/tensor.py +0 -72
- nextrec-0.4.8.dist-info/RECORD +0 -71
- /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,52 +2,52 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
import getpass
|
|
10
|
+
import logging
|
|
9
11
|
import os
|
|
10
|
-
import tqdm
|
|
11
12
|
import pickle
|
|
12
|
-
import logging
|
|
13
|
-
import getpass
|
|
14
13
|
import socket
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Literal, Union
|
|
16
|
+
|
|
15
17
|
import numpy as np
|
|
16
18
|
import pandas as pd
|
|
17
19
|
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
18
21
|
import torch.nn as nn
|
|
19
22
|
import torch.nn.functional as F
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from pathlib import Path
|
|
23
|
-
from typing import Union, Literal, Any
|
|
23
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
24
24
|
from torch.utils.data import DataLoader
|
|
25
25
|
from torch.utils.data.distributed import DistributedSampler
|
|
26
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
27
26
|
|
|
27
|
+
from nextrec import __version__
|
|
28
28
|
from nextrec.basic.callback import (
|
|
29
|
-
EarlyStopper,
|
|
30
|
-
CallbackList,
|
|
31
29
|
Callback,
|
|
30
|
+
CallbackList,
|
|
32
31
|
CheckpointSaver,
|
|
32
|
+
EarlyStopper,
|
|
33
33
|
LearningRateScheduler,
|
|
34
34
|
)
|
|
35
35
|
from nextrec.basic.features import (
|
|
36
36
|
DenseFeature,
|
|
37
|
-
SparseFeature,
|
|
38
|
-
SequenceFeature,
|
|
39
37
|
FeatureSet,
|
|
38
|
+
SequenceFeature,
|
|
39
|
+
SparseFeature,
|
|
40
40
|
)
|
|
41
|
-
from nextrec.
|
|
42
|
-
|
|
43
|
-
from nextrec.basic.
|
|
44
|
-
from nextrec.
|
|
45
|
-
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
46
|
-
|
|
47
|
-
from nextrec.data.dataloader import build_tensors_from_data
|
|
48
|
-
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
41
|
+
from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
|
|
42
|
+
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
43
|
+
from nextrec.basic.session import create_session, resolve_save_path
|
|
44
|
+
from nextrec.data.batch_utils import batch_to_dict, collate_fn
|
|
49
45
|
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
50
|
-
|
|
46
|
+
from nextrec.data.dataloader import (
|
|
47
|
+
RecDataLoader,
|
|
48
|
+
TensorDictDataset,
|
|
49
|
+
build_tensors_from_data,
|
|
50
|
+
)
|
|
51
51
|
from nextrec.loss import (
|
|
52
52
|
BPRLoss,
|
|
53
53
|
HingeLoss,
|
|
@@ -56,15 +56,16 @@ from nextrec.loss import (
|
|
|
56
56
|
TripletLoss,
|
|
57
57
|
get_loss_fn,
|
|
58
58
|
)
|
|
59
|
-
from nextrec.utils.
|
|
60
|
-
from nextrec.utils.
|
|
61
|
-
|
|
62
|
-
|
|
59
|
+
from nextrec.utils.console import display_metrics_table, progress
|
|
60
|
+
from nextrec.utils.torch_utils import (
|
|
61
|
+
add_distributed_sampler,
|
|
62
|
+
configure_device,
|
|
63
63
|
gather_numpy,
|
|
64
|
+
get_optimizer,
|
|
65
|
+
get_scheduler,
|
|
64
66
|
init_process_group,
|
|
65
|
-
|
|
67
|
+
to_tensor,
|
|
66
68
|
)
|
|
67
|
-
from nextrec import __version__
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
class BaseModel(FeatureSet, nn.Module):
|
|
@@ -90,6 +91,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
90
91
|
dense_l2_reg: float = 0.0,
|
|
91
92
|
device: str = "cpu",
|
|
92
93
|
early_stop_patience: int = 20,
|
|
94
|
+
max_metrics_samples: int | None = 200000,
|
|
93
95
|
session_id: str | None = None,
|
|
94
96
|
callbacks: list[Callback] | None = None,
|
|
95
97
|
distributed: bool = False,
|
|
@@ -116,6 +118,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
116
118
|
|
|
117
119
|
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
118
120
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
121
|
+
max_metrics_samples: Max samples to keep for training metrics. None disables limit.
|
|
119
122
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
120
123
|
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
121
124
|
|
|
@@ -145,7 +148,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
145
148
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
146
149
|
self.checkpoint_path = os.path.join(
|
|
147
150
|
self.session_path, self.model_name + "_checkpoint.pt"
|
|
148
|
-
) #
|
|
151
|
+
) # e.g., pwd/session_id/DeepFM_checkpoint.pt
|
|
149
152
|
self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
|
|
150
153
|
self.features_config_path = os.path.join(
|
|
151
154
|
self.session_path, "features_config.pkl"
|
|
@@ -166,6 +169,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
166
169
|
self.loss_weight = None
|
|
167
170
|
|
|
168
171
|
self.early_stop_patience = early_stop_patience
|
|
172
|
+
self.max_metrics_samples = (
|
|
173
|
+
None if max_metrics_samples is None else int(max_metrics_samples)
|
|
174
|
+
)
|
|
169
175
|
self.max_gradient_norm = 1.0
|
|
170
176
|
self.logger_initialized = False
|
|
171
177
|
self.training_logger = None
|
|
@@ -181,17 +187,15 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
181
187
|
include_modules = include_modules or []
|
|
182
188
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
183
189
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
184
|
-
embedding_params: list[torch.Tensor] = []
|
|
185
190
|
if embed_dict is not None:
|
|
186
|
-
embedding_params
|
|
191
|
+
embedding_params = [
|
|
187
192
|
embed.weight
|
|
188
193
|
for embed in embed_dict.values()
|
|
189
194
|
if hasattr(embed, "weight")
|
|
190
|
-
|
|
195
|
+
]
|
|
191
196
|
else:
|
|
192
197
|
weight = getattr(embedding_layer, "weight", None)
|
|
193
|
-
if isinstance(weight, torch.Tensor)
|
|
194
|
-
embedding_params.append(weight)
|
|
198
|
+
embedding_params = [weight] if isinstance(weight, torch.Tensor) else []
|
|
195
199
|
|
|
196
200
|
existing_embedding_ids = {id(param) for param in self.embedding_params}
|
|
197
201
|
for param in embedding_params:
|
|
@@ -213,10 +217,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
213
217
|
module is self
|
|
214
218
|
or embedding_attr in name
|
|
215
219
|
or isinstance(module, skip_types)
|
|
216
|
-
or (include_modules and not any(inc in name for inc in include_modules))
|
|
217
|
-
or any(exc in name for exc in exclude_modules)
|
|
218
220
|
):
|
|
219
221
|
continue
|
|
222
|
+
if include_modules and not any(inc in name for inc in include_modules):
|
|
223
|
+
continue
|
|
224
|
+
if exclude_modules and any(exc in name for exc in exclude_modules):
|
|
225
|
+
continue
|
|
220
226
|
if isinstance(module, nn.Linear):
|
|
221
227
|
if id(module.weight) not in existing_reg_ids:
|
|
222
228
|
self.regularization_weights.append(module.weight)
|
|
@@ -318,22 +324,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
318
324
|
raise ValueError(
|
|
319
325
|
f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
|
|
320
326
|
)
|
|
321
|
-
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
322
|
-
raise TypeError(
|
|
323
|
-
f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
|
|
324
|
-
)
|
|
325
327
|
if isinstance(train_data, pd.DataFrame):
|
|
326
328
|
total_length = len(train_data)
|
|
327
|
-
|
|
328
|
-
sample_key = next(
|
|
329
|
-
|
|
330
|
-
) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
331
|
-
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
329
|
+
elif isinstance(train_data, dict):
|
|
330
|
+
sample_key = next(iter(train_data))
|
|
331
|
+
total_length = len(train_data[sample_key])
|
|
332
332
|
for k, v in train_data.items():
|
|
333
333
|
if len(v) != total_length:
|
|
334
334
|
raise ValueError(
|
|
335
335
|
f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
|
|
336
336
|
)
|
|
337
|
+
else:
|
|
338
|
+
raise TypeError(
|
|
339
|
+
f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
|
|
340
|
+
)
|
|
337
341
|
rng = np.random.default_rng(42)
|
|
338
342
|
indices = rng.permutation(total_length)
|
|
339
343
|
split_idx = int(total_length * (1 - validation_split))
|
|
@@ -343,12 +347,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
343
347
|
train_split = train_data.iloc[train_indices].reset_index(drop=True)
|
|
344
348
|
valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
|
|
345
349
|
else:
|
|
346
|
-
train_split = {
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
350
|
+
train_split = {
|
|
351
|
+
k: np.asarray(v)[train_indices] for k, v in train_data.items()
|
|
352
|
+
}
|
|
353
|
+
valid_split = {
|
|
354
|
+
k: np.asarray(v)[valid_indices] for k, v in train_data.items()
|
|
355
|
+
}
|
|
352
356
|
train_loader = self.prepare_data_loader(
|
|
353
357
|
train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
|
|
354
358
|
)
|
|
@@ -403,11 +407,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
403
407
|
)
|
|
404
408
|
|
|
405
409
|
scheduler_params = scheduler_params or {}
|
|
406
|
-
if
|
|
407
|
-
self.scheduler_name = scheduler
|
|
408
|
-
elif scheduler is None:
|
|
410
|
+
if scheduler is None:
|
|
409
411
|
self.scheduler_name = None
|
|
410
|
-
|
|
412
|
+
elif isinstance(scheduler, str):
|
|
413
|
+
self.scheduler_name = scheduler
|
|
414
|
+
else:
|
|
411
415
|
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
412
416
|
self.scheduler_params = scheduler_params
|
|
413
417
|
self.scheduler_fn = (
|
|
@@ -418,25 +422,23 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
418
422
|
|
|
419
423
|
self.loss_config = loss
|
|
420
424
|
self.loss_params = loss_params or {}
|
|
421
|
-
|
|
422
|
-
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
425
|
+
if isinstance(loss, list):
|
|
423
426
|
if len(loss) != self.nums_task:
|
|
424
427
|
raise ValueError(
|
|
425
428
|
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
|
|
426
429
|
)
|
|
427
|
-
loss_list =
|
|
428
|
-
else:
|
|
430
|
+
loss_list = list(loss)
|
|
431
|
+
else:
|
|
429
432
|
loss_list = [loss] * self.nums_task
|
|
430
|
-
|
|
431
433
|
if isinstance(self.loss_params, dict):
|
|
432
|
-
|
|
433
|
-
else:
|
|
434
|
-
|
|
434
|
+
loss_params_list = [self.loss_params] * self.nums_task
|
|
435
|
+
else:
|
|
436
|
+
loss_params_list = [
|
|
435
437
|
self.loss_params[i] if i < len(self.loss_params) else {}
|
|
436
438
|
for i in range(self.nums_task)
|
|
437
439
|
]
|
|
438
440
|
self.loss_fn = [
|
|
439
|
-
get_loss_fn(loss=loss_list[i], **
|
|
441
|
+
get_loss_fn(loss=loss_list[i], **loss_params_list[i])
|
|
440
442
|
for i in range(self.nums_task)
|
|
441
443
|
]
|
|
442
444
|
|
|
@@ -448,10 +450,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
448
450
|
raise ValueError(
|
|
449
451
|
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
450
452
|
)
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
weight_value = loss_weights
|
|
454
|
-
self.loss_weights = [float(weight_value)]
|
|
453
|
+
loss_weights = loss_weights[0]
|
|
454
|
+
self.loss_weights = [float(loss_weights)]
|
|
455
455
|
else:
|
|
456
456
|
if isinstance(loss_weights, (int, float)):
|
|
457
457
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -484,7 +484,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
484
484
|
y_true = y_true.view(-1, 1)
|
|
485
485
|
if y_pred.shape != y_true.shape:
|
|
486
486
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
487
|
-
task_dim =
|
|
487
|
+
task_dim = (
|
|
488
|
+
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
489
|
+
)
|
|
488
490
|
if task_dim == 1:
|
|
489
491
|
loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
|
|
490
492
|
else:
|
|
@@ -495,12 +497,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
495
497
|
# multi-task
|
|
496
498
|
if y_pred.shape != y_true.shape:
|
|
497
499
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
498
|
-
|
|
499
|
-
self
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
500
|
+
slices = (
|
|
501
|
+
self.prediction_layer.task_slices # type: ignore
|
|
502
|
+
if hasattr(self, "prediction_layer")
|
|
503
|
+
else [(i, i + 1) for i in range(self.nums_task)]
|
|
504
|
+
)
|
|
504
505
|
task_losses = []
|
|
505
506
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
506
507
|
y_pred_i = y_pred[:, start:end]
|
|
@@ -520,6 +521,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
520
521
|
sampler=None,
|
|
521
522
|
return_dataset: bool = False,
|
|
522
523
|
) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
|
|
524
|
+
"""
|
|
525
|
+
Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
|
|
526
|
+
"""
|
|
523
527
|
if isinstance(data, DataLoader):
|
|
524
528
|
return (data, None) if return_dataset else data
|
|
525
529
|
tensors = build_tensors_from_data(
|
|
@@ -626,54 +630,55 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
626
630
|
)
|
|
627
631
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
628
632
|
|
|
629
|
-
# Setup default callbacks if
|
|
630
|
-
if
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
633
|
+
# Setup default callbacks if missing
|
|
634
|
+
if self.nums_task == 1:
|
|
635
|
+
monitor_metric = f"val_{self.metrics[0]}"
|
|
636
|
+
else:
|
|
637
|
+
monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
|
|
638
|
+
|
|
639
|
+
existing_callbacks = self.callbacks.callbacks
|
|
640
|
+
has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
|
|
641
|
+
has_checkpoint = any(
|
|
642
|
+
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
643
|
+
)
|
|
644
|
+
has_lr_scheduler = any(
|
|
645
|
+
isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
if self.early_stop_patience > 0 and not has_early_stop:
|
|
649
|
+
self.callbacks.append(
|
|
650
|
+
EarlyStopper(
|
|
651
|
+
monitor=monitor_metric,
|
|
652
|
+
patience=self.early_stop_patience,
|
|
653
|
+
mode=self.best_metrics_mode,
|
|
654
|
+
restore_best_weights=not self.distributed,
|
|
655
|
+
verbose=1 if self.is_main_process else 0,
|
|
645
656
|
)
|
|
657
|
+
)
|
|
646
658
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
659
|
+
if self.is_main_process and not has_checkpoint:
|
|
660
|
+
self.callbacks.append(
|
|
661
|
+
CheckpointSaver(
|
|
662
|
+
best_path=self.best_path,
|
|
663
|
+
checkpoint_path=self.checkpoint_path,
|
|
664
|
+
monitor=monitor_metric,
|
|
665
|
+
mode=self.best_metrics_mode,
|
|
666
|
+
save_best_only=True,
|
|
667
|
+
verbose=1,
|
|
656
668
|
)
|
|
669
|
+
)
|
|
657
670
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
)
|
|
671
|
+
if self.scheduler_fn is not None and not has_lr_scheduler:
|
|
672
|
+
self.callbacks.append(
|
|
673
|
+
LearningRateScheduler(
|
|
674
|
+
scheduler=self.scheduler_fn,
|
|
675
|
+
verbose=1 if self.is_main_process else 0,
|
|
664
676
|
)
|
|
677
|
+
)
|
|
665
678
|
|
|
666
679
|
self.callbacks.set_model(self)
|
|
667
680
|
self.callbacks.set_params(
|
|
668
|
-
{
|
|
669
|
-
"epochs": epochs,
|
|
670
|
-
"batch_size": batch_size,
|
|
671
|
-
"metrics": self.metrics,
|
|
672
|
-
}
|
|
673
|
-
)
|
|
674
|
-
|
|
675
|
-
self.early_stopper = EarlyStopper(
|
|
676
|
-
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
681
|
+
{"epochs": epochs, "batch_size": batch_size, "metrics": self.metrics}
|
|
677
682
|
)
|
|
678
683
|
self.best_metric = (
|
|
679
684
|
float("-inf") if self.best_metrics_mode == "max" else float("inf")
|
|
@@ -685,6 +690,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
685
690
|
self.epoch_index = 0
|
|
686
691
|
self.stop_training = False
|
|
687
692
|
self.best_checkpoint_path = self.best_path
|
|
693
|
+
use_ddp_sampler = (
|
|
694
|
+
auto_distributed_sampler
|
|
695
|
+
and self.distributed
|
|
696
|
+
and dist.is_available()
|
|
697
|
+
and dist.is_initialized()
|
|
698
|
+
)
|
|
688
699
|
|
|
689
700
|
if not auto_distributed_sampler and self.distributed and self.is_main_process:
|
|
690
701
|
logging.info(
|
|
@@ -697,12 +708,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
697
708
|
train_sampler: DistributedSampler | None = None
|
|
698
709
|
if validation_split is not None and valid_data is None:
|
|
699
710
|
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
700
|
-
if
|
|
701
|
-
auto_distributed_sampler
|
|
702
|
-
and self.distributed
|
|
703
|
-
and dist.is_available()
|
|
704
|
-
and dist.is_initialized()
|
|
705
|
-
):
|
|
711
|
+
if use_ddp_sampler:
|
|
706
712
|
base_dataset = getattr(train_loader, "dataset", None)
|
|
707
713
|
if base_dataset is not None and not isinstance(
|
|
708
714
|
getattr(train_loader, "sampler", None), DistributedSampler
|
|
@@ -725,7 +731,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
725
731
|
)
|
|
726
732
|
else:
|
|
727
733
|
if isinstance(train_data, DataLoader):
|
|
728
|
-
if
|
|
734
|
+
if use_ddp_sampler:
|
|
729
735
|
train_loader, train_sampler = add_distributed_sampler(
|
|
730
736
|
train_data,
|
|
731
737
|
distributed=self.distributed,
|
|
@@ -749,15 +755,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
749
755
|
)
|
|
750
756
|
assert isinstance(
|
|
751
757
|
result, tuple
|
|
752
|
-
), "Expected tuple from prepare_data_loader with return_dataset=True"
|
|
758
|
+
), "[BaseModel-fit Error] Expected tuple from prepare_data_loader with return_dataset=True, but got something else."
|
|
753
759
|
loader, dataset = result
|
|
754
|
-
if
|
|
755
|
-
auto_distributed_sampler
|
|
756
|
-
and self.distributed
|
|
757
|
-
and dataset is not None
|
|
758
|
-
and dist.is_available()
|
|
759
|
-
and dist.is_initialized()
|
|
760
|
-
):
|
|
760
|
+
if use_ddp_sampler and dataset is not None:
|
|
761
761
|
train_sampler = DistributedSampler(
|
|
762
762
|
dataset,
|
|
763
763
|
num_replicas=self.world_size,
|
|
@@ -802,34 +802,42 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
802
802
|
except TypeError: # streaming data loader does not supported len()
|
|
803
803
|
self.steps_per_epoch = None
|
|
804
804
|
is_streaming = True
|
|
805
|
+
self.collect_train_metrics = not is_streaming
|
|
806
|
+
if is_streaming and self.is_main_process:
|
|
807
|
+
logging.info(
|
|
808
|
+
colorize(
|
|
809
|
+
"[Training Info] Streaming mode detected; training metrics collection is disabled to avoid memory growth.",
|
|
810
|
+
color="yellow",
|
|
811
|
+
)
|
|
812
|
+
)
|
|
805
813
|
|
|
806
814
|
if self.is_main_process:
|
|
807
815
|
self.summary()
|
|
808
816
|
logging.info("")
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
if
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
817
|
+
tb_dir = (
|
|
818
|
+
self.training_logger.tensorboard_logdir
|
|
819
|
+
if self.training_logger and self.training_logger.enable_tensorboard
|
|
820
|
+
else None
|
|
821
|
+
)
|
|
822
|
+
if tb_dir:
|
|
823
|
+
user = getpass.getuser()
|
|
824
|
+
host = socket.gethostname()
|
|
825
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
826
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
827
|
+
logging.info(
|
|
828
|
+
colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
|
|
829
|
+
)
|
|
830
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
831
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
832
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
833
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
823
834
|
|
|
824
835
|
logging.info("")
|
|
825
|
-
logging.info(colorize("="
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
logging.info(colorize("Start training", bold=True))
|
|
830
|
-
logging.info(colorize("=" * 80, bold=True))
|
|
836
|
+
logging.info(colorize("[Training]", color="bright_blue", bold=True))
|
|
837
|
+
logging.info(colorize("-" * 80, color="bright_blue"))
|
|
838
|
+
logging.info(format_kv("Start training", f"{epochs} epochs"))
|
|
839
|
+
logging.info(format_kv("Model device", self.device))
|
|
831
840
|
logging.info("")
|
|
832
|
-
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
833
841
|
|
|
834
842
|
self.callbacks.on_train_begin()
|
|
835
843
|
|
|
@@ -852,128 +860,77 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
852
860
|
and isinstance(train_loader.sampler, DistributedSampler)
|
|
853
861
|
):
|
|
854
862
|
train_loader.sampler.set_epoch(epoch)
|
|
855
|
-
|
|
863
|
+
|
|
856
864
|
if not isinstance(train_loader, DataLoader):
|
|
857
865
|
raise TypeError(
|
|
858
866
|
f"Expected DataLoader for training, got {type(train_loader)}"
|
|
859
867
|
)
|
|
860
868
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
861
|
-
if isinstance(
|
|
869
|
+
if isinstance(
|
|
870
|
+
train_result, tuple
|
|
871
|
+
): # [avg_loss, metrics_dict], e.g., (0.5, {'auc': 0.75, 'logloss': 0.45})
|
|
862
872
|
train_loss, train_metrics = train_result
|
|
863
873
|
else:
|
|
864
874
|
train_loss = train_result
|
|
865
875
|
train_metrics = None
|
|
866
876
|
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
[f"{k}={v:.4f}" for k, v in train_metrics.items()]
|
|
874
|
-
)
|
|
875
|
-
log_str += f", {metrics_str}"
|
|
876
|
-
if self.is_main_process:
|
|
877
|
-
logging.info(colorize(log_str))
|
|
878
|
-
train_log_payload["loss"] = float(train_loss)
|
|
879
|
-
if train_metrics:
|
|
880
|
-
train_log_payload.update(train_metrics)
|
|
881
|
-
else:
|
|
882
|
-
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
883
|
-
log_str = (
|
|
884
|
-
f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
877
|
+
logging.info("")
|
|
878
|
+
train_log_payload = {
|
|
879
|
+
"loss": (
|
|
880
|
+
float(np.sum(train_loss))
|
|
881
|
+
if isinstance(train_loss, np.ndarray)
|
|
882
|
+
else float(train_loss)
|
|
885
883
|
)
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
]
|
|
906
|
-
)
|
|
907
|
-
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
908
|
-
log_str += ", " + ", ".join(task_metric_strs)
|
|
909
|
-
if self.is_main_process:
|
|
910
|
-
logging.info(colorize(log_str))
|
|
911
|
-
train_log_payload["loss"] = float(total_loss_val)
|
|
912
|
-
if train_metrics:
|
|
913
|
-
train_log_payload.update(train_metrics)
|
|
884
|
+
}
|
|
885
|
+
if train_metrics:
|
|
886
|
+
train_log_payload.update(train_metrics)
|
|
887
|
+
|
|
888
|
+
display_metrics_table(
|
|
889
|
+
epoch=epoch + 1,
|
|
890
|
+
epochs=epochs,
|
|
891
|
+
split="Train",
|
|
892
|
+
loss=train_loss,
|
|
893
|
+
metrics=train_metrics,
|
|
894
|
+
target_names=self.target_columns,
|
|
895
|
+
base_metrics=(
|
|
896
|
+
self.metrics
|
|
897
|
+
if isinstance(getattr(self, "metrics", None), list)
|
|
898
|
+
else None
|
|
899
|
+
),
|
|
900
|
+
is_main_process=self.is_main_process,
|
|
901
|
+
colorize=lambda s: colorize(s),
|
|
902
|
+
)
|
|
914
903
|
if self.training_logger:
|
|
915
904
|
self.training_logger.log_metrics(
|
|
916
905
|
train_log_payload, step=epoch + 1, split="train"
|
|
917
906
|
)
|
|
918
907
|
if valid_loader is not None:
|
|
919
|
-
# Call on_validation_begin
|
|
920
908
|
self.callbacks.on_validation_begin()
|
|
921
|
-
|
|
922
|
-
# pass user_ids only if needed for GAUC metric
|
|
923
909
|
val_metrics = self.evaluate(
|
|
924
910
|
valid_loader,
|
|
925
911
|
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
926
912
|
num_workers=num_workers,
|
|
927
|
-
)
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
for target_name in self.target_columns:
|
|
944
|
-
if metric_key.endswith(f"_{target_name}"):
|
|
945
|
-
if target_name not in task_metrics:
|
|
946
|
-
task_metrics[target_name] = {}
|
|
947
|
-
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
948
|
-
task_metrics[target_name][metric_name] = metric_value
|
|
949
|
-
break
|
|
950
|
-
task_metric_strs = []
|
|
951
|
-
for target_name in self.target_columns:
|
|
952
|
-
if target_name in task_metrics:
|
|
953
|
-
metrics_str = ", ".join(
|
|
954
|
-
[
|
|
955
|
-
f"{k}={v:.4f}"
|
|
956
|
-
for k, v in task_metrics[target_name].items()
|
|
957
|
-
]
|
|
958
|
-
)
|
|
959
|
-
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
960
|
-
if self.is_main_process:
|
|
961
|
-
logging.info(
|
|
962
|
-
colorize(
|
|
963
|
-
f" Epoch {epoch + 1}/{epochs} - Valid: "
|
|
964
|
-
+ ", ".join(task_metric_strs),
|
|
965
|
-
color="cyan",
|
|
966
|
-
)
|
|
967
|
-
)
|
|
968
|
-
|
|
969
|
-
# Call on_validation_end
|
|
913
|
+
)
|
|
914
|
+
display_metrics_table(
|
|
915
|
+
epoch=epoch + 1,
|
|
916
|
+
epochs=epochs,
|
|
917
|
+
split="Valid",
|
|
918
|
+
loss=None,
|
|
919
|
+
metrics=val_metrics,
|
|
920
|
+
target_names=self.target_columns,
|
|
921
|
+
base_metrics=(
|
|
922
|
+
self.metrics
|
|
923
|
+
if isinstance(getattr(self, "metrics", None), list)
|
|
924
|
+
else None
|
|
925
|
+
),
|
|
926
|
+
is_main_process=self.is_main_process,
|
|
927
|
+
colorize=lambda s: colorize(" " + s, color="cyan"),
|
|
928
|
+
)
|
|
970
929
|
self.callbacks.on_validation_end()
|
|
971
930
|
if val_metrics and self.training_logger:
|
|
972
931
|
self.training_logger.log_metrics(
|
|
973
932
|
val_metrics, step=epoch + 1, split="valid"
|
|
974
933
|
)
|
|
975
|
-
|
|
976
|
-
# Handle empty validation metrics
|
|
977
934
|
if not val_metrics:
|
|
978
935
|
if self.is_main_process:
|
|
979
936
|
logging.info(
|
|
@@ -983,15 +940,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
983
940
|
)
|
|
984
941
|
)
|
|
985
942
|
continue
|
|
986
|
-
|
|
987
|
-
# Prepare epoch logs for callbacks
|
|
988
943
|
epoch_logs = {**train_log_payload}
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
for k, v in val_metrics.items():
|
|
992
|
-
epoch_logs[f"val_{k}"] = v
|
|
944
|
+
for k, v in val_metrics.items():
|
|
945
|
+
epoch_logs[f"val_{k}"] = v
|
|
993
946
|
else:
|
|
994
|
-
# No validation data
|
|
995
947
|
epoch_logs = {**train_log_payload}
|
|
996
948
|
if self.is_main_process:
|
|
997
949
|
self.save_model(
|
|
@@ -1018,13 +970,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1018
970
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
1019
971
|
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
1020
972
|
if self.is_main_process:
|
|
1021
|
-
logging.info("
|
|
1022
|
-
logging.info(colorize("Training finished.", bold=True))
|
|
1023
|
-
logging.info("
|
|
973
|
+
logging.info("")
|
|
974
|
+
logging.info(colorize("Training finished.", color="bright_blue", bold=True))
|
|
975
|
+
logging.info("")
|
|
1024
976
|
if valid_loader is not None:
|
|
1025
977
|
if self.is_main_process:
|
|
1026
978
|
logging.info(
|
|
1027
|
-
|
|
979
|
+
format_kv("Load best model from", self.best_checkpoint_path)
|
|
1028
980
|
)
|
|
1029
981
|
if os.path.exists(self.best_checkpoint_path):
|
|
1030
982
|
self.load_model(
|
|
@@ -1051,14 +1003,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1051
1003
|
num_batches = 0
|
|
1052
1004
|
y_true_list = []
|
|
1053
1005
|
y_pred_list = []
|
|
1006
|
+
collect_metrics = getattr(self, "collect_train_metrics", True)
|
|
1007
|
+
max_samples = getattr(self, "max_metrics_samples", None)
|
|
1008
|
+
collected_samples = 0
|
|
1009
|
+
metrics_capped = False
|
|
1054
1010
|
|
|
1055
1011
|
user_ids_list = [] if self.needs_user_ids else None
|
|
1056
1012
|
tqdm_disable = not self.is_main_process
|
|
1057
1013
|
if self.steps_per_epoch is not None:
|
|
1058
1014
|
batch_iter = enumerate(
|
|
1059
|
-
|
|
1015
|
+
progress(
|
|
1060
1016
|
train_loader,
|
|
1061
|
-
|
|
1017
|
+
description=f"Epoch {self.epoch_index + 1}",
|
|
1062
1018
|
total=self.steps_per_epoch,
|
|
1063
1019
|
disable=tqdm_disable,
|
|
1064
1020
|
)
|
|
@@ -1066,7 +1022,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1066
1022
|
else:
|
|
1067
1023
|
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
1068
1024
|
batch_iter = enumerate(
|
|
1069
|
-
|
|
1025
|
+
progress(
|
|
1026
|
+
train_loader,
|
|
1027
|
+
description=desc,
|
|
1028
|
+
disable=tqdm_disable,
|
|
1029
|
+
)
|
|
1070
1030
|
)
|
|
1071
1031
|
for batch_index, batch_data in batch_iter:
|
|
1072
1032
|
batch_dict = batch_to_dict(batch_data)
|
|
@@ -1085,16 +1045,34 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1085
1045
|
self.optimizer_fn.step()
|
|
1086
1046
|
accumulated_loss += loss.item()
|
|
1087
1047
|
|
|
1088
|
-
if
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
)
|
|
1094
|
-
if
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1048
|
+
if (
|
|
1049
|
+
collect_metrics
|
|
1050
|
+
and y_true is not None
|
|
1051
|
+
and isinstance(y_pred, torch.Tensor)
|
|
1052
|
+
):
|
|
1053
|
+
batch_size = int(y_true.size(0))
|
|
1054
|
+
if max_samples is not None and collected_samples >= max_samples:
|
|
1055
|
+
collect_metrics = False
|
|
1056
|
+
metrics_capped = True
|
|
1057
|
+
else:
|
|
1058
|
+
take_count = batch_size
|
|
1059
|
+
if (
|
|
1060
|
+
max_samples is not None
|
|
1061
|
+
and collected_samples + batch_size > max_samples
|
|
1062
|
+
):
|
|
1063
|
+
take_count = max_samples - collected_samples
|
|
1064
|
+
metrics_capped = True
|
|
1065
|
+
collect_metrics = False
|
|
1066
|
+
if take_count > 0:
|
|
1067
|
+
y_true_list.append(y_true[:take_count].detach().cpu().numpy())
|
|
1068
|
+
y_pred_list.append(y_pred[:take_count].detach().cpu().numpy())
|
|
1069
|
+
if self.needs_user_ids and user_ids_list is not None:
|
|
1070
|
+
batch_user_id = get_user_ids(
|
|
1071
|
+
data=batch_dict, id_columns=self.id_columns
|
|
1072
|
+
)
|
|
1073
|
+
if batch_user_id is not None:
|
|
1074
|
+
user_ids_list.append(batch_user_id[:take_count])
|
|
1075
|
+
collected_samples += take_count
|
|
1098
1076
|
num_batches += 1
|
|
1099
1077
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
1100
1078
|
loss_tensor = torch.tensor(
|
|
@@ -1120,6 +1098,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1120
1098
|
gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
|
|
1121
1099
|
)
|
|
1122
1100
|
|
|
1101
|
+
if metrics_capped and self.is_main_process:
|
|
1102
|
+
logging.info(
|
|
1103
|
+
colorize(
|
|
1104
|
+
f"[Training Info] Training metrics capped at {max_samples} samples to limit memory usage.",
|
|
1105
|
+
color="yellow",
|
|
1106
|
+
)
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1123
1109
|
if (
|
|
1124
1110
|
y_true_all is not None
|
|
1125
1111
|
and y_pred_all is not None
|
|
@@ -1258,11 +1244,15 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1258
1244
|
)
|
|
1259
1245
|
if batch_user_id is not None:
|
|
1260
1246
|
collected_user_ids.append(batch_user_id)
|
|
1261
|
-
if self.is_main_process:
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1247
|
+
# if self.is_main_process:
|
|
1248
|
+
# logging.info("")
|
|
1249
|
+
# logging.info(
|
|
1250
|
+
# colorize(
|
|
1251
|
+
# format_kv(
|
|
1252
|
+
# "Evaluation batches processed", batch_count
|
|
1253
|
+
# ),
|
|
1254
|
+
# )
|
|
1255
|
+
# )
|
|
1266
1256
|
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
1267
1257
|
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
1268
1258
|
|
|
@@ -1301,10 +1291,15 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1301
1291
|
)
|
|
1302
1292
|
)
|
|
1303
1293
|
return {}
|
|
1304
|
-
if self.is_main_process:
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1294
|
+
# if self.is_main_process:
|
|
1295
|
+
# logging.info(
|
|
1296
|
+
# colorize(
|
|
1297
|
+
# format_kv(
|
|
1298
|
+
# "Evaluation samples", y_true_all.shape[0]
|
|
1299
|
+
# ),
|
|
1300
|
+
# )
|
|
1301
|
+
# )
|
|
1302
|
+
logging.info("")
|
|
1308
1303
|
metrics_dict = evaluate_metrics(
|
|
1309
1304
|
y_true=y_true_all,
|
|
1310
1305
|
y_pred=y_pred_all,
|
|
@@ -1396,7 +1391,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1396
1391
|
id_arrays = None
|
|
1397
1392
|
|
|
1398
1393
|
with torch.no_grad():
|
|
1399
|
-
for batch_data in
|
|
1394
|
+
for batch_data in progress(data_loader, description="Predicting"):
|
|
1400
1395
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
1401
1396
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
1402
1397
|
y_pred = self(X_input)
|
|
@@ -1417,10 +1412,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1417
1412
|
if id_np.ndim == 1
|
|
1418
1413
|
else id_np
|
|
1419
1414
|
)
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
y_pred_all = np.array([])
|
|
1415
|
+
y_pred_all = (
|
|
1416
|
+
np.concatenate(y_pred_list, axis=0) if y_pred_list else np.array([])
|
|
1417
|
+
)
|
|
1424
1418
|
|
|
1425
1419
|
if y_pred_all.ndim == 1:
|
|
1426
1420
|
y_pred_all = y_pred_all.reshape(-1, 1)
|
|
@@ -1428,22 +1422,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1428
1422
|
num_outputs = len(self.target_columns) if self.target_columns else 1
|
|
1429
1423
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
1430
1424
|
num_outputs = y_pred_all.shape[1]
|
|
1431
|
-
pred_columns: list[str] =
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
pred_columns.append(f"{name}")
|
|
1425
|
+
pred_columns: list[str] = (
|
|
1426
|
+
list(self.target_columns[:num_outputs]) if self.target_columns else []
|
|
1427
|
+
)
|
|
1435
1428
|
while len(pred_columns) < num_outputs:
|
|
1436
1429
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1437
1430
|
if include_ids and predict_id_columns:
|
|
1438
|
-
id_arrays = {
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
concatenated = np.concatenate(
|
|
1431
|
+
id_arrays = {
|
|
1432
|
+
id_name: (
|
|
1433
|
+
np.concatenate(
|
|
1442
1434
|
[p.reshape(p.shape[0], -1) for p in pieces], axis=0
|
|
1443
|
-
)
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1435
|
+
).reshape(-1)
|
|
1436
|
+
if pieces
|
|
1437
|
+
else np.array([], dtype=np.int64)
|
|
1438
|
+
)
|
|
1439
|
+
for id_name, pieces in id_buffers.items()
|
|
1440
|
+
}
|
|
1447
1441
|
if return_dataframe:
|
|
1448
1442
|
id_df = pd.DataFrame(id_arrays)
|
|
1449
1443
|
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
@@ -1544,7 +1538,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1544
1538
|
collected_frames = [] # only used when return_dataframe is True
|
|
1545
1539
|
|
|
1546
1540
|
with torch.no_grad():
|
|
1547
|
-
for batch_data in
|
|
1541
|
+
for batch_data in progress(data_loader, description="Predicting"):
|
|
1548
1542
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
1549
1543
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
1550
1544
|
y_pred = self.forward(X_input)
|
|
@@ -1555,25 +1549,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1555
1549
|
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
1556
1550
|
if pred_columns is None:
|
|
1557
1551
|
num_outputs = y_pred_np.shape[1]
|
|
1558
|
-
pred_columns =
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1552
|
+
pred_columns = (
|
|
1553
|
+
list(self.target_columns[:num_outputs])
|
|
1554
|
+
if self.target_columns
|
|
1555
|
+
else []
|
|
1556
|
+
)
|
|
1562
1557
|
while len(pred_columns) < num_outputs:
|
|
1563
1558
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1564
1559
|
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
)
|
|
1576
|
-
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
1560
|
+
ids = batch_dict.get("ids") if include_ids and id_columns else None
|
|
1561
|
+
id_arrays_batch = {
|
|
1562
|
+
id_name: (
|
|
1563
|
+
ids[id_name].detach().cpu().numpy()
|
|
1564
|
+
if isinstance(ids[id_name], torch.Tensor)
|
|
1565
|
+
else np.asarray(ids[id_name])
|
|
1566
|
+
).reshape(-1)
|
|
1567
|
+
for id_name in (id_columns or [])
|
|
1568
|
+
if ids and id_name in ids
|
|
1569
|
+
}
|
|
1577
1570
|
|
|
1578
1571
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
1579
1572
|
if id_arrays_batch:
|
|
@@ -1775,13 +1768,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1775
1768
|
def summary(self):
|
|
1776
1769
|
logger = logging.getLogger()
|
|
1777
1770
|
|
|
1778
|
-
logger.info(
|
|
1771
|
+
logger.info("")
|
|
1779
1772
|
logger.info(
|
|
1780
1773
|
colorize(
|
|
1781
1774
|
f"Model Summary: {self.model_name}", color="bright_blue", bold=True
|
|
1782
1775
|
)
|
|
1783
1776
|
)
|
|
1784
|
-
logger.info(
|
|
1777
|
+
logger.info("")
|
|
1785
1778
|
|
|
1786
1779
|
logger.info("")
|
|
1787
1780
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
@@ -1903,6 +1896,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1903
1896
|
logger.info("Other Settings:")
|
|
1904
1897
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1905
1898
|
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1899
|
+
logger.info(f" Max Metrics Samples: {self.max_metrics_samples}")
|
|
1906
1900
|
logger.info(f" Session ID: {self.session_id}")
|
|
1907
1901
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1908
1902
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
@@ -2296,7 +2290,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2296
2290
|
|
|
2297
2291
|
embeddings_list = []
|
|
2298
2292
|
with torch.no_grad():
|
|
2299
|
-
for batch_data in
|
|
2293
|
+
for batch_data in progress(data_loader, description="Encoding users"):
|
|
2300
2294
|
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
2301
2295
|
user_input = self.get_user_features(batch_dict["features"])
|
|
2302
2296
|
user_emb = self.user_tower(user_input)
|
|
@@ -2316,7 +2310,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2316
2310
|
|
|
2317
2311
|
embeddings_list = []
|
|
2318
2312
|
with torch.no_grad():
|
|
2319
|
-
for batch_data in
|
|
2313
|
+
for batch_data in progress(data_loader, description="Encoding items"):
|
|
2320
2314
|
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
2321
2315
|
item_input = self.get_item_features(batch_dict["features"])
|
|
2322
2316
|
item_emb = self.item_tower(item_input)
|