nextrec 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +337 -328
  8. nextrec/cli.py +25 -4
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +12 -13
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +2 -2
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/ffm.py +0 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +0 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +12 -10
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/METADATA +8 -7
  54. nextrec-0.4.9.dist-info/RECORD +70 -0
  55. nextrec/utils/device.py +0 -78
  56. nextrec/utils/distributed.py +0 -141
  57. nextrec/utils/file.py +0 -92
  58. nextrec/utils/initializer.py +0 -79
  59. nextrec/utils/optimizer.py +0 -75
  60. nextrec/utils/tensor.py +0 -72
  61. nextrec-0.4.7.dist-info/RECORD +0 -70
  62. /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
  63. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
  64. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
  65. {nextrec-0.4.7.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,52 +2,52 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 18/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
+ import getpass
10
+ import logging
9
11
  import os
10
- import tqdm
11
12
  import pickle
12
- import logging
13
- import getpass
14
13
  import socket
14
+ from pathlib import Path
15
+ from typing import Any, Literal, Union
16
+
15
17
  import numpy as np
16
18
  import pandas as pd
17
19
  import torch
20
+ import torch.distributed as dist
18
21
  import torch.nn as nn
19
22
  import torch.nn.functional as F
20
- import torch.distributed as dist
21
-
22
- from pathlib import Path
23
- from typing import Union, Literal, Any
23
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
24
  from torch.utils.data import DataLoader
25
25
  from torch.utils.data.distributed import DistributedSampler
26
- from torch.nn.parallel import DistributedDataParallel as DDP
27
26
 
27
+ from nextrec import __version__
28
28
  from nextrec.basic.callback import (
29
- EarlyStopper,
30
- CallbackList,
31
29
  Callback,
30
+ CallbackList,
32
31
  CheckpointSaver,
32
+ EarlyStopper,
33
33
  LearningRateScheduler,
34
34
  )
35
35
  from nextrec.basic.features import (
36
36
  DenseFeature,
37
- SparseFeature,
38
- SequenceFeature,
39
37
  FeatureSet,
38
+ SequenceFeature,
39
+ SparseFeature,
40
40
  )
41
- from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
42
-
43
- from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
44
- from nextrec.basic.session import resolve_save_path, create_session
45
- from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
46
-
47
- from nextrec.data.dataloader import build_tensors_from_data
48
- from nextrec.data.batch_utils import collate_fn, batch_to_dict
41
+ from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
42
+ from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
43
+ from nextrec.basic.session import create_session, resolve_save_path
44
+ from nextrec.data.batch_utils import batch_to_dict, collate_fn
49
45
  from nextrec.data.data_processing import get_column_data, get_user_ids
50
-
46
+ from nextrec.data.dataloader import (
47
+ RecDataLoader,
48
+ TensorDictDataset,
49
+ build_tensors_from_data,
50
+ )
51
51
  from nextrec.loss import (
52
52
  BPRLoss,
53
53
  HingeLoss,
@@ -55,17 +55,17 @@ from nextrec.loss import (
55
55
  SampledSoftmaxLoss,
56
56
  TripletLoss,
57
57
  get_loss_fn,
58
- get_loss_kwargs,
59
58
  )
60
- from nextrec.utils.tensor import to_tensor
61
- from nextrec.utils.device import configure_device
62
- from nextrec.utils.optimizer import get_optimizer, get_scheduler
63
- from nextrec.utils.distributed import (
59
+ from nextrec.utils.console import display_metrics_table, progress
60
+ from nextrec.utils.torch_utils import (
61
+ add_distributed_sampler,
62
+ configure_device,
64
63
  gather_numpy,
64
+ get_optimizer,
65
+ get_scheduler,
65
66
  init_process_group,
66
- add_distributed_sampler,
67
+ to_tensor,
67
68
  )
68
- from nextrec import __version__
69
69
 
70
70
 
71
71
  class BaseModel(FeatureSet, nn.Module):
@@ -91,6 +91,7 @@ class BaseModel(FeatureSet, nn.Module):
91
91
  dense_l2_reg: float = 0.0,
92
92
  device: str = "cpu",
93
93
  early_stop_patience: int = 20,
94
+ max_metrics_samples: int | None = 200000,
94
95
  session_id: str | None = None,
95
96
  callbacks: list[Callback] | None = None,
96
97
  distributed: bool = False,
@@ -117,6 +118,7 @@ class BaseModel(FeatureSet, nn.Module):
117
118
 
118
119
  device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
119
120
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
121
+ max_metrics_samples: Max samples to keep for training metrics. None disables limit.
120
122
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
121
123
  callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
122
124
 
@@ -146,7 +148,7 @@ class BaseModel(FeatureSet, nn.Module):
146
148
  self.session_path = self.session.root # pwd/session_id, path for this session
147
149
  self.checkpoint_path = os.path.join(
148
150
  self.session_path, self.model_name + "_checkpoint.pt"
149
- ) # example: pwd/session_id/DeepFM_checkpoint.pt
151
+ ) # e.g., pwd/session_id/DeepFM_checkpoint.pt
150
152
  self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
151
153
  self.features_config_path = os.path.join(
152
154
  self.session_path, "features_config.pkl"
@@ -167,6 +169,9 @@ class BaseModel(FeatureSet, nn.Module):
167
169
  self.loss_weight = None
168
170
 
169
171
  self.early_stop_patience = early_stop_patience
172
+ self.max_metrics_samples = (
173
+ None if max_metrics_samples is None else int(max_metrics_samples)
174
+ )
170
175
  self.max_gradient_norm = 1.0
171
176
  self.logger_initialized = False
172
177
  self.training_logger = None
@@ -182,15 +187,15 @@ class BaseModel(FeatureSet, nn.Module):
182
187
  include_modules = include_modules or []
183
188
  embedding_layer = getattr(self, embedding_attr, None)
184
189
  embed_dict = getattr(embedding_layer, "embed_dict", None)
185
- embedding_params: list[torch.Tensor] = []
186
190
  if embed_dict is not None:
187
- embedding_params.extend(
188
- embed.weight for embed in embed_dict.values() if hasattr(embed, "weight")
189
- )
191
+ embedding_params = [
192
+ embed.weight
193
+ for embed in embed_dict.values()
194
+ if hasattr(embed, "weight")
195
+ ]
190
196
  else:
191
197
  weight = getattr(embedding_layer, "weight", None)
192
- if isinstance(weight, torch.Tensor):
193
- embedding_params.append(weight)
198
+ embedding_params = [weight] if isinstance(weight, torch.Tensor) else []
194
199
 
195
200
  existing_embedding_ids = {id(param) for param in self.embedding_params}
196
201
  for param in embedding_params:
@@ -212,10 +217,12 @@ class BaseModel(FeatureSet, nn.Module):
212
217
  module is self
213
218
  or embedding_attr in name
214
219
  or isinstance(module, skip_types)
215
- or (include_modules and not any(inc in name for inc in include_modules))
216
- or any(exc in name for exc in exclude_modules)
217
220
  ):
218
221
  continue
222
+ if include_modules and not any(inc in name for inc in include_modules):
223
+ continue
224
+ if exclude_modules and any(exc in name for exc in exclude_modules):
225
+ continue
219
226
  if isinstance(module, nn.Linear):
220
227
  if id(module.weight) not in existing_reg_ids:
221
228
  self.regularization_weights.append(module.weight)
@@ -317,22 +324,20 @@ class BaseModel(FeatureSet, nn.Module):
317
324
  raise ValueError(
318
325
  f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
319
326
  )
320
- if not isinstance(train_data, (pd.DataFrame, dict)):
321
- raise TypeError(
322
- f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
323
- )
324
327
  if isinstance(train_data, pd.DataFrame):
325
328
  total_length = len(train_data)
326
- else:
327
- sample_key = next(
328
- iter(train_data)
329
- ) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
330
- total_length = len(train_data[sample_key]) # len(train_data['user_id'])
329
+ elif isinstance(train_data, dict):
330
+ sample_key = next(iter(train_data))
331
+ total_length = len(train_data[sample_key])
331
332
  for k, v in train_data.items():
332
333
  if len(v) != total_length:
333
334
  raise ValueError(
334
335
  f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
335
336
  )
337
+ else:
338
+ raise TypeError(
339
+ f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
340
+ )
336
341
  rng = np.random.default_rng(42)
337
342
  indices = rng.permutation(total_length)
338
343
  split_idx = int(total_length * (1 - validation_split))
@@ -342,12 +347,12 @@ class BaseModel(FeatureSet, nn.Module):
342
347
  train_split = train_data.iloc[train_indices].reset_index(drop=True)
343
348
  valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
344
349
  else:
345
- train_split = {}
346
- valid_split = {}
347
- for key, value in train_data.items():
348
- arr = np.asarray(value)
349
- train_split[key] = arr[train_indices]
350
- valid_split[key] = arr[valid_indices]
350
+ train_split = {
351
+ k: np.asarray(v)[train_indices] for k, v in train_data.items()
352
+ }
353
+ valid_split = {
354
+ k: np.asarray(v)[valid_indices] for k, v in train_data.items()
355
+ }
351
356
  train_loader = self.prepare_data_loader(
352
357
  train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
353
358
  )
@@ -402,11 +407,11 @@ class BaseModel(FeatureSet, nn.Module):
402
407
  )
403
408
 
404
409
  scheduler_params = scheduler_params or {}
405
- if isinstance(scheduler, str):
406
- self.scheduler_name = scheduler
407
- elif scheduler is None:
410
+ if scheduler is None:
408
411
  self.scheduler_name = None
409
- else: # for custom scheduler instance, need to provide class name for logging
412
+ elif isinstance(scheduler, str):
413
+ self.scheduler_name = scheduler
414
+ else:
410
415
  self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
411
416
  self.scheduler_params = scheduler_params
412
417
  self.scheduler_fn = (
@@ -417,25 +422,23 @@ class BaseModel(FeatureSet, nn.Module):
417
422
 
418
423
  self.loss_config = loss
419
424
  self.loss_params = loss_params or {}
420
- self.loss_fn = []
421
- if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
425
+ if isinstance(loss, list):
422
426
  if len(loss) != self.nums_task:
423
427
  raise ValueError(
424
428
  f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
425
429
  )
426
- loss_list = [loss[i] for i in range(self.nums_task)]
427
- else: # for example: 'bce' -> ['bce', 'bce']
430
+ loss_list = list(loss)
431
+ else:
428
432
  loss_list = [loss] * self.nums_task
429
-
430
433
  if isinstance(self.loss_params, dict):
431
- params_list = [self.loss_params] * self.nums_task
432
- else: # list[dict]
433
- params_list = [
434
+ loss_params_list = [self.loss_params] * self.nums_task
435
+ else:
436
+ loss_params_list = [
434
437
  self.loss_params[i] if i < len(self.loss_params) else {}
435
438
  for i in range(self.nums_task)
436
439
  ]
437
440
  self.loss_fn = [
438
- get_loss_fn(loss=loss_list[i], **params_list[i])
441
+ get_loss_fn(loss=loss_list[i], **loss_params_list[i])
439
442
  for i in range(self.nums_task)
440
443
  ]
441
444
 
@@ -447,10 +450,8 @@ class BaseModel(FeatureSet, nn.Module):
447
450
  raise ValueError(
448
451
  "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
449
452
  )
450
- weight_value = loss_weights[0]
451
- else:
452
- weight_value = loss_weights
453
- self.loss_weights = [float(weight_value)]
453
+ loss_weights = loss_weights[0]
454
+ self.loss_weights = [float(loss_weights)]
454
455
  else:
455
456
  if isinstance(loss_weights, (int, float)):
456
457
  weights = [float(loss_weights)] * self.nums_task
@@ -483,7 +484,9 @@ class BaseModel(FeatureSet, nn.Module):
483
484
  y_true = y_true.view(-1, 1)
484
485
  if y_pred.shape != y_true.shape:
485
486
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
486
- task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
487
+ task_dim = (
488
+ self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
489
+ )
487
490
  if task_dim == 1:
488
491
  loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
489
492
  else:
@@ -494,12 +497,11 @@ class BaseModel(FeatureSet, nn.Module):
494
497
  # multi-task
495
498
  if y_pred.shape != y_true.shape:
496
499
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
497
- if hasattr(
498
- self, "prediction_layer"
499
- ): # we need to use registered task_slices for multi-task and multi-class
500
- slices = self.prediction_layer.task_slices # type: ignore
501
- else:
502
- slices = [(i, i + 1) for i in range(self.nums_task)]
500
+ slices = (
501
+ self.prediction_layer.task_slices # type: ignore
502
+ if hasattr(self, "prediction_layer")
503
+ else [(i, i + 1) for i in range(self.nums_task)]
504
+ )
503
505
  task_losses = []
504
506
  for i, (start, end) in enumerate(slices): # type: ignore
505
507
  y_pred_i = y_pred[:, start:end]
@@ -519,6 +521,9 @@ class BaseModel(FeatureSet, nn.Module):
519
521
  sampler=None,
520
522
  return_dataset: bool = False,
521
523
  ) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
524
+ """
525
+ Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
526
+ """
522
527
  if isinstance(data, DataLoader):
523
528
  return (data, None) if return_dataset else data
524
529
  tensors = build_tensors_from_data(
@@ -625,54 +630,55 @@ class BaseModel(FeatureSet, nn.Module):
625
630
  )
626
631
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
627
632
 
628
- # Setup default callbacks if none exist
629
- if len(self.callbacks.callbacks) == 0:
630
- if self.nums_task == 1:
631
- monitor_metric = f"val_{self.metrics[0]}"
632
- else:
633
- monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
634
-
635
- if self.early_stop_patience > 0:
636
- self.callbacks.append(
637
- EarlyStopper(
638
- monitor=monitor_metric,
639
- patience=self.early_stop_patience,
640
- mode=self.best_metrics_mode,
641
- restore_best_weights=not self.distributed,
642
- verbose=1 if self.is_main_process else 0,
643
- )
633
+ # Setup default callbacks if missing
634
+ if self.nums_task == 1:
635
+ monitor_metric = f"val_{self.metrics[0]}"
636
+ else:
637
+ monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
638
+
639
+ existing_callbacks = self.callbacks.callbacks
640
+ has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
641
+ has_checkpoint = any(
642
+ isinstance(cb, CheckpointSaver) for cb in existing_callbacks
643
+ )
644
+ has_lr_scheduler = any(
645
+ isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
646
+ )
647
+
648
+ if self.early_stop_patience > 0 and not has_early_stop:
649
+ self.callbacks.append(
650
+ EarlyStopper(
651
+ monitor=monitor_metric,
652
+ patience=self.early_stop_patience,
653
+ mode=self.best_metrics_mode,
654
+ restore_best_weights=not self.distributed,
655
+ verbose=1 if self.is_main_process else 0,
644
656
  )
657
+ )
645
658
 
646
- if self.is_main_process:
647
- self.callbacks.append(
648
- CheckpointSaver(
649
- save_path=self.best_path,
650
- monitor=monitor_metric,
651
- mode=self.best_metrics_mode,
652
- save_best_only=True,
653
- verbose=1,
654
- )
659
+ if self.is_main_process and not has_checkpoint:
660
+ self.callbacks.append(
661
+ CheckpointSaver(
662
+ best_path=self.best_path,
663
+ checkpoint_path=self.checkpoint_path,
664
+ monitor=monitor_metric,
665
+ mode=self.best_metrics_mode,
666
+ save_best_only=True,
667
+ verbose=1,
655
668
  )
669
+ )
656
670
 
657
- if self.scheduler_fn is not None:
658
- self.callbacks.append(
659
- LearningRateScheduler(
660
- scheduler=self.scheduler_fn,
661
- verbose=1 if self.is_main_process else 0,
662
- )
671
+ if self.scheduler_fn is not None and not has_lr_scheduler:
672
+ self.callbacks.append(
673
+ LearningRateScheduler(
674
+ scheduler=self.scheduler_fn,
675
+ verbose=1 if self.is_main_process else 0,
663
676
  )
677
+ )
664
678
 
665
679
  self.callbacks.set_model(self)
666
680
  self.callbacks.set_params(
667
- {
668
- "epochs": epochs,
669
- "batch_size": batch_size,
670
- "metrics": self.metrics,
671
- }
672
- )
673
-
674
- self.early_stopper = EarlyStopper(
675
- patience=self.early_stop_patience, mode=self.best_metrics_mode
681
+ {"epochs": epochs, "batch_size": batch_size, "metrics": self.metrics}
676
682
  )
677
683
  self.best_metric = (
678
684
  float("-inf") if self.best_metrics_mode == "max" else float("inf")
@@ -684,6 +690,12 @@ class BaseModel(FeatureSet, nn.Module):
684
690
  self.epoch_index = 0
685
691
  self.stop_training = False
686
692
  self.best_checkpoint_path = self.best_path
693
+ use_ddp_sampler = (
694
+ auto_distributed_sampler
695
+ and self.distributed
696
+ and dist.is_available()
697
+ and dist.is_initialized()
698
+ )
687
699
 
688
700
  if not auto_distributed_sampler and self.distributed and self.is_main_process:
689
701
  logging.info(
@@ -696,12 +708,7 @@ class BaseModel(FeatureSet, nn.Module):
696
708
  train_sampler: DistributedSampler | None = None
697
709
  if validation_split is not None and valid_data is None:
698
710
  train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
699
- if (
700
- auto_distributed_sampler
701
- and self.distributed
702
- and dist.is_available()
703
- and dist.is_initialized()
704
- ):
711
+ if use_ddp_sampler:
705
712
  base_dataset = getattr(train_loader, "dataset", None)
706
713
  if base_dataset is not None and not isinstance(
707
714
  getattr(train_loader, "sampler", None), DistributedSampler
@@ -724,7 +731,7 @@ class BaseModel(FeatureSet, nn.Module):
724
731
  )
725
732
  else:
726
733
  if isinstance(train_data, DataLoader):
727
- if auto_distributed_sampler and self.distributed:
734
+ if use_ddp_sampler:
728
735
  train_loader, train_sampler = add_distributed_sampler(
729
736
  train_data,
730
737
  distributed=self.distributed,
@@ -739,16 +746,18 @@ class BaseModel(FeatureSet, nn.Module):
739
746
  else:
740
747
  train_loader = train_data
741
748
  else:
742
- result = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True)
743
- assert isinstance(result, tuple), "Expected tuple from prepare_data_loader with return_dataset=True"
749
+ result = self.prepare_data_loader(
750
+ train_data,
751
+ batch_size=batch_size,
752
+ shuffle=shuffle,
753
+ num_workers=num_workers,
754
+ return_dataset=True,
755
+ )
756
+ assert isinstance(
757
+ result, tuple
758
+ ), "[BaseModel-fit Error] Expected tuple from prepare_data_loader with return_dataset=True, but got something else."
744
759
  loader, dataset = result
745
- if (
746
- auto_distributed_sampler
747
- and self.distributed
748
- and dataset is not None
749
- and dist.is_available()
750
- and dist.is_initialized()
751
- ):
760
+ if use_ddp_sampler and dataset is not None:
752
761
  train_sampler = DistributedSampler(
753
762
  dataset,
754
763
  num_replicas=self.world_size,
@@ -793,34 +802,42 @@ class BaseModel(FeatureSet, nn.Module):
793
802
  except TypeError: # streaming data loader does not supported len()
794
803
  self.steps_per_epoch = None
795
804
  is_streaming = True
805
+ self.collect_train_metrics = not is_streaming
806
+ if is_streaming and self.is_main_process:
807
+ logging.info(
808
+ colorize(
809
+ "[Training Info] Streaming mode detected; training metrics collection is disabled to avoid memory growth.",
810
+ color="yellow",
811
+ )
812
+ )
796
813
 
797
814
  if self.is_main_process:
798
815
  self.summary()
799
816
  logging.info("")
800
- if self.training_logger and self.training_logger.enable_tensorboard:
801
- tb_dir = self.training_logger.tensorboard_logdir
802
- if tb_dir:
803
- user = getpass.getuser()
804
- host = socket.gethostname()
805
- tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
806
- ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
807
- logging.info(
808
- colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
809
- )
810
- logging.info(colorize("To view logs, run:", color="cyan"))
811
- logging.info(colorize(f" {tb_cmd}", color="cyan"))
812
- logging.info(colorize("Then SSH port forward:", color="cyan"))
813
- logging.info(colorize(f" {ssh_hint}", color="cyan"))
817
+ tb_dir = (
818
+ self.training_logger.tensorboard_logdir
819
+ if self.training_logger and self.training_logger.enable_tensorboard
820
+ else None
821
+ )
822
+ if tb_dir:
823
+ user = getpass.getuser()
824
+ host = socket.gethostname()
825
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
826
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
827
+ logging.info(
828
+ colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
829
+ )
830
+ logging.info(colorize("To view logs, run:", color="cyan"))
831
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
832
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
833
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
814
834
 
815
835
  logging.info("")
816
- logging.info(colorize("=" * 80, bold=True))
817
- if is_streaming:
818
- logging.info(colorize("Start streaming training", bold=True))
819
- else:
820
- logging.info(colorize("Start training", bold=True))
821
- logging.info(colorize("=" * 80, bold=True))
836
+ logging.info(colorize("[Training]", color="bright_blue", bold=True))
837
+ logging.info(colorize("-" * 80, color="bright_blue"))
838
+ logging.info(format_kv("Start training", f"{epochs} epochs"))
839
+ logging.info(format_kv("Model device", self.device))
822
840
  logging.info("")
823
- logging.info(colorize(f"Model device: {self.device}", bold=True))
824
841
 
825
842
  self.callbacks.on_train_begin()
826
843
 
@@ -843,126 +860,77 @@ class BaseModel(FeatureSet, nn.Module):
843
860
  and isinstance(train_loader.sampler, DistributedSampler)
844
861
  ):
845
862
  train_loader.sampler.set_epoch(epoch)
846
- # Type guard: ensure train_loader is DataLoader for train_epoch
863
+
847
864
  if not isinstance(train_loader, DataLoader):
848
- raise TypeError(f"Expected DataLoader for training, got {type(train_loader)}")
865
+ raise TypeError(
866
+ f"Expected DataLoader for training, got {type(train_loader)}"
867
+ )
849
868
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
850
- if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
869
+ if isinstance(
870
+ train_result, tuple
871
+ ): # [avg_loss, metrics_dict], e.g., (0.5, {'auc': 0.75, 'logloss': 0.45})
851
872
  train_loss, train_metrics = train_result
852
873
  else:
853
874
  train_loss = train_result
854
875
  train_metrics = None
855
876
 
856
- train_log_payload: dict[str, float] = {}
857
- # handle logging for single-task and multi-task
858
- if self.nums_task == 1:
859
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
860
- if train_metrics:
861
- metrics_str = ", ".join(
862
- [f"{k}={v:.4f}" for k, v in train_metrics.items()]
863
- )
864
- log_str += f", {metrics_str}"
865
- if self.is_main_process:
866
- logging.info(colorize(log_str))
867
- train_log_payload["loss"] = float(train_loss)
868
- if train_metrics:
869
- train_log_payload.update(train_metrics)
870
- else:
871
- total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
872
- log_str = (
873
- f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
877
+ logging.info("")
878
+ train_log_payload = {
879
+ "loss": (
880
+ float(np.sum(train_loss))
881
+ if isinstance(train_loss, np.ndarray)
882
+ else float(train_loss)
874
883
  )
875
- if train_metrics:
876
- # group metrics by task
877
- task_metrics = {}
878
- for metric_key, metric_value in train_metrics.items():
879
- for target_name in self.target_columns:
880
- if metric_key.endswith(f"_{target_name}"):
881
- if target_name not in task_metrics:
882
- task_metrics[target_name] = {}
883
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
884
- task_metrics[target_name][metric_name] = metric_value
885
- break
886
- if task_metrics:
887
- task_metric_strs = []
888
- for target_name in self.target_columns:
889
- if target_name in task_metrics:
890
- metrics_str = ", ".join(
891
- [
892
- f"{k}={v:.4f}"
893
- for k, v in task_metrics[target_name].items()
894
- ]
895
- )
896
- task_metric_strs.append(f"{target_name}[{metrics_str}]")
897
- log_str += ", " + ", ".join(task_metric_strs)
898
- if self.is_main_process:
899
- logging.info(colorize(log_str))
900
- train_log_payload["loss"] = float(total_loss_val)
901
- if train_metrics:
902
- train_log_payload.update(train_metrics)
884
+ }
885
+ if train_metrics:
886
+ train_log_payload.update(train_metrics)
887
+
888
+ display_metrics_table(
889
+ epoch=epoch + 1,
890
+ epochs=epochs,
891
+ split="Train",
892
+ loss=train_loss,
893
+ metrics=train_metrics,
894
+ target_names=self.target_columns,
895
+ base_metrics=(
896
+ self.metrics
897
+ if isinstance(getattr(self, "metrics", None), list)
898
+ else None
899
+ ),
900
+ is_main_process=self.is_main_process,
901
+ colorize=lambda s: colorize(s),
902
+ )
903
903
  if self.training_logger:
904
904
  self.training_logger.log_metrics(
905
905
  train_log_payload, step=epoch + 1, split="train"
906
906
  )
907
907
  if valid_loader is not None:
908
- # Call on_validation_begin
909
908
  self.callbacks.on_validation_begin()
910
-
911
- # pass user_ids only if needed for GAUC metric
912
909
  val_metrics = self.evaluate(
913
910
  valid_loader,
914
911
  user_ids=valid_user_ids if self.needs_user_ids else None,
915
912
  num_workers=num_workers,
916
- ) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
917
- if self.nums_task == 1:
918
- metrics_str = ", ".join(
919
- [f"{k}={v:.4f}" for k, v in val_metrics.items()]
920
- )
921
- if self.is_main_process:
922
- logging.info(
923
- colorize(
924
- f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
925
- color="cyan",
926
- )
927
- )
928
- else:
929
- # multi task metrics
930
- task_metrics = {}
931
- for metric_key, metric_value in val_metrics.items():
932
- for target_name in self.target_columns:
933
- if metric_key.endswith(f"_{target_name}"):
934
- if target_name not in task_metrics:
935
- task_metrics[target_name] = {}
936
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
937
- task_metrics[target_name][metric_name] = metric_value
938
- break
939
- task_metric_strs = []
940
- for target_name in self.target_columns:
941
- if target_name in task_metrics:
942
- metrics_str = ", ".join(
943
- [
944
- f"{k}={v:.4f}"
945
- for k, v in task_metrics[target_name].items()
946
- ]
947
- )
948
- task_metric_strs.append(f"{target_name}[{metrics_str}]")
949
- if self.is_main_process:
950
- logging.info(
951
- colorize(
952
- f" Epoch {epoch + 1}/{epochs} - Valid: "
953
- + ", ".join(task_metric_strs),
954
- color="cyan",
955
- )
956
- )
957
-
958
- # Call on_validation_end
913
+ )
914
+ display_metrics_table(
915
+ epoch=epoch + 1,
916
+ epochs=epochs,
917
+ split="Valid",
918
+ loss=None,
919
+ metrics=val_metrics,
920
+ target_names=self.target_columns,
921
+ base_metrics=(
922
+ self.metrics
923
+ if isinstance(getattr(self, "metrics", None), list)
924
+ else None
925
+ ),
926
+ is_main_process=self.is_main_process,
927
+ colorize=lambda s: colorize(" " + s, color="cyan"),
928
+ )
959
929
  self.callbacks.on_validation_end()
960
930
  if val_metrics and self.training_logger:
961
931
  self.training_logger.log_metrics(
962
932
  val_metrics, step=epoch + 1, split="valid"
963
933
  )
964
-
965
- # Handle empty validation metrics
966
934
  if not val_metrics:
967
935
  if self.is_main_process:
968
936
  logging.info(
@@ -972,15 +940,10 @@ class BaseModel(FeatureSet, nn.Module):
972
940
  )
973
941
  )
974
942
  continue
975
-
976
- # Prepare epoch logs for callbacks
977
943
  epoch_logs = {**train_log_payload}
978
- if val_metrics:
979
- # Add val_ prefix to validation metrics
980
- for k, v in val_metrics.items():
981
- epoch_logs[f"val_{k}"] = v
944
+ for k, v in val_metrics.items():
945
+ epoch_logs[f"val_{k}"] = v
982
946
  else:
983
- # No validation data
984
947
  epoch_logs = {**train_log_payload}
985
948
  if self.is_main_process:
986
949
  self.save_model(
@@ -1007,13 +970,13 @@ class BaseModel(FeatureSet, nn.Module):
1007
970
  if self.distributed and dist.is_available() and dist.is_initialized():
1008
971
  dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
1009
972
  if self.is_main_process:
1010
- logging.info(" ")
1011
- logging.info(colorize("Training finished.", bold=True))
1012
- logging.info(" ")
973
+ logging.info("")
974
+ logging.info(colorize("Training finished.", color="bright_blue", bold=True))
975
+ logging.info("")
1013
976
  if valid_loader is not None:
1014
977
  if self.is_main_process:
1015
978
  logging.info(
1016
- colorize(f"Load best model from: {self.best_checkpoint_path}")
979
+ format_kv("Load best model from", self.best_checkpoint_path)
1017
980
  )
1018
981
  if os.path.exists(self.best_checkpoint_path):
1019
982
  self.load_model(
@@ -1040,14 +1003,18 @@ class BaseModel(FeatureSet, nn.Module):
1040
1003
  num_batches = 0
1041
1004
  y_true_list = []
1042
1005
  y_pred_list = []
1006
+ collect_metrics = getattr(self, "collect_train_metrics", True)
1007
+ max_samples = getattr(self, "max_metrics_samples", None)
1008
+ collected_samples = 0
1009
+ metrics_capped = False
1043
1010
 
1044
1011
  user_ids_list = [] if self.needs_user_ids else None
1045
1012
  tqdm_disable = not self.is_main_process
1046
1013
  if self.steps_per_epoch is not None:
1047
1014
  batch_iter = enumerate(
1048
- tqdm.tqdm(
1015
+ progress(
1049
1016
  train_loader,
1050
- desc=f"Epoch {self.epoch_index + 1}",
1017
+ description=f"Epoch {self.epoch_index + 1}",
1051
1018
  total=self.steps_per_epoch,
1052
1019
  disable=tqdm_disable,
1053
1020
  )
@@ -1055,7 +1022,11 @@ class BaseModel(FeatureSet, nn.Module):
1055
1022
  else:
1056
1023
  desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
1057
1024
  batch_iter = enumerate(
1058
- tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable)
1025
+ progress(
1026
+ train_loader,
1027
+ description=desc,
1028
+ disable=tqdm_disable,
1029
+ )
1059
1030
  )
1060
1031
  for batch_index, batch_data in batch_iter:
1061
1032
  batch_dict = batch_to_dict(batch_data)
@@ -1074,16 +1045,34 @@ class BaseModel(FeatureSet, nn.Module):
1074
1045
  self.optimizer_fn.step()
1075
1046
  accumulated_loss += loss.item()
1076
1047
 
1077
- if y_true is not None:
1078
- y_true_list.append(y_true.detach().cpu().numpy())
1079
- if self.needs_user_ids and user_ids_list is not None:
1080
- batch_user_id = get_user_ids(
1081
- data=batch_dict, id_columns=self.id_columns
1082
- )
1083
- if batch_user_id is not None:
1084
- user_ids_list.append(batch_user_id)
1085
- if y_pred is not None and isinstance(y_pred, torch.Tensor):
1086
- y_pred_list.append(y_pred.detach().cpu().numpy())
1048
+ if (
1049
+ collect_metrics
1050
+ and y_true is not None
1051
+ and isinstance(y_pred, torch.Tensor)
1052
+ ):
1053
+ batch_size = int(y_true.size(0))
1054
+ if max_samples is not None and collected_samples >= max_samples:
1055
+ collect_metrics = False
1056
+ metrics_capped = True
1057
+ else:
1058
+ take_count = batch_size
1059
+ if (
1060
+ max_samples is not None
1061
+ and collected_samples + batch_size > max_samples
1062
+ ):
1063
+ take_count = max_samples - collected_samples
1064
+ metrics_capped = True
1065
+ collect_metrics = False
1066
+ if take_count > 0:
1067
+ y_true_list.append(y_true[:take_count].detach().cpu().numpy())
1068
+ y_pred_list.append(y_pred[:take_count].detach().cpu().numpy())
1069
+ if self.needs_user_ids and user_ids_list is not None:
1070
+ batch_user_id = get_user_ids(
1071
+ data=batch_dict, id_columns=self.id_columns
1072
+ )
1073
+ if batch_user_id is not None:
1074
+ user_ids_list.append(batch_user_id[:take_count])
1075
+ collected_samples += take_count
1087
1076
  num_batches += 1
1088
1077
  if self.distributed and dist.is_available() and dist.is_initialized():
1089
1078
  loss_tensor = torch.tensor(
@@ -1109,6 +1098,14 @@ class BaseModel(FeatureSet, nn.Module):
1109
1098
  gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
1110
1099
  )
1111
1100
 
1101
+ if metrics_capped and self.is_main_process:
1102
+ logging.info(
1103
+ colorize(
1104
+ f"[Training Info] Training metrics capped at {max_samples} samples to limit memory usage.",
1105
+ color="yellow",
1106
+ )
1107
+ )
1108
+
1112
1109
  if (
1113
1110
  y_true_all is not None
1114
1111
  and y_pred_all is not None
@@ -1247,11 +1244,15 @@ class BaseModel(FeatureSet, nn.Module):
1247
1244
  )
1248
1245
  if batch_user_id is not None:
1249
1246
  collected_user_ids.append(batch_user_id)
1250
- if self.is_main_process:
1251
- logging.info(" ")
1252
- logging.info(
1253
- colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
1254
- )
1247
+ # if self.is_main_process:
1248
+ # logging.info("")
1249
+ # logging.info(
1250
+ # colorize(
1251
+ # format_kv(
1252
+ # "Evaluation batches processed", batch_count
1253
+ # ),
1254
+ # )
1255
+ # )
1255
1256
  y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
1256
1257
  y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
1257
1258
 
@@ -1290,10 +1291,15 @@ class BaseModel(FeatureSet, nn.Module):
1290
1291
  )
1291
1292
  )
1292
1293
  return {}
1293
- if self.is_main_process:
1294
- logging.info(
1295
- colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan")
1296
- )
1294
+ # if self.is_main_process:
1295
+ # logging.info(
1296
+ # colorize(
1297
+ # format_kv(
1298
+ # "Evaluation samples", y_true_all.shape[0]
1299
+ # ),
1300
+ # )
1301
+ # )
1302
+ logging.info("")
1297
1303
  metrics_dict = evaluate_metrics(
1298
1304
  y_true=y_true_all,
1299
1305
  y_pred=y_pred_all,
@@ -1385,7 +1391,7 @@ class BaseModel(FeatureSet, nn.Module):
1385
1391
  id_arrays = None
1386
1392
 
1387
1393
  with torch.no_grad():
1388
- for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
1394
+ for batch_data in progress(data_loader, description="Predicting"):
1389
1395
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
1390
1396
  X_input, _ = self.get_input(batch_dict, require_labels=False)
1391
1397
  y_pred = self(X_input)
@@ -1406,10 +1412,9 @@ class BaseModel(FeatureSet, nn.Module):
1406
1412
  if id_np.ndim == 1
1407
1413
  else id_np
1408
1414
  )
1409
- if len(y_pred_list) > 0:
1410
- y_pred_all = np.concatenate(y_pred_list, axis=0)
1411
- else:
1412
- y_pred_all = np.array([])
1415
+ y_pred_all = (
1416
+ np.concatenate(y_pred_list, axis=0) if y_pred_list else np.array([])
1417
+ )
1413
1418
 
1414
1419
  if y_pred_all.ndim == 1:
1415
1420
  y_pred_all = y_pred_all.reshape(-1, 1)
@@ -1417,22 +1422,22 @@ class BaseModel(FeatureSet, nn.Module):
1417
1422
  num_outputs = len(self.target_columns) if self.target_columns else 1
1418
1423
  y_pred_all = y_pred_all.reshape(0, num_outputs)
1419
1424
  num_outputs = y_pred_all.shape[1]
1420
- pred_columns: list[str] = []
1421
- if self.target_columns:
1422
- for name in self.target_columns[:num_outputs]:
1423
- pred_columns.append(f"{name}")
1425
+ pred_columns: list[str] = (
1426
+ list(self.target_columns[:num_outputs]) if self.target_columns else []
1427
+ )
1424
1428
  while len(pred_columns) < num_outputs:
1425
1429
  pred_columns.append(f"pred_{len(pred_columns)}")
1426
1430
  if include_ids and predict_id_columns:
1427
- id_arrays = {}
1428
- for id_name, pieces in id_buffers.items():
1429
- if pieces:
1430
- concatenated = np.concatenate(
1431
+ id_arrays = {
1432
+ id_name: (
1433
+ np.concatenate(
1431
1434
  [p.reshape(p.shape[0], -1) for p in pieces], axis=0
1432
- )
1433
- id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
1434
- else:
1435
- id_arrays[id_name] = np.array([], dtype=np.int64)
1435
+ ).reshape(-1)
1436
+ if pieces
1437
+ else np.array([], dtype=np.int64)
1438
+ )
1439
+ for id_name, pieces in id_buffers.items()
1440
+ }
1436
1441
  if return_dataframe:
1437
1442
  id_df = pd.DataFrame(id_arrays)
1438
1443
  pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
@@ -1533,7 +1538,7 @@ class BaseModel(FeatureSet, nn.Module):
1533
1538
  collected_frames = [] # only used when return_dataframe is True
1534
1539
 
1535
1540
  with torch.no_grad():
1536
- for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
1541
+ for batch_data in progress(data_loader, description="Predicting"):
1537
1542
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
1538
1543
  X_input, _ = self.get_input(batch_dict, require_labels=False)
1539
1544
  y_pred = self.forward(X_input)
@@ -1544,25 +1549,24 @@ class BaseModel(FeatureSet, nn.Module):
1544
1549
  y_pred_np = y_pred_np.reshape(-1, 1)
1545
1550
  if pred_columns is None:
1546
1551
  num_outputs = y_pred_np.shape[1]
1547
- pred_columns = []
1548
- if self.target_columns:
1549
- for name in self.target_columns[:num_outputs]:
1550
- pred_columns.append(f"{name}")
1552
+ pred_columns = (
1553
+ list(self.target_columns[:num_outputs])
1554
+ if self.target_columns
1555
+ else []
1556
+ )
1551
1557
  while len(pred_columns) < num_outputs:
1552
1558
  pred_columns.append(f"pred_{len(pred_columns)}")
1553
1559
 
1554
- id_arrays_batch = {}
1555
- if include_ids and id_columns and batch_dict.get("ids"):
1556
- for id_name in id_columns:
1557
- if id_name not in batch_dict["ids"]:
1558
- continue
1559
- id_tensor = batch_dict["ids"][id_name]
1560
- id_np = (
1561
- id_tensor.detach().cpu().numpy()
1562
- if isinstance(id_tensor, torch.Tensor)
1563
- else np.asarray(id_tensor)
1564
- )
1565
- id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
1560
+ ids = batch_dict.get("ids") if include_ids and id_columns else None
1561
+ id_arrays_batch = {
1562
+ id_name: (
1563
+ ids[id_name].detach().cpu().numpy()
1564
+ if isinstance(ids[id_name], torch.Tensor)
1565
+ else np.asarray(ids[id_name])
1566
+ ).reshape(-1)
1567
+ for id_name in (id_columns or [])
1568
+ if ids and id_name in ids
1569
+ }
1566
1570
 
1567
1571
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
1568
1572
  if id_arrays_batch:
@@ -1764,13 +1768,13 @@ class BaseModel(FeatureSet, nn.Module):
1764
1768
  def summary(self):
1765
1769
  logger = logging.getLogger()
1766
1770
 
1767
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1771
+ logger.info("")
1768
1772
  logger.info(
1769
1773
  colorize(
1770
1774
  f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1771
1775
  )
1772
1776
  )
1773
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1777
+ logger.info("")
1774
1778
 
1775
1779
  logger.info("")
1776
1780
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
@@ -1892,6 +1896,7 @@ class BaseModel(FeatureSet, nn.Module):
1892
1896
  logger.info("Other Settings:")
1893
1897
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1894
1898
  logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1899
+ logger.info(f" Max Metrics Samples: {self.max_metrics_samples}")
1895
1900
  logger.info(f" Session ID: {self.session_id}")
1896
1901
  logger.info(f" Features Config Path: {self.features_config_path}")
1897
1902
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
@@ -2085,10 +2090,10 @@ class BaseMatchModel(BaseModel):
2085
2090
  if effective_loss is None:
2086
2091
  effective_loss = default_loss_by_mode[self.training_mode]
2087
2092
  elif isinstance(effective_loss, (str,)):
2088
- if (
2089
- self.training_mode in {"pairwise", "listwise"}
2090
- and effective_loss in {"bce", "binary_crossentropy"}
2091
- ):
2093
+ if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
2094
+ "bce",
2095
+ "binary_crossentropy",
2096
+ }:
2092
2097
  effective_loss = default_loss_by_mode[self.training_mode]
2093
2098
  elif isinstance(effective_loss, list):
2094
2099
  if not effective_loss:
@@ -2115,7 +2120,9 @@ class BaseMatchModel(BaseModel):
2115
2120
  callbacks=callbacks,
2116
2121
  )
2117
2122
 
2118
- def inbatch_logits(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
2123
+ def inbatch_logits(
2124
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
2125
+ ) -> torch.Tensor:
2119
2126
  if self.similarity_metric == "dot":
2120
2127
  logits = torch.matmul(user_emb, item_emb.t())
2121
2128
  elif self.similarity_metric == "cosine":
@@ -2216,7 +2223,9 @@ class BaseMatchModel(BaseModel):
2216
2223
 
2217
2224
  eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
2218
2225
  pos_logits = logits.diag() # [B]
2219
- neg_logits = logits.masked_select(~eye).view(batch_size, batch_size - 1) # [B, B-1]
2226
+ neg_logits = logits.masked_select(~eye).view(
2227
+ batch_size, batch_size - 1
2228
+ ) # [B, B-1]
2220
2229
 
2221
2230
  loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
2222
2231
  if isinstance(loss_fn, SampledSoftmaxLoss):
@@ -2281,7 +2290,7 @@ class BaseMatchModel(BaseModel):
2281
2290
 
2282
2291
  embeddings_list = []
2283
2292
  with torch.no_grad():
2284
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
2293
+ for batch_data in progress(data_loader, description="Encoding users"):
2285
2294
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2286
2295
  user_input = self.get_user_features(batch_dict["features"])
2287
2296
  user_emb = self.user_tower(user_input)
@@ -2301,7 +2310,7 @@ class BaseMatchModel(BaseModel):
2301
2310
 
2302
2311
  embeddings_list = []
2303
2312
  with torch.no_grad():
2304
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
2313
+ for batch_data in progress(data_loader, description="Encoding items"):
2305
2314
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2306
2315
  item_input = self.get_item_features(batch_dict["features"])
2307
2316
  item_emb = self.item_tower(item_input)