torch-rechub 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +120 -0
- torch_rechub/trainers/ctr_trainer.py +137 -1
- torch_rechub/trainers/match_trainer.py +136 -1
- torch_rechub/trainers/mtl_trainer.py +146 -1
- torch_rechub/trainers/seq_trainer.py +193 -2
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/onnx_export.py +3 -136
- torch_rechub/utils/visualization.py +271 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/METADATA +68 -49
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/RECORD +15 -9
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Experiment tracking utilities for Torch-RecHub.
|
|
2
|
+
|
|
3
|
+
This module exposes lightweight adapters for common visualization and
|
|
4
|
+
experiment tracking tools, namely Weights & Biases (wandb), SwanLab, and
|
|
5
|
+
TensorBoardX.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseLogger(ABC):
|
|
13
|
+
"""Base interface for experiment tracking backends.
|
|
14
|
+
|
|
15
|
+
Methods
|
|
16
|
+
-------
|
|
17
|
+
log_metrics(metrics, step=None)
|
|
18
|
+
Record scalar metrics at a given step.
|
|
19
|
+
log_hyperparams(params)
|
|
20
|
+
Store hyperparameters and run configuration.
|
|
21
|
+
finish()
|
|
22
|
+
Flush pending logs and release resources.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
27
|
+
"""Log metrics to the tracking backend.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
metrics : dict of str to Any
|
|
32
|
+
Metric name-value pairs to record.
|
|
33
|
+
step : int, optional
|
|
34
|
+
Explicit global step or epoch index. When ``None``, the backend
|
|
35
|
+
uses its own default step handling.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
41
|
+
"""Log experiment hyperparameters.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
params : dict of str to Any
|
|
46
|
+
Hyperparameters or configuration values to persist with the run.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def finish(self) -> None:
|
|
52
|
+
"""Finalize logging and free any backend resources."""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class WandbLogger(BaseLogger):
|
|
57
|
+
"""Weights & Biases logger implementation.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
project : str
|
|
62
|
+
Name of the wandb project to log to.
|
|
63
|
+
name : str, optional
|
|
64
|
+
Display name for the run.
|
|
65
|
+
config : dict, optional
|
|
66
|
+
Initial hyperparameter configuration to record.
|
|
67
|
+
tags : list of str, optional
|
|
68
|
+
Optional tags for grouping runs.
|
|
69
|
+
notes : str, optional
|
|
70
|
+
Long-form notes shown in the run overview.
|
|
71
|
+
dir : str, optional
|
|
72
|
+
Local directory for wandb artifacts and cache.
|
|
73
|
+
**kwargs : dict
|
|
74
|
+
Additional keyword arguments forwarded to ``wandb.init``.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ImportError
|
|
79
|
+
If ``wandb`` is not installed in the current environment.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, project: str, name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, notes: Optional[str] = None, dir: Optional[str] = None, **kwargs):
|
|
83
|
+
try:
|
|
84
|
+
import wandb
|
|
85
|
+
self._wandb = wandb
|
|
86
|
+
except ImportError:
|
|
87
|
+
raise ImportError("wandb is not installed. Install it with: pip install wandb")
|
|
88
|
+
|
|
89
|
+
self.run = self._wandb.init(project=project, name=name, config=config, tags=tags, notes=notes, dir=dir, **kwargs)
|
|
90
|
+
|
|
91
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
92
|
+
if step is not None:
|
|
93
|
+
self._wandb.log(metrics, step=step)
|
|
94
|
+
else:
|
|
95
|
+
self._wandb.log(metrics)
|
|
96
|
+
|
|
97
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
98
|
+
if self.run is not None:
|
|
99
|
+
self.run.config.update(params)
|
|
100
|
+
|
|
101
|
+
def finish(self) -> None:
|
|
102
|
+
if self.run is not None:
|
|
103
|
+
self.run.finish()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class SwanLabLogger(BaseLogger):
|
|
107
|
+
"""SwanLab logger implementation.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
project : str, optional
|
|
112
|
+
Project identifier for grouping experiments.
|
|
113
|
+
experiment_name : str, optional
|
|
114
|
+
Display name for the experiment or run.
|
|
115
|
+
description : str, optional
|
|
116
|
+
Text description shown alongside the run.
|
|
117
|
+
config : dict, optional
|
|
118
|
+
Hyperparameters or configuration to log at startup.
|
|
119
|
+
logdir : str, optional
|
|
120
|
+
Directory where logs and artifacts are stored.
|
|
121
|
+
**kwargs : dict
|
|
122
|
+
Additional keyword arguments forwarded to ``swanlab.init``.
|
|
123
|
+
|
|
124
|
+
Raises
|
|
125
|
+
------
|
|
126
|
+
ImportError
|
|
127
|
+
If ``swanlab`` is not installed in the current environment.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None, description: Optional[str] = None, config: Optional[Dict[str, Any]] = None, logdir: Optional[str] = None, **kwargs):
|
|
131
|
+
try:
|
|
132
|
+
import swanlab
|
|
133
|
+
self._swanlab = swanlab
|
|
134
|
+
except ImportError:
|
|
135
|
+
raise ImportError("swanlab is not installed. Install it with: pip install swanlab")
|
|
136
|
+
|
|
137
|
+
self.run = self._swanlab.init(project=project, experiment_name=experiment_name, description=description, config=config, logdir=logdir, **kwargs)
|
|
138
|
+
|
|
139
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
140
|
+
if step is not None:
|
|
141
|
+
self._swanlab.log(metrics, step=step)
|
|
142
|
+
else:
|
|
143
|
+
self._swanlab.log(metrics)
|
|
144
|
+
|
|
145
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
146
|
+
if self.run is not None:
|
|
147
|
+
self.run.config.update(params)
|
|
148
|
+
|
|
149
|
+
def finish(self) -> None:
|
|
150
|
+
self._swanlab.finish()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class TensorBoardXLogger(BaseLogger):
|
|
154
|
+
"""TensorBoardX logger implementation.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
log_dir : str
|
|
159
|
+
Directory where event files will be written.
|
|
160
|
+
comment : str, default=""
|
|
161
|
+
Comment appended to the log directory name.
|
|
162
|
+
**kwargs : dict
|
|
163
|
+
Additional keyword arguments forwarded to
|
|
164
|
+
``tensorboardX.SummaryWriter``.
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
ImportError
|
|
169
|
+
If ``tensorboardX`` is not installed in the current environment.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, log_dir: str, comment: str = "", **kwargs):
|
|
173
|
+
try:
|
|
174
|
+
from tensorboardX import SummaryWriter
|
|
175
|
+
self._SummaryWriter = SummaryWriter
|
|
176
|
+
except ImportError:
|
|
177
|
+
raise ImportError("tensorboardX is not installed. Install it with: pip install tensorboardX")
|
|
178
|
+
|
|
179
|
+
self.writer = self._SummaryWriter(log_dir=log_dir, comment=comment, **kwargs)
|
|
180
|
+
self._step = 0
|
|
181
|
+
|
|
182
|
+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
|
|
183
|
+
if step is None:
|
|
184
|
+
step = self._step
|
|
185
|
+
self._step += 1
|
|
186
|
+
|
|
187
|
+
for key, value in metrics.items():
|
|
188
|
+
if value is not None:
|
|
189
|
+
if isinstance(value, (int, float)):
|
|
190
|
+
self.writer.add_scalar(key, value, step)
|
|
191
|
+
|
|
192
|
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
|
193
|
+
hparam_str = "\n".join([f"{k}: {v}" for k, v in params.items()])
|
|
194
|
+
self.writer.add_text("hyperparameters", hparam_str, 0)
|
|
195
|
+
|
|
196
|
+
def finish(self) -> None:
|
|
197
|
+
if self.writer is not None:
|
|
198
|
+
self.writer.close()
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Utilities for converting array-like data structures into PyTorch tensors."""
|
|
2
|
+
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
import pyarrow.compute as pc
|
|
6
|
+
import pyarrow.types as pt
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def pa_array_to_tensor(arr: pa.Array) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
Convert a PyArrow array to a PyTorch tensor.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
arr : pa.Array
|
|
17
|
+
The given PyArrow array.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
torch.Tensor: The result PyTorch tensor.
|
|
22
|
+
|
|
23
|
+
Raises
|
|
24
|
+
------
|
|
25
|
+
TypeError
|
|
26
|
+
if the array type or the value type (when nested) is unsupported.
|
|
27
|
+
ValueError
|
|
28
|
+
if the nested array is ragged (unequal lengths of each row).
|
|
29
|
+
"""
|
|
30
|
+
if _is_supported_scalar(arr.type):
|
|
31
|
+
arr = pc.cast(arr, pa.float32())
|
|
32
|
+
return torch.from_numpy(_to_writable_numpy(arr))
|
|
33
|
+
|
|
34
|
+
if not _is_supported_list(arr.type):
|
|
35
|
+
raise TypeError(f"Unsupported array type: {arr.type}")
|
|
36
|
+
|
|
37
|
+
if not _is_supported_scalar(val_type := arr.type.value_type):
|
|
38
|
+
raise TypeError(f"Unsupported value type in the nested array: {val_type}")
|
|
39
|
+
|
|
40
|
+
if len(pc.unique(pc.list_value_length(arr))) > 1:
|
|
41
|
+
raise ValueError("Cannot convert the ragged nested array.")
|
|
42
|
+
|
|
43
|
+
arr = pc.cast(arr, pa.list_(pa.float32()))
|
|
44
|
+
np_arr = _to_writable_numpy(arr.values) # type: ignore[attr-defined]
|
|
45
|
+
|
|
46
|
+
# For empty list-of-lists, define output shape as (0, 0); otherwise infer width.
|
|
47
|
+
return torch.from_numpy(np_arr.reshape(len(arr), -1 if len(arr) > 0 else 0))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# helper functions
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _is_supported_list(t: pa.DataType) -> bool:
|
|
54
|
+
"""Check if the given PyArrow data type is a supported list."""
|
|
55
|
+
return pt.is_fixed_size_list(t) or pt.is_large_list(t) or pt.is_list(t)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_supported_scalar(t: pa.DataType) -> bool:
|
|
59
|
+
"""Check if the given PyArrow data type is a supported scalar type."""
|
|
60
|
+
return pt.is_boolean(t) or pt.is_floating(t) or pt.is_integer(t) or pt.is_null(t)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _to_writable_numpy(arr: pa.Array) -> npt.NDArray:
|
|
64
|
+
"""Dump a PyArrow array into a writable NumPy array."""
|
|
65
|
+
# Force the NumPy array to be writable. PyArrow's to_numpy() often returns a
|
|
66
|
+
# read-only view for zero-copy, which PyTorch's from_numpy() does not support.
|
|
67
|
+
return arr.to_numpy(writable=True, zero_copy_only=False)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Dataset implementations providing streaming, batch-wise data access for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import typing as ty
|
|
5
|
+
|
|
6
|
+
import pyarrow.dataset as pd
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
|
+
|
|
10
|
+
from .convert import pa_array_to_tensor
|
|
11
|
+
|
|
12
|
+
# Type for path to a file
|
|
13
|
+
_FilePath = ty.Union[str, os.PathLike]
|
|
14
|
+
|
|
15
|
+
# The default batch size when reading a Parquet dataset
|
|
16
|
+
_DEFAULT_BATCH_SIZE = 1024
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ParquetIterableDataset(IterableDataset):
|
|
20
|
+
"""
|
|
21
|
+
IterableDataset that streams data from one or more Parquet files.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
file_paths : list[_FilePath]
|
|
26
|
+
Paths to Parquet files.
|
|
27
|
+
columns : list[str], optional
|
|
28
|
+
Column names to select. If ``None``, all columns are read.
|
|
29
|
+
batch_size : int, default DEFAULT_BATCH_SIZE
|
|
30
|
+
Number of rows per streamed batch.
|
|
31
|
+
|
|
32
|
+
Notes
|
|
33
|
+
-----
|
|
34
|
+
This dataset reads data lazily and never loads the entire Parquet dataset to memory.
|
|
35
|
+
The current worker receives a partition of ``file_paths`` and builds its own PyArrow
|
|
36
|
+
Dataset and Scanner. Iteration yields dictionaries mapping column names to PyTorch
|
|
37
|
+
tensors created via NumPy, one batch at a time.
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
>>> ds = ParquetIterableDataset(
|
|
42
|
+
... ["/data/train1.parquet", "/data/train2.parquet"],
|
|
43
|
+
... columns=["x", "y", "label"],
|
|
44
|
+
... batch_size=1024,
|
|
45
|
+
... )
|
|
46
|
+
>>> loader = DataLoader(ds, batch_size=None)
|
|
47
|
+
>>> # Now iterate over batches.
|
|
48
|
+
>>> for batch in loader:
|
|
49
|
+
... x, y, label = batch["x"], batch["y"], batch["label"]
|
|
50
|
+
... # Do some work.
|
|
51
|
+
... ...
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
file_paths: ty.Sequence[_FilePath],
|
|
57
|
+
/,
|
|
58
|
+
columns: ty.Optional[ty.Sequence[str]] = None,
|
|
59
|
+
batch_size: int = _DEFAULT_BATCH_SIZE,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Initialize this instance."""
|
|
62
|
+
self._file_paths = tuple(map(str, file_paths))
|
|
63
|
+
self._columns = None if columns is None else tuple(columns)
|
|
64
|
+
self._batch_size = batch_size
|
|
65
|
+
|
|
66
|
+
def __iter__(self) -> ty.Iterator[dict[str, torch.Tensor]]:
|
|
67
|
+
"""
|
|
68
|
+
Stream Parquet data as mapped PyTorch tensors.
|
|
69
|
+
|
|
70
|
+
Build a PyArrow Dataset from the current worker's assigned file partition, then
|
|
71
|
+
create a Scanner to lazily read batches of the selected columns. Each batch is
|
|
72
|
+
converted to a dict mapping column names to PyTorch tensors (via NumPy).
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Iterator[dict[str, torch.Tensor]]
|
|
77
|
+
An iterator that yields one converted batch at a time.
|
|
78
|
+
"""
|
|
79
|
+
if not (partition := self._get_partition()):
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
# Build the dataset for the current worker.
|
|
83
|
+
ds = pd.dataset(partition, format="parquet")
|
|
84
|
+
|
|
85
|
+
# Create a scanner. This does not read data.
|
|
86
|
+
columns = None if self._columns is None else list(self._columns)
|
|
87
|
+
scanner = ds.scanner(columns=columns, batch_size=self._batch_size)
|
|
88
|
+
|
|
89
|
+
for batch in scanner.to_batches():
|
|
90
|
+
data_dict: dict[str, torch.Tensor] = {}
|
|
91
|
+
for name, array in zip(batch.column_names, batch.columns):
|
|
92
|
+
data_dict[name] = pa_array_to_tensor(array)
|
|
93
|
+
yield data_dict
|
|
94
|
+
|
|
95
|
+
# private interfaces
|
|
96
|
+
|
|
97
|
+
def _get_partition(self) -> tuple[str, ...]:
|
|
98
|
+
"""
|
|
99
|
+
Get the partition of file paths for the current worker.
|
|
100
|
+
|
|
101
|
+
This method splits the full list of file paths into contiguous partitions with
|
|
102
|
+
a nearly equal size by the total number of workers and the current worker ID.
|
|
103
|
+
|
|
104
|
+
If running in the main process (i.e., no worker information is available), the
|
|
105
|
+
entire list of file paths is returned.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
tuple[str, ...]
|
|
110
|
+
The partition of file paths for the current worker.
|
|
111
|
+
"""
|
|
112
|
+
if (info := get_worker_info()) is None:
|
|
113
|
+
return self._file_paths
|
|
114
|
+
|
|
115
|
+
n = len(self._file_paths)
|
|
116
|
+
per_worker = (n + info.num_workers - 1) // info.num_workers
|
|
117
|
+
|
|
118
|
+
start = info.id * per_worker
|
|
119
|
+
end = n if (end := start + per_worker) > n else end
|
|
120
|
+
return self._file_paths[start:end]
|
|
@@ -43,6 +43,7 @@ class CTRTrainer(object):
|
|
|
43
43
|
gpus=None,
|
|
44
44
|
loss_mode=True,
|
|
45
45
|
model_path="./",
|
|
46
|
+
model_logger=None,
|
|
46
47
|
):
|
|
47
48
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
48
49
|
if gpus is None:
|
|
@@ -70,10 +71,13 @@ class CTRTrainer(object):
|
|
|
70
71
|
self.model_path = model_path
|
|
71
72
|
# Initialize regularization loss
|
|
72
73
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
74
|
+
self.model_logger = model_logger
|
|
73
75
|
|
|
74
76
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
75
77
|
self.model.train()
|
|
76
78
|
total_loss = 0
|
|
79
|
+
epoch_loss = 0
|
|
80
|
+
batch_count = 0
|
|
77
81
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
78
82
|
for i, (x_dict, y) in enumerate(tk0):
|
|
79
83
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
@@ -93,27 +97,62 @@ class CTRTrainer(object):
|
|
|
93
97
|
loss.backward()
|
|
94
98
|
self.optimizer.step()
|
|
95
99
|
total_loss += loss.item()
|
|
100
|
+
epoch_loss += loss.item()
|
|
101
|
+
batch_count += 1
|
|
96
102
|
if (i + 1) % log_interval == 0:
|
|
97
103
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
98
104
|
total_loss = 0
|
|
99
105
|
|
|
106
|
+
# Return average epoch loss
|
|
107
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
108
|
+
|
|
100
109
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
110
|
+
for logger in self._iter_loggers():
|
|
111
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.loss_mode})
|
|
112
|
+
|
|
101
113
|
for epoch_i in range(self.n_epoch):
|
|
102
114
|
print('epoch:', epoch_i)
|
|
103
|
-
self.train_one_epoch(train_dataloader)
|
|
115
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
116
|
+
|
|
117
|
+
for logger in self._iter_loggers():
|
|
118
|
+
logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
|
|
119
|
+
|
|
104
120
|
if self.scheduler is not None:
|
|
105
121
|
if epoch_i % self.scheduler.step_size == 0:
|
|
106
122
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
107
123
|
self.scheduler.step() # update lr in epoch level by scheduler
|
|
124
|
+
|
|
108
125
|
if val_dataloader:
|
|
109
126
|
auc = self.evaluate(self.model, val_dataloader)
|
|
110
127
|
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
128
|
+
|
|
129
|
+
for logger in self._iter_loggers():
|
|
130
|
+
logger.log_metrics({'val/auc': auc}, step=epoch_i)
|
|
131
|
+
|
|
111
132
|
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
112
133
|
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
113
134
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
135
|
break
|
|
136
|
+
|
|
115
137
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
116
138
|
|
|
139
|
+
for logger in self._iter_loggers():
|
|
140
|
+
logger.finish()
|
|
141
|
+
|
|
142
|
+
def _iter_loggers(self):
|
|
143
|
+
"""Return logger instances as a list.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
list
|
|
148
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
149
|
+
"""
|
|
150
|
+
if self.model_logger is None:
|
|
151
|
+
return []
|
|
152
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
153
|
+
return list(self.model_logger)
|
|
154
|
+
return [self.model_logger]
|
|
155
|
+
|
|
117
156
|
def evaluate(self, model, data_loader):
|
|
118
157
|
model.eval()
|
|
119
158
|
targets, predicts = list(), list()
|
|
@@ -189,3 +228,100 @@ class CTRTrainer(object):
|
|
|
189
228
|
|
|
190
229
|
exporter = ONNXExporter(model, device=export_device)
|
|
191
230
|
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|
|
231
|
+
|
|
232
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
233
|
+
"""Visualize the model's computation graph.
|
|
234
|
+
|
|
235
|
+
This method generates a visual representation of the model architecture,
|
|
236
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
237
|
+
It automatically extracts feature information from the model.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
input_data : dict, optional
|
|
242
|
+
Example input dict {feature_name: tensor}.
|
|
243
|
+
If not provided, dummy inputs will be generated automatically.
|
|
244
|
+
batch_size : int, default=2
|
|
245
|
+
Batch size for auto-generated dummy input.
|
|
246
|
+
seq_length : int, default=10
|
|
247
|
+
Sequence length for SequenceFeature.
|
|
248
|
+
depth : int, default=3
|
|
249
|
+
Visualization depth, higher values show more detail.
|
|
250
|
+
Set to -1 to show all layers.
|
|
251
|
+
show_shapes : bool, default=True
|
|
252
|
+
Whether to display tensor shapes.
|
|
253
|
+
expand_nested : bool, default=True
|
|
254
|
+
Whether to expand nested modules.
|
|
255
|
+
save_path : str, optional
|
|
256
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
257
|
+
If None, displays in Jupyter or opens system viewer.
|
|
258
|
+
graph_name : str, default="model"
|
|
259
|
+
Name for the graph.
|
|
260
|
+
device : str, optional
|
|
261
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
262
|
+
dpi : int, default=300
|
|
263
|
+
Resolution in dots per inch for output image.
|
|
264
|
+
Higher values produce sharper images suitable for papers.
|
|
265
|
+
**kwargs : dict
|
|
266
|
+
Additional arguments passed to torchview.draw_graph().
|
|
267
|
+
|
|
268
|
+
Returns
|
|
269
|
+
-------
|
|
270
|
+
ComputationGraph
|
|
271
|
+
A torchview ComputationGraph object.
|
|
272
|
+
|
|
273
|
+
Raises
|
|
274
|
+
------
|
|
275
|
+
ImportError
|
|
276
|
+
If torchview or graphviz is not installed.
|
|
277
|
+
|
|
278
|
+
Notes
|
|
279
|
+
-----
|
|
280
|
+
Default Display Behavior:
|
|
281
|
+
When `save_path` is None (default):
|
|
282
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
283
|
+
- In Python script: opens the graph with system default viewer
|
|
284
|
+
|
|
285
|
+
Examples
|
|
286
|
+
--------
|
|
287
|
+
>>> trainer = CTRTrainer(model, ...)
|
|
288
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
289
|
+
>>>
|
|
290
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
291
|
+
>>> trainer.visualization(depth=4)
|
|
292
|
+
>>>
|
|
293
|
+
>>> # Save to high-DPI PNG for papers
|
|
294
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
295
|
+
"""
|
|
296
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
297
|
+
|
|
298
|
+
if not TORCHVIEW_AVAILABLE:
|
|
299
|
+
raise ImportError(
|
|
300
|
+
"Visualization requires torchview. "
|
|
301
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
302
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
303
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
304
|
+
" - macOS: brew install graphviz\n"
|
|
305
|
+
" - Windows: choco install graphviz"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Handle DataParallel wrapped model
|
|
309
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
310
|
+
|
|
311
|
+
# Use provided device or default to 'cpu'
|
|
312
|
+
viz_device = device if device is not None else 'cpu'
|
|
313
|
+
|
|
314
|
+
return visualize_model(
|
|
315
|
+
model,
|
|
316
|
+
input_data=input_data,
|
|
317
|
+
batch_size=batch_size,
|
|
318
|
+
seq_length=seq_length,
|
|
319
|
+
depth=depth,
|
|
320
|
+
show_shapes=show_shapes,
|
|
321
|
+
expand_nested=expand_nested,
|
|
322
|
+
save_path=save_path,
|
|
323
|
+
graph_name=graph_name,
|
|
324
|
+
device=viz_device,
|
|
325
|
+
dpi=dpi,
|
|
326
|
+
**kwargs
|
|
327
|
+
)
|