nextrec 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +30 -15
- nextrec/basic/features.py +1 -0
- nextrec/basic/layers.py +6 -8
- nextrec/basic/loggers.py +14 -7
- nextrec/basic/metrics.py +6 -76
- nextrec/basic/model.py +337 -328
- nextrec/cli.py +25 -4
- nextrec/data/__init__.py +13 -16
- nextrec/data/batch_utils.py +3 -2
- nextrec/data/data_processing.py +10 -2
- nextrec/data/data_utils.py +9 -14
- nextrec/data/dataloader.py +12 -13
- nextrec/data/preprocessor.py +328 -255
- nextrec/loss/__init__.py +1 -5
- nextrec/loss/loss_utils.py +2 -8
- nextrec/models/generative/__init__.py +1 -8
- nextrec/models/generative/hstu.py +6 -4
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +2 -2
- nextrec/models/multi_task/ple.py +2 -2
- nextrec/models/multi_task/poso.py +2 -3
- nextrec/models/multi_task/share_bottom.py +2 -2
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +2 -2
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/ffm.py +0 -0
- nextrec/models/ranking/fibinet.py +5 -5
- nextrec/models/ranking/fm.py +3 -7
- nextrec/models/ranking/lr.py +0 -0
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +2 -2
- nextrec/models/ranking/widedeep.py +2 -2
- nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec/models/representation/__init__.py +9 -0
- nextrec/models/{generative → representation}/rqvae.py +9 -9
- nextrec/models/retrieval/__init__.py +0 -0
- nextrec/models/{match → retrieval}/dssm.py +8 -3
- nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
- nextrec/models/{match → retrieval}/mind.py +4 -3
- nextrec/models/{match → retrieval}/sdm.py +4 -3
- nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
- nextrec/utils/__init__.py +60 -46
- nextrec/utils/config.py +12 -10
- nextrec/utils/console.py +371 -0
- nextrec/utils/{synthetic_data.py → data.py} +102 -15
- nextrec/utils/feature.py +15 -0
- nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/METADATA +8 -7
- nextrec-0.4.9.dist-info/RECORD +70 -0
- nextrec/utils/device.py +0 -78
- nextrec/utils/distributed.py +0 -141
- nextrec/utils/file.py +0 -92
- nextrec/utils/initializer.py +0 -79
- nextrec/utils/optimizer.py +0 -75
- nextrec/utils/tensor.py +0 -72
- nextrec-0.4.7.dist-info/RECORD +0 -70
- /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
- {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
- {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,52 +2,52 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
import getpass
|
|
10
|
+
import logging
|
|
9
11
|
import os
|
|
10
|
-
import tqdm
|
|
11
12
|
import pickle
|
|
12
|
-
import logging
|
|
13
|
-
import getpass
|
|
14
13
|
import socket
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Literal, Union
|
|
16
|
+
|
|
15
17
|
import numpy as np
|
|
16
18
|
import pandas as pd
|
|
17
19
|
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
18
21
|
import torch.nn as nn
|
|
19
22
|
import torch.nn.functional as F
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from pathlib import Path
|
|
23
|
-
from typing import Union, Literal, Any
|
|
23
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
24
24
|
from torch.utils.data import DataLoader
|
|
25
25
|
from torch.utils.data.distributed import DistributedSampler
|
|
26
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
27
26
|
|
|
27
|
+
from nextrec import __version__
|
|
28
28
|
from nextrec.basic.callback import (
|
|
29
|
-
EarlyStopper,
|
|
30
|
-
CallbackList,
|
|
31
29
|
Callback,
|
|
30
|
+
CallbackList,
|
|
32
31
|
CheckpointSaver,
|
|
32
|
+
EarlyStopper,
|
|
33
33
|
LearningRateScheduler,
|
|
34
34
|
)
|
|
35
35
|
from nextrec.basic.features import (
|
|
36
36
|
DenseFeature,
|
|
37
|
-
SparseFeature,
|
|
38
|
-
SequenceFeature,
|
|
39
37
|
FeatureSet,
|
|
38
|
+
SequenceFeature,
|
|
39
|
+
SparseFeature,
|
|
40
40
|
)
|
|
41
|
-
from nextrec.
|
|
42
|
-
|
|
43
|
-
from nextrec.basic.
|
|
44
|
-
from nextrec.
|
|
45
|
-
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
46
|
-
|
|
47
|
-
from nextrec.data.dataloader import build_tensors_from_data
|
|
48
|
-
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
41
|
+
from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
|
|
42
|
+
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
43
|
+
from nextrec.basic.session import create_session, resolve_save_path
|
|
44
|
+
from nextrec.data.batch_utils import batch_to_dict, collate_fn
|
|
49
45
|
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
50
|
-
|
|
46
|
+
from nextrec.data.dataloader import (
|
|
47
|
+
RecDataLoader,
|
|
48
|
+
TensorDictDataset,
|
|
49
|
+
build_tensors_from_data,
|
|
50
|
+
)
|
|
51
51
|
from nextrec.loss import (
|
|
52
52
|
BPRLoss,
|
|
53
53
|
HingeLoss,
|
|
@@ -55,17 +55,17 @@ from nextrec.loss import (
|
|
|
55
55
|
SampledSoftmaxLoss,
|
|
56
56
|
TripletLoss,
|
|
57
57
|
get_loss_fn,
|
|
58
|
-
get_loss_kwargs,
|
|
59
58
|
)
|
|
60
|
-
from nextrec.utils.
|
|
61
|
-
from nextrec.utils.
|
|
62
|
-
|
|
63
|
-
|
|
59
|
+
from nextrec.utils.console import display_metrics_table, progress
|
|
60
|
+
from nextrec.utils.torch_utils import (
|
|
61
|
+
add_distributed_sampler,
|
|
62
|
+
configure_device,
|
|
64
63
|
gather_numpy,
|
|
64
|
+
get_optimizer,
|
|
65
|
+
get_scheduler,
|
|
65
66
|
init_process_group,
|
|
66
|
-
|
|
67
|
+
to_tensor,
|
|
67
68
|
)
|
|
68
|
-
from nextrec import __version__
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
class BaseModel(FeatureSet, nn.Module):
|
|
@@ -91,6 +91,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
91
91
|
dense_l2_reg: float = 0.0,
|
|
92
92
|
device: str = "cpu",
|
|
93
93
|
early_stop_patience: int = 20,
|
|
94
|
+
max_metrics_samples: int | None = 200000,
|
|
94
95
|
session_id: str | None = None,
|
|
95
96
|
callbacks: list[Callback] | None = None,
|
|
96
97
|
distributed: bool = False,
|
|
@@ -117,6 +118,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
117
118
|
|
|
118
119
|
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
119
120
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
121
|
+
max_metrics_samples: Max samples to keep for training metrics. None disables limit.
|
|
120
122
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
121
123
|
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
122
124
|
|
|
@@ -146,7 +148,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
146
148
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
147
149
|
self.checkpoint_path = os.path.join(
|
|
148
150
|
self.session_path, self.model_name + "_checkpoint.pt"
|
|
149
|
-
) #
|
|
151
|
+
) # e.g., pwd/session_id/DeepFM_checkpoint.pt
|
|
150
152
|
self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
|
|
151
153
|
self.features_config_path = os.path.join(
|
|
152
154
|
self.session_path, "features_config.pkl"
|
|
@@ -167,6 +169,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
167
169
|
self.loss_weight = None
|
|
168
170
|
|
|
169
171
|
self.early_stop_patience = early_stop_patience
|
|
172
|
+
self.max_metrics_samples = (
|
|
173
|
+
None if max_metrics_samples is None else int(max_metrics_samples)
|
|
174
|
+
)
|
|
170
175
|
self.max_gradient_norm = 1.0
|
|
171
176
|
self.logger_initialized = False
|
|
172
177
|
self.training_logger = None
|
|
@@ -182,15 +187,15 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
182
187
|
include_modules = include_modules or []
|
|
183
188
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
184
189
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
185
|
-
embedding_params: list[torch.Tensor] = []
|
|
186
190
|
if embed_dict is not None:
|
|
187
|
-
embedding_params
|
|
188
|
-
embed.weight
|
|
189
|
-
|
|
191
|
+
embedding_params = [
|
|
192
|
+
embed.weight
|
|
193
|
+
for embed in embed_dict.values()
|
|
194
|
+
if hasattr(embed, "weight")
|
|
195
|
+
]
|
|
190
196
|
else:
|
|
191
197
|
weight = getattr(embedding_layer, "weight", None)
|
|
192
|
-
if isinstance(weight, torch.Tensor)
|
|
193
|
-
embedding_params.append(weight)
|
|
198
|
+
embedding_params = [weight] if isinstance(weight, torch.Tensor) else []
|
|
194
199
|
|
|
195
200
|
existing_embedding_ids = {id(param) for param in self.embedding_params}
|
|
196
201
|
for param in embedding_params:
|
|
@@ -212,10 +217,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
212
217
|
module is self
|
|
213
218
|
or embedding_attr in name
|
|
214
219
|
or isinstance(module, skip_types)
|
|
215
|
-
or (include_modules and not any(inc in name for inc in include_modules))
|
|
216
|
-
or any(exc in name for exc in exclude_modules)
|
|
217
220
|
):
|
|
218
221
|
continue
|
|
222
|
+
if include_modules and not any(inc in name for inc in include_modules):
|
|
223
|
+
continue
|
|
224
|
+
if exclude_modules and any(exc in name for exc in exclude_modules):
|
|
225
|
+
continue
|
|
219
226
|
if isinstance(module, nn.Linear):
|
|
220
227
|
if id(module.weight) not in existing_reg_ids:
|
|
221
228
|
self.regularization_weights.append(module.weight)
|
|
@@ -317,22 +324,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
317
324
|
raise ValueError(
|
|
318
325
|
f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
|
|
319
326
|
)
|
|
320
|
-
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
321
|
-
raise TypeError(
|
|
322
|
-
f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
|
|
323
|
-
)
|
|
324
327
|
if isinstance(train_data, pd.DataFrame):
|
|
325
328
|
total_length = len(train_data)
|
|
326
|
-
|
|
327
|
-
sample_key = next(
|
|
328
|
-
|
|
329
|
-
) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
330
|
-
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
329
|
+
elif isinstance(train_data, dict):
|
|
330
|
+
sample_key = next(iter(train_data))
|
|
331
|
+
total_length = len(train_data[sample_key])
|
|
331
332
|
for k, v in train_data.items():
|
|
332
333
|
if len(v) != total_length:
|
|
333
334
|
raise ValueError(
|
|
334
335
|
f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
|
|
335
336
|
)
|
|
337
|
+
else:
|
|
338
|
+
raise TypeError(
|
|
339
|
+
f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
|
|
340
|
+
)
|
|
336
341
|
rng = np.random.default_rng(42)
|
|
337
342
|
indices = rng.permutation(total_length)
|
|
338
343
|
split_idx = int(total_length * (1 - validation_split))
|
|
@@ -342,12 +347,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
342
347
|
train_split = train_data.iloc[train_indices].reset_index(drop=True)
|
|
343
348
|
valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
|
|
344
349
|
else:
|
|
345
|
-
train_split = {
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
350
|
+
train_split = {
|
|
351
|
+
k: np.asarray(v)[train_indices] for k, v in train_data.items()
|
|
352
|
+
}
|
|
353
|
+
valid_split = {
|
|
354
|
+
k: np.asarray(v)[valid_indices] for k, v in train_data.items()
|
|
355
|
+
}
|
|
351
356
|
train_loader = self.prepare_data_loader(
|
|
352
357
|
train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
|
|
353
358
|
)
|
|
@@ -402,11 +407,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
402
407
|
)
|
|
403
408
|
|
|
404
409
|
scheduler_params = scheduler_params or {}
|
|
405
|
-
if
|
|
406
|
-
self.scheduler_name = scheduler
|
|
407
|
-
elif scheduler is None:
|
|
410
|
+
if scheduler is None:
|
|
408
411
|
self.scheduler_name = None
|
|
409
|
-
|
|
412
|
+
elif isinstance(scheduler, str):
|
|
413
|
+
self.scheduler_name = scheduler
|
|
414
|
+
else:
|
|
410
415
|
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
411
416
|
self.scheduler_params = scheduler_params
|
|
412
417
|
self.scheduler_fn = (
|
|
@@ -417,25 +422,23 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
417
422
|
|
|
418
423
|
self.loss_config = loss
|
|
419
424
|
self.loss_params = loss_params or {}
|
|
420
|
-
|
|
421
|
-
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
425
|
+
if isinstance(loss, list):
|
|
422
426
|
if len(loss) != self.nums_task:
|
|
423
427
|
raise ValueError(
|
|
424
428
|
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
|
|
425
429
|
)
|
|
426
|
-
loss_list =
|
|
427
|
-
else:
|
|
430
|
+
loss_list = list(loss)
|
|
431
|
+
else:
|
|
428
432
|
loss_list = [loss] * self.nums_task
|
|
429
|
-
|
|
430
433
|
if isinstance(self.loss_params, dict):
|
|
431
|
-
|
|
432
|
-
else:
|
|
433
|
-
|
|
434
|
+
loss_params_list = [self.loss_params] * self.nums_task
|
|
435
|
+
else:
|
|
436
|
+
loss_params_list = [
|
|
434
437
|
self.loss_params[i] if i < len(self.loss_params) else {}
|
|
435
438
|
for i in range(self.nums_task)
|
|
436
439
|
]
|
|
437
440
|
self.loss_fn = [
|
|
438
|
-
get_loss_fn(loss=loss_list[i], **
|
|
441
|
+
get_loss_fn(loss=loss_list[i], **loss_params_list[i])
|
|
439
442
|
for i in range(self.nums_task)
|
|
440
443
|
]
|
|
441
444
|
|
|
@@ -447,10 +450,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
447
450
|
raise ValueError(
|
|
448
451
|
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
449
452
|
)
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
weight_value = loss_weights
|
|
453
|
-
self.loss_weights = [float(weight_value)]
|
|
453
|
+
loss_weights = loss_weights[0]
|
|
454
|
+
self.loss_weights = [float(loss_weights)]
|
|
454
455
|
else:
|
|
455
456
|
if isinstance(loss_weights, (int, float)):
|
|
456
457
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -483,7 +484,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
483
484
|
y_true = y_true.view(-1, 1)
|
|
484
485
|
if y_pred.shape != y_true.shape:
|
|
485
486
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
486
|
-
task_dim =
|
|
487
|
+
task_dim = (
|
|
488
|
+
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
489
|
+
)
|
|
487
490
|
if task_dim == 1:
|
|
488
491
|
loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
|
|
489
492
|
else:
|
|
@@ -494,12 +497,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
494
497
|
# multi-task
|
|
495
498
|
if y_pred.shape != y_true.shape:
|
|
496
499
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
497
|
-
|
|
498
|
-
self
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
500
|
+
slices = (
|
|
501
|
+
self.prediction_layer.task_slices # type: ignore
|
|
502
|
+
if hasattr(self, "prediction_layer")
|
|
503
|
+
else [(i, i + 1) for i in range(self.nums_task)]
|
|
504
|
+
)
|
|
503
505
|
task_losses = []
|
|
504
506
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
505
507
|
y_pred_i = y_pred[:, start:end]
|
|
@@ -519,6 +521,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
519
521
|
sampler=None,
|
|
520
522
|
return_dataset: bool = False,
|
|
521
523
|
) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
|
|
524
|
+
"""
|
|
525
|
+
Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
|
|
526
|
+
"""
|
|
522
527
|
if isinstance(data, DataLoader):
|
|
523
528
|
return (data, None) if return_dataset else data
|
|
524
529
|
tensors = build_tensors_from_data(
|
|
@@ -625,54 +630,55 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
625
630
|
)
|
|
626
631
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
627
632
|
|
|
628
|
-
# Setup default callbacks if
|
|
629
|
-
if
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
633
|
+
# Setup default callbacks if missing
|
|
634
|
+
if self.nums_task == 1:
|
|
635
|
+
monitor_metric = f"val_{self.metrics[0]}"
|
|
636
|
+
else:
|
|
637
|
+
monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
|
|
638
|
+
|
|
639
|
+
existing_callbacks = self.callbacks.callbacks
|
|
640
|
+
has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
|
|
641
|
+
has_checkpoint = any(
|
|
642
|
+
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
643
|
+
)
|
|
644
|
+
has_lr_scheduler = any(
|
|
645
|
+
isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
if self.early_stop_patience > 0 and not has_early_stop:
|
|
649
|
+
self.callbacks.append(
|
|
650
|
+
EarlyStopper(
|
|
651
|
+
monitor=monitor_metric,
|
|
652
|
+
patience=self.early_stop_patience,
|
|
653
|
+
mode=self.best_metrics_mode,
|
|
654
|
+
restore_best_weights=not self.distributed,
|
|
655
|
+
verbose=1 if self.is_main_process else 0,
|
|
644
656
|
)
|
|
657
|
+
)
|
|
645
658
|
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
659
|
+
if self.is_main_process and not has_checkpoint:
|
|
660
|
+
self.callbacks.append(
|
|
661
|
+
CheckpointSaver(
|
|
662
|
+
best_path=self.best_path,
|
|
663
|
+
checkpoint_path=self.checkpoint_path,
|
|
664
|
+
monitor=monitor_metric,
|
|
665
|
+
mode=self.best_metrics_mode,
|
|
666
|
+
save_best_only=True,
|
|
667
|
+
verbose=1,
|
|
655
668
|
)
|
|
669
|
+
)
|
|
656
670
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
)
|
|
671
|
+
if self.scheduler_fn is not None and not has_lr_scheduler:
|
|
672
|
+
self.callbacks.append(
|
|
673
|
+
LearningRateScheduler(
|
|
674
|
+
scheduler=self.scheduler_fn,
|
|
675
|
+
verbose=1 if self.is_main_process else 0,
|
|
663
676
|
)
|
|
677
|
+
)
|
|
664
678
|
|
|
665
679
|
self.callbacks.set_model(self)
|
|
666
680
|
self.callbacks.set_params(
|
|
667
|
-
{
|
|
668
|
-
"epochs": epochs,
|
|
669
|
-
"batch_size": batch_size,
|
|
670
|
-
"metrics": self.metrics,
|
|
671
|
-
}
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
self.early_stopper = EarlyStopper(
|
|
675
|
-
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
681
|
+
{"epochs": epochs, "batch_size": batch_size, "metrics": self.metrics}
|
|
676
682
|
)
|
|
677
683
|
self.best_metric = (
|
|
678
684
|
float("-inf") if self.best_metrics_mode == "max" else float("inf")
|
|
@@ -684,6 +690,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
684
690
|
self.epoch_index = 0
|
|
685
691
|
self.stop_training = False
|
|
686
692
|
self.best_checkpoint_path = self.best_path
|
|
693
|
+
use_ddp_sampler = (
|
|
694
|
+
auto_distributed_sampler
|
|
695
|
+
and self.distributed
|
|
696
|
+
and dist.is_available()
|
|
697
|
+
and dist.is_initialized()
|
|
698
|
+
)
|
|
687
699
|
|
|
688
700
|
if not auto_distributed_sampler and self.distributed and self.is_main_process:
|
|
689
701
|
logging.info(
|
|
@@ -696,12 +708,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
696
708
|
train_sampler: DistributedSampler | None = None
|
|
697
709
|
if validation_split is not None and valid_data is None:
|
|
698
710
|
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
699
|
-
if
|
|
700
|
-
auto_distributed_sampler
|
|
701
|
-
and self.distributed
|
|
702
|
-
and dist.is_available()
|
|
703
|
-
and dist.is_initialized()
|
|
704
|
-
):
|
|
711
|
+
if use_ddp_sampler:
|
|
705
712
|
base_dataset = getattr(train_loader, "dataset", None)
|
|
706
713
|
if base_dataset is not None and not isinstance(
|
|
707
714
|
getattr(train_loader, "sampler", None), DistributedSampler
|
|
@@ -724,7 +731,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
724
731
|
)
|
|
725
732
|
else:
|
|
726
733
|
if isinstance(train_data, DataLoader):
|
|
727
|
-
if
|
|
734
|
+
if use_ddp_sampler:
|
|
728
735
|
train_loader, train_sampler = add_distributed_sampler(
|
|
729
736
|
train_data,
|
|
730
737
|
distributed=self.distributed,
|
|
@@ -739,16 +746,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
739
746
|
else:
|
|
740
747
|
train_loader = train_data
|
|
741
748
|
else:
|
|
742
|
-
result = self.prepare_data_loader(
|
|
743
|
-
|
|
749
|
+
result = self.prepare_data_loader(
|
|
750
|
+
train_data,
|
|
751
|
+
batch_size=batch_size,
|
|
752
|
+
shuffle=shuffle,
|
|
753
|
+
num_workers=num_workers,
|
|
754
|
+
return_dataset=True,
|
|
755
|
+
)
|
|
756
|
+
assert isinstance(
|
|
757
|
+
result, tuple
|
|
758
|
+
), "[BaseModel-fit Error] Expected tuple from prepare_data_loader with return_dataset=True, but got something else."
|
|
744
759
|
loader, dataset = result
|
|
745
|
-
if
|
|
746
|
-
auto_distributed_sampler
|
|
747
|
-
and self.distributed
|
|
748
|
-
and dataset is not None
|
|
749
|
-
and dist.is_available()
|
|
750
|
-
and dist.is_initialized()
|
|
751
|
-
):
|
|
760
|
+
if use_ddp_sampler and dataset is not None:
|
|
752
761
|
train_sampler = DistributedSampler(
|
|
753
762
|
dataset,
|
|
754
763
|
num_replicas=self.world_size,
|
|
@@ -793,34 +802,42 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
793
802
|
except TypeError: # streaming data loader does not supported len()
|
|
794
803
|
self.steps_per_epoch = None
|
|
795
804
|
is_streaming = True
|
|
805
|
+
self.collect_train_metrics = not is_streaming
|
|
806
|
+
if is_streaming and self.is_main_process:
|
|
807
|
+
logging.info(
|
|
808
|
+
colorize(
|
|
809
|
+
"[Training Info] Streaming mode detected; training metrics collection is disabled to avoid memory growth.",
|
|
810
|
+
color="yellow",
|
|
811
|
+
)
|
|
812
|
+
)
|
|
796
813
|
|
|
797
814
|
if self.is_main_process:
|
|
798
815
|
self.summary()
|
|
799
816
|
logging.info("")
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
if
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
817
|
+
tb_dir = (
|
|
818
|
+
self.training_logger.tensorboard_logdir
|
|
819
|
+
if self.training_logger and self.training_logger.enable_tensorboard
|
|
820
|
+
else None
|
|
821
|
+
)
|
|
822
|
+
if tb_dir:
|
|
823
|
+
user = getpass.getuser()
|
|
824
|
+
host = socket.gethostname()
|
|
825
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
826
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
827
|
+
logging.info(
|
|
828
|
+
colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
|
|
829
|
+
)
|
|
830
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
831
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
832
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
833
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
814
834
|
|
|
815
835
|
logging.info("")
|
|
816
|
-
logging.info(colorize("="
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
logging.info(colorize("Start training", bold=True))
|
|
821
|
-
logging.info(colorize("=" * 80, bold=True))
|
|
836
|
+
logging.info(colorize("[Training]", color="bright_blue", bold=True))
|
|
837
|
+
logging.info(colorize("-" * 80, color="bright_blue"))
|
|
838
|
+
logging.info(format_kv("Start training", f"{epochs} epochs"))
|
|
839
|
+
logging.info(format_kv("Model device", self.device))
|
|
822
840
|
logging.info("")
|
|
823
|
-
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
824
841
|
|
|
825
842
|
self.callbacks.on_train_begin()
|
|
826
843
|
|
|
@@ -843,126 +860,77 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
843
860
|
and isinstance(train_loader.sampler, DistributedSampler)
|
|
844
861
|
):
|
|
845
862
|
train_loader.sampler.set_epoch(epoch)
|
|
846
|
-
|
|
863
|
+
|
|
847
864
|
if not isinstance(train_loader, DataLoader):
|
|
848
|
-
raise TypeError(
|
|
865
|
+
raise TypeError(
|
|
866
|
+
f"Expected DataLoader for training, got {type(train_loader)}"
|
|
867
|
+
)
|
|
849
868
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
850
|
-
if isinstance(
|
|
869
|
+
if isinstance(
|
|
870
|
+
train_result, tuple
|
|
871
|
+
): # [avg_loss, metrics_dict], e.g., (0.5, {'auc': 0.75, 'logloss': 0.45})
|
|
851
872
|
train_loss, train_metrics = train_result
|
|
852
873
|
else:
|
|
853
874
|
train_loss = train_result
|
|
854
875
|
train_metrics = None
|
|
855
876
|
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
[f"{k}={v:.4f}" for k, v in train_metrics.items()]
|
|
863
|
-
)
|
|
864
|
-
log_str += f", {metrics_str}"
|
|
865
|
-
if self.is_main_process:
|
|
866
|
-
logging.info(colorize(log_str))
|
|
867
|
-
train_log_payload["loss"] = float(train_loss)
|
|
868
|
-
if train_metrics:
|
|
869
|
-
train_log_payload.update(train_metrics)
|
|
870
|
-
else:
|
|
871
|
-
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
872
|
-
log_str = (
|
|
873
|
-
f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
877
|
+
logging.info("")
|
|
878
|
+
train_log_payload = {
|
|
879
|
+
"loss": (
|
|
880
|
+
float(np.sum(train_loss))
|
|
881
|
+
if isinstance(train_loss, np.ndarray)
|
|
882
|
+
else float(train_loss)
|
|
874
883
|
)
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
]
|
|
895
|
-
)
|
|
896
|
-
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
897
|
-
log_str += ", " + ", ".join(task_metric_strs)
|
|
898
|
-
if self.is_main_process:
|
|
899
|
-
logging.info(colorize(log_str))
|
|
900
|
-
train_log_payload["loss"] = float(total_loss_val)
|
|
901
|
-
if train_metrics:
|
|
902
|
-
train_log_payload.update(train_metrics)
|
|
884
|
+
}
|
|
885
|
+
if train_metrics:
|
|
886
|
+
train_log_payload.update(train_metrics)
|
|
887
|
+
|
|
888
|
+
display_metrics_table(
|
|
889
|
+
epoch=epoch + 1,
|
|
890
|
+
epochs=epochs,
|
|
891
|
+
split="Train",
|
|
892
|
+
loss=train_loss,
|
|
893
|
+
metrics=train_metrics,
|
|
894
|
+
target_names=self.target_columns,
|
|
895
|
+
base_metrics=(
|
|
896
|
+
self.metrics
|
|
897
|
+
if isinstance(getattr(self, "metrics", None), list)
|
|
898
|
+
else None
|
|
899
|
+
),
|
|
900
|
+
is_main_process=self.is_main_process,
|
|
901
|
+
colorize=lambda s: colorize(s),
|
|
902
|
+
)
|
|
903
903
|
if self.training_logger:
|
|
904
904
|
self.training_logger.log_metrics(
|
|
905
905
|
train_log_payload, step=epoch + 1, split="train"
|
|
906
906
|
)
|
|
907
907
|
if valid_loader is not None:
|
|
908
|
-
# Call on_validation_begin
|
|
909
908
|
self.callbacks.on_validation_begin()
|
|
910
|
-
|
|
911
|
-
# pass user_ids only if needed for GAUC metric
|
|
912
909
|
val_metrics = self.evaluate(
|
|
913
910
|
valid_loader,
|
|
914
911
|
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
915
912
|
num_workers=num_workers,
|
|
916
|
-
)
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
for target_name in self.target_columns:
|
|
933
|
-
if metric_key.endswith(f"_{target_name}"):
|
|
934
|
-
if target_name not in task_metrics:
|
|
935
|
-
task_metrics[target_name] = {}
|
|
936
|
-
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
937
|
-
task_metrics[target_name][metric_name] = metric_value
|
|
938
|
-
break
|
|
939
|
-
task_metric_strs = []
|
|
940
|
-
for target_name in self.target_columns:
|
|
941
|
-
if target_name in task_metrics:
|
|
942
|
-
metrics_str = ", ".join(
|
|
943
|
-
[
|
|
944
|
-
f"{k}={v:.4f}"
|
|
945
|
-
for k, v in task_metrics[target_name].items()
|
|
946
|
-
]
|
|
947
|
-
)
|
|
948
|
-
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
949
|
-
if self.is_main_process:
|
|
950
|
-
logging.info(
|
|
951
|
-
colorize(
|
|
952
|
-
f" Epoch {epoch + 1}/{epochs} - Valid: "
|
|
953
|
-
+ ", ".join(task_metric_strs),
|
|
954
|
-
color="cyan",
|
|
955
|
-
)
|
|
956
|
-
)
|
|
957
|
-
|
|
958
|
-
# Call on_validation_end
|
|
913
|
+
)
|
|
914
|
+
display_metrics_table(
|
|
915
|
+
epoch=epoch + 1,
|
|
916
|
+
epochs=epochs,
|
|
917
|
+
split="Valid",
|
|
918
|
+
loss=None,
|
|
919
|
+
metrics=val_metrics,
|
|
920
|
+
target_names=self.target_columns,
|
|
921
|
+
base_metrics=(
|
|
922
|
+
self.metrics
|
|
923
|
+
if isinstance(getattr(self, "metrics", None), list)
|
|
924
|
+
else None
|
|
925
|
+
),
|
|
926
|
+
is_main_process=self.is_main_process,
|
|
927
|
+
colorize=lambda s: colorize(" " + s, color="cyan"),
|
|
928
|
+
)
|
|
959
929
|
self.callbacks.on_validation_end()
|
|
960
930
|
if val_metrics and self.training_logger:
|
|
961
931
|
self.training_logger.log_metrics(
|
|
962
932
|
val_metrics, step=epoch + 1, split="valid"
|
|
963
933
|
)
|
|
964
|
-
|
|
965
|
-
# Handle empty validation metrics
|
|
966
934
|
if not val_metrics:
|
|
967
935
|
if self.is_main_process:
|
|
968
936
|
logging.info(
|
|
@@ -972,15 +940,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
972
940
|
)
|
|
973
941
|
)
|
|
974
942
|
continue
|
|
975
|
-
|
|
976
|
-
# Prepare epoch logs for callbacks
|
|
977
943
|
epoch_logs = {**train_log_payload}
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
for k, v in val_metrics.items():
|
|
981
|
-
epoch_logs[f"val_{k}"] = v
|
|
944
|
+
for k, v in val_metrics.items():
|
|
945
|
+
epoch_logs[f"val_{k}"] = v
|
|
982
946
|
else:
|
|
983
|
-
# No validation data
|
|
984
947
|
epoch_logs = {**train_log_payload}
|
|
985
948
|
if self.is_main_process:
|
|
986
949
|
self.save_model(
|
|
@@ -1007,13 +970,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1007
970
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
1008
971
|
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
1009
972
|
if self.is_main_process:
|
|
1010
|
-
logging.info("
|
|
1011
|
-
logging.info(colorize("Training finished.", bold=True))
|
|
1012
|
-
logging.info("
|
|
973
|
+
logging.info("")
|
|
974
|
+
logging.info(colorize("Training finished.", color="bright_blue", bold=True))
|
|
975
|
+
logging.info("")
|
|
1013
976
|
if valid_loader is not None:
|
|
1014
977
|
if self.is_main_process:
|
|
1015
978
|
logging.info(
|
|
1016
|
-
|
|
979
|
+
format_kv("Load best model from", self.best_checkpoint_path)
|
|
1017
980
|
)
|
|
1018
981
|
if os.path.exists(self.best_checkpoint_path):
|
|
1019
982
|
self.load_model(
|
|
@@ -1040,14 +1003,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1040
1003
|
num_batches = 0
|
|
1041
1004
|
y_true_list = []
|
|
1042
1005
|
y_pred_list = []
|
|
1006
|
+
collect_metrics = getattr(self, "collect_train_metrics", True)
|
|
1007
|
+
max_samples = getattr(self, "max_metrics_samples", None)
|
|
1008
|
+
collected_samples = 0
|
|
1009
|
+
metrics_capped = False
|
|
1043
1010
|
|
|
1044
1011
|
user_ids_list = [] if self.needs_user_ids else None
|
|
1045
1012
|
tqdm_disable = not self.is_main_process
|
|
1046
1013
|
if self.steps_per_epoch is not None:
|
|
1047
1014
|
batch_iter = enumerate(
|
|
1048
|
-
|
|
1015
|
+
progress(
|
|
1049
1016
|
train_loader,
|
|
1050
|
-
|
|
1017
|
+
description=f"Epoch {self.epoch_index + 1}",
|
|
1051
1018
|
total=self.steps_per_epoch,
|
|
1052
1019
|
disable=tqdm_disable,
|
|
1053
1020
|
)
|
|
@@ -1055,7 +1022,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1055
1022
|
else:
|
|
1056
1023
|
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
1057
1024
|
batch_iter = enumerate(
|
|
1058
|
-
|
|
1025
|
+
progress(
|
|
1026
|
+
train_loader,
|
|
1027
|
+
description=desc,
|
|
1028
|
+
disable=tqdm_disable,
|
|
1029
|
+
)
|
|
1059
1030
|
)
|
|
1060
1031
|
for batch_index, batch_data in batch_iter:
|
|
1061
1032
|
batch_dict = batch_to_dict(batch_data)
|
|
@@ -1074,16 +1045,34 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1074
1045
|
self.optimizer_fn.step()
|
|
1075
1046
|
accumulated_loss += loss.item()
|
|
1076
1047
|
|
|
1077
|
-
if
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
)
|
|
1083
|
-
if
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1048
|
+
if (
|
|
1049
|
+
collect_metrics
|
|
1050
|
+
and y_true is not None
|
|
1051
|
+
and isinstance(y_pred, torch.Tensor)
|
|
1052
|
+
):
|
|
1053
|
+
batch_size = int(y_true.size(0))
|
|
1054
|
+
if max_samples is not None and collected_samples >= max_samples:
|
|
1055
|
+
collect_metrics = False
|
|
1056
|
+
metrics_capped = True
|
|
1057
|
+
else:
|
|
1058
|
+
take_count = batch_size
|
|
1059
|
+
if (
|
|
1060
|
+
max_samples is not None
|
|
1061
|
+
and collected_samples + batch_size > max_samples
|
|
1062
|
+
):
|
|
1063
|
+
take_count = max_samples - collected_samples
|
|
1064
|
+
metrics_capped = True
|
|
1065
|
+
collect_metrics = False
|
|
1066
|
+
if take_count > 0:
|
|
1067
|
+
y_true_list.append(y_true[:take_count].detach().cpu().numpy())
|
|
1068
|
+
y_pred_list.append(y_pred[:take_count].detach().cpu().numpy())
|
|
1069
|
+
if self.needs_user_ids and user_ids_list is not None:
|
|
1070
|
+
batch_user_id = get_user_ids(
|
|
1071
|
+
data=batch_dict, id_columns=self.id_columns
|
|
1072
|
+
)
|
|
1073
|
+
if batch_user_id is not None:
|
|
1074
|
+
user_ids_list.append(batch_user_id[:take_count])
|
|
1075
|
+
collected_samples += take_count
|
|
1087
1076
|
num_batches += 1
|
|
1088
1077
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
1089
1078
|
loss_tensor = torch.tensor(
|
|
@@ -1109,6 +1098,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1109
1098
|
gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
|
|
1110
1099
|
)
|
|
1111
1100
|
|
|
1101
|
+
if metrics_capped and self.is_main_process:
|
|
1102
|
+
logging.info(
|
|
1103
|
+
colorize(
|
|
1104
|
+
f"[Training Info] Training metrics capped at {max_samples} samples to limit memory usage.",
|
|
1105
|
+
color="yellow",
|
|
1106
|
+
)
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1112
1109
|
if (
|
|
1113
1110
|
y_true_all is not None
|
|
1114
1111
|
and y_pred_all is not None
|
|
@@ -1247,11 +1244,15 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1247
1244
|
)
|
|
1248
1245
|
if batch_user_id is not None:
|
|
1249
1246
|
collected_user_ids.append(batch_user_id)
|
|
1250
|
-
if self.is_main_process:
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1247
|
+
# if self.is_main_process:
|
|
1248
|
+
# logging.info("")
|
|
1249
|
+
# logging.info(
|
|
1250
|
+
# colorize(
|
|
1251
|
+
# format_kv(
|
|
1252
|
+
# "Evaluation batches processed", batch_count
|
|
1253
|
+
# ),
|
|
1254
|
+
# )
|
|
1255
|
+
# )
|
|
1255
1256
|
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
1256
1257
|
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
1257
1258
|
|
|
@@ -1290,10 +1291,15 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1290
1291
|
)
|
|
1291
1292
|
)
|
|
1292
1293
|
return {}
|
|
1293
|
-
if self.is_main_process:
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1294
|
+
# if self.is_main_process:
|
|
1295
|
+
# logging.info(
|
|
1296
|
+
# colorize(
|
|
1297
|
+
# format_kv(
|
|
1298
|
+
# "Evaluation samples", y_true_all.shape[0]
|
|
1299
|
+
# ),
|
|
1300
|
+
# )
|
|
1301
|
+
# )
|
|
1302
|
+
logging.info("")
|
|
1297
1303
|
metrics_dict = evaluate_metrics(
|
|
1298
1304
|
y_true=y_true_all,
|
|
1299
1305
|
y_pred=y_pred_all,
|
|
@@ -1385,7 +1391,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1385
1391
|
id_arrays = None
|
|
1386
1392
|
|
|
1387
1393
|
with torch.no_grad():
|
|
1388
|
-
for batch_data in
|
|
1394
|
+
for batch_data in progress(data_loader, description="Predicting"):
|
|
1389
1395
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
1390
1396
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
1391
1397
|
y_pred = self(X_input)
|
|
@@ -1406,10 +1412,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1406
1412
|
if id_np.ndim == 1
|
|
1407
1413
|
else id_np
|
|
1408
1414
|
)
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
y_pred_all = np.array([])
|
|
1415
|
+
y_pred_all = (
|
|
1416
|
+
np.concatenate(y_pred_list, axis=0) if y_pred_list else np.array([])
|
|
1417
|
+
)
|
|
1413
1418
|
|
|
1414
1419
|
if y_pred_all.ndim == 1:
|
|
1415
1420
|
y_pred_all = y_pred_all.reshape(-1, 1)
|
|
@@ -1417,22 +1422,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1417
1422
|
num_outputs = len(self.target_columns) if self.target_columns else 1
|
|
1418
1423
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
1419
1424
|
num_outputs = y_pred_all.shape[1]
|
|
1420
|
-
pred_columns: list[str] =
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
pred_columns.append(f"{name}")
|
|
1425
|
+
pred_columns: list[str] = (
|
|
1426
|
+
list(self.target_columns[:num_outputs]) if self.target_columns else []
|
|
1427
|
+
)
|
|
1424
1428
|
while len(pred_columns) < num_outputs:
|
|
1425
1429
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1426
1430
|
if include_ids and predict_id_columns:
|
|
1427
|
-
id_arrays = {
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
concatenated = np.concatenate(
|
|
1431
|
+
id_arrays = {
|
|
1432
|
+
id_name: (
|
|
1433
|
+
np.concatenate(
|
|
1431
1434
|
[p.reshape(p.shape[0], -1) for p in pieces], axis=0
|
|
1432
|
-
)
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1435
|
+
).reshape(-1)
|
|
1436
|
+
if pieces
|
|
1437
|
+
else np.array([], dtype=np.int64)
|
|
1438
|
+
)
|
|
1439
|
+
for id_name, pieces in id_buffers.items()
|
|
1440
|
+
}
|
|
1436
1441
|
if return_dataframe:
|
|
1437
1442
|
id_df = pd.DataFrame(id_arrays)
|
|
1438
1443
|
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
@@ -1533,7 +1538,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1533
1538
|
collected_frames = [] # only used when return_dataframe is True
|
|
1534
1539
|
|
|
1535
1540
|
with torch.no_grad():
|
|
1536
|
-
for batch_data in
|
|
1541
|
+
for batch_data in progress(data_loader, description="Predicting"):
|
|
1537
1542
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
1538
1543
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
1539
1544
|
y_pred = self.forward(X_input)
|
|
@@ -1544,25 +1549,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1544
1549
|
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
1545
1550
|
if pred_columns is None:
|
|
1546
1551
|
num_outputs = y_pred_np.shape[1]
|
|
1547
|
-
pred_columns =
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1552
|
+
pred_columns = (
|
|
1553
|
+
list(self.target_columns[:num_outputs])
|
|
1554
|
+
if self.target_columns
|
|
1555
|
+
else []
|
|
1556
|
+
)
|
|
1551
1557
|
while len(pred_columns) < num_outputs:
|
|
1552
1558
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
1553
1559
|
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
)
|
|
1565
|
-
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
1560
|
+
ids = batch_dict.get("ids") if include_ids and id_columns else None
|
|
1561
|
+
id_arrays_batch = {
|
|
1562
|
+
id_name: (
|
|
1563
|
+
ids[id_name].detach().cpu().numpy()
|
|
1564
|
+
if isinstance(ids[id_name], torch.Tensor)
|
|
1565
|
+
else np.asarray(ids[id_name])
|
|
1566
|
+
).reshape(-1)
|
|
1567
|
+
for id_name in (id_columns or [])
|
|
1568
|
+
if ids and id_name in ids
|
|
1569
|
+
}
|
|
1566
1570
|
|
|
1567
1571
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
1568
1572
|
if id_arrays_batch:
|
|
@@ -1764,13 +1768,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1764
1768
|
def summary(self):
|
|
1765
1769
|
logger = logging.getLogger()
|
|
1766
1770
|
|
|
1767
|
-
logger.info(
|
|
1771
|
+
logger.info("")
|
|
1768
1772
|
logger.info(
|
|
1769
1773
|
colorize(
|
|
1770
1774
|
f"Model Summary: {self.model_name}", color="bright_blue", bold=True
|
|
1771
1775
|
)
|
|
1772
1776
|
)
|
|
1773
|
-
logger.info(
|
|
1777
|
+
logger.info("")
|
|
1774
1778
|
|
|
1775
1779
|
logger.info("")
|
|
1776
1780
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
@@ -1892,6 +1896,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1892
1896
|
logger.info("Other Settings:")
|
|
1893
1897
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1894
1898
|
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1899
|
+
logger.info(f" Max Metrics Samples: {self.max_metrics_samples}")
|
|
1895
1900
|
logger.info(f" Session ID: {self.session_id}")
|
|
1896
1901
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1897
1902
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
@@ -2085,10 +2090,10 @@ class BaseMatchModel(BaseModel):
|
|
|
2085
2090
|
if effective_loss is None:
|
|
2086
2091
|
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2087
2092
|
elif isinstance(effective_loss, (str,)):
|
|
2088
|
-
if
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2093
|
+
if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
|
|
2094
|
+
"bce",
|
|
2095
|
+
"binary_crossentropy",
|
|
2096
|
+
}:
|
|
2092
2097
|
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2093
2098
|
elif isinstance(effective_loss, list):
|
|
2094
2099
|
if not effective_loss:
|
|
@@ -2115,7 +2120,9 @@ class BaseMatchModel(BaseModel):
|
|
|
2115
2120
|
callbacks=callbacks,
|
|
2116
2121
|
)
|
|
2117
2122
|
|
|
2118
|
-
def inbatch_logits(
|
|
2123
|
+
def inbatch_logits(
|
|
2124
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2125
|
+
) -> torch.Tensor:
|
|
2119
2126
|
if self.similarity_metric == "dot":
|
|
2120
2127
|
logits = torch.matmul(user_emb, item_emb.t())
|
|
2121
2128
|
elif self.similarity_metric == "cosine":
|
|
@@ -2216,7 +2223,9 @@ class BaseMatchModel(BaseModel):
|
|
|
2216
2223
|
|
|
2217
2224
|
eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
|
|
2218
2225
|
pos_logits = logits.diag() # [B]
|
|
2219
|
-
neg_logits = logits.masked_select(~eye).view(
|
|
2226
|
+
neg_logits = logits.masked_select(~eye).view(
|
|
2227
|
+
batch_size, batch_size - 1
|
|
2228
|
+
) # [B, B-1]
|
|
2220
2229
|
|
|
2221
2230
|
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
2222
2231
|
if isinstance(loss_fn, SampledSoftmaxLoss):
|
|
@@ -2281,7 +2290,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2281
2290
|
|
|
2282
2291
|
embeddings_list = []
|
|
2283
2292
|
with torch.no_grad():
|
|
2284
|
-
for batch_data in
|
|
2293
|
+
for batch_data in progress(data_loader, description="Encoding users"):
|
|
2285
2294
|
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
2286
2295
|
user_input = self.get_user_features(batch_dict["features"])
|
|
2287
2296
|
user_emb = self.user_tower(user_input)
|
|
@@ -2301,7 +2310,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2301
2310
|
|
|
2302
2311
|
embeddings_list = []
|
|
2303
2312
|
with torch.no_grad():
|
|
2304
|
-
for batch_data in
|
|
2313
|
+
for batch_data in progress(data_loader, description="Encoding items"):
|
|
2305
2314
|
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
2306
2315
|
item_input = self.get_item_features(batch_dict["features"])
|
|
2307
2316
|
item_emb = self.item_tower(item_input)
|