torch-rechub 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ """Experiment tracking utilities for Torch-RecHub.
2
+
3
+ This module exposes lightweight adapters for common visualization and
4
+ experiment tracking tools, namely Weights & Biases (wandb), SwanLab, and
5
+ TensorBoardX.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+
12
+ class BaseLogger(ABC):
13
+ """Base interface for experiment tracking backends.
14
+
15
+ Methods
16
+ -------
17
+ log_metrics(metrics, step=None)
18
+ Record scalar metrics at a given step.
19
+ log_hyperparams(params)
20
+ Store hyperparameters and run configuration.
21
+ finish()
22
+ Flush pending logs and release resources.
23
+ """
24
+
25
+ @abstractmethod
26
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
27
+ """Log metrics to the tracking backend.
28
+
29
+ Parameters
30
+ ----------
31
+ metrics : dict of str to Any
32
+ Metric name-value pairs to record.
33
+ step : int, optional
34
+ Explicit global step or epoch index. When ``None``, the backend
35
+ uses its own default step handling.
36
+ """
37
+ raise NotImplementedError
38
+
39
+ @abstractmethod
40
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
41
+ """Log experiment hyperparameters.
42
+
43
+ Parameters
44
+ ----------
45
+ params : dict of str to Any
46
+ Hyperparameters or configuration values to persist with the run.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ @abstractmethod
51
+ def finish(self) -> None:
52
+ """Finalize logging and free any backend resources."""
53
+ raise NotImplementedError
54
+
55
+
56
+ class WandbLogger(BaseLogger):
57
+ """Weights & Biases logger implementation.
58
+
59
+ Parameters
60
+ ----------
61
+ project : str
62
+ Name of the wandb project to log to.
63
+ name : str, optional
64
+ Display name for the run.
65
+ config : dict, optional
66
+ Initial hyperparameter configuration to record.
67
+ tags : list of str, optional
68
+ Optional tags for grouping runs.
69
+ notes : str, optional
70
+ Long-form notes shown in the run overview.
71
+ dir : str, optional
72
+ Local directory for wandb artifacts and cache.
73
+ **kwargs : dict
74
+ Additional keyword arguments forwarded to ``wandb.init``.
75
+
76
+ Raises
77
+ ------
78
+ ImportError
79
+ If ``wandb`` is not installed in the current environment.
80
+ """
81
+
82
+ def __init__(self, project: str, name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, notes: Optional[str] = None, dir: Optional[str] = None, **kwargs):
83
+ try:
84
+ import wandb
85
+ self._wandb = wandb
86
+ except ImportError:
87
+ raise ImportError("wandb is not installed. Install it with: pip install wandb")
88
+
89
+ self.run = self._wandb.init(project=project, name=name, config=config, tags=tags, notes=notes, dir=dir, **kwargs)
90
+
91
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
92
+ if step is not None:
93
+ self._wandb.log(metrics, step=step)
94
+ else:
95
+ self._wandb.log(metrics)
96
+
97
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
98
+ if self.run is not None:
99
+ self.run.config.update(params)
100
+
101
+ def finish(self) -> None:
102
+ if self.run is not None:
103
+ self.run.finish()
104
+
105
+
106
+ class SwanLabLogger(BaseLogger):
107
+ """SwanLab logger implementation.
108
+
109
+ Parameters
110
+ ----------
111
+ project : str, optional
112
+ Project identifier for grouping experiments.
113
+ experiment_name : str, optional
114
+ Display name for the experiment or run.
115
+ description : str, optional
116
+ Text description shown alongside the run.
117
+ config : dict, optional
118
+ Hyperparameters or configuration to log at startup.
119
+ logdir : str, optional
120
+ Directory where logs and artifacts are stored.
121
+ **kwargs : dict
122
+ Additional keyword arguments forwarded to ``swanlab.init``.
123
+
124
+ Raises
125
+ ------
126
+ ImportError
127
+ If ``swanlab`` is not installed in the current environment.
128
+ """
129
+
130
+ def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None, description: Optional[str] = None, config: Optional[Dict[str, Any]] = None, logdir: Optional[str] = None, **kwargs):
131
+ try:
132
+ import swanlab
133
+ self._swanlab = swanlab
134
+ except ImportError:
135
+ raise ImportError("swanlab is not installed. Install it with: pip install swanlab")
136
+
137
+ self.run = self._swanlab.init(project=project, experiment_name=experiment_name, description=description, config=config, logdir=logdir, **kwargs)
138
+
139
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
140
+ if step is not None:
141
+ self._swanlab.log(metrics, step=step)
142
+ else:
143
+ self._swanlab.log(metrics)
144
+
145
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
146
+ if self.run is not None:
147
+ self.run.config.update(params)
148
+
149
+ def finish(self) -> None:
150
+ self._swanlab.finish()
151
+
152
+
153
+ class TensorBoardXLogger(BaseLogger):
154
+ """TensorBoardX logger implementation.
155
+
156
+ Parameters
157
+ ----------
158
+ log_dir : str
159
+ Directory where event files will be written.
160
+ comment : str, default=""
161
+ Comment appended to the log directory name.
162
+ **kwargs : dict
163
+ Additional keyword arguments forwarded to
164
+ ``tensorboardX.SummaryWriter``.
165
+
166
+ Raises
167
+ ------
168
+ ImportError
169
+ If ``tensorboardX`` is not installed in the current environment.
170
+ """
171
+
172
+ def __init__(self, log_dir: str, comment: str = "", **kwargs):
173
+ try:
174
+ from tensorboardX import SummaryWriter
175
+ self._SummaryWriter = SummaryWriter
176
+ except ImportError:
177
+ raise ImportError("tensorboardX is not installed. Install it with: pip install tensorboardX")
178
+
179
+ self.writer = self._SummaryWriter(log_dir=log_dir, comment=comment, **kwargs)
180
+ self._step = 0
181
+
182
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
183
+ if step is None:
184
+ step = self._step
185
+ self._step += 1
186
+
187
+ for key, value in metrics.items():
188
+ if value is not None:
189
+ if isinstance(value, (int, float)):
190
+ self.writer.add_scalar(key, value, step)
191
+
192
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
193
+ hparam_str = "\n".join([f"{k}: {v}" for k, v in params.items()])
194
+ self.writer.add_text("hyperparameters", hparam_str, 0)
195
+
196
+ def finish(self) -> None:
197
+ if self.writer is not None:
198
+ self.writer.close()
File without changes
@@ -0,0 +1,67 @@
1
+ """Utilities for converting array-like data structures into PyTorch tensors."""
2
+
3
+ import numpy.typing as npt
4
+ import pyarrow as pa
5
+ import pyarrow.compute as pc
6
+ import pyarrow.types as pt
7
+ import torch
8
+
9
+
10
+ def pa_array_to_tensor(arr: pa.Array) -> torch.Tensor:
11
+ """
12
+ Convert a PyArrow array to a PyTorch tensor.
13
+
14
+ Parameters
15
+ ----------
16
+ arr : pa.Array
17
+ The given PyArrow array.
18
+
19
+ Returns
20
+ -------
21
+ torch.Tensor: The result PyTorch tensor.
22
+
23
+ Raises
24
+ ------
25
+ TypeError
26
+ if the array type or the value type (when nested) is unsupported.
27
+ ValueError
28
+ if the nested array is ragged (unequal lengths of each row).
29
+ """
30
+ if _is_supported_scalar(arr.type):
31
+ arr = pc.cast(arr, pa.float32())
32
+ return torch.from_numpy(_to_writable_numpy(arr))
33
+
34
+ if not _is_supported_list(arr.type):
35
+ raise TypeError(f"Unsupported array type: {arr.type}")
36
+
37
+ if not _is_supported_scalar(val_type := arr.type.value_type):
38
+ raise TypeError(f"Unsupported value type in the nested array: {val_type}")
39
+
40
+ if len(pc.unique(pc.list_value_length(arr))) > 1:
41
+ raise ValueError("Cannot convert the ragged nested array.")
42
+
43
+ arr = pc.cast(arr, pa.list_(pa.float32()))
44
+ np_arr = _to_writable_numpy(arr.values) # type: ignore[attr-defined]
45
+
46
+ # For empty list-of-lists, define output shape as (0, 0); otherwise infer width.
47
+ return torch.from_numpy(np_arr.reshape(len(arr), -1 if len(arr) > 0 else 0))
48
+
49
+
50
+ # helper functions
51
+
52
+
53
+ def _is_supported_list(t: pa.DataType) -> bool:
54
+ """Check if the given PyArrow data type is a supported list."""
55
+ return pt.is_fixed_size_list(t) or pt.is_large_list(t) or pt.is_list(t)
56
+
57
+
58
+ def _is_supported_scalar(t: pa.DataType) -> bool:
59
+ """Check if the given PyArrow data type is a supported scalar type."""
60
+ return pt.is_boolean(t) or pt.is_floating(t) or pt.is_integer(t) or pt.is_null(t)
61
+
62
+
63
+ def _to_writable_numpy(arr: pa.Array) -> npt.NDArray:
64
+ """Dump a PyArrow array into a writable NumPy array."""
65
+ # Force the NumPy array to be writable. PyArrow's to_numpy() often returns a
66
+ # read-only view for zero-copy, which PyTorch's from_numpy() does not support.
67
+ return arr.to_numpy(writable=True, zero_copy_only=False)
@@ -0,0 +1,120 @@
1
+ """Dataset implementations providing streaming, batch-wise data access for PyTorch."""
2
+
3
+ import os
4
+ import typing as ty
5
+
6
+ import pyarrow.dataset as pd
7
+ import torch
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+ from .convert import pa_array_to_tensor
11
+
12
+ # Type for path to a file
13
+ _FilePath = ty.Union[str, os.PathLike]
14
+
15
+ # The default batch size when reading a Parquet dataset
16
+ _DEFAULT_BATCH_SIZE = 1024
17
+
18
+
19
+ class ParquetIterableDataset(IterableDataset):
20
+ """
21
+ IterableDataset that streams data from one or more Parquet files.
22
+
23
+ Parameters
24
+ ----------
25
+ file_paths : list[_FilePath]
26
+ Paths to Parquet files.
27
+ columns : list[str], optional
28
+ Column names to select. If ``None``, all columns are read.
29
+ batch_size : int, default DEFAULT_BATCH_SIZE
30
+ Number of rows per streamed batch.
31
+
32
+ Notes
33
+ -----
34
+ This dataset reads data lazily and never loads the entire Parquet dataset to memory.
35
+ The current worker receives a partition of ``file_paths`` and builds its own PyArrow
36
+ Dataset and Scanner. Iteration yields dictionaries mapping column names to PyTorch
37
+ tensors created via NumPy, one batch at a time.
38
+
39
+ Examples
40
+ --------
41
+ >>> ds = ParquetIterableDataset(
42
+ ... ["/data/train1.parquet", "/data/train2.parquet"],
43
+ ... columns=["x", "y", "label"],
44
+ ... batch_size=1024,
45
+ ... )
46
+ >>> loader = DataLoader(ds, batch_size=None)
47
+ >>> # Now iterate over batches.
48
+ >>> for batch in loader:
49
+ ... x, y, label = batch["x"], batch["y"], batch["label"]
50
+ ... # Do some work.
51
+ ... ...
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ file_paths: ty.Sequence[_FilePath],
57
+ /,
58
+ columns: ty.Optional[ty.Sequence[str]] = None,
59
+ batch_size: int = _DEFAULT_BATCH_SIZE,
60
+ ) -> None:
61
+ """Initialize this instance."""
62
+ self._file_paths = tuple(map(str, file_paths))
63
+ self._columns = None if columns is None else tuple(columns)
64
+ self._batch_size = batch_size
65
+
66
+ def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
67
+ """
68
+ Stream Parquet data as mapped PyTorch tensors.
69
+
70
+ Build a PyArrow Dataset from the current worker's assigned file partition, then
71
+ create a Scanner to lazily read batches of the selected columns. Each batch is
72
+ converted to a dict mapping column names to PyTorch tensors (via NumPy).
73
+
74
+ Returns
75
+ -------
76
+ Iterator[dict[str, torch.Tensor]]
77
+ An iterator that yields one converted batch at a time.
78
+ """
79
+ if not (partition := self._get_partition()):
80
+ return
81
+
82
+ # Build the dataset for the current worker.
83
+ ds = pd.dataset(partition, format="parquet")
84
+
85
+ # Create a scanner. This does not read data.
86
+ columns = None if self._columns is None else list(self._columns)
87
+ scanner = ds.scanner(columns=columns, batch_size=self._batch_size)
88
+
89
+ for batch in scanner.to_batches():
90
+ data_dict: dict[str, torch.Tensor] = {}
91
+ for name, array in zip(batch.column_names, batch.columns):
92
+ data_dict[name] = pa_array_to_tensor(array)
93
+ yield data_dict
94
+
95
+ # private interfaces
96
+
97
+ def _get_partition(self) -> tuple[str, ...]:
98
+ """
99
+ Get the partition of file paths for the current worker.
100
+
101
+ This method splits the full list of file paths into contiguous partitions with
102
+ a nearly equal size by the total number of workers and the current worker ID.
103
+
104
+ If running in the main process (i.e., no worker information is available), the
105
+ entire list of file paths is returned.
106
+
107
+ Returns
108
+ -------
109
+ tuple[str, ...]
110
+ The partition of file paths for the current worker.
111
+ """
112
+ if (info := get_worker_info()) is None:
113
+ return self._file_paths
114
+
115
+ n = len(self._file_paths)
116
+ per_worker = (n + info.num_workers - 1) // info.num_workers
117
+
118
+ start = info.id * per_worker
119
+ end = n if (end := start + per_worker) > n else end
120
+ return self._file_paths[start:end]
@@ -43,6 +43,7 @@ class CTRTrainer(object):
43
43
  gpus=None,
44
44
  loss_mode=True,
45
45
  model_path="./",
46
+ model_logger=None,
46
47
  ):
47
48
  self.model = model # for uniform weights save method in one gpu or multi gpu
48
49
  if gpus is None:
@@ -70,10 +71,13 @@ class CTRTrainer(object):
70
71
  self.model_path = model_path
71
72
  # Initialize regularization loss
72
73
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
74
+ self.model_logger = model_logger
73
75
 
74
76
  def train_one_epoch(self, data_loader, log_interval=10):
75
77
  self.model.train()
76
78
  total_loss = 0
79
+ epoch_loss = 0
80
+ batch_count = 0
77
81
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
78
82
  for i, (x_dict, y) in enumerate(tk0):
79
83
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
@@ -93,27 +97,62 @@ class CTRTrainer(object):
93
97
  loss.backward()
94
98
  self.optimizer.step()
95
99
  total_loss += loss.item()
100
+ epoch_loss += loss.item()
101
+ batch_count += 1
96
102
  if (i + 1) % log_interval == 0:
97
103
  tk0.set_postfix(loss=total_loss / log_interval)
98
104
  total_loss = 0
99
105
 
106
+ # Return average epoch loss
107
+ return epoch_loss / batch_count if batch_count > 0 else 0
108
+
100
109
  def fit(self, train_dataloader, val_dataloader=None):
110
+ for logger in self._iter_loggers():
111
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.loss_mode})
112
+
101
113
  for epoch_i in range(self.n_epoch):
102
114
  print('epoch:', epoch_i)
103
- self.train_one_epoch(train_dataloader)
115
+ train_loss = self.train_one_epoch(train_dataloader)
116
+
117
+ for logger in self._iter_loggers():
118
+ logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
119
+
104
120
  if self.scheduler is not None:
105
121
  if epoch_i % self.scheduler.step_size == 0:
106
122
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
107
123
  self.scheduler.step() # update lr in epoch level by scheduler
124
+
108
125
  if val_dataloader:
109
126
  auc = self.evaluate(self.model, val_dataloader)
110
127
  print('epoch:', epoch_i, 'validation: auc:', auc)
128
+
129
+ for logger in self._iter_loggers():
130
+ logger.log_metrics({'val/auc': auc}, step=epoch_i)
131
+
111
132
  if self.early_stopper.stop_training(auc, self.model.state_dict()):
112
133
  print(f'validation: best auc: {self.early_stopper.best_auc}')
113
134
  self.model.load_state_dict(self.early_stopper.best_weights)
114
135
  break
136
+
115
137
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
116
138
 
139
+ for logger in self._iter_loggers():
140
+ logger.finish()
141
+
142
+ def _iter_loggers(self):
143
+ """Return logger instances as a list.
144
+
145
+ Returns
146
+ -------
147
+ list
148
+ Active logger instances. Empty when ``model_logger`` is ``None``.
149
+ """
150
+ if self.model_logger is None:
151
+ return []
152
+ if isinstance(self.model_logger, (list, tuple)):
153
+ return list(self.model_logger)
154
+ return [self.model_logger]
155
+
117
156
  def evaluate(self, model, data_loader):
118
157
  model.eval()
119
158
  targets, predicts = list(), list()
@@ -39,6 +39,7 @@ class MatchTrainer(object):
39
39
  device="cpu",
40
40
  gpus=None,
41
41
  model_path="./",
42
+ model_logger=None,
42
43
  ):
43
44
  self.model = model # for uniform weights save method in one gpu or multi gpu
44
45
  if gpus is None:
@@ -73,10 +74,13 @@ class MatchTrainer(object):
73
74
  self.model_path = model_path
74
75
  # Initialize regularization loss
75
76
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
77
+ self.model_logger = model_logger
76
78
 
77
79
  def train_one_epoch(self, data_loader, log_interval=10):
78
80
  self.model.train()
79
81
  total_loss = 0
82
+ epoch_loss = 0
83
+ batch_count = 0
80
84
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
81
85
  for i, (x_dict, y) in enumerate(tk0):
82
86
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
@@ -114,14 +118,26 @@ class MatchTrainer(object):
114
118
  loss.backward()
115
119
  self.optimizer.step()
116
120
  total_loss += loss.item()
121
+ epoch_loss += loss.item()
122
+ batch_count += 1
117
123
  if (i + 1) % log_interval == 0:
118
124
  tk0.set_postfix(loss=total_loss / log_interval)
119
125
  total_loss = 0
120
126
 
127
+ # Return average epoch loss
128
+ return epoch_loss / batch_count if batch_count > 0 else 0
129
+
121
130
  def fit(self, train_dataloader, val_dataloader=None):
131
+ for logger in self._iter_loggers():
132
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.mode})
133
+
122
134
  for epoch_i in range(self.n_epoch):
123
135
  print('epoch:', epoch_i)
124
- self.train_one_epoch(train_dataloader)
136
+ train_loss = self.train_one_epoch(train_dataloader)
137
+
138
+ for logger in self._iter_loggers():
139
+ logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
140
+
125
141
  if self.scheduler is not None:
126
142
  if epoch_i % self.scheduler.step_size == 0:
127
143
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
@@ -130,12 +146,34 @@ class MatchTrainer(object):
130
146
  if val_dataloader:
131
147
  auc = self.evaluate(self.model, val_dataloader)
132
148
  print('epoch:', epoch_i, 'validation: auc:', auc)
149
+
150
+ for logger in self._iter_loggers():
151
+ logger.log_metrics({'val/auc': auc}, step=epoch_i)
152
+
133
153
  if self.early_stopper.stop_training(auc, self.model.state_dict()):
134
154
  print(f'validation: best auc: {self.early_stopper.best_auc}')
135
155
  self.model.load_state_dict(self.early_stopper.best_weights)
136
156
  break
157
+
137
158
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
138
159
 
160
+ for logger in self._iter_loggers():
161
+ logger.finish()
162
+
163
+ def _iter_loggers(self):
164
+ """Return logger instances as a list.
165
+
166
+ Returns
167
+ -------
168
+ list
169
+ Active logger instances. Empty when ``model_logger`` is ``None``.
170
+ """
171
+ if self.model_logger is None:
172
+ return []
173
+ if isinstance(self.model_logger, (list, tuple)):
174
+ return list(self.model_logger)
175
+ return [self.model_logger]
176
+
139
177
  def evaluate(self, model, data_loader):
140
178
  model.eval()
141
179
  targets, predicts = list(), list()
@@ -47,6 +47,7 @@ class MTLTrainer(object):
47
47
  device="cpu",
48
48
  gpus=None,
49
49
  model_path="./",
50
+ model_logger=None,
50
51
  ):
51
52
  self.model = model
52
53
  if gpus is None:
@@ -104,6 +105,7 @@ class MTLTrainer(object):
104
105
  self.model_path = model_path
105
106
  # Initialize regularization loss
106
107
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
108
+ self.model_logger = model_logger
107
109
 
108
110
  def train_one_epoch(self, data_loader):
109
111
  self.model.train()
@@ -163,21 +165,42 @@ class MTLTrainer(object):
163
165
  def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
164
166
  total_log = []
165
167
 
168
+ # Log hyperparameters once
169
+ for logger in self._iter_loggers():
170
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self._current_lr(), 'adaptive_method': self.adaptive_method})
171
+
166
172
  for epoch_i in range(self.n_epoch):
167
173
  _log_per_epoch = self.train_one_epoch(train_dataloader)
168
174
 
175
+ # Collect metrics
176
+ logs = {f'train/task_{task_id}_loss': loss_val for task_id, loss_val in enumerate(_log_per_epoch)}
177
+ lr_value = self._current_lr()
178
+ if lr_value is not None:
179
+ logs['learning_rate'] = lr_value
180
+
169
181
  if self.scheduler is not None:
170
182
  if epoch_i % self.scheduler.step_size == 0:
171
183
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
172
184
  self.scheduler.step() # update lr in epoch level by scheduler
185
+
173
186
  scores = self.evaluate(self.model, val_dataloader)
174
187
  print('epoch:', epoch_i, 'validation scores: ', scores)
175
188
 
176
- for score in scores:
189
+ for task_id, score in enumerate(scores):
190
+ logs[f'val/task_{task_id}_score'] = score
177
191
  _log_per_epoch.append(score)
192
+ logs['auc'] = scores[self.earlystop_taskid]
193
+
194
+ if self.loss_weight:
195
+ for task_id, weight in enumerate(self.loss_weight):
196
+ logs[f'loss_weight/task_{task_id}'] = weight.item()
178
197
 
179
198
  total_log.append(_log_per_epoch)
180
199
 
200
+ # Log metrics once per epoch
201
+ for logger in self._iter_loggers():
202
+ logger.log_metrics(logs, step=epoch_i)
203
+
181
204
  if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
182
205
  print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
183
206
  self.model.load_state_dict(self.early_stopper.best_weights)
@@ -185,8 +208,33 @@ class MTLTrainer(object):
185
208
 
186
209
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
187
210
 
211
+ for logger in self._iter_loggers():
212
+ logger.finish()
213
+
188
214
  return total_log
189
215
 
216
+ def _iter_loggers(self):
217
+ """Return logger instances as a list.
218
+
219
+ Returns
220
+ -------
221
+ list
222
+ Active logger instances. Empty when ``model_logger`` is ``None``.
223
+ """
224
+ if self.model_logger is None:
225
+ return []
226
+ if isinstance(self.model_logger, (list, tuple)):
227
+ return list(self.model_logger)
228
+ return [self.model_logger]
229
+
230
+ def _current_lr(self):
231
+ """Fetch current learning rate regardless of adaptive method."""
232
+ if self.adaptive_method == "metabalance":
233
+ return self.share_optimizer.param_groups[0]['lr'] if hasattr(self, 'share_optimizer') else None
234
+ if hasattr(self, 'optimizer'):
235
+ return self.optimizer.param_groups[0]['lr']
236
+ return None
237
+
190
238
  def evaluate(self, model, data_loader):
191
239
  model.eval()
192
240
  targets, predicts = list(), list()
@@ -46,7 +46,22 @@ class SeqTrainer(object):
46
46
  ... )
47
47
  """
48
48
 
49
- def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
49
+ def __init__(
50
+ self,
51
+ model,
52
+ optimizer_fn=torch.optim.Adam,
53
+ optimizer_params=None,
54
+ scheduler_fn=None,
55
+ scheduler_params=None,
56
+ n_epoch=10,
57
+ earlystop_patience=10,
58
+ device='cpu',
59
+ gpus=None,
60
+ model_path='./',
61
+ loss_type='cross_entropy',
62
+ loss_params=None,
63
+ model_logger=None
64
+ ):
50
65
  self.model = model # for uniform weights save method in one gpu or multi gpu
51
66
  if gpus is None:
52
67
  gpus = []
@@ -74,9 +89,11 @@ class SeqTrainer(object):
74
89
  loss_params = {"ignore_index": 0}
75
90
  self.loss_fn = nn.CrossEntropyLoss(**loss_params)
76
91
 
92
+ self.loss_type = loss_type
77
93
  self.n_epoch = n_epoch
78
94
  self.early_stopper = EarlyStopper(patience=earlystop_patience)
79
95
  self.model_path = model_path
96
+ self.model_logger = model_logger
80
97
 
81
98
  def fit(self, train_dataloader, val_dataloader=None):
82
99
  """训练模型.
@@ -90,10 +107,18 @@ class SeqTrainer(object):
90
107
  """
91
108
  history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
92
109
 
110
+ for logger in self._iter_loggers():
111
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_type': self.loss_type})
112
+
93
113
  for epoch_i in range(self.n_epoch):
94
114
  print('epoch:', epoch_i)
95
115
  # 训练阶段
96
- self.train_one_epoch(train_dataloader)
116
+ train_loss = self.train_one_epoch(train_dataloader)
117
+ history['train_loss'].append(train_loss)
118
+
119
+ # Collect metrics
120
+ logs = {'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}
121
+
97
122
  if self.scheduler is not None:
98
123
  if epoch_i % self.scheduler.step_size == 0:
99
124
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
@@ -105,6 +130,10 @@ class SeqTrainer(object):
105
130
  history['val_loss'].append(val_loss)
106
131
  history['val_accuracy'].append(val_accuracy)
107
132
 
133
+ logs['val/loss'] = val_loss
134
+ logs['val/accuracy'] = val_accuracy
135
+ logs['auc'] = val_accuracy # For compatibility with EarlyStopper
136
+
108
137
  print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
109
138
 
110
139
  # 早停
@@ -113,9 +142,30 @@ class SeqTrainer(object):
113
142
  self.model.load_state_dict(self.early_stopper.best_weights)
114
143
  break
115
144
 
145
+ for logger in self._iter_loggers():
146
+ logger.log_metrics(logs, step=epoch_i)
147
+
116
148
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
149
+
150
+ for logger in self._iter_loggers():
151
+ logger.finish()
152
+
117
153
  return history
118
154
 
155
+ def _iter_loggers(self):
156
+ """Return logger instances as a list.
157
+
158
+ Returns
159
+ -------
160
+ list
161
+ Active logger instances. Empty when ``model_logger`` is ``None``.
162
+ """
163
+ if self.model_logger is None:
164
+ return []
165
+ if isinstance(self.model_logger, (list, tuple)):
166
+ return list(self.model_logger)
167
+ return [self.model_logger]
168
+
119
169
  def train_one_epoch(self, data_loader, log_interval=10):
120
170
  """Train the model for a single epoch.
121
171
 
@@ -128,6 +178,8 @@ class SeqTrainer(object):
128
178
  """
129
179
  self.model.train()
130
180
  total_loss = 0
181
+ epoch_loss = 0
182
+ batch_count = 0
131
183
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
132
184
  for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
133
185
  # Move tensors to the target device
@@ -152,10 +204,15 @@ class SeqTrainer(object):
152
204
  self.optimizer.step()
153
205
 
154
206
  total_loss += loss.item()
207
+ epoch_loss += loss.item()
208
+ batch_count += 1
155
209
  if (i + 1) % log_interval == 0:
156
210
  tk0.set_postfix(loss=total_loss / log_interval)
157
211
  total_loss = 0
158
212
 
213
+ # Return average epoch loss
214
+ return epoch_loss / batch_count if batch_count > 0 else 0
215
+
159
216
  def evaluate(self, data_loader):
160
217
  """Evaluate the model on a validation/test data loader.
161
218
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-rechub
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
5
5
  Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
6
6
  Project-URL: Documentation, https://www.torch-rechub.com
@@ -28,19 +28,26 @@ Requires-Dist: scikit-learn>=0.24.0
28
28
  Requires-Dist: torch>=1.10.0
29
29
  Requires-Dist: tqdm>=4.60.0
30
30
  Requires-Dist: transformers>=4.46.3
31
+ Provides-Extra: bigdata
32
+ Requires-Dist: pyarrow~=21.0; extra == 'bigdata'
31
33
  Provides-Extra: dev
32
34
  Requires-Dist: bandit>=1.7.0; extra == 'dev'
33
35
  Requires-Dist: flake8>=3.8.0; extra == 'dev'
34
36
  Requires-Dist: isort==5.13.2; extra == 'dev'
35
37
  Requires-Dist: mypy>=0.800; extra == 'dev'
36
38
  Requires-Dist: pre-commit>=2.20.0; extra == 'dev'
39
+ Requires-Dist: pyarrow-stubs>=20.0; extra == 'dev'
37
40
  Requires-Dist: pytest-cov>=2.0; extra == 'dev'
38
41
  Requires-Dist: pytest>=6.0; extra == 'dev'
39
42
  Requires-Dist: toml>=0.10.2; extra == 'dev'
40
43
  Requires-Dist: yapf==0.43.0; extra == 'dev'
41
44
  Provides-Extra: onnx
42
- Requires-Dist: onnx>=1.12.0; extra == 'onnx'
43
- Requires-Dist: onnxruntime>=1.12.0; extra == 'onnx'
45
+ Requires-Dist: onnx>=1.14.0; extra == 'onnx'
46
+ Requires-Dist: onnxruntime>=1.14.0; extra == 'onnx'
47
+ Provides-Extra: tracking
48
+ Requires-Dist: swanlab>=0.1.0; extra == 'tracking'
49
+ Requires-Dist: tensorboardx>=2.5; extra == 'tracking'
50
+ Requires-Dist: wandb>=0.13.0; extra == 'tracking'
44
51
  Provides-Extra: visualization
45
52
  Requires-Dist: graphviz>=0.20; extra == 'visualization'
46
53
  Requires-Dist: torchview>=0.2.6; extra == 'visualization'
@@ -89,7 +96,8 @@ Description-Content-Type: text/markdown
89
96
  * **易于配置:** 通过配置文件或命令行参数轻松调整实验设置。
90
97
  * **可复现性:** 旨在确保实验结果的可复现性。
91
98
  * **ONNX 导出:** 支持将训练好的模型导出为 ONNX 格式,便于部署到生产环境。
92
- * **其他特性:** 例如,支持负采样、多任务学习等。
99
+ * **跨引擎数据处理:** 现已支持基于 PySpark 的数据处理与转换,方便在大数据管道中落地。
100
+ * **实验可视化与跟踪:** 内置 WandB、SwanLab、TensorBoardX 三种可视化/追踪工具的统一集成。
93
101
 
94
102
  ## 📖 目录
95
103
 
@@ -399,4 +407,4 @@ ctr_trainer.visualization(save_path="model.pdf", dpi=300) # 保存为高清 PDF
399
407
 
400
408
  ---
401
409
 
402
- *最后更新: [2025-12-04]*
410
+ *最后更新: [2025-12-11]*
@@ -8,6 +8,10 @@ torch_rechub/basic/layers.py,sha256=URWk78dlffMOAhDVDhOhugcr4nmwEa192AI1diktC-4,
8
8
  torch_rechub/basic/loss_func.py,sha256=6bjljqpiuUP6O8-wUbGd8FSvflY5Dp_DV_57OuQVMz4,7969
9
9
  torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
10
10
  torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
11
+ torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
12
+ torch_rechub/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ torch_rechub/data/convert.py,sha256=clGFEbDSDpdZBvscWatfjtuXMZUzgy1kiEAg4w_q7VM,2241
14
+ torch_rechub/data/dataset.py,sha256=fDDQ5N3x99KPfy0Ux4LRQbFlWbLg_dvKTO1WUEbEN04,4111
11
15
  torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
16
  torch_rechub/models/generative/__init__.py,sha256=TsCdVIhOcalQwqKZKjEuNbHKyIjyclapKGNwYfFR7TM,135
13
17
  torch_rechub/models/generative/hllm.py,sha256=6Vrp5Bh0fTFHCn7C-3EqzOyc7UunOyEY9TzAKGHrW-8,9669
@@ -45,11 +49,11 @@ torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPH
45
49
  torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
46
50
  torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
47
51
  torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
48
- torch_rechub/trainers/ctr_trainer.py,sha256=ECXaK0x2_6jZVxtEazgN3hkBpSAMPeGeNtunqI_OECo,12860
49
- torch_rechub/trainers/match_trainer.py,sha256=QHZb32Rf7yp-NvEzdeiG1HQghQ76_vuu59K1IsdK60k,15055
52
+ torch_rechub/trainers/ctr_trainer.py,sha256=e0xS-W48BOixN0ogksWOcVJNKFiO3g2oNA_hlHytRqk,14138
53
+ torch_rechub/trainers/match_trainer.py,sha256=atkO-gfDuTk6lh-WvaJOh5kgn6HPzbQQN42Rvz8kyXY,16327
50
54
  torch_rechub/trainers/matching.md,sha256=vIBQ3UMmVpUpyk38rrkelFwm_wXVXqMOuqzYZ4M8bzw,30
51
- torch_rechub/trainers/mtl_trainer.py,sha256=MjasE_QOPfGxiUW1JpYYQ2iuBSSk-lissAGp4Sw1CWk,16427
52
- torch_rechub/trainers/seq_trainer.py,sha256=uAo9XymwQupCqvm5otKW81tz1nxd3crJ2ul2r7lrEAE,17633
55
+ torch_rechub/trainers/mtl_trainer.py,sha256=n3T-ctWACSyl0awBQixOlZUQ8I5cfGyZzgKV09EF8hw,18293
56
+ torch_rechub/trainers/seq_trainer.py,sha256=pyY70kAjTWdKrnAYZynql1PPNtveYDLMB_1hbpCHa48,19217
53
57
  torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
58
  torch_rechub/utils/data.py,sha256=vzLAAVt6dujg_vbGhQewiJc0l6JzwzdcM_9EjoOz898,19882
55
59
  torch_rechub/utils/hstu_utils.py,sha256=qLON_pJDC-kDyQn1PoN_HaHi5xTNCwZPgJeV51Z61Lc,6207
@@ -58,7 +62,7 @@ torch_rechub/utils/model_utils.py,sha256=VLhSbTpupxrFyyY3NzMQ32PPmo5YHm1T96u9KDl
58
62
  torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
59
63
  torch_rechub/utils/onnx_export.py,sha256=LRHyZaR9zZJyg6xtuqQHWmusWq-yEvw9EhlmoEwcqsg,8364
60
64
  torch_rechub/utils/visualization.py,sha256=Djv8W5SkCk3P2dol5VXf0_eanIhxDwRd7fzNOQY4uiU,9506
61
- torch_rechub-0.0.5.dist-info/METADATA,sha256=7k9N1xGB4JeWzri7iA7kJbPnAJ-KhXF7vBV-_b8Ghrg,17998
62
- torch_rechub-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
- torch_rechub-0.0.5.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
64
- torch_rechub-0.0.5.dist-info/RECORD,,
65
+ torch_rechub-0.0.6.dist-info/METADATA,sha256=OihjWb0yCI1bmTEoCYAC6pI6cCgl5KS5uSrAGZwv7yY,18470
66
+ torch_rechub-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
67
+ torch_rechub-0.0.6.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
68
+ torch_rechub-0.0.6.dist-info/RECORD,,