nextrec 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +312 -318
  8. nextrec/cli.py +5 -10
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +12 -13
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +2 -2
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/ffm.py +0 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +0 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +8 -7
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/METADATA +6 -6
  54. nextrec-0.4.9.dist-info/RECORD +70 -0
  55. nextrec/utils/cli_utils.py +0 -58
  56. nextrec/utils/device.py +0 -78
  57. nextrec/utils/distributed.py +0 -141
  58. nextrec/utils/file.py +0 -92
  59. nextrec/utils/initializer.py +0 -79
  60. nextrec/utils/optimizer.py +0 -75
  61. nextrec/utils/tensor.py +0 -72
  62. nextrec-0.4.8.dist-info/RECORD +0 -71
  63. /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
  64. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
  65. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
  66. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,52 +2,52 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 18/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
+ import getpass
10
+ import logging
9
11
  import os
10
- import tqdm
11
12
  import pickle
12
- import logging
13
- import getpass
14
13
  import socket
14
+ from pathlib import Path
15
+ from typing import Any, Literal, Union
16
+
15
17
  import numpy as np
16
18
  import pandas as pd
17
19
  import torch
20
+ import torch.distributed as dist
18
21
  import torch.nn as nn
19
22
  import torch.nn.functional as F
20
- import torch.distributed as dist
21
-
22
- from pathlib import Path
23
- from typing import Union, Literal, Any
23
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
24
  from torch.utils.data import DataLoader
25
25
  from torch.utils.data.distributed import DistributedSampler
26
- from torch.nn.parallel import DistributedDataParallel as DDP
27
26
 
27
+ from nextrec import __version__
28
28
  from nextrec.basic.callback import (
29
- EarlyStopper,
30
- CallbackList,
31
29
  Callback,
30
+ CallbackList,
32
31
  CheckpointSaver,
32
+ EarlyStopper,
33
33
  LearningRateScheduler,
34
34
  )
35
35
  from nextrec.basic.features import (
36
36
  DenseFeature,
37
- SparseFeature,
38
- SequenceFeature,
39
37
  FeatureSet,
38
+ SequenceFeature,
39
+ SparseFeature,
40
40
  )
41
- from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
42
-
43
- from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
44
- from nextrec.basic.session import resolve_save_path, create_session
45
- from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
46
-
47
- from nextrec.data.dataloader import build_tensors_from_data
48
- from nextrec.data.batch_utils import collate_fn, batch_to_dict
41
+ from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
42
+ from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
43
+ from nextrec.basic.session import create_session, resolve_save_path
44
+ from nextrec.data.batch_utils import batch_to_dict, collate_fn
49
45
  from nextrec.data.data_processing import get_column_data, get_user_ids
50
-
46
+ from nextrec.data.dataloader import (
47
+ RecDataLoader,
48
+ TensorDictDataset,
49
+ build_tensors_from_data,
50
+ )
51
51
  from nextrec.loss import (
52
52
  BPRLoss,
53
53
  HingeLoss,
@@ -56,15 +56,16 @@ from nextrec.loss import (
56
56
  TripletLoss,
57
57
  get_loss_fn,
58
58
  )
59
- from nextrec.utils.tensor import to_tensor
60
- from nextrec.utils.device import configure_device
61
- from nextrec.utils.optimizer import get_optimizer, get_scheduler
62
- from nextrec.utils.distributed import (
59
+ from nextrec.utils.console import display_metrics_table, progress
60
+ from nextrec.utils.torch_utils import (
61
+ add_distributed_sampler,
62
+ configure_device,
63
63
  gather_numpy,
64
+ get_optimizer,
65
+ get_scheduler,
64
66
  init_process_group,
65
- add_distributed_sampler,
67
+ to_tensor,
66
68
  )
67
- from nextrec import __version__
68
69
 
69
70
 
70
71
  class BaseModel(FeatureSet, nn.Module):
@@ -90,6 +91,7 @@ class BaseModel(FeatureSet, nn.Module):
90
91
  dense_l2_reg: float = 0.0,
91
92
  device: str = "cpu",
92
93
  early_stop_patience: int = 20,
94
+ max_metrics_samples: int | None = 200000,
93
95
  session_id: str | None = None,
94
96
  callbacks: list[Callback] | None = None,
95
97
  distributed: bool = False,
@@ -116,6 +118,7 @@ class BaseModel(FeatureSet, nn.Module):
116
118
 
117
119
  device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
118
120
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
121
+ max_metrics_samples: Max samples to keep for training metrics. None disables limit.
119
122
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
120
123
  callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
121
124
 
@@ -145,7 +148,7 @@ class BaseModel(FeatureSet, nn.Module):
145
148
  self.session_path = self.session.root # pwd/session_id, path for this session
146
149
  self.checkpoint_path = os.path.join(
147
150
  self.session_path, self.model_name + "_checkpoint.pt"
148
- ) # example: pwd/session_id/DeepFM_checkpoint.pt
151
+ ) # e.g., pwd/session_id/DeepFM_checkpoint.pt
149
152
  self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
150
153
  self.features_config_path = os.path.join(
151
154
  self.session_path, "features_config.pkl"
@@ -166,6 +169,9 @@ class BaseModel(FeatureSet, nn.Module):
166
169
  self.loss_weight = None
167
170
 
168
171
  self.early_stop_patience = early_stop_patience
172
+ self.max_metrics_samples = (
173
+ None if max_metrics_samples is None else int(max_metrics_samples)
174
+ )
169
175
  self.max_gradient_norm = 1.0
170
176
  self.logger_initialized = False
171
177
  self.training_logger = None
@@ -181,17 +187,15 @@ class BaseModel(FeatureSet, nn.Module):
181
187
  include_modules = include_modules or []
182
188
  embedding_layer = getattr(self, embedding_attr, None)
183
189
  embed_dict = getattr(embedding_layer, "embed_dict", None)
184
- embedding_params: list[torch.Tensor] = []
185
190
  if embed_dict is not None:
186
- embedding_params.extend(
191
+ embedding_params = [
187
192
  embed.weight
188
193
  for embed in embed_dict.values()
189
194
  if hasattr(embed, "weight")
190
- )
195
+ ]
191
196
  else:
192
197
  weight = getattr(embedding_layer, "weight", None)
193
- if isinstance(weight, torch.Tensor):
194
- embedding_params.append(weight)
198
+ embedding_params = [weight] if isinstance(weight, torch.Tensor) else []
195
199
 
196
200
  existing_embedding_ids = {id(param) for param in self.embedding_params}
197
201
  for param in embedding_params:
@@ -213,10 +217,12 @@ class BaseModel(FeatureSet, nn.Module):
213
217
  module is self
214
218
  or embedding_attr in name
215
219
  or isinstance(module, skip_types)
216
- or (include_modules and not any(inc in name for inc in include_modules))
217
- or any(exc in name for exc in exclude_modules)
218
220
  ):
219
221
  continue
222
+ if include_modules and not any(inc in name for inc in include_modules):
223
+ continue
224
+ if exclude_modules and any(exc in name for exc in exclude_modules):
225
+ continue
220
226
  if isinstance(module, nn.Linear):
221
227
  if id(module.weight) not in existing_reg_ids:
222
228
  self.regularization_weights.append(module.weight)
@@ -318,22 +324,20 @@ class BaseModel(FeatureSet, nn.Module):
318
324
  raise ValueError(
319
325
  f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
320
326
  )
321
- if not isinstance(train_data, (pd.DataFrame, dict)):
322
- raise TypeError(
323
- f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
324
- )
325
327
  if isinstance(train_data, pd.DataFrame):
326
328
  total_length = len(train_data)
327
- else:
328
- sample_key = next(
329
- iter(train_data)
330
- ) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
331
- total_length = len(train_data[sample_key]) # len(train_data['user_id'])
329
+ elif isinstance(train_data, dict):
330
+ sample_key = next(iter(train_data))
331
+ total_length = len(train_data[sample_key])
332
332
  for k, v in train_data.items():
333
333
  if len(v) != total_length:
334
334
  raise ValueError(
335
335
  f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
336
336
  )
337
+ else:
338
+ raise TypeError(
339
+ f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
340
+ )
337
341
  rng = np.random.default_rng(42)
338
342
  indices = rng.permutation(total_length)
339
343
  split_idx = int(total_length * (1 - validation_split))
@@ -343,12 +347,12 @@ class BaseModel(FeatureSet, nn.Module):
343
347
  train_split = train_data.iloc[train_indices].reset_index(drop=True)
344
348
  valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
345
349
  else:
346
- train_split = {}
347
- valid_split = {}
348
- for key, value in train_data.items():
349
- arr = np.asarray(value)
350
- train_split[key] = arr[train_indices]
351
- valid_split[key] = arr[valid_indices]
350
+ train_split = {
351
+ k: np.asarray(v)[train_indices] for k, v in train_data.items()
352
+ }
353
+ valid_split = {
354
+ k: np.asarray(v)[valid_indices] for k, v in train_data.items()
355
+ }
352
356
  train_loader = self.prepare_data_loader(
353
357
  train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
354
358
  )
@@ -403,11 +407,11 @@ class BaseModel(FeatureSet, nn.Module):
403
407
  )
404
408
 
405
409
  scheduler_params = scheduler_params or {}
406
- if isinstance(scheduler, str):
407
- self.scheduler_name = scheduler
408
- elif scheduler is None:
410
+ if scheduler is None:
409
411
  self.scheduler_name = None
410
- else: # for custom scheduler instance, need to provide class name for logging
412
+ elif isinstance(scheduler, str):
413
+ self.scheduler_name = scheduler
414
+ else:
411
415
  self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
412
416
  self.scheduler_params = scheduler_params
413
417
  self.scheduler_fn = (
@@ -418,25 +422,23 @@ class BaseModel(FeatureSet, nn.Module):
418
422
 
419
423
  self.loss_config = loss
420
424
  self.loss_params = loss_params or {}
421
- self.loss_fn = []
422
- if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
425
+ if isinstance(loss, list):
423
426
  if len(loss) != self.nums_task:
424
427
  raise ValueError(
425
428
  f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
426
429
  )
427
- loss_list = [loss[i] for i in range(self.nums_task)]
428
- else: # for example: 'bce' -> ['bce', 'bce']
430
+ loss_list = list(loss)
431
+ else:
429
432
  loss_list = [loss] * self.nums_task
430
-
431
433
  if isinstance(self.loss_params, dict):
432
- params_list = [self.loss_params] * self.nums_task
433
- else: # list[dict]
434
- params_list = [
434
+ loss_params_list = [self.loss_params] * self.nums_task
435
+ else:
436
+ loss_params_list = [
435
437
  self.loss_params[i] if i < len(self.loss_params) else {}
436
438
  for i in range(self.nums_task)
437
439
  ]
438
440
  self.loss_fn = [
439
- get_loss_fn(loss=loss_list[i], **params_list[i])
441
+ get_loss_fn(loss=loss_list[i], **loss_params_list[i])
440
442
  for i in range(self.nums_task)
441
443
  ]
442
444
 
@@ -448,10 +450,8 @@ class BaseModel(FeatureSet, nn.Module):
448
450
  raise ValueError(
449
451
  "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
450
452
  )
451
- weight_value = loss_weights[0]
452
- else:
453
- weight_value = loss_weights
454
- self.loss_weights = [float(weight_value)]
453
+ loss_weights = loss_weights[0]
454
+ self.loss_weights = [float(loss_weights)]
455
455
  else:
456
456
  if isinstance(loss_weights, (int, float)):
457
457
  weights = [float(loss_weights)] * self.nums_task
@@ -484,7 +484,9 @@ class BaseModel(FeatureSet, nn.Module):
484
484
  y_true = y_true.view(-1, 1)
485
485
  if y_pred.shape != y_true.shape:
486
486
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
487
- task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
487
+ task_dim = (
488
+ self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
489
+ )
488
490
  if task_dim == 1:
489
491
  loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
490
492
  else:
@@ -495,12 +497,11 @@ class BaseModel(FeatureSet, nn.Module):
495
497
  # multi-task
496
498
  if y_pred.shape != y_true.shape:
497
499
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
498
- if hasattr(
499
- self, "prediction_layer"
500
- ): # we need to use registered task_slices for multi-task and multi-class
501
- slices = self.prediction_layer.task_slices # type: ignore
502
- else:
503
- slices = [(i, i + 1) for i in range(self.nums_task)]
500
+ slices = (
501
+ self.prediction_layer.task_slices # type: ignore
502
+ if hasattr(self, "prediction_layer")
503
+ else [(i, i + 1) for i in range(self.nums_task)]
504
+ )
504
505
  task_losses = []
505
506
  for i, (start, end) in enumerate(slices): # type: ignore
506
507
  y_pred_i = y_pred[:, start:end]
@@ -520,6 +521,9 @@ class BaseModel(FeatureSet, nn.Module):
520
521
  sampler=None,
521
522
  return_dataset: bool = False,
522
523
  ) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
524
+ """
525
+ Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
526
+ """
523
527
  if isinstance(data, DataLoader):
524
528
  return (data, None) if return_dataset else data
525
529
  tensors = build_tensors_from_data(
@@ -626,54 +630,55 @@ class BaseModel(FeatureSet, nn.Module):
626
630
  )
627
631
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
628
632
 
629
- # Setup default callbacks if none exist
630
- if len(self.callbacks.callbacks) == 0:
631
- if self.nums_task == 1:
632
- monitor_metric = f"val_{self.metrics[0]}"
633
- else:
634
- monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
635
-
636
- if self.early_stop_patience > 0:
637
- self.callbacks.append(
638
- EarlyStopper(
639
- monitor=monitor_metric,
640
- patience=self.early_stop_patience,
641
- mode=self.best_metrics_mode,
642
- restore_best_weights=not self.distributed,
643
- verbose=1 if self.is_main_process else 0,
644
- )
633
+ # Setup default callbacks if missing
634
+ if self.nums_task == 1:
635
+ monitor_metric = f"val_{self.metrics[0]}"
636
+ else:
637
+ monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
638
+
639
+ existing_callbacks = self.callbacks.callbacks
640
+ has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
641
+ has_checkpoint = any(
642
+ isinstance(cb, CheckpointSaver) for cb in existing_callbacks
643
+ )
644
+ has_lr_scheduler = any(
645
+ isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
646
+ )
647
+
648
+ if self.early_stop_patience > 0 and not has_early_stop:
649
+ self.callbacks.append(
650
+ EarlyStopper(
651
+ monitor=monitor_metric,
652
+ patience=self.early_stop_patience,
653
+ mode=self.best_metrics_mode,
654
+ restore_best_weights=not self.distributed,
655
+ verbose=1 if self.is_main_process else 0,
645
656
  )
657
+ )
646
658
 
647
- if self.is_main_process:
648
- self.callbacks.append(
649
- CheckpointSaver(
650
- save_path=self.best_path,
651
- monitor=monitor_metric,
652
- mode=self.best_metrics_mode,
653
- save_best_only=True,
654
- verbose=1,
655
- )
659
+ if self.is_main_process and not has_checkpoint:
660
+ self.callbacks.append(
661
+ CheckpointSaver(
662
+ best_path=self.best_path,
663
+ checkpoint_path=self.checkpoint_path,
664
+ monitor=monitor_metric,
665
+ mode=self.best_metrics_mode,
666
+ save_best_only=True,
667
+ verbose=1,
656
668
  )
669
+ )
657
670
 
658
- if self.scheduler_fn is not None:
659
- self.callbacks.append(
660
- LearningRateScheduler(
661
- scheduler=self.scheduler_fn,
662
- verbose=1 if self.is_main_process else 0,
663
- )
671
+ if self.scheduler_fn is not None and not has_lr_scheduler:
672
+ self.callbacks.append(
673
+ LearningRateScheduler(
674
+ scheduler=self.scheduler_fn,
675
+ verbose=1 if self.is_main_process else 0,
664
676
  )
677
+ )
665
678
 
666
679
  self.callbacks.set_model(self)
667
680
  self.callbacks.set_params(
668
- {
669
- "epochs": epochs,
670
- "batch_size": batch_size,
671
- "metrics": self.metrics,
672
- }
673
- )
674
-
675
- self.early_stopper = EarlyStopper(
676
- patience=self.early_stop_patience, mode=self.best_metrics_mode
681
+ {"epochs": epochs, "batch_size": batch_size, "metrics": self.metrics}
677
682
  )
678
683
  self.best_metric = (
679
684
  float("-inf") if self.best_metrics_mode == "max" else float("inf")
@@ -685,6 +690,12 @@ class BaseModel(FeatureSet, nn.Module):
685
690
  self.epoch_index = 0
686
691
  self.stop_training = False
687
692
  self.best_checkpoint_path = self.best_path
693
+ use_ddp_sampler = (
694
+ auto_distributed_sampler
695
+ and self.distributed
696
+ and dist.is_available()
697
+ and dist.is_initialized()
698
+ )
688
699
 
689
700
  if not auto_distributed_sampler and self.distributed and self.is_main_process:
690
701
  logging.info(
@@ -697,12 +708,7 @@ class BaseModel(FeatureSet, nn.Module):
697
708
  train_sampler: DistributedSampler | None = None
698
709
  if validation_split is not None and valid_data is None:
699
710
  train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
700
- if (
701
- auto_distributed_sampler
702
- and self.distributed
703
- and dist.is_available()
704
- and dist.is_initialized()
705
- ):
711
+ if use_ddp_sampler:
706
712
  base_dataset = getattr(train_loader, "dataset", None)
707
713
  if base_dataset is not None and not isinstance(
708
714
  getattr(train_loader, "sampler", None), DistributedSampler
@@ -725,7 +731,7 @@ class BaseModel(FeatureSet, nn.Module):
725
731
  )
726
732
  else:
727
733
  if isinstance(train_data, DataLoader):
728
- if auto_distributed_sampler and self.distributed:
734
+ if use_ddp_sampler:
729
735
  train_loader, train_sampler = add_distributed_sampler(
730
736
  train_data,
731
737
  distributed=self.distributed,
@@ -749,15 +755,9 @@ class BaseModel(FeatureSet, nn.Module):
749
755
  )
750
756
  assert isinstance(
751
757
  result, tuple
752
- ), "Expected tuple from prepare_data_loader with return_dataset=True"
758
+ ), "[BaseModel-fit Error] Expected tuple from prepare_data_loader with return_dataset=True, but got something else."
753
759
  loader, dataset = result
754
- if (
755
- auto_distributed_sampler
756
- and self.distributed
757
- and dataset is not None
758
- and dist.is_available()
759
- and dist.is_initialized()
760
- ):
760
+ if use_ddp_sampler and dataset is not None:
761
761
  train_sampler = DistributedSampler(
762
762
  dataset,
763
763
  num_replicas=self.world_size,
@@ -802,34 +802,42 @@ class BaseModel(FeatureSet, nn.Module):
802
802
  except TypeError: # streaming data loader does not supported len()
803
803
  self.steps_per_epoch = None
804
804
  is_streaming = True
805
+ self.collect_train_metrics = not is_streaming
806
+ if is_streaming and self.is_main_process:
807
+ logging.info(
808
+ colorize(
809
+ "[Training Info] Streaming mode detected; training metrics collection is disabled to avoid memory growth.",
810
+ color="yellow",
811
+ )
812
+ )
805
813
 
806
814
  if self.is_main_process:
807
815
  self.summary()
808
816
  logging.info("")
809
- if self.training_logger and self.training_logger.enable_tensorboard:
810
- tb_dir = self.training_logger.tensorboard_logdir
811
- if tb_dir:
812
- user = getpass.getuser()
813
- host = socket.gethostname()
814
- tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
815
- ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
816
- logging.info(
817
- colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
818
- )
819
- logging.info(colorize("To view logs, run:", color="cyan"))
820
- logging.info(colorize(f" {tb_cmd}", color="cyan"))
821
- logging.info(colorize("Then SSH port forward:", color="cyan"))
822
- logging.info(colorize(f" {ssh_hint}", color="cyan"))
817
+ tb_dir = (
818
+ self.training_logger.tensorboard_logdir
819
+ if self.training_logger and self.training_logger.enable_tensorboard
820
+ else None
821
+ )
822
+ if tb_dir:
823
+ user = getpass.getuser()
824
+ host = socket.gethostname()
825
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
826
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
827
+ logging.info(
828
+ colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
829
+ )
830
+ logging.info(colorize("To view logs, run:", color="cyan"))
831
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
832
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
833
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
823
834
 
824
835
  logging.info("")
825
- logging.info(colorize("=" * 80, bold=True))
826
- if is_streaming:
827
- logging.info(colorize("Start streaming training", bold=True))
828
- else:
829
- logging.info(colorize("Start training", bold=True))
830
- logging.info(colorize("=" * 80, bold=True))
836
+ logging.info(colorize("[Training]", color="bright_blue", bold=True))
837
+ logging.info(colorize("-" * 80, color="bright_blue"))
838
+ logging.info(format_kv("Start training", f"{epochs} epochs"))
839
+ logging.info(format_kv("Model device", self.device))
831
840
  logging.info("")
832
- logging.info(colorize(f"Model device: {self.device}", bold=True))
833
841
 
834
842
  self.callbacks.on_train_begin()
835
843
 
@@ -852,128 +860,77 @@ class BaseModel(FeatureSet, nn.Module):
852
860
  and isinstance(train_loader.sampler, DistributedSampler)
853
861
  ):
854
862
  train_loader.sampler.set_epoch(epoch)
855
- # Type guard: ensure train_loader is DataLoader for train_epoch
863
+
856
864
  if not isinstance(train_loader, DataLoader):
857
865
  raise TypeError(
858
866
  f"Expected DataLoader for training, got {type(train_loader)}"
859
867
  )
860
868
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
861
- if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
869
+ if isinstance(
870
+ train_result, tuple
871
+ ): # [avg_loss, metrics_dict], e.g., (0.5, {'auc': 0.75, 'logloss': 0.45})
862
872
  train_loss, train_metrics = train_result
863
873
  else:
864
874
  train_loss = train_result
865
875
  train_metrics = None
866
876
 
867
- train_log_payload: dict[str, float] = {}
868
- # handle logging for single-task and multi-task
869
- if self.nums_task == 1:
870
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
871
- if train_metrics:
872
- metrics_str = ", ".join(
873
- [f"{k}={v:.4f}" for k, v in train_metrics.items()]
874
- )
875
- log_str += f", {metrics_str}"
876
- if self.is_main_process:
877
- logging.info(colorize(log_str))
878
- train_log_payload["loss"] = float(train_loss)
879
- if train_metrics:
880
- train_log_payload.update(train_metrics)
881
- else:
882
- total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
883
- log_str = (
884
- f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
877
+ logging.info("")
878
+ train_log_payload = {
879
+ "loss": (
880
+ float(np.sum(train_loss))
881
+ if isinstance(train_loss, np.ndarray)
882
+ else float(train_loss)
885
883
  )
886
- if train_metrics:
887
- # group metrics by task
888
- task_metrics = {}
889
- for metric_key, metric_value in train_metrics.items():
890
- for target_name in self.target_columns:
891
- if metric_key.endswith(f"_{target_name}"):
892
- if target_name not in task_metrics:
893
- task_metrics[target_name] = {}
894
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
895
- task_metrics[target_name][metric_name] = metric_value
896
- break
897
- if task_metrics:
898
- task_metric_strs = []
899
- for target_name in self.target_columns:
900
- if target_name in task_metrics:
901
- metrics_str = ", ".join(
902
- [
903
- f"{k}={v:.4f}"
904
- for k, v in task_metrics[target_name].items()
905
- ]
906
- )
907
- task_metric_strs.append(f"{target_name}[{metrics_str}]")
908
- log_str += ", " + ", ".join(task_metric_strs)
909
- if self.is_main_process:
910
- logging.info(colorize(log_str))
911
- train_log_payload["loss"] = float(total_loss_val)
912
- if train_metrics:
913
- train_log_payload.update(train_metrics)
884
+ }
885
+ if train_metrics:
886
+ train_log_payload.update(train_metrics)
887
+
888
+ display_metrics_table(
889
+ epoch=epoch + 1,
890
+ epochs=epochs,
891
+ split="Train",
892
+ loss=train_loss,
893
+ metrics=train_metrics,
894
+ target_names=self.target_columns,
895
+ base_metrics=(
896
+ self.metrics
897
+ if isinstance(getattr(self, "metrics", None), list)
898
+ else None
899
+ ),
900
+ is_main_process=self.is_main_process,
901
+ colorize=lambda s: colorize(s),
902
+ )
914
903
  if self.training_logger:
915
904
  self.training_logger.log_metrics(
916
905
  train_log_payload, step=epoch + 1, split="train"
917
906
  )
918
907
  if valid_loader is not None:
919
- # Call on_validation_begin
920
908
  self.callbacks.on_validation_begin()
921
-
922
- # pass user_ids only if needed for GAUC metric
923
909
  val_metrics = self.evaluate(
924
910
  valid_loader,
925
911
  user_ids=valid_user_ids if self.needs_user_ids else None,
926
912
  num_workers=num_workers,
927
- ) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
928
- if self.nums_task == 1:
929
- metrics_str = ", ".join(
930
- [f"{k}={v:.4f}" for k, v in val_metrics.items()]
931
- )
932
- if self.is_main_process:
933
- logging.info(
934
- colorize(
935
- f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
936
- color="cyan",
937
- )
938
- )
939
- else:
940
- # multi task metrics
941
- task_metrics = {}
942
- for metric_key, metric_value in val_metrics.items():
943
- for target_name in self.target_columns:
944
- if metric_key.endswith(f"_{target_name}"):
945
- if target_name not in task_metrics:
946
- task_metrics[target_name] = {}
947
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
948
- task_metrics[target_name][metric_name] = metric_value
949
- break
950
- task_metric_strs = []
951
- for target_name in self.target_columns:
952
- if target_name in task_metrics:
953
- metrics_str = ", ".join(
954
- [
955
- f"{k}={v:.4f}"
956
- for k, v in task_metrics[target_name].items()
957
- ]
958
- )
959
- task_metric_strs.append(f"{target_name}[{metrics_str}]")
960
- if self.is_main_process:
961
- logging.info(
962
- colorize(
963
- f" Epoch {epoch + 1}/{epochs} - Valid: "
964
- + ", ".join(task_metric_strs),
965
- color="cyan",
966
- )
967
- )
968
-
969
- # Call on_validation_end
913
+ )
914
+ display_metrics_table(
915
+ epoch=epoch + 1,
916
+ epochs=epochs,
917
+ split="Valid",
918
+ loss=None,
919
+ metrics=val_metrics,
920
+ target_names=self.target_columns,
921
+ base_metrics=(
922
+ self.metrics
923
+ if isinstance(getattr(self, "metrics", None), list)
924
+ else None
925
+ ),
926
+ is_main_process=self.is_main_process,
927
+ colorize=lambda s: colorize(" " + s, color="cyan"),
928
+ )
970
929
  self.callbacks.on_validation_end()
971
930
  if val_metrics and self.training_logger:
972
931
  self.training_logger.log_metrics(
973
932
  val_metrics, step=epoch + 1, split="valid"
974
933
  )
975
-
976
- # Handle empty validation metrics
977
934
  if not val_metrics:
978
935
  if self.is_main_process:
979
936
  logging.info(
@@ -983,15 +940,10 @@ class BaseModel(FeatureSet, nn.Module):
983
940
  )
984
941
  )
985
942
  continue
986
-
987
- # Prepare epoch logs for callbacks
988
943
  epoch_logs = {**train_log_payload}
989
- if val_metrics:
990
- # Add val_ prefix to validation metrics
991
- for k, v in val_metrics.items():
992
- epoch_logs[f"val_{k}"] = v
944
+ for k, v in val_metrics.items():
945
+ epoch_logs[f"val_{k}"] = v
993
946
  else:
994
- # No validation data
995
947
  epoch_logs = {**train_log_payload}
996
948
  if self.is_main_process:
997
949
  self.save_model(
@@ -1018,13 +970,13 @@ class BaseModel(FeatureSet, nn.Module):
1018
970
  if self.distributed and dist.is_available() and dist.is_initialized():
1019
971
  dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
1020
972
  if self.is_main_process:
1021
- logging.info(" ")
1022
- logging.info(colorize("Training finished.", bold=True))
1023
- logging.info(" ")
973
+ logging.info("")
974
+ logging.info(colorize("Training finished.", color="bright_blue", bold=True))
975
+ logging.info("")
1024
976
  if valid_loader is not None:
1025
977
  if self.is_main_process:
1026
978
  logging.info(
1027
- colorize(f"Load best model from: {self.best_checkpoint_path}")
979
+ format_kv("Load best model from", self.best_checkpoint_path)
1028
980
  )
1029
981
  if os.path.exists(self.best_checkpoint_path):
1030
982
  self.load_model(
@@ -1051,14 +1003,18 @@ class BaseModel(FeatureSet, nn.Module):
1051
1003
  num_batches = 0
1052
1004
  y_true_list = []
1053
1005
  y_pred_list = []
1006
+ collect_metrics = getattr(self, "collect_train_metrics", True)
1007
+ max_samples = getattr(self, "max_metrics_samples", None)
1008
+ collected_samples = 0
1009
+ metrics_capped = False
1054
1010
 
1055
1011
  user_ids_list = [] if self.needs_user_ids else None
1056
1012
  tqdm_disable = not self.is_main_process
1057
1013
  if self.steps_per_epoch is not None:
1058
1014
  batch_iter = enumerate(
1059
- tqdm.tqdm(
1015
+ progress(
1060
1016
  train_loader,
1061
- desc=f"Epoch {self.epoch_index + 1}",
1017
+ description=f"Epoch {self.epoch_index + 1}",
1062
1018
  total=self.steps_per_epoch,
1063
1019
  disable=tqdm_disable,
1064
1020
  )
@@ -1066,7 +1022,11 @@ class BaseModel(FeatureSet, nn.Module):
1066
1022
  else:
1067
1023
  desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
1068
1024
  batch_iter = enumerate(
1069
- tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable)
1025
+ progress(
1026
+ train_loader,
1027
+ description=desc,
1028
+ disable=tqdm_disable,
1029
+ )
1070
1030
  )
1071
1031
  for batch_index, batch_data in batch_iter:
1072
1032
  batch_dict = batch_to_dict(batch_data)
@@ -1085,16 +1045,34 @@ class BaseModel(FeatureSet, nn.Module):
1085
1045
  self.optimizer_fn.step()
1086
1046
  accumulated_loss += loss.item()
1087
1047
 
1088
- if y_true is not None:
1089
- y_true_list.append(y_true.detach().cpu().numpy())
1090
- if self.needs_user_ids and user_ids_list is not None:
1091
- batch_user_id = get_user_ids(
1092
- data=batch_dict, id_columns=self.id_columns
1093
- )
1094
- if batch_user_id is not None:
1095
- user_ids_list.append(batch_user_id)
1096
- if y_pred is not None and isinstance(y_pred, torch.Tensor):
1097
- y_pred_list.append(y_pred.detach().cpu().numpy())
1048
+ if (
1049
+ collect_metrics
1050
+ and y_true is not None
1051
+ and isinstance(y_pred, torch.Tensor)
1052
+ ):
1053
+ batch_size = int(y_true.size(0))
1054
+ if max_samples is not None and collected_samples >= max_samples:
1055
+ collect_metrics = False
1056
+ metrics_capped = True
1057
+ else:
1058
+ take_count = batch_size
1059
+ if (
1060
+ max_samples is not None
1061
+ and collected_samples + batch_size > max_samples
1062
+ ):
1063
+ take_count = max_samples - collected_samples
1064
+ metrics_capped = True
1065
+ collect_metrics = False
1066
+ if take_count > 0:
1067
+ y_true_list.append(y_true[:take_count].detach().cpu().numpy())
1068
+ y_pred_list.append(y_pred[:take_count].detach().cpu().numpy())
1069
+ if self.needs_user_ids and user_ids_list is not None:
1070
+ batch_user_id = get_user_ids(
1071
+ data=batch_dict, id_columns=self.id_columns
1072
+ )
1073
+ if batch_user_id is not None:
1074
+ user_ids_list.append(batch_user_id[:take_count])
1075
+ collected_samples += take_count
1098
1076
  num_batches += 1
1099
1077
  if self.distributed and dist.is_available() and dist.is_initialized():
1100
1078
  loss_tensor = torch.tensor(
@@ -1120,6 +1098,14 @@ class BaseModel(FeatureSet, nn.Module):
1120
1098
  gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
1121
1099
  )
1122
1100
 
1101
+ if metrics_capped and self.is_main_process:
1102
+ logging.info(
1103
+ colorize(
1104
+ f"[Training Info] Training metrics capped at {max_samples} samples to limit memory usage.",
1105
+ color="yellow",
1106
+ )
1107
+ )
1108
+
1123
1109
  if (
1124
1110
  y_true_all is not None
1125
1111
  and y_pred_all is not None
@@ -1258,11 +1244,15 @@ class BaseModel(FeatureSet, nn.Module):
1258
1244
  )
1259
1245
  if batch_user_id is not None:
1260
1246
  collected_user_ids.append(batch_user_id)
1261
- if self.is_main_process:
1262
- logging.info(" ")
1263
- logging.info(
1264
- colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
1265
- )
1247
+ # if self.is_main_process:
1248
+ # logging.info("")
1249
+ # logging.info(
1250
+ # colorize(
1251
+ # format_kv(
1252
+ # "Evaluation batches processed", batch_count
1253
+ # ),
1254
+ # )
1255
+ # )
1266
1256
  y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
1267
1257
  y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
1268
1258
 
@@ -1301,10 +1291,15 @@ class BaseModel(FeatureSet, nn.Module):
1301
1291
  )
1302
1292
  )
1303
1293
  return {}
1304
- if self.is_main_process:
1305
- logging.info(
1306
- colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan")
1307
- )
1294
+ # if self.is_main_process:
1295
+ # logging.info(
1296
+ # colorize(
1297
+ # format_kv(
1298
+ # "Evaluation samples", y_true_all.shape[0]
1299
+ # ),
1300
+ # )
1301
+ # )
1302
+ logging.info("")
1308
1303
  metrics_dict = evaluate_metrics(
1309
1304
  y_true=y_true_all,
1310
1305
  y_pred=y_pred_all,
@@ -1396,7 +1391,7 @@ class BaseModel(FeatureSet, nn.Module):
1396
1391
  id_arrays = None
1397
1392
 
1398
1393
  with torch.no_grad():
1399
- for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
1394
+ for batch_data in progress(data_loader, description="Predicting"):
1400
1395
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
1401
1396
  X_input, _ = self.get_input(batch_dict, require_labels=False)
1402
1397
  y_pred = self(X_input)
@@ -1417,10 +1412,9 @@ class BaseModel(FeatureSet, nn.Module):
1417
1412
  if id_np.ndim == 1
1418
1413
  else id_np
1419
1414
  )
1420
- if len(y_pred_list) > 0:
1421
- y_pred_all = np.concatenate(y_pred_list, axis=0)
1422
- else:
1423
- y_pred_all = np.array([])
1415
+ y_pred_all = (
1416
+ np.concatenate(y_pred_list, axis=0) if y_pred_list else np.array([])
1417
+ )
1424
1418
 
1425
1419
  if y_pred_all.ndim == 1:
1426
1420
  y_pred_all = y_pred_all.reshape(-1, 1)
@@ -1428,22 +1422,22 @@ class BaseModel(FeatureSet, nn.Module):
1428
1422
  num_outputs = len(self.target_columns) if self.target_columns else 1
1429
1423
  y_pred_all = y_pred_all.reshape(0, num_outputs)
1430
1424
  num_outputs = y_pred_all.shape[1]
1431
- pred_columns: list[str] = []
1432
- if self.target_columns:
1433
- for name in self.target_columns[:num_outputs]:
1434
- pred_columns.append(f"{name}")
1425
+ pred_columns: list[str] = (
1426
+ list(self.target_columns[:num_outputs]) if self.target_columns else []
1427
+ )
1435
1428
  while len(pred_columns) < num_outputs:
1436
1429
  pred_columns.append(f"pred_{len(pred_columns)}")
1437
1430
  if include_ids and predict_id_columns:
1438
- id_arrays = {}
1439
- for id_name, pieces in id_buffers.items():
1440
- if pieces:
1441
- concatenated = np.concatenate(
1431
+ id_arrays = {
1432
+ id_name: (
1433
+ np.concatenate(
1442
1434
  [p.reshape(p.shape[0], -1) for p in pieces], axis=0
1443
- )
1444
- id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
1445
- else:
1446
- id_arrays[id_name] = np.array([], dtype=np.int64)
1435
+ ).reshape(-1)
1436
+ if pieces
1437
+ else np.array([], dtype=np.int64)
1438
+ )
1439
+ for id_name, pieces in id_buffers.items()
1440
+ }
1447
1441
  if return_dataframe:
1448
1442
  id_df = pd.DataFrame(id_arrays)
1449
1443
  pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
@@ -1544,7 +1538,7 @@ class BaseModel(FeatureSet, nn.Module):
1544
1538
  collected_frames = [] # only used when return_dataframe is True
1545
1539
 
1546
1540
  with torch.no_grad():
1547
- for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
1541
+ for batch_data in progress(data_loader, description="Predicting"):
1548
1542
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
1549
1543
  X_input, _ = self.get_input(batch_dict, require_labels=False)
1550
1544
  y_pred = self.forward(X_input)
@@ -1555,25 +1549,24 @@ class BaseModel(FeatureSet, nn.Module):
1555
1549
  y_pred_np = y_pred_np.reshape(-1, 1)
1556
1550
  if pred_columns is None:
1557
1551
  num_outputs = y_pred_np.shape[1]
1558
- pred_columns = []
1559
- if self.target_columns:
1560
- for name in self.target_columns[:num_outputs]:
1561
- pred_columns.append(f"{name}")
1552
+ pred_columns = (
1553
+ list(self.target_columns[:num_outputs])
1554
+ if self.target_columns
1555
+ else []
1556
+ )
1562
1557
  while len(pred_columns) < num_outputs:
1563
1558
  pred_columns.append(f"pred_{len(pred_columns)}")
1564
1559
 
1565
- id_arrays_batch = {}
1566
- if include_ids and id_columns and batch_dict.get("ids"):
1567
- for id_name in id_columns:
1568
- if id_name not in batch_dict["ids"]:
1569
- continue
1570
- id_tensor = batch_dict["ids"][id_name]
1571
- id_np = (
1572
- id_tensor.detach().cpu().numpy()
1573
- if isinstance(id_tensor, torch.Tensor)
1574
- else np.asarray(id_tensor)
1575
- )
1576
- id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
1560
+ ids = batch_dict.get("ids") if include_ids and id_columns else None
1561
+ id_arrays_batch = {
1562
+ id_name: (
1563
+ ids[id_name].detach().cpu().numpy()
1564
+ if isinstance(ids[id_name], torch.Tensor)
1565
+ else np.asarray(ids[id_name])
1566
+ ).reshape(-1)
1567
+ for id_name in (id_columns or [])
1568
+ if ids and id_name in ids
1569
+ }
1577
1570
 
1578
1571
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
1579
1572
  if id_arrays_batch:
@@ -1775,13 +1768,13 @@ class BaseModel(FeatureSet, nn.Module):
1775
1768
  def summary(self):
1776
1769
  logger = logging.getLogger()
1777
1770
 
1778
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1771
+ logger.info("")
1779
1772
  logger.info(
1780
1773
  colorize(
1781
1774
  f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1782
1775
  )
1783
1776
  )
1784
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1777
+ logger.info("")
1785
1778
 
1786
1779
  logger.info("")
1787
1780
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
@@ -1903,6 +1896,7 @@ class BaseModel(FeatureSet, nn.Module):
1903
1896
  logger.info("Other Settings:")
1904
1897
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1905
1898
  logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1899
+ logger.info(f" Max Metrics Samples: {self.max_metrics_samples}")
1906
1900
  logger.info(f" Session ID: {self.session_id}")
1907
1901
  logger.info(f" Features Config Path: {self.features_config_path}")
1908
1902
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
@@ -2296,7 +2290,7 @@ class BaseMatchModel(BaseModel):
2296
2290
 
2297
2291
  embeddings_list = []
2298
2292
  with torch.no_grad():
2299
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
2293
+ for batch_data in progress(data_loader, description="Encoding users"):
2300
2294
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2301
2295
  user_input = self.get_user_features(batch_dict["features"])
2302
2296
  user_emb = self.user_tower(user_input)
@@ -2316,7 +2310,7 @@ class BaseMatchModel(BaseModel):
2316
2310
 
2317
2311
  embeddings_list = []
2318
2312
  with torch.no_grad():
2319
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
2313
+ for batch_data in progress(data_loader, description="Encoding items"):
2320
2314
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2321
2315
  item_input = self.get_item_features(batch_dict["features"])
2322
2316
  item_emb = self.item_tower(item_input)