torch-rechub 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,6 +39,7 @@ class MatchTrainer(object):
39
39
  device="cpu",
40
40
  gpus=None,
41
41
  model_path="./",
42
+ model_logger=None,
42
43
  ):
43
44
  self.model = model # for uniform weights save method in one gpu or multi gpu
44
45
  if gpus is None:
@@ -73,10 +74,13 @@ class MatchTrainer(object):
73
74
  self.model_path = model_path
74
75
  # Initialize regularization loss
75
76
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
77
+ self.model_logger = model_logger
76
78
 
77
79
  def train_one_epoch(self, data_loader, log_interval=10):
78
80
  self.model.train()
79
81
  total_loss = 0
82
+ epoch_loss = 0
83
+ batch_count = 0
80
84
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
81
85
  for i, (x_dict, y) in enumerate(tk0):
82
86
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
@@ -114,14 +118,26 @@ class MatchTrainer(object):
114
118
  loss.backward()
115
119
  self.optimizer.step()
116
120
  total_loss += loss.item()
121
+ epoch_loss += loss.item()
122
+ batch_count += 1
117
123
  if (i + 1) % log_interval == 0:
118
124
  tk0.set_postfix(loss=total_loss / log_interval)
119
125
  total_loss = 0
120
126
 
127
+ # Return average epoch loss
128
+ return epoch_loss / batch_count if batch_count > 0 else 0
129
+
121
130
  def fit(self, train_dataloader, val_dataloader=None):
131
+ for logger in self._iter_loggers():
132
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.mode})
133
+
122
134
  for epoch_i in range(self.n_epoch):
123
135
  print('epoch:', epoch_i)
124
- self.train_one_epoch(train_dataloader)
136
+ train_loss = self.train_one_epoch(train_dataloader)
137
+
138
+ for logger in self._iter_loggers():
139
+ logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
140
+
125
141
  if self.scheduler is not None:
126
142
  if epoch_i % self.scheduler.step_size == 0:
127
143
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
@@ -130,12 +146,34 @@ class MatchTrainer(object):
130
146
  if val_dataloader:
131
147
  auc = self.evaluate(self.model, val_dataloader)
132
148
  print('epoch:', epoch_i, 'validation: auc:', auc)
149
+
150
+ for logger in self._iter_loggers():
151
+ logger.log_metrics({'val/auc': auc}, step=epoch_i)
152
+
133
153
  if self.early_stopper.stop_training(auc, self.model.state_dict()):
134
154
  print(f'validation: best auc: {self.early_stopper.best_auc}')
135
155
  self.model.load_state_dict(self.early_stopper.best_weights)
136
156
  break
157
+
137
158
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
138
159
 
160
+ for logger in self._iter_loggers():
161
+ logger.finish()
162
+
163
+ def _iter_loggers(self):
164
+ """Return logger instances as a list.
165
+
166
+ Returns
167
+ -------
168
+ list
169
+ Active logger instances. Empty when ``model_logger`` is ``None``.
170
+ """
171
+ if self.model_logger is None:
172
+ return []
173
+ if isinstance(self.model_logger, (list, tuple)):
174
+ return list(self.model_logger)
175
+ return [self.model_logger]
176
+
139
177
  def evaluate(self, model, data_loader):
140
178
  model.eval()
141
179
  targets, predicts = list(), list()
@@ -237,3 +275,100 @@ class MatchTrainer(object):
237
275
  # Restore original mode
238
276
  if hasattr(model, 'mode'):
239
277
  model.mode = original_mode
278
+
279
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
280
+ """Visualize the model's computation graph.
281
+
282
+ This method generates a visual representation of the model architecture,
283
+ showing layer connections, tensor shapes, and nested module structures.
284
+ It automatically extracts feature information from the model.
285
+
286
+ Parameters
287
+ ----------
288
+ input_data : dict, optional
289
+ Example input dict {feature_name: tensor}.
290
+ If not provided, dummy inputs will be generated automatically.
291
+ batch_size : int, default=2
292
+ Batch size for auto-generated dummy input.
293
+ seq_length : int, default=10
294
+ Sequence length for SequenceFeature.
295
+ depth : int, default=3
296
+ Visualization depth, higher values show more detail.
297
+ Set to -1 to show all layers.
298
+ show_shapes : bool, default=True
299
+ Whether to display tensor shapes.
300
+ expand_nested : bool, default=True
301
+ Whether to expand nested modules.
302
+ save_path : str, optional
303
+ Path to save the graph image (.pdf, .svg, .png).
304
+ If None, displays in Jupyter or opens system viewer.
305
+ graph_name : str, default="model"
306
+ Name for the graph.
307
+ device : str, optional
308
+ Device for model execution. If None, defaults to 'cpu'.
309
+ dpi : int, default=300
310
+ Resolution in dots per inch for output image.
311
+ Higher values produce sharper images suitable for papers.
312
+ **kwargs : dict
313
+ Additional arguments passed to torchview.draw_graph().
314
+
315
+ Returns
316
+ -------
317
+ ComputationGraph
318
+ A torchview ComputationGraph object.
319
+
320
+ Raises
321
+ ------
322
+ ImportError
323
+ If torchview or graphviz is not installed.
324
+
325
+ Notes
326
+ -----
327
+ Default Display Behavior:
328
+ When `save_path` is None (default):
329
+ - In Jupyter/IPython: automatically displays the graph inline
330
+ - In Python script: opens the graph with system default viewer
331
+
332
+ Examples
333
+ --------
334
+ >>> trainer = MatchTrainer(model, ...)
335
+ >>> trainer.fit(train_dl)
336
+ >>>
337
+ >>> # Auto-display in Jupyter (no save_path needed)
338
+ >>> trainer.visualization(depth=4)
339
+ >>>
340
+ >>> # Save to high-DPI PNG for papers
341
+ >>> trainer.visualization(save_path="model.png", dpi=300)
342
+ """
343
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
344
+
345
+ if not TORCHVIEW_AVAILABLE:
346
+ raise ImportError(
347
+ "Visualization requires torchview. "
348
+ "Install with: pip install torch-rechub[visualization]\n"
349
+ "Also ensure graphviz is installed on your system:\n"
350
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
351
+ " - macOS: brew install graphviz\n"
352
+ " - Windows: choco install graphviz"
353
+ )
354
+
355
+ # Handle DataParallel wrapped model
356
+ model = self.model.module if hasattr(self.model, 'module') else self.model
357
+
358
+ # Use provided device or default to 'cpu'
359
+ viz_device = device if device is not None else 'cpu'
360
+
361
+ return visualize_model(
362
+ model,
363
+ input_data=input_data,
364
+ batch_size=batch_size,
365
+ seq_length=seq_length,
366
+ depth=depth,
367
+ show_shapes=show_shapes,
368
+ expand_nested=expand_nested,
369
+ save_path=save_path,
370
+ graph_name=graph_name,
371
+ device=viz_device,
372
+ dpi=dpi,
373
+ **kwargs
374
+ )
@@ -47,6 +47,7 @@ class MTLTrainer(object):
47
47
  device="cpu",
48
48
  gpus=None,
49
49
  model_path="./",
50
+ model_logger=None,
50
51
  ):
51
52
  self.model = model
52
53
  if gpus is None:
@@ -104,6 +105,7 @@ class MTLTrainer(object):
104
105
  self.model_path = model_path
105
106
  # Initialize regularization loss
106
107
  self.reg_loss_fn = RegularizationLoss(**regularization_params)
108
+ self.model_logger = model_logger
107
109
 
108
110
  def train_one_epoch(self, data_loader):
109
111
  self.model.train()
@@ -163,21 +165,42 @@ class MTLTrainer(object):
163
165
  def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
164
166
  total_log = []
165
167
 
168
+ # Log hyperparameters once
169
+ for logger in self._iter_loggers():
170
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self._current_lr(), 'adaptive_method': self.adaptive_method})
171
+
166
172
  for epoch_i in range(self.n_epoch):
167
173
  _log_per_epoch = self.train_one_epoch(train_dataloader)
168
174
 
175
+ # Collect metrics
176
+ logs = {f'train/task_{task_id}_loss': loss_val for task_id, loss_val in enumerate(_log_per_epoch)}
177
+ lr_value = self._current_lr()
178
+ if lr_value is not None:
179
+ logs['learning_rate'] = lr_value
180
+
169
181
  if self.scheduler is not None:
170
182
  if epoch_i % self.scheduler.step_size == 0:
171
183
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
172
184
  self.scheduler.step() # update lr in epoch level by scheduler
185
+
173
186
  scores = self.evaluate(self.model, val_dataloader)
174
187
  print('epoch:', epoch_i, 'validation scores: ', scores)
175
188
 
176
- for score in scores:
189
+ for task_id, score in enumerate(scores):
190
+ logs[f'val/task_{task_id}_score'] = score
177
191
  _log_per_epoch.append(score)
192
+ logs['auc'] = scores[self.earlystop_taskid]
193
+
194
+ if self.loss_weight:
195
+ for task_id, weight in enumerate(self.loss_weight):
196
+ logs[f'loss_weight/task_{task_id}'] = weight.item()
178
197
 
179
198
  total_log.append(_log_per_epoch)
180
199
 
200
+ # Log metrics once per epoch
201
+ for logger in self._iter_loggers():
202
+ logger.log_metrics(logs, step=epoch_i)
203
+
181
204
  if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
182
205
  print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
183
206
  self.model.load_state_dict(self.early_stopper.best_weights)
@@ -185,8 +208,33 @@ class MTLTrainer(object):
185
208
 
186
209
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
187
210
 
211
+ for logger in self._iter_loggers():
212
+ logger.finish()
213
+
188
214
  return total_log
189
215
 
216
+ def _iter_loggers(self):
217
+ """Return logger instances as a list.
218
+
219
+ Returns
220
+ -------
221
+ list
222
+ Active logger instances. Empty when ``model_logger`` is ``None``.
223
+ """
224
+ if self.model_logger is None:
225
+ return []
226
+ if isinstance(self.model_logger, (list, tuple)):
227
+ return list(self.model_logger)
228
+ return [self.model_logger]
229
+
230
+ def _current_lr(self):
231
+ """Fetch current learning rate regardless of adaptive method."""
232
+ if self.adaptive_method == "metabalance":
233
+ return self.share_optimizer.param_groups[0]['lr'] if hasattr(self, 'share_optimizer') else None
234
+ if hasattr(self, 'optimizer'):
235
+ return self.optimizer.param_groups[0]['lr']
236
+ return None
237
+
190
238
  def evaluate(self, model, data_loader):
191
239
  model.eval()
192
240
  targets, predicts = list(), list()
@@ -257,3 +305,100 @@ class MTLTrainer(object):
257
305
 
258
306
  exporter = ONNXExporter(model, device=export_device)
259
307
  return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
308
+
309
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
310
+ """Visualize the model's computation graph.
311
+
312
+ This method generates a visual representation of the model architecture,
313
+ showing layer connections, tensor shapes, and nested module structures.
314
+ It automatically extracts feature information from the model.
315
+
316
+ Parameters
317
+ ----------
318
+ input_data : dict, optional
319
+ Example input dict {feature_name: tensor}.
320
+ If not provided, dummy inputs will be generated automatically.
321
+ batch_size : int, default=2
322
+ Batch size for auto-generated dummy input.
323
+ seq_length : int, default=10
324
+ Sequence length for SequenceFeature.
325
+ depth : int, default=3
326
+ Visualization depth, higher values show more detail.
327
+ Set to -1 to show all layers.
328
+ show_shapes : bool, default=True
329
+ Whether to display tensor shapes.
330
+ expand_nested : bool, default=True
331
+ Whether to expand nested modules.
332
+ save_path : str, optional
333
+ Path to save the graph image (.pdf, .svg, .png).
334
+ If None, displays in Jupyter or opens system viewer.
335
+ graph_name : str, default="model"
336
+ Name for the graph.
337
+ device : str, optional
338
+ Device for model execution. If None, defaults to 'cpu'.
339
+ dpi : int, default=300
340
+ Resolution in dots per inch for output image.
341
+ Higher values produce sharper images suitable for papers.
342
+ **kwargs : dict
343
+ Additional arguments passed to torchview.draw_graph().
344
+
345
+ Returns
346
+ -------
347
+ ComputationGraph
348
+ A torchview ComputationGraph object.
349
+
350
+ Raises
351
+ ------
352
+ ImportError
353
+ If torchview or graphviz is not installed.
354
+
355
+ Notes
356
+ -----
357
+ Default Display Behavior:
358
+ When `save_path` is None (default):
359
+ - In Jupyter/IPython: automatically displays the graph inline
360
+ - In Python script: opens the graph with system default viewer
361
+
362
+ Examples
363
+ --------
364
+ >>> trainer = MTLTrainer(model, task_types=["classification", "classification"])
365
+ >>> trainer.fit(train_dl, val_dl)
366
+ >>>
367
+ >>> # Auto-display in Jupyter (no save_path needed)
368
+ >>> trainer.visualization(depth=4)
369
+ >>>
370
+ >>> # Save to high-DPI PNG for papers
371
+ >>> trainer.visualization(save_path="model.png", dpi=300)
372
+ """
373
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
374
+
375
+ if not TORCHVIEW_AVAILABLE:
376
+ raise ImportError(
377
+ "Visualization requires torchview. "
378
+ "Install with: pip install torch-rechub[visualization]\n"
379
+ "Also ensure graphviz is installed on your system:\n"
380
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
381
+ " - macOS: brew install graphviz\n"
382
+ " - Windows: choco install graphviz"
383
+ )
384
+
385
+ # Handle DataParallel wrapped model
386
+ model = self.model.module if hasattr(self.model, 'module') else self.model
387
+
388
+ # Use provided device or default to 'cpu'
389
+ viz_device = device if device is not None else 'cpu'
390
+
391
+ return visualize_model(
392
+ model,
393
+ input_data=input_data,
394
+ batch_size=batch_size,
395
+ seq_length=seq_length,
396
+ depth=depth,
397
+ show_shapes=show_shapes,
398
+ expand_nested=expand_nested,
399
+ save_path=save_path,
400
+ graph_name=graph_name,
401
+ device=viz_device,
402
+ dpi=dpi,
403
+ **kwargs
404
+ )
@@ -46,7 +46,22 @@ class SeqTrainer(object):
46
46
  ... )
47
47
  """
48
48
 
49
- def __init__(self, model, optimizer_fn=torch.optim.Adam, optimizer_params=None, scheduler_fn=None, scheduler_params=None, n_epoch=10, earlystop_patience=10, device='cpu', gpus=None, model_path='./', loss_type='cross_entropy', loss_params=None):
49
+ def __init__(
50
+ self,
51
+ model,
52
+ optimizer_fn=torch.optim.Adam,
53
+ optimizer_params=None,
54
+ scheduler_fn=None,
55
+ scheduler_params=None,
56
+ n_epoch=10,
57
+ earlystop_patience=10,
58
+ device='cpu',
59
+ gpus=None,
60
+ model_path='./',
61
+ loss_type='cross_entropy',
62
+ loss_params=None,
63
+ model_logger=None
64
+ ):
50
65
  self.model = model # for uniform weights save method in one gpu or multi gpu
51
66
  if gpus is None:
52
67
  gpus = []
@@ -74,9 +89,11 @@ class SeqTrainer(object):
74
89
  loss_params = {"ignore_index": 0}
75
90
  self.loss_fn = nn.CrossEntropyLoss(**loss_params)
76
91
 
92
+ self.loss_type = loss_type
77
93
  self.n_epoch = n_epoch
78
94
  self.early_stopper = EarlyStopper(patience=earlystop_patience)
79
95
  self.model_path = model_path
96
+ self.model_logger = model_logger
80
97
 
81
98
  def fit(self, train_dataloader, val_dataloader=None):
82
99
  """训练模型.
@@ -90,10 +107,18 @@ class SeqTrainer(object):
90
107
  """
91
108
  history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
92
109
 
110
+ for logger in self._iter_loggers():
111
+ logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_type': self.loss_type})
112
+
93
113
  for epoch_i in range(self.n_epoch):
94
114
  print('epoch:', epoch_i)
95
115
  # 训练阶段
96
- self.train_one_epoch(train_dataloader)
116
+ train_loss = self.train_one_epoch(train_dataloader)
117
+ history['train_loss'].append(train_loss)
118
+
119
+ # Collect metrics
120
+ logs = {'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}
121
+
97
122
  if self.scheduler is not None:
98
123
  if epoch_i % self.scheduler.step_size == 0:
99
124
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
@@ -105,6 +130,10 @@ class SeqTrainer(object):
105
130
  history['val_loss'].append(val_loss)
106
131
  history['val_accuracy'].append(val_accuracy)
107
132
 
133
+ logs['val/loss'] = val_loss
134
+ logs['val/accuracy'] = val_accuracy
135
+ logs['auc'] = val_accuracy # For compatibility with EarlyStopper
136
+
108
137
  print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
109
138
 
110
139
  # 早停
@@ -113,9 +142,30 @@ class SeqTrainer(object):
113
142
  self.model.load_state_dict(self.early_stopper.best_weights)
114
143
  break
115
144
 
145
+ for logger in self._iter_loggers():
146
+ logger.log_metrics(logs, step=epoch_i)
147
+
116
148
  torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
149
+
150
+ for logger in self._iter_loggers():
151
+ logger.finish()
152
+
117
153
  return history
118
154
 
155
+ def _iter_loggers(self):
156
+ """Return logger instances as a list.
157
+
158
+ Returns
159
+ -------
160
+ list
161
+ Active logger instances. Empty when ``model_logger`` is ``None``.
162
+ """
163
+ if self.model_logger is None:
164
+ return []
165
+ if isinstance(self.model_logger, (list, tuple)):
166
+ return list(self.model_logger)
167
+ return [self.model_logger]
168
+
119
169
  def train_one_epoch(self, data_loader, log_interval=10):
120
170
  """Train the model for a single epoch.
121
171
 
@@ -128,6 +178,8 @@ class SeqTrainer(object):
128
178
  """
129
179
  self.model.train()
130
180
  total_loss = 0
181
+ epoch_loss = 0
182
+ batch_count = 0
131
183
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
132
184
  for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
133
185
  # Move tensors to the target device
@@ -152,10 +204,15 @@ class SeqTrainer(object):
152
204
  self.optimizer.step()
153
205
 
154
206
  total_loss += loss.item()
207
+ epoch_loss += loss.item()
208
+ batch_count += 1
155
209
  if (i + 1) % log_interval == 0:
156
210
  tk0.set_postfix(loss=total_loss / log_interval)
157
211
  total_loss = 0
158
212
 
213
+ # Return average epoch loss
214
+ return epoch_loss / batch_count if batch_count > 0 else 0
215
+
159
216
  def evaluate(self, data_loader):
160
217
  """Evaluate the model on a validation/test data loader.
161
218
 
@@ -291,3 +348,137 @@ class SeqTrainer(object):
291
348
  except Exception as e:
292
349
  warnings.warn(f"ONNX export failed: {str(e)}")
293
350
  raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
351
+
352
+ def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
353
+ """Visualize the model's computation graph.
354
+
355
+ This method generates a visual representation of the sequence model
356
+ architecture, showing layer connections, tensor shapes, and nested
357
+ module structures.
358
+
359
+ Parameters
360
+ ----------
361
+ seq_length : int, default=50
362
+ Sequence length for dummy input.
363
+ vocab_size : int, optional
364
+ Vocabulary size for generating dummy tokens.
365
+ If None, will try to get from model.vocab_size or model.item_num.
366
+ batch_size : int, default=2
367
+ Batch size for dummy input.
368
+ depth : int, default=3
369
+ Visualization depth, higher values show more detail.
370
+ Set to -1 to show all layers.
371
+ show_shapes : bool, default=True
372
+ Whether to display tensor shapes.
373
+ expand_nested : bool, default=True
374
+ Whether to expand nested modules.
375
+ save_path : str, optional
376
+ Path to save the graph image (.pdf, .svg, .png).
377
+ If None, displays in Jupyter or opens system viewer.
378
+ graph_name : str, default="model"
379
+ Name for the graph.
380
+ device : str, optional
381
+ Device for model execution. If None, defaults to 'cpu'.
382
+ dpi : int, default=300
383
+ Resolution in dots per inch for output image.
384
+ Higher values produce sharper images suitable for papers.
385
+ **kwargs : dict
386
+ Additional arguments passed to torchview.draw_graph().
387
+
388
+ Returns
389
+ -------
390
+ ComputationGraph
391
+ A torchview ComputationGraph object.
392
+
393
+ Raises
394
+ ------
395
+ ImportError
396
+ If torchview or graphviz is not installed.
397
+ ValueError
398
+ If vocab_size is not provided and cannot be inferred from model.
399
+
400
+ Notes
401
+ -----
402
+ Default Display Behavior:
403
+ When `save_path` is None (default):
404
+ - In Jupyter/IPython: automatically displays the graph inline
405
+ - In Python script: opens the graph with system default viewer
406
+
407
+ Examples
408
+ --------
409
+ >>> trainer = SeqTrainer(hstu_model, ...)
410
+ >>> trainer.fit(train_dl, val_dl)
411
+ >>>
412
+ >>> # Auto-display in Jupyter (no save_path needed)
413
+ >>> trainer.visualization(depth=4, vocab_size=10000)
414
+ >>>
415
+ >>> # Save to high-DPI PNG for papers
416
+ >>> trainer.visualization(save_path="model.png", dpi=300)
417
+ """
418
+ try:
419
+ from torchview import draw_graph
420
+ TORCHVIEW_AVAILABLE = True
421
+ except ImportError:
422
+ TORCHVIEW_AVAILABLE = False
423
+
424
+ if not TORCHVIEW_AVAILABLE:
425
+ raise ImportError(
426
+ "Visualization requires torchview. "
427
+ "Install with: pip install torch-rechub[visualization]\n"
428
+ "Also ensure graphviz is installed on your system:\n"
429
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
430
+ " - macOS: brew install graphviz\n"
431
+ " - Windows: choco install graphviz"
432
+ )
433
+
434
+ from ..utils.visualization import _is_jupyter_environment, display_graph
435
+
436
+ # Handle DataParallel wrapped model
437
+ model = self.model.module if hasattr(self.model, 'module') else self.model
438
+
439
+ # Use provided device or default to 'cpu'
440
+ viz_device = device if device is not None else 'cpu'
441
+
442
+ # Get vocab_size from model if not provided
443
+ if vocab_size is None:
444
+ if hasattr(model, 'vocab_size'):
445
+ vocab_size = model.vocab_size
446
+ elif hasattr(model, 'item_num'):
447
+ vocab_size = model.item_num
448
+ else:
449
+ raise ValueError("vocab_size must be provided or model must have "
450
+ "'vocab_size' or 'item_num' attribute")
451
+
452
+ # Generate dummy inputs for sequence model
453
+ dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)
454
+ dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=viz_device)
455
+
456
+ # Move model to device
457
+ model = model.to(viz_device)
458
+ model.eval()
459
+
460
+ # Call torchview.draw_graph
461
+ graph = draw_graph(model, input_data=(dummy_seq_tokens, dummy_seq_time_diffs), graph_name=graph_name, depth=depth, device=viz_device, expand_nested=expand_nested, show_shapes=show_shapes, save_graph=False, **kwargs)
462
+
463
+ # Set DPI for high-quality output
464
+ graph.visual_graph.graph_attr['dpi'] = str(dpi)
465
+
466
+ # Handle save_path: manually save with DPI applied
467
+ if save_path:
468
+ import os
469
+ directory = os.path.dirname(save_path) or "."
470
+ filename = os.path.splitext(os.path.basename(save_path))[0]
471
+ ext = os.path.splitext(save_path)[1].lstrip('.')
472
+ output_format = ext if ext else 'pdf'
473
+ if directory != "." and not os.path.exists(directory):
474
+ os.makedirs(directory, exist_ok=True)
475
+ graph.visual_graph.render(filename=filename, directory=directory, format=output_format, cleanup=True)
476
+
477
+ # Handle default display behavior when save_path is None
478
+ if save_path is None:
479
+ if _is_jupyter_environment():
480
+ display_graph(graph)
481
+ else:
482
+ graph.visual_graph.view(cleanup=True)
483
+
484
+ return graph