torch-rechub 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ """Experiment tracking utilities for Torch-RecHub.
2
+
3
+ This module exposes lightweight adapters for common visualization and
4
+ experiment tracking tools, namely Weights & Biases (wandb), SwanLab, and
5
+ TensorBoardX.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+
12
+ class BaseLogger(ABC):
13
+ """Base interface for experiment tracking backends.
14
+
15
+ Methods
16
+ -------
17
+ log_metrics(metrics, step=None)
18
+ Record scalar metrics at a given step.
19
+ log_hyperparams(params)
20
+ Store hyperparameters and run configuration.
21
+ finish()
22
+ Flush pending logs and release resources.
23
+ """
24
+
25
+ @abstractmethod
26
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
27
+ """Log metrics to the tracking backend.
28
+
29
+ Parameters
30
+ ----------
31
+ metrics : dict of str to Any
32
+ Metric name-value pairs to record.
33
+ step : int, optional
34
+ Explicit global step or epoch index. When ``None``, the backend
35
+ uses its own default step handling.
36
+ """
37
+ raise NotImplementedError
38
+
39
+ @abstractmethod
40
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
41
+ """Log experiment hyperparameters.
42
+
43
+ Parameters
44
+ ----------
45
+ params : dict of str to Any
46
+ Hyperparameters or configuration values to persist with the run.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ @abstractmethod
51
+ def finish(self) -> None:
52
+ """Finalize logging and free any backend resources."""
53
+ raise NotImplementedError
54
+
55
+
56
+ class WandbLogger(BaseLogger):
57
+ """Weights & Biases logger implementation.
58
+
59
+ Parameters
60
+ ----------
61
+ project : str
62
+ Name of the wandb project to log to.
63
+ name : str, optional
64
+ Display name for the run.
65
+ config : dict, optional
66
+ Initial hyperparameter configuration to record.
67
+ tags : list of str, optional
68
+ Optional tags for grouping runs.
69
+ notes : str, optional
70
+ Long-form notes shown in the run overview.
71
+ dir : str, optional
72
+ Local directory for wandb artifacts and cache.
73
+ **kwargs : dict
74
+ Additional keyword arguments forwarded to ``wandb.init``.
75
+
76
+ Raises
77
+ ------
78
+ ImportError
79
+ If ``wandb`` is not installed in the current environment.
80
+ """
81
+
82
+ def __init__(self, project: str, name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, notes: Optional[str] = None, dir: Optional[str] = None, **kwargs):
83
+ try:
84
+ import wandb
85
+ self._wandb = wandb
86
+ except ImportError:
87
+ raise ImportError("wandb is not installed. Install it with: pip install wandb")
88
+
89
+ self.run = self._wandb.init(project=project, name=name, config=config, tags=tags, notes=notes, dir=dir, **kwargs)
90
+
91
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
92
+ if step is not None:
93
+ self._wandb.log(metrics, step=step)
94
+ else:
95
+ self._wandb.log(metrics)
96
+
97
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
98
+ if self.run is not None:
99
+ self.run.config.update(params)
100
+
101
+ def finish(self) -> None:
102
+ if self.run is not None:
103
+ self.run.finish()
104
+
105
+
106
+ class SwanLabLogger(BaseLogger):
107
+ """SwanLab logger implementation.
108
+
109
+ Parameters
110
+ ----------
111
+ project : str, optional
112
+ Project identifier for grouping experiments.
113
+ experiment_name : str, optional
114
+ Display name for the experiment or run.
115
+ description : str, optional
116
+ Text description shown alongside the run.
117
+ config : dict, optional
118
+ Hyperparameters or configuration to log at startup.
119
+ logdir : str, optional
120
+ Directory where logs and artifacts are stored.
121
+ **kwargs : dict
122
+ Additional keyword arguments forwarded to ``swanlab.init``.
123
+
124
+ Raises
125
+ ------
126
+ ImportError
127
+ If ``swanlab`` is not installed in the current environment.
128
+ """
129
+
130
+ def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None, description: Optional[str] = None, config: Optional[Dict[str, Any]] = None, logdir: Optional[str] = None, **kwargs):
131
+ try:
132
+ import swanlab
133
+ self._swanlab = swanlab
134
+ except ImportError:
135
+ raise ImportError("swanlab is not installed. Install it with: pip install swanlab")
136
+
137
+ self.run = self._swanlab.init(project=project, experiment_name=experiment_name, description=description, config=config, logdir=logdir, **kwargs)
138
+
139
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
140
+ if step is not None:
141
+ self._swanlab.log(metrics, step=step)
142
+ else:
143
+ self._swanlab.log(metrics)
144
+
145
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
146
+ if self.run is not None:
147
+ self.run.config.update(params)
148
+
149
+ def finish(self) -> None:
150
+ self._swanlab.finish()
151
+
152
+
153
+ class TensorBoardXLogger(BaseLogger):
154
+ """TensorBoardX logger implementation.
155
+
156
+ Parameters
157
+ ----------
158
+ log_dir : str
159
+ Directory where event files will be written.
160
+ comment : str, default=""
161
+ Comment appended to the log directory name.
162
+ **kwargs : dict
163
+ Additional keyword arguments forwarded to
164
+ ``tensorboardX.SummaryWriter``.
165
+
166
+ Raises
167
+ ------
168
+ ImportError
169
+ If ``tensorboardX`` is not installed in the current environment.
170
+ """
171
+
172
+ def __init__(self, log_dir: str, comment: str = "", **kwargs):
173
+ try:
174
+ from tensorboardX import SummaryWriter
175
+ self._SummaryWriter = SummaryWriter
176
+ except ImportError:
177
+ raise ImportError("tensorboardX is not installed. Install it with: pip install tensorboardX")
178
+
179
+ self.writer = self._SummaryWriter(log_dir=log_dir, comment=comment, **kwargs)
180
+ self._step = 0
181
+
182
+ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
183
+ if step is None:
184
+ step = self._step
185
+ self._step += 1
186
+
187
+ for key, value in metrics.items():
188
+ if value is not None:
189
+ if isinstance(value, (int, float)):
190
+ self.writer.add_scalar(key, value, step)
191
+
192
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
193
+ hparam_str = "\n".join([f"{k}: {v}" for k, v in params.items()])
194
+ self.writer.add_text("hyperparameters", hparam_str, 0)
195
+
196
+ def finish(self) -> None:
197
+ if self.writer is not None:
198
+ self.writer.close()
File without changes
@@ -0,0 +1,67 @@
1
+ """Utilities for converting array-like data structures into PyTorch tensors."""
2
+
3
+ import numpy.typing as npt
4
+ import pyarrow as pa
5
+ import pyarrow.compute as pc
6
+ import pyarrow.types as pt
7
+ import torch
8
+
9
+
10
+ def pa_array_to_tensor(arr: pa.Array) -> torch.Tensor:
11
+ """
12
+ Convert a PyArrow array to a PyTorch tensor.
13
+
14
+ Parameters
15
+ ----------
16
+ arr : pa.Array
17
+ The given PyArrow array.
18
+
19
+ Returns
20
+ -------
21
+ torch.Tensor: The result PyTorch tensor.
22
+
23
+ Raises
24
+ ------
25
+ TypeError
26
+ if the array type or the value type (when nested) is unsupported.
27
+ ValueError
28
+ if the nested array is ragged (unequal lengths of each row).
29
+ """
30
+ if _is_supported_scalar(arr.type):
31
+ arr = pc.cast(arr, pa.float32())
32
+ return torch.from_numpy(_to_writable_numpy(arr))
33
+
34
+ if not _is_supported_list(arr.type):
35
+ raise TypeError(f"Unsupported array type: {arr.type}")
36
+
37
+ if not _is_supported_scalar(val_type := arr.type.value_type):
38
+ raise TypeError(f"Unsupported value type in the nested array: {val_type}")
39
+
40
+ if len(pc.unique(pc.list_value_length(arr))) > 1:
41
+ raise ValueError("Cannot convert the ragged nested array.")
42
+
43
+ arr = pc.cast(arr, pa.list_(pa.float32()))
44
+ np_arr = _to_writable_numpy(arr.values) # type: ignore[attr-defined]
45
+
46
+ # For empty list-of-lists, define output shape as (0, 0); otherwise infer width.
47
+ return torch.from_numpy(np_arr.reshape(len(arr), -1 if len(arr) > 0 else 0))
48
+
49
+
50
+ # helper functions
51
+
52
+
53
+ def _is_supported_list(t: pa.DataType) -> bool:
54
+ """Check if the given PyArrow data type is a supported list."""
55
+ return pt.is_fixed_size_list(t) or pt.is_large_list(t) or pt.is_list(t)
56
+
57
+
58
+ def _is_supported_scalar(t: pa.DataType) -> bool:
59
+ """Check if the given PyArrow data type is a supported scalar type."""
60
+ return pt.is_boolean(t) or pt.is_floating(t) or pt.is_integer(t) or pt.is_null(t)
61
+
62
+
63
+ def _to_writable_numpy(arr: pa.Array) -> npt.NDArray:
64
+ """Dump a PyArrow array into a writable NumPy array."""
65
+ # Force the NumPy array to be writable. PyArrow's to_numpy() often returns a
66
+ # read-only view for zero-copy, which PyTorch's from_numpy() does not support.
67
+ return arr.to_numpy(writable=True, zero_copy_only=False)
@@ -0,0 +1,120 @@
1
+ """Dataset implementations providing streaming, batch-wise data access for PyTorch."""
2
+
3
+ import os
4
+ import typing as ty
5
+
6
+ import pyarrow.dataset as pd
7
+ import torch
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+ from .convert import pa_array_to_tensor
11
+
12
+ # Type for path to a file
13
+ _FilePath = ty.Union[str, os.PathLike]
14
+
15
+ # The default batch size when reading a Parquet dataset
16
+ _DEFAULT_BATCH_SIZE = 1024
17
+
18
+
19
+ class ParquetIterableDataset(IterableDataset):
20
+ """
21
+ IterableDataset that streams data from one or more Parquet files.
22
+
23
+ Parameters
24
+ ----------
25
+ file_paths : list[_FilePath]
26
+ Paths to Parquet files.
27
+ columns : list[str], optional
28
+ Column names to select. If ``None``, all columns are read.
29
+ batch_size : int, default DEFAULT_BATCH_SIZE
30
+ Number of rows per streamed batch.
31
+
32
+ Notes
33
+ -----
34
+ This dataset reads data lazily and never loads the entire Parquet dataset to memory.
35
+ The current worker receives a partition of ``file_paths`` and builds its own PyArrow
36
+ Dataset and Scanner. Iteration yields dictionaries mapping column names to PyTorch
37
+ tensors created via NumPy, one batch at a time.
38
+
39
+ Examples
40
+ --------
41
+ >>> ds = ParquetIterableDataset(
42
+ ... ["/data/train1.parquet", "/data/train2.parquet"],
43
+ ... columns=["x", "y", "label"],
44
+ ... batch_size=1024,
45
+ ... )
46
+ >>> loader = DataLoader(ds, batch_size=None)
47
+ >>> # Now iterate over batches.
48
+ >>> for batch in loader:
49
+ ... x, y, label = batch["x"], batch["y"], batch["label"]
50
+ ... # Do some work.
51
+ ... ...
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ file_paths: ty.Sequence[_FilePath],
57
+ /,
58
+ columns: ty.Optional[ty.Sequence[str]] = None,
59
+ batch_size: int = _DEFAULT_BATCH_SIZE,
60
+ ) -> None:
61
+ """Initialize this instance."""
62
+ self._file_paths = tuple(map(str, file_paths))
63
+ self._columns = None if columns is None else tuple(columns)
64
+ self._batch_size = batch_size
65
+
66
+ def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
67
+ """
68
+ Stream Parquet data as mapped PyTorch tensors.
69
+
70
+ Build a PyArrow Dataset from the current worker's assigned file partition, then
71
+ create a Scanner to lazily read batches of the selected columns. Each batch is
72
+ converted to a dict mapping column names to PyTorch tensors (via NumPy).
73
+
74
+ Returns
75
+ -------
76
+ Iterator[dict[str, torch.Tensor]]
77
+ An iterator that yields one converted batch at a time.
78
+ """
79
+ if not (partition := self._get_partition()):
80
+ return
81
+
82
+ # Build the dataset for the current worker.
83
+ ds = pd.dataset(partition, format="parquet")
84
+
85
+ # Create a scanner. This does not read data.
86
+ columns = None if self._columns is None else list(self._columns)
87
+ scanner = ds.scanner(columns=columns, batch_size=self._batch_size)
88
+
89
+ for batch in scanner.to_batches():
90
+ data_dict: dict[str, torch.Tensor] = {}
91
+ for name, array in zip(batch.column_names, batch.columns):
92
+ data_dict[name] = pa_array_to_tensor(array)
93
+ yield data_dict
94
+
95
+ # private interfaces
96
+
97
+ def _get_partition(self) -> tuple[str, ...]:
98
+ """
99
+ Get the partition of file paths for the current worker.
100
+
101
+ This method splits the full list of file paths into contiguous partitions with
102
+ a nearly equal size by the total number of workers and the current worker ID.
103
+
104
+ If running in the main process (i.e., no worker information is available), the
105
+ entire list of file paths is returned.
106
+
107
+ Returns
108
+ -------
109
+ tuple[str, ...]
110
+ The partition of file paths for the current worker.
111
+ """
112
+ if (info := get_worker_info()) is None:
113
+ return self._file_paths
114
+
115
+ n = len(self._file_paths)
116
+ per_worker = (n + info.num_workers - 1) // info.num_workers
117
+
118
+ start = info.id * per_worker
119
+ end = n if (end := start + per_worker) > n else end
120
+ return self._file_paths[start:end]
@@ -43,6 +43,7 @@ class CTRTrainer(object):
43
43
  gpus=None,
44
44
  loss_mode=True,
45
45
  model_path="./",
46
+ model_logger=None,
46
47
  ):
47
48
  self.model = model # for uniform weights save method in one gpu or multi gpu
48
49
  if gpus is None:
@@ -70,10 +71,13 @@ class CTRTrainer(object):
70
71
  self.model_path = model_path
71
72
  # Initialize regularization loss
72
73
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
74
+ self.model_logger = model_logger
73
75
 
74
76
  def train_one_epoch(self, data_loader, log_interval=10):
75
77
  self.model.train()
76
78
  total_loss = 0
79
+ epoch_loss = 0
80
+ batch_count = 0
77
81
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
78
82
  for i, (x_dict, y) in enumerate(tk0):
79
83
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
@@ -93,27 +97,62 @@ class CTRTrainer(object):
93
97
  loss.backward()
94
98
  self.optimizer.step()
95
99
  total_loss += loss.item()
100
+ epoch_loss += loss.item()
101
+ batch_count += 1
96
102
  if (i + 1) % log_interval == 0:
97
103
  tk0.set_postfix(loss=total_loss / log_interval)
98
104
  total_loss = 0
99
105
 
106
+ # Return average epoch loss
107
+ return epoch_loss / batch_count if batch_count > 0 else 0
108
+
100
109
  def fit(self, train_dataloader, val_dataloader=None):
110
+ for logger in self._iter_loggers():
111
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.loss_mode})
112
+
101
113
  for epoch_i in range(self.n_epoch):
102
114
  print('epoch:', epoch_i)
103
- self.train_one_epoch(train_dataloader)
115
+ train_loss = self.train_one_epoch(train_dataloader)
116
+
117
+ for logger in self._iter_loggers():
118
+ logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
119
+
104
120
  if self.scheduler is not None:
105
121
  if epoch_i % self.scheduler.step_size == 0:
106
122
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
107
123
  self.scheduler.step() # update lr in epoch level by scheduler
124
+
108
125
  if val_dataloader:
109
126
  auc = self.evaluate(self.model, val_dataloader)
110
127
  print('epoch:', epoch_i, 'validation: auc:', auc)
128
+
129
+ for logger in self._iter_loggers():
130
+ logger.log_metrics({'val/auc': auc}, step=epoch_i)
131
+
111
132
  if self.early_stopper.stop_training(auc, self.model.state_dict()):
112
133
  print(f'validation: best auc: {self.early_stopper.best_auc}')
113
134
  self.model.load_state_dict(self.early_stopper.best_weights)
114
135
  break
136
+
115
137
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
116
138
 
139
+ for logger in self._iter_loggers():
140
+ logger.finish()
141
+
142
+ def _iter_loggers(self):
143
+ """Return logger instances as a list.
144
+
145
+ Returns
146
+ -------
147
+ list
148
+ Active logger instances. Empty when ``model_logger`` is ``None``.
149
+ """
150
+ if self.model_logger is None:
151
+ return []
152
+ if isinstance(self.model_logger, (list, tuple)):
153
+ return list(self.model_logger)
154
+ return [self.model_logger]
155
+
117
156
  def evaluate(self, model, data_loader):
118
157
  model.eval()
119
158
  targets, predicts = list(), list()
@@ -189,3 +228,100 @@ class CTRTrainer(object):
189
228
 
190
229
  exporter = ONNXExporter(model, device=export_device)
191
230
  return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
231
+
232
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
233
+ """Visualize the model's computation graph.
234
+
235
+ This method generates a visual representation of the model architecture,
236
+ showing layer connections, tensor shapes, and nested module structures.
237
+ It automatically extracts feature information from the model.
238
+
239
+ Parameters
240
+ ----------
241
+ input_data : dict, optional
242
+ Example input dict {feature_name: tensor}.
243
+ If not provided, dummy inputs will be generated automatically.
244
+ batch_size : int, default=2
245
+ Batch size for auto-generated dummy input.
246
+ seq_length : int, default=10
247
+ Sequence length for SequenceFeature.
248
+ depth : int, default=3
249
+ Visualization depth, higher values show more detail.
250
+ Set to -1 to show all layers.
251
+ show_shapes : bool, default=True
252
+ Whether to display tensor shapes.
253
+ expand_nested : bool, default=True
254
+ Whether to expand nested modules.
255
+ save_path : str, optional
256
+ Path to save the graph image (.pdf, .svg, .png).
257
+ If None, displays in Jupyter or opens system viewer.
258
+ graph_name : str, default="model"
259
+ Name for the graph.
260
+ device : str, optional
261
+ Device for model execution. If None, defaults to 'cpu'.
262
+ dpi : int, default=300
263
+ Resolution in dots per inch for output image.
264
+ Higher values produce sharper images suitable for papers.
265
+ **kwargs : dict
266
+ Additional arguments passed to torchview.draw_graph().
267
+
268
+ Returns
269
+ -------
270
+ ComputationGraph
271
+ A torchview ComputationGraph object.
272
+
273
+ Raises
274
+ ------
275
+ ImportError
276
+ If torchview or graphviz is not installed.
277
+
278
+ Notes
279
+ -----
280
+ Default Display Behavior:
281
+ When `save_path` is None (default):
282
+ - In Jupyter/IPython: automatically displays the graph inline
283
+ - In Python script: opens the graph with system default viewer
284
+
285
+ Examples
286
+ --------
287
+ >>> trainer = CTRTrainer(model, ...)
288
+ >>> trainer.fit(train_dl, val_dl)
289
+ >>>
290
+ >>> # Auto-display in Jupyter (no save_path needed)
291
+ >>> trainer.visualization(depth=4)
292
+ >>>
293
+ >>> # Save to high-DPI PNG for papers
294
+ >>> trainer.visualization(save_path="model.png", dpi=300)
295
+ """
296
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
297
+
298
+ if not TORCHVIEW_AVAILABLE:
299
+ raise ImportError(
300
+ "Visualization requires torchview. "
301
+ "Install with: pip install torch-rechub[visualization]\n"
302
+ "Also ensure graphviz is installed on your system:\n"
303
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
304
+ " - macOS: brew install graphviz\n"
305
+ " - Windows: choco install graphviz"
306
+ )
307
+
308
+ # Handle DataParallel wrapped model
309
+ model = self.model.module if hasattr(self.model, 'module') else self.model
310
+
311
+ # Use provided device or default to 'cpu'
312
+ viz_device = device if device is not None else 'cpu'
313
+
314
+ return visualize_model(
315
+ model,
316
+ input_data=input_data,
317
+ batch_size=batch_size,
318
+ seq_length=seq_length,
319
+ depth=depth,
320
+ show_shapes=show_shapes,
321
+ expand_nested=expand_nested,
322
+ save_path=save_path,
323
+ graph_name=graph_name,
324
+ device=viz_device,
325
+ dpi=dpi,
326
+ **kwargs
327
+ )