nextrec 0.4.21__py3-none-any.whl → 0.4.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -1
  3. nextrec/basic/heads.py +2 -3
  4. nextrec/basic/metrics.py +1 -2
  5. nextrec/basic/model.py +115 -80
  6. nextrec/basic/summary.py +36 -2
  7. nextrec/data/preprocessor.py +137 -5
  8. nextrec/loss/__init__.py +0 -4
  9. nextrec/loss/grad_norm.py +3 -3
  10. nextrec/loss/listwise.py +19 -6
  11. nextrec/loss/pairwise.py +6 -4
  12. nextrec/loss/pointwise.py +8 -6
  13. nextrec/models/multi_task/esmm.py +3 -26
  14. nextrec/models/multi_task/mmoe.py +2 -24
  15. nextrec/models/multi_task/ple.py +13 -35
  16. nextrec/models/multi_task/poso.py +4 -28
  17. nextrec/models/multi_task/share_bottom.py +1 -24
  18. nextrec/models/ranking/afm.py +3 -27
  19. nextrec/models/ranking/autoint.py +5 -38
  20. nextrec/models/ranking/dcn.py +1 -26
  21. nextrec/models/ranking/dcn_v2.py +5 -33
  22. nextrec/models/ranking/deepfm.py +2 -29
  23. nextrec/models/ranking/dien.py +2 -28
  24. nextrec/models/ranking/din.py +2 -27
  25. nextrec/models/ranking/eulernet.py +3 -30
  26. nextrec/models/ranking/ffm.py +0 -26
  27. nextrec/models/ranking/fibinet.py +8 -32
  28. nextrec/models/ranking/fm.py +0 -29
  29. nextrec/models/ranking/lr.py +0 -30
  30. nextrec/models/ranking/masknet.py +4 -30
  31. nextrec/models/ranking/pnn.py +4 -28
  32. nextrec/models/ranking/widedeep.py +0 -32
  33. nextrec/models/ranking/xdeepfm.py +0 -30
  34. nextrec/models/retrieval/dssm.py +0 -24
  35. nextrec/models/retrieval/dssm_v2.py +0 -24
  36. nextrec/models/retrieval/mind.py +0 -20
  37. nextrec/models/retrieval/sdm.py +0 -20
  38. nextrec/models/retrieval/youtube_dnn.py +0 -21
  39. nextrec/models/sequential/hstu.py +0 -18
  40. nextrec/utils/__init__.py +5 -1
  41. nextrec/{loss/loss_utils.py → utils/loss.py} +17 -7
  42. nextrec/utils/model.py +79 -1
  43. nextrec/utils/types.py +62 -23
  44. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/METADATA +8 -6
  45. nextrec-0.4.23.dist-info/RECORD +81 -0
  46. nextrec-0.4.21.dist-info/RECORD +0 -81
  47. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/WHEEL +0 -0
  48. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/entry_points.txt +0 -0
  49. {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.21"
1
+ __version__ = "0.4.23"
@@ -9,10 +9,10 @@ Author: Yang Zhou, zyaztec@gmail.com
9
9
  import torch
10
10
  import torch.nn as nn
11
11
 
12
- from typing import Literal
13
12
 
14
13
  from nextrec.utils.types import ActivationName
15
14
 
15
+
16
16
  class Dice(nn.Module):
17
17
  """
18
18
  Dice activation function from the paper:
nextrec/basic/heads.py CHANGED
@@ -15,6 +15,7 @@ import torch.nn as nn
15
15
  import torch.nn.functional as F
16
16
 
17
17
  from nextrec.basic.layers import PredictionLayer
18
+ from nextrec.utils.types import TaskTypeName
18
19
 
19
20
 
20
21
  class TaskHead(nn.Module):
@@ -27,9 +28,7 @@ class TaskHead(nn.Module):
27
28
 
28
29
  def __init__(
29
30
  self,
30
- task_type: (
31
- Literal["binary", "regression"] | list[Literal["binary", "regression"]]
32
- ) = "binary",
31
+ task_type: TaskTypeName | list[TaskTypeName] = "binary",
33
32
  task_dims: int | list[int] | None = None,
34
33
  use_bias: bool = True,
35
34
  return_logits: bool = False,
nextrec/basic/metrics.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
39
39
  TASK_DEFAULT_METRICS = {
40
40
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
41
41
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
42
- "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
43
42
  "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
44
43
  + [f"recall@{k}" for k in (5, 10, 20)]
45
44
  + [f"ndcg@{k}" for k in (5, 10, 20)]
nextrec/basic/model.py CHANGED
@@ -2,13 +2,14 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import getpass
10
10
  import logging
11
11
  import os
12
+ import sys
12
13
  import pickle
13
14
  import socket
14
15
  from pathlib import Path
@@ -16,6 +17,16 @@ from typing import Any, Literal
16
17
 
17
18
  import numpy as np
18
19
  import pandas as pd
20
+
21
+ try:
22
+ import swanlab # type: ignore
23
+ except ModuleNotFoundError:
24
+ swanlab = None
25
+ try:
26
+ import wandb # type: ignore
27
+ except ModuleNotFoundError:
28
+ wandb = None
29
+
19
30
  import torch
20
31
  import torch.distributed as dist
21
32
  import torch.nn as nn
@@ -60,8 +71,8 @@ from nextrec.loss import (
60
71
  InfoNCELoss,
61
72
  SampledSoftmaxLoss,
62
73
  TripletLoss,
63
- get_loss_fn,
64
74
  )
75
+ from nextrec.utils.loss import get_loss_fn
65
76
  from nextrec.loss.grad_norm import get_grad_norm_shared_params
66
77
  from nextrec.utils.console import display_metrics_table, progress
67
78
  from nextrec.utils.torch_utils import (
@@ -74,8 +85,20 @@ from nextrec.utils.torch_utils import (
74
85
  to_tensor,
75
86
  )
76
87
  from nextrec.utils.config import safe_value
77
- from nextrec.utils.model import compute_ranking_loss
78
- from nextrec.utils.types import LossName, OptimizerName, SchedulerName
88
+ from nextrec.utils.model import (
89
+ compute_ranking_loss,
90
+ get_loss_list,
91
+ resolve_loss_weights,
92
+ get_training_modes,
93
+ )
94
+ from nextrec.utils.types import (
95
+ LossName,
96
+ OptimizerName,
97
+ SchedulerName,
98
+ TrainingModeName,
99
+ TaskTypeName,
100
+ MetricsName,
101
+ )
79
102
 
80
103
 
81
104
  class BaseModel(SummarySet, FeatureSet, nn.Module):
@@ -84,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
84
107
  raise NotImplementedError
85
108
 
86
109
  @property
87
- def default_task(self) -> str | list[str]:
110
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
88
111
  raise NotImplementedError
89
112
 
90
113
  def __init__(
@@ -94,11 +117,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
94
117
  sequence_features: list[SequenceFeature] | None = None,
95
118
  target: list[str] | str | None = None,
96
119
  id_columns: list[str] | str | None = None,
97
- task: str | list[str] | None = None,
98
- training_mode: (
99
- Literal["pointwise", "pairwise", "listwise"]
100
- | list[Literal["pointwise", "pairwise", "listwise"]]
101
- ) = "pointwise",
120
+ task: TaskTypeName | list[TaskTypeName] | None = None,
121
+ training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
102
122
  embedding_l1_reg: float = 0.0,
103
123
  dense_l1_reg: float = 0.0,
104
124
  embedding_l2_reg: float = 0.0,
@@ -136,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
136
156
  world_size: Number of processes (defaults to env WORLD_SIZE).
137
157
  local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
138
158
  ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
159
+
160
+ Note:
161
+ Optimizer, scheduler, and loss are configured via compile().
139
162
  """
140
163
  super(BaseModel, self).__init__()
141
164
 
@@ -168,25 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
168
191
  dense_features, sparse_features, sequence_features, target, id_columns
169
192
  )
170
193
 
171
- self.task = self.default_task if task is None else task
194
+ self.task = task or self.default_task
172
195
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
173
- if isinstance(training_mode, list):
174
- training_modes = list(training_mode)
175
- if len(training_modes) != self.nums_task:
176
- raise ValueError(
177
- "[BaseModel-init Error] training_mode list length must match number of tasks."
178
- )
179
- else:
180
- training_modes = [training_mode] * self.nums_task
181
- if any(
182
- mode not in {"pointwise", "pairwise", "listwise"}
183
- for mode in training_modes
184
- ):
185
- raise ValueError(
186
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
187
- )
188
- self.training_modes = training_modes
189
- self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
196
+ self.training_modes = get_training_modes(training_mode, self.nums_task)
197
+ self.training_mode = (
198
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
+ )
190
200
 
191
201
  self.embedding_l1_reg = embedding_l1_reg
192
202
  self.dense_l1_reg = dense_l1_reg
@@ -194,7 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
194
204
  self.dense_l2_reg = dense_l2_reg
195
205
  self.regularization_weights = []
196
206
  self.embedding_params = []
197
- self.loss_weight = None
207
+
208
+ self.ignore_label = None
209
+ self.compiled = False
198
210
 
199
211
  self.max_gradient_norm = 1.0
200
212
  self.logger_initialized = False
@@ -407,6 +419,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
407
419
  loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
408
420
  loss_params: dict | list[dict] | None = None,
409
421
  loss_weights: int | float | list[int | float] | dict | str | None = None,
422
+ ignore_label: int | float | None = -1,
410
423
  ):
411
424
  """
412
425
  Configure the model for training.
@@ -419,34 +432,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
419
432
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
420
433
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
421
434
  Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
435
+ ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
422
436
  """
437
+ self.ignore_label = ignore_label
423
438
  default_losses = {
424
439
  "pointwise": "bce",
425
440
  "pairwise": "bpr",
426
441
  "listwise": "listnet",
427
442
  }
428
- effective_loss = loss
429
- if effective_loss is None:
430
- loss_list = [default_losses[mode] for mode in self.training_modes]
431
- elif isinstance(effective_loss, list):
432
- if not effective_loss:
433
- loss_list = [default_losses[mode] for mode in self.training_modes]
434
- else:
435
- if len(effective_loss) != self.nums_task:
436
- raise ValueError(
437
- f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
438
- )
439
- loss_list = list(effective_loss)
440
- else:
441
- loss_list = [effective_loss] * self.nums_task
442
-
443
- for idx, mode in enumerate(self.training_modes):
444
- if isinstance(loss_list[idx], str) and loss_list[idx] in {
445
- "bce",
446
- "binary_crossentropy",
447
- }:
448
- if mode in {"pairwise", "listwise"}:
449
- loss_list[idx] = default_losses[mode]
443
+ loss_list = get_loss_list(
444
+ loss, self.training_modes, self.nums_task, default_losses
445
+ )
450
446
  self.loss_params = loss_params or {}
451
447
  optimizer_params = optimizer_params or {}
452
448
  self.optimizer_name = (
@@ -510,36 +506,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
510
506
  nums_task=self.nums_task, device=self.device, **grad_norm_params
511
507
  )
512
508
  self.loss_weights = None
513
- elif loss_weights is None:
514
- self.loss_weights = None
515
- elif self.nums_task == 1:
516
- if isinstance(loss_weights, (list, tuple)):
517
- if len(loss_weights) != 1:
518
- raise ValueError(
519
- "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
520
- )
521
- loss_weights = loss_weights[0]
522
- self.loss_weights = [float(loss_weights)] # type: ignore
523
509
  else:
524
- if isinstance(loss_weights, (int, float)):
525
- weights = [float(loss_weights)] * self.nums_task
526
- elif isinstance(loss_weights, (list, tuple)):
527
- weights = [float(w) for w in loss_weights]
528
- if len(weights) != self.nums_task:
529
- raise ValueError(
530
- f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
531
- )
532
- else:
533
- raise TypeError(
534
- f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
535
- )
536
- self.loss_weights = weights
510
+ self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
511
+ self.compiled = True
537
512
 
538
513
  def compute_loss(self, y_pred, y_true):
539
514
  if y_true is None:
540
515
  raise ValueError(
541
516
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
542
517
  )
518
+
543
519
  # single-task
544
520
  if self.nums_task == 1:
545
521
  if y_pred.dim() == 1:
@@ -547,13 +523,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
547
523
  if y_true.dim() == 1:
548
524
  y_true = y_true.view(-1, 1)
549
525
  if y_pred.shape != y_true.shape:
550
- raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
551
- loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
552
- if loss_fn is None:
553
526
  raise ValueError(
554
- "[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
527
+ f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
555
528
  )
529
+
530
+ loss_fn = self.loss_fn[0]
531
+
532
+ if self.ignore_label is not None:
533
+ valid_mask = y_true != self.ignore_label
534
+ if valid_mask.dim() > 1:
535
+ valid_mask = valid_mask.all(dim=1)
536
+ if not torch.any(valid_mask): # if no valid labels, return zero loss
537
+ return y_pred.sum() * 0.0
538
+
539
+ y_pred = y_pred[valid_mask]
540
+ y_true = y_true[valid_mask]
541
+
556
542
  mode = self.training_modes[0]
543
+
557
544
  task_dim = (
558
545
  self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
559
546
  )
@@ -584,7 +571,19 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
584
571
  for i, (start, end) in enumerate(slices): # type: ignore
585
572
  y_pred_i = y_pred[:, start:end]
586
573
  y_true_i = y_true[:, start:end]
574
+ # mask ignored labels
575
+ if self.ignore_label is not None:
576
+ valid_mask = y_true_i != self.ignore_label
577
+ if valid_mask.dim() > 1:
578
+ valid_mask = valid_mask.all(dim=1)
579
+ if not torch.any(valid_mask):
580
+ task_losses.append(y_pred_i.sum() * 0.0)
581
+ continue
582
+ y_pred_i = y_pred_i[valid_mask]
583
+ y_true_i = y_true_i[valid_mask]
584
+
587
585
  mode = self.training_modes[i]
586
+
588
587
  if mode in {"pairwise", "listwise"}:
589
588
  task_loss = compute_ranking_loss(
590
589
  training_mode=mode,
@@ -594,7 +593,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
594
593
  )
595
594
  else:
596
595
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
596
+ # task_loss = normalize_task_loss(
597
+ # task_loss, valid_count, total_count
598
+ # ) # normalize by valid samples to avoid loss scale issues
597
599
  task_losses.append(task_loss)
600
+
598
601
  if self.grad_norm is not None:
599
602
  if self.grad_norm_shared_params is None:
600
603
  self.grad_norm_shared_params = get_grad_norm_shared_params(
@@ -651,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
651
654
  train_data=None,
652
655
  valid_data=None,
653
656
  metrics: (
654
- list[str] | dict[str, list[str]] | None
657
+ list[MetricsName] | dict[str, list[MetricsName]] | None
655
658
  ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
656
659
  epochs: int = 1,
657
660
  shuffle: bool = True,
@@ -665,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
665
668
  use_tensorboard: bool = True,
666
669
  use_wandb: bool = False,
667
670
  use_swanlab: bool = False,
671
+ wandb_api: str | None = None,
672
+ swanlab_api: str | None = None,
668
673
  wandb_kwargs: dict | None = None,
669
674
  swanlab_kwargs: dict | None = None,
670
675
  auto_ddp_sampler: bool = True,
@@ -694,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
694
699
  use_tensorboard: Enable tensorboard logging.
695
700
  use_wandb: Enable Weights & Biases logging.
696
701
  use_swanlab: Enable SwanLab logging.
702
+ wandb_api: W&B API key for non-tty login.
703
+ swanlab_api: SwanLab API key for non-tty login.
697
704
  wandb_kwargs: Optional kwargs for wandb.init(...).
698
705
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
699
706
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
@@ -711,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
711
718
  )
712
719
  self.to(self.device)
713
720
 
721
+ if not self.compiled:
722
+ self.compile(
723
+ optimizer="adam",
724
+ optimizer_params={},
725
+ scheduler=None,
726
+ scheduler_params={},
727
+ loss=None,
728
+ loss_params={},
729
+ )
730
+
714
731
  if (
715
732
  self.distributed
716
733
  and dist.is_available()
@@ -785,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
785
802
  }
786
803
  training_config: dict = safe_value(training_config) # type: ignore
787
804
 
805
+ if self.is_main_process:
806
+ is_tty = sys.stdin.isatty() and sys.stdout.isatty()
807
+ if not is_tty:
808
+ if use_wandb and wandb_api:
809
+ if wandb is None:
810
+ logging.warning(
811
+ "[BaseModel-fit] wandb not installed, skip wandb login."
812
+ )
813
+ else:
814
+ wandb.login(key=wandb_api)
815
+ if use_swanlab and swanlab_api:
816
+ if swanlab is None:
817
+ logging.warning(
818
+ "[BaseModel-fit] swanlab not installed, skip swanlab login."
819
+ )
820
+ else:
821
+ swanlab.login(api_key=swanlab_api)
822
+
788
823
  self.training_logger = (
789
824
  TrainingLogger(
790
825
  session=self.session,
@@ -2164,7 +2199,7 @@ class BaseMatchModel(BaseModel):
2164
2199
  scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
2165
2200
  loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
2166
2201
  loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
2167
- loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2202
+ loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2168
2203
  """
2169
2204
  if self.training_mode not in self.support_training_modes:
2170
2205
  raise ValueError(
nextrec/basic/summary.py CHANGED
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Summary utilities for BaseModel.
3
+
4
+ Date: create on 03/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
+ Author: Yang Zhou,zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from __future__ import annotations
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
12
16
 
13
17
  from nextrec.basic.loggers import colorize, format_kv
14
18
  from nextrec.data.data_processing import extract_label_arrays, get_data_length
19
+ from nextrec.utils.types import TaskTypeName
15
20
 
16
21
 
17
22
  class SummarySet:
23
+ model_name: str
24
+ dense_features: list[Any]
25
+ sparse_features: list[Any]
26
+ sequence_features: list[Any]
27
+ task: TaskTypeName | list[TaskTypeName]
28
+ target_columns: list[str]
29
+ nums_task: int
30
+ metrics: Any
31
+ device: Any
32
+ optimizer_name: str
33
+ optimizer_params: dict[str, Any]
34
+ scheduler_name: str | None
35
+ scheduler_params: dict[str, Any]
36
+ loss_config: Any
37
+ loss_weights: Any
38
+ grad_norm: Any
39
+ embedding_l1_reg: float
40
+ embedding_l2_reg: float
41
+ dense_l1_reg: float
42
+ dense_l2_reg: float
43
+ early_stop_patience: int
44
+ max_gradient_norm: float | None
45
+ metrics_sample_limit: int | None
46
+ session_id: str | None
47
+ features_config_path: str
48
+ checkpoint_path: str
49
+ train_data_summary: dict[str, Any] | None
50
+ valid_data_summary: dict[str, Any] | None
51
+
18
52
  def build_data_summary(
19
53
  self, data: Any, data_loader: DataLoader | None, sample_key: str
20
54
  ):
@@ -305,7 +339,7 @@ class SummarySet:
305
339
  lines = details.get("lines", [])
306
340
  logger.info(f"{target_name}:")
307
341
  for label, value in lines:
308
- logger.info(format_kv(label, value))
342
+ logger.info(f" {format_kv(label, value)}")
309
343
 
310
344
  if self.valid_data_summary:
311
345
  if self.train_data_summary:
@@ -320,4 +354,4 @@ class SummarySet:
320
354
  lines = details.get("lines", [])
321
355
  logger.info(f"{target_name}:")
322
356
  for label, value in lines:
323
- logger.info(format_kv(label, value))
357
+ logger.info(f" {format_kv(label, value)}")