torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
"""Sequence Generation Model Trainer."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import tqdm
|
|
8
|
+
|
|
9
|
+
from ..basic.callback import EarlyStopper
|
|
10
|
+
from ..basic.loss_func import NCELoss
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SeqTrainer(object):
|
|
14
|
+
"""序列生成模型训练器.
|
|
15
|
+
|
|
16
|
+
用于训练HSTU等序列生成模型。
|
|
17
|
+
支持CrossEntropyLoss损失函数和生成式评估指标。
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model (nn.Module): 要训练的模型
|
|
21
|
+
optimizer_fn (torch.optim): 优化器函数,默认为torch.optim.Adam
|
|
22
|
+
optimizer_params (dict): 优化器参数
|
|
23
|
+
scheduler_fn (torch.optim.lr_scheduler): torch调度器类
|
|
24
|
+
scheduler_params (dict): 调度器参数
|
|
25
|
+
n_epoch (int): 训练轮数,默认10
|
|
26
|
+
earlystop_patience (int): 早停耐心值,默认10
|
|
27
|
+
device (str): 设备,'cpu'或'cuda',默认'cpu'
|
|
28
|
+
gpus (list): 多GPU的id列表,默认为[]
|
|
29
|
+
model_path (str): 模型保存路径,默认为'./'
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
fit: 训练模型
|
|
33
|
+
evaluate: 评估模型
|
|
34
|
+
predict: 生成预测
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> trainer = SeqTrainer(
|
|
38
|
+
... model=model,
|
|
39
|
+
... optimizer_fn=torch.optim.Adam,
|
|
40
|
+
... optimizer_params={'lr': 1e-3, 'weight_decay': 1e-5},
|
|
41
|
+
... device='cuda'
|
|
42
|
+
... )
|
|
43
|
+
>>> trainer.fit(
|
|
44
|
+
... train_loader=train_loader,
|
|
45
|
+
... val_loader=val_loader
|
|
46
|
+
... )
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
|
|
50
|
+
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
51
|
+
if gpus is None:
|
|
52
|
+
gpus = []
|
|
53
|
+
self.gpus = gpus
|
|
54
|
+
if len(gpus) > 1:
|
|
55
|
+
print('parallel running on these gpus:', gpus)
|
|
56
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
57
|
+
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
58
|
+
self.device = torch.device(device)
|
|
59
|
+
self.model.to(self.device)
|
|
60
|
+
if optimizer_params is None:
|
|
61
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
62
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
|
|
63
|
+
self.scheduler = None
|
|
64
|
+
if scheduler_fn is not None:
|
|
65
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
66
|
+
|
|
67
|
+
# 损失函数
|
|
68
|
+
if loss_type == 'nce':
|
|
69
|
+
if loss_params is None:
|
|
70
|
+
loss_params = {"temperature": 0.1, "ignore_index": 0}
|
|
71
|
+
self.loss_fn = NCELoss(**loss_params)
|
|
72
|
+
else: # default to cross_entropy
|
|
73
|
+
if loss_params is None:
|
|
74
|
+
loss_params = {"ignore_index": 0}
|
|
75
|
+
self.loss_fn = nn.CrossEntropyLoss(**loss_params)
|
|
76
|
+
|
|
77
|
+
self.n_epoch = n_epoch
|
|
78
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
79
|
+
self.model_path = model_path
|
|
80
|
+
|
|
81
|
+
def fit(self, train_dataloader, val_dataloader=None):
|
|
82
|
+
"""训练模型.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
train_dataloader (DataLoader): 训练数据加载器
|
|
86
|
+
val_dataloader (DataLoader): 验证数据加载器
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
dict: 训练历史
|
|
90
|
+
"""
|
|
91
|
+
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
|
|
92
|
+
|
|
93
|
+
for epoch_i in range(self.n_epoch):
|
|
94
|
+
print('epoch:', epoch_i)
|
|
95
|
+
# 训练阶段
|
|
96
|
+
self.train_one_epoch(train_dataloader)
|
|
97
|
+
if self.scheduler is not None:
|
|
98
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
99
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
100
|
+
self.scheduler.step() # update lr in epoch level by scheduler
|
|
101
|
+
|
|
102
|
+
# 验证阶段
|
|
103
|
+
if val_dataloader:
|
|
104
|
+
val_loss, val_accuracy = self.evaluate(val_dataloader)
|
|
105
|
+
history['val_loss'].append(val_loss)
|
|
106
|
+
history['val_accuracy'].append(val_accuracy)
|
|
107
|
+
|
|
108
|
+
print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
|
|
109
|
+
|
|
110
|
+
# 早停
|
|
111
|
+
if self.early_stopper.stop_training(val_accuracy, self.model.state_dict()):
|
|
112
|
+
print(f'validation: best accuracy: {self.early_stopper.best_auc}')
|
|
113
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
|
+
break
|
|
115
|
+
|
|
116
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
|
|
117
|
+
return history
|
|
118
|
+
|
|
119
|
+
def train_one_epoch(self, data_loader, log_interval=10):
|
|
120
|
+
"""Train the model for a single epoch.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data_loader (DataLoader): Training data loader.
|
|
124
|
+
log_interval (int): Interval (in steps) for logging average loss.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
float: Average training loss for this epoch.
|
|
128
|
+
"""
|
|
129
|
+
self.model.train()
|
|
130
|
+
total_loss = 0
|
|
131
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
132
|
+
for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
|
|
133
|
+
# Move tensors to the target device
|
|
134
|
+
seq_tokens = seq_tokens.to(self.device)
|
|
135
|
+
seq_positions = seq_positions.to(self.device)
|
|
136
|
+
seq_time_diffs = seq_time_diffs.to(self.device)
|
|
137
|
+
targets = targets.to(self.device).squeeze(-1)
|
|
138
|
+
|
|
139
|
+
# Forward pass
|
|
140
|
+
logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
|
|
141
|
+
|
|
142
|
+
# Compute loss
|
|
143
|
+
# For next-item prediction we only use the last position in the sequence
|
|
144
|
+
# logits[:, -1, :] selects the prediction at the last step for each sequence
|
|
145
|
+
last_logits = logits[:, -1, :] # (B, V)
|
|
146
|
+
|
|
147
|
+
loss = self.loss_fn(last_logits, targets)
|
|
148
|
+
|
|
149
|
+
# 反向传播
|
|
150
|
+
self.model.zero_grad()
|
|
151
|
+
loss.backward()
|
|
152
|
+
self.optimizer.step()
|
|
153
|
+
|
|
154
|
+
total_loss += loss.item()
|
|
155
|
+
if (i + 1) % log_interval == 0:
|
|
156
|
+
tk0.set_postfix(loss=total_loss / log_interval)
|
|
157
|
+
total_loss = 0
|
|
158
|
+
|
|
159
|
+
def evaluate(self, data_loader):
|
|
160
|
+
"""Evaluate the model on a validation/test data loader.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data_loader (DataLoader): Validation or test data loader.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
tuple: ``(avg_loss, top1_accuracy)``.
|
|
167
|
+
"""
|
|
168
|
+
self.model.eval()
|
|
169
|
+
total_loss = 0.0
|
|
170
|
+
total_correct = 0
|
|
171
|
+
total_samples = 0
|
|
172
|
+
|
|
173
|
+
with torch.no_grad():
|
|
174
|
+
for seq_tokens, seq_positions, seq_time_diffs, targets in tqdm.tqdm(data_loader, desc="evaluating", smoothing=0, mininterval=1.0):
|
|
175
|
+
# Move tensors to the target device
|
|
176
|
+
seq_tokens = seq_tokens.to(self.device)
|
|
177
|
+
seq_positions = seq_positions.to(self.device)
|
|
178
|
+
seq_time_diffs = seq_time_diffs.to(self.device)
|
|
179
|
+
targets = targets.to(self.device).squeeze(-1)
|
|
180
|
+
|
|
181
|
+
# Forward pass
|
|
182
|
+
logits = self.model(seq_tokens, seq_time_diffs) # (B, L, V)
|
|
183
|
+
|
|
184
|
+
# Compute loss using only the last position (next-item prediction)
|
|
185
|
+
last_logits = logits[:, -1, :] # (B, V)
|
|
186
|
+
|
|
187
|
+
loss = self.loss_fn(last_logits, targets)
|
|
188
|
+
total_loss += loss.item()
|
|
189
|
+
|
|
190
|
+
# Compute top-1 accuracy
|
|
191
|
+
predictions = torch.argmax(last_logits, dim=-1) # (B,)
|
|
192
|
+
correct = (predictions == targets).sum().item()
|
|
193
|
+
total_correct += correct
|
|
194
|
+
total_samples += targets.numel()
|
|
195
|
+
|
|
196
|
+
avg_loss = total_loss / len(data_loader)
|
|
197
|
+
accuracy = total_correct / total_samples
|
|
198
|
+
|
|
199
|
+
return avg_loss, accuracy
|
|
200
|
+
|
|
201
|
+
def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
202
|
+
"""Export the trained sequence generation model to ONNX format.
|
|
203
|
+
|
|
204
|
+
This method exports sequence generation models (e.g., HSTU) to ONNX format.
|
|
205
|
+
Unlike other trainers, sequence models use positional arguments (seq_tokens, seq_time_diffs)
|
|
206
|
+
instead of dict input, making ONNX export more straightforward.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
output_path (str): Path to save the ONNX model file.
|
|
210
|
+
batch_size (int): Batch size for dummy input (default: 2).
|
|
211
|
+
seq_length (int): Sequence length for dummy input (default: 50).
|
|
212
|
+
vocab_size (int, optional): Vocabulary size for generating dummy tokens.
|
|
213
|
+
If None, will try to get from model.vocab_size.
|
|
214
|
+
opset_version (int): ONNX opset version (default: 14).
|
|
215
|
+
dynamic_batch (bool): Enable dynamic batch size (default: True).
|
|
216
|
+
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
217
|
+
If None, defaults to 'cpu' for maximum compatibility.
|
|
218
|
+
verbose (bool): Print export details (default: False).
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
bool: True if export succeeded, False otherwise.
|
|
222
|
+
|
|
223
|
+
Example:
|
|
224
|
+
>>> trainer = SeqTrainer(hstu_model, ...)
|
|
225
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
226
|
+
>>> trainer.export_onnx("hstu.onnx", vocab_size=10000)
|
|
227
|
+
|
|
228
|
+
>>> # Export on specific device
|
|
229
|
+
>>> trainer.export_onnx("hstu.onnx", vocab_size=10000, device="cpu")
|
|
230
|
+
"""
|
|
231
|
+
import warnings
|
|
232
|
+
|
|
233
|
+
# Use provided device or default to 'cpu'
|
|
234
|
+
export_device = device if device is not None else 'cpu'
|
|
235
|
+
|
|
236
|
+
# Handle DataParallel wrapped model
|
|
237
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
238
|
+
model.eval()
|
|
239
|
+
model.to(export_device)
|
|
240
|
+
|
|
241
|
+
# Get vocab_size from model if not provided
|
|
242
|
+
if vocab_size is None:
|
|
243
|
+
if hasattr(model, 'vocab_size'):
|
|
244
|
+
vocab_size = model.vocab_size
|
|
245
|
+
elif hasattr(model, 'item_num'):
|
|
246
|
+
vocab_size = model.item_num
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError("vocab_size must be provided or model must have 'vocab_size' or 'item_num' attribute")
|
|
249
|
+
|
|
250
|
+
# Generate dummy inputs on the export device
|
|
251
|
+
dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=export_device)
|
|
252
|
+
dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=export_device)
|
|
253
|
+
|
|
254
|
+
# Configure dynamic axes
|
|
255
|
+
dynamic_axes = None
|
|
256
|
+
if dynamic_batch:
|
|
257
|
+
dynamic_axes = {"seq_tokens": {0: "batch_size", 1: "seq_length"}, "seq_time_diffs": {0: "batch_size", 1: "seq_length"}, "output": {0: "batch_size", 1: "seq_length"}}
|
|
258
|
+
|
|
259
|
+
# Ensure output directory exists
|
|
260
|
+
import os
|
|
261
|
+
output_dir = os.path.dirname(output_path)
|
|
262
|
+
if output_dir and not os.path.exists(output_dir):
|
|
263
|
+
os.makedirs(output_dir)
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
with torch.no_grad():
|
|
267
|
+
torch.onnx.export(
|
|
268
|
+
model,
|
|
269
|
+
(dummy_seq_tokens,
|
|
270
|
+
dummy_seq_time_diffs),
|
|
271
|
+
output_path,
|
|
272
|
+
input_names=["seq_tokens",
|
|
273
|
+
"seq_time_diffs"],
|
|
274
|
+
output_names=["output"],
|
|
275
|
+
dynamic_axes=dynamic_axes,
|
|
276
|
+
opset_version=opset_version,
|
|
277
|
+
do_constant_folding=True,
|
|
278
|
+
verbose=verbose,
|
|
279
|
+
dynamo=False # Use legacy exporter for dynamic_axes support
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if verbose:
|
|
283
|
+
print(f"Successfully exported ONNX model to: {output_path}")
|
|
284
|
+
print(" Input names: ['seq_tokens', 'seq_time_diffs']")
|
|
285
|
+
print(f" Vocab size: {vocab_size}")
|
|
286
|
+
print(f" Opset version: {opset_version}")
|
|
287
|
+
print(f" Dynamic batch: {dynamic_batch}")
|
|
288
|
+
|
|
289
|
+
return True
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
warnings.warn(f"ONNX export failed: {str(e)}")
|
|
293
|
+
raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
|
|
294
|
+
|
|
295
|
+
def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
296
|
+
"""Visualize the model's computation graph.
|
|
297
|
+
|
|
298
|
+
This method generates a visual representation of the sequence model
|
|
299
|
+
architecture, showing layer connections, tensor shapes, and nested
|
|
300
|
+
module structures.
|
|
301
|
+
|
|
302
|
+
Parameters
|
|
303
|
+
----------
|
|
304
|
+
seq_length : int, default=50
|
|
305
|
+
Sequence length for dummy input.
|
|
306
|
+
vocab_size : int, optional
|
|
307
|
+
Vocabulary size for generating dummy tokens.
|
|
308
|
+
If None, will try to get from model.vocab_size or model.item_num.
|
|
309
|
+
batch_size : int, default=2
|
|
310
|
+
Batch size for dummy input.
|
|
311
|
+
depth : int, default=3
|
|
312
|
+
Visualization depth, higher values show more detail.
|
|
313
|
+
Set to -1 to show all layers.
|
|
314
|
+
show_shapes : bool, default=True
|
|
315
|
+
Whether to display tensor shapes.
|
|
316
|
+
expand_nested : bool, default=True
|
|
317
|
+
Whether to expand nested modules.
|
|
318
|
+
save_path : str, optional
|
|
319
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
320
|
+
If None, displays in Jupyter or opens system viewer.
|
|
321
|
+
graph_name : str, default="model"
|
|
322
|
+
Name for the graph.
|
|
323
|
+
device : str, optional
|
|
324
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
325
|
+
dpi : int, default=300
|
|
326
|
+
Resolution in dots per inch for output image.
|
|
327
|
+
Higher values produce sharper images suitable for papers.
|
|
328
|
+
**kwargs : dict
|
|
329
|
+
Additional arguments passed to torchview.draw_graph().
|
|
330
|
+
|
|
331
|
+
Returns
|
|
332
|
+
-------
|
|
333
|
+
ComputationGraph
|
|
334
|
+
A torchview ComputationGraph object.
|
|
335
|
+
|
|
336
|
+
Raises
|
|
337
|
+
------
|
|
338
|
+
ImportError
|
|
339
|
+
If torchview or graphviz is not installed.
|
|
340
|
+
ValueError
|
|
341
|
+
If vocab_size is not provided and cannot be inferred from model.
|
|
342
|
+
|
|
343
|
+
Notes
|
|
344
|
+
-----
|
|
345
|
+
Default Display Behavior:
|
|
346
|
+
When `save_path` is None (default):
|
|
347
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
348
|
+
- In Python script: opens the graph with system default viewer
|
|
349
|
+
|
|
350
|
+
Examples
|
|
351
|
+
--------
|
|
352
|
+
>>> trainer = SeqTrainer(hstu_model, ...)
|
|
353
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
354
|
+
>>>
|
|
355
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
356
|
+
>>> trainer.visualization(depth=4, vocab_size=10000)
|
|
357
|
+
>>>
|
|
358
|
+
>>> # Save to high-DPI PNG for papers
|
|
359
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
360
|
+
"""
|
|
361
|
+
try:
|
|
362
|
+
from torchview import draw_graph
|
|
363
|
+
TORCHVIEW_AVAILABLE = True
|
|
364
|
+
except ImportError:
|
|
365
|
+
TORCHVIEW_AVAILABLE = False
|
|
366
|
+
|
|
367
|
+
if not TORCHVIEW_AVAILABLE:
|
|
368
|
+
raise ImportError(
|
|
369
|
+
"Visualization requires torchview. "
|
|
370
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
371
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
372
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
373
|
+
" - macOS: brew install graphviz\n"
|
|
374
|
+
" - Windows: choco install graphviz"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
from ..utils.visualization import _is_jupyter_environment, display_graph
|
|
378
|
+
|
|
379
|
+
# Handle DataParallel wrapped model
|
|
380
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
381
|
+
|
|
382
|
+
# Use provided device or default to 'cpu'
|
|
383
|
+
viz_device = device if device is not None else 'cpu'
|
|
384
|
+
|
|
385
|
+
# Get vocab_size from model if not provided
|
|
386
|
+
if vocab_size is None:
|
|
387
|
+
if hasattr(model, 'vocab_size'):
|
|
388
|
+
vocab_size = model.vocab_size
|
|
389
|
+
elif hasattr(model, 'item_num'):
|
|
390
|
+
vocab_size = model.item_num
|
|
391
|
+
else:
|
|
392
|
+
raise ValueError("vocab_size must be provided or model must have "
|
|
393
|
+
"'vocab_size' or 'item_num' attribute")
|
|
394
|
+
|
|
395
|
+
# Generate dummy inputs for sequence model
|
|
396
|
+
dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)
|
|
397
|
+
dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=viz_device)
|
|
398
|
+
|
|
399
|
+
# Move model to device
|
|
400
|
+
model = model.to(viz_device)
|
|
401
|
+
model.eval()
|
|
402
|
+
|
|
403
|
+
# Call torchview.draw_graph
|
|
404
|
+
graph = draw_graph(model, input_data=(dummy_seq_tokens, dummy_seq_time_diffs), graph_name=graph_name, depth=depth, device=viz_device, expand_nested=expand_nested, show_shapes=show_shapes, save_graph=False, **kwargs)
|
|
405
|
+
|
|
406
|
+
# Set DPI for high-quality output
|
|
407
|
+
graph.visual_graph.graph_attr['dpi'] = str(dpi)
|
|
408
|
+
|
|
409
|
+
# Handle save_path: manually save with DPI applied
|
|
410
|
+
if save_path:
|
|
411
|
+
import os
|
|
412
|
+
directory = os.path.dirname(save_path) or "."
|
|
413
|
+
filename = os.path.splitext(os.path.basename(save_path))[0]
|
|
414
|
+
ext = os.path.splitext(save_path)[1].lstrip('.')
|
|
415
|
+
output_format = ext if ext else 'pdf'
|
|
416
|
+
if directory != "." and not os.path.exists(directory):
|
|
417
|
+
os.makedirs(directory, exist_ok=True)
|
|
418
|
+
graph.visual_graph.render(filename=filename, directory=directory, format=output_format, cleanup=True)
|
|
419
|
+
|
|
420
|
+
# Handle default display behavior when save_path is None
|
|
421
|
+
if save_path is None:
|
|
422
|
+
if _is_jupyter_environment():
|
|
423
|
+
display_graph(graph)
|
|
424
|
+
else:
|
|
425
|
+
graph.visual_graph.view(cleanup=True)
|
|
426
|
+
|
|
427
|
+
return graph
|