torch-rechub 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,233 @@
1
+ """Common model utility functions for Torch-RecHub.
2
+
3
+ This module provides shared utilities for model introspection and input generation,
4
+ used by both ONNX export and visualization features.
5
+
6
+ Examples
7
+ --------
8
+ >>> from torch_rechub.utils.model_utils import extract_feature_info, generate_dummy_input
9
+ >>> feature_info = extract_feature_info(model)
10
+ >>> dummy_input = generate_dummy_input(feature_info['features'], batch_size=2)
11
+ """
12
+
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ # Import feature types for type checking
19
+ try:
20
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
21
+ except ImportError:
22
+ # Fallback for standalone usage
23
+ SparseFeature = None
24
+ DenseFeature = None
25
+ SequenceFeature = None
26
+
27
+
28
+ def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
29
+ """Extract feature information from a torch-rechub model using reflection.
30
+
31
+ This function inspects model attributes to find feature lists without
32
+ modifying the model code. Supports various model architectures.
33
+
34
+ Parameters
35
+ ----------
36
+ model : nn.Module
37
+ The recommendation model to inspect.
38
+
39
+ Returns
40
+ -------
41
+ dict
42
+ Dictionary containing:
43
+ - 'features': List of unique Feature objects
44
+ - 'input_names': List of feature names in order
45
+ - 'input_types': Dict mapping feature name to feature type
46
+ - 'user_features': List of user-side features (for dual-tower models)
47
+ - 'item_features': List of item-side features (for dual-tower models)
48
+
49
+ Examples
50
+ --------
51
+ >>> from torch_rechub.models.ranking import DeepFM
52
+ >>> model = DeepFM(deep_features, fm_features, mlp_params)
53
+ >>> info = extract_feature_info(model)
54
+ >>> print(info['input_names']) # ['user_id', 'item_id', ...]
55
+ """
56
+ # Common feature attribute names across different model types
57
+ feature_attrs = [
58
+ 'features', # MMOE, DCN, etc.
59
+ 'deep_features', # DeepFM, WideDeep
60
+ 'fm_features', # DeepFM
61
+ 'wide_features', # WideDeep
62
+ 'linear_features', # DeepFFM
63
+ 'cross_features', # DeepFFM
64
+ 'user_features', # DSSM, YoutubeDNN, MIND
65
+ 'item_features', # DSSM, YoutubeDNN, MIND
66
+ 'history_features', # DIN, MIND
67
+ 'target_features', # DIN
68
+ 'neg_item_feature', # YoutubeDNN, MIND
69
+ ]
70
+
71
+ all_features = []
72
+ user_features = []
73
+ item_features = []
74
+
75
+ for attr in feature_attrs:
76
+ if hasattr(model, attr):
77
+ feat_list = getattr(model, attr)
78
+ if isinstance(feat_list, list) and len(feat_list) > 0:
79
+ all_features.extend(feat_list)
80
+ # Track user/item features for dual-tower models
81
+ if 'user' in attr or 'history' in attr:
82
+ user_features.extend(feat_list)
83
+ elif 'item' in attr:
84
+ item_features.extend(feat_list)
85
+
86
+ # Deduplicate features by name while preserving order
87
+ seen = set()
88
+ unique_features = []
89
+ for f in all_features:
90
+ if hasattr(f, 'name') and f.name not in seen:
91
+ seen.add(f.name)
92
+ unique_features.append(f)
93
+
94
+ # Deduplicate user/item features
95
+ seen_user = set()
96
+ unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
97
+ seen_item = set()
98
+ unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
99
+
100
+ # Build input names and types
101
+ input_names = [f.name for f in unique_features if hasattr(f, 'name')]
102
+ input_types = {f.name: type(f).__name__ for f in unique_features if hasattr(f, 'name')}
103
+
104
+ return {
105
+ 'features': unique_features,
106
+ 'input_names': input_names,
107
+ 'input_types': input_types,
108
+ 'user_features': unique_user,
109
+ 'item_features': unique_item,
110
+ }
111
+
112
+
113
+ def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
114
+ """Generate dummy input tensors based on feature definitions.
115
+
116
+ Parameters
117
+ ----------
118
+ features : list
119
+ List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
120
+ batch_size : int, default=2
121
+ Batch size for dummy input.
122
+ seq_length : int, default=10
123
+ Sequence length for SequenceFeature.
124
+ device : str, default='cpu'
125
+ Device to create tensors on.
126
+
127
+ Returns
128
+ -------
129
+ tuple of Tensor
130
+ Tuple of tensors in the order of input features.
131
+
132
+ Examples
133
+ --------
134
+ >>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
135
+ >>> dummy = generate_dummy_input(features, batch_size=4)
136
+ >>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
137
+ """
138
+ # Dynamic import to handle feature types
139
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
140
+
141
+ inputs = []
142
+ for feat in features:
143
+ if isinstance(feat, SequenceFeature):
144
+ # Sequence features have shape [batch_size, seq_length]
145
+ tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
146
+ elif isinstance(feat, SparseFeature):
147
+ # Sparse features have shape [batch_size]
148
+ tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
149
+ elif isinstance(feat, DenseFeature):
150
+ # Dense features always have shape [batch_size, embed_dim]
151
+ tensor = torch.randn(batch_size, feat.embed_dim, device=device)
152
+ else:
153
+ raise TypeError(f"Unsupported feature type: {type(feat)}")
154
+ inputs.append(tensor)
155
+ return tuple(inputs)
156
+
157
+
158
+ def generate_dummy_input_dict(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Dict[str, torch.Tensor]:
159
+ """Generate dummy input dict based on feature definitions.
160
+
161
+ Similar to generate_dummy_input but returns a dict mapping feature names
162
+ to tensors. This is the expected input format for torch-rechub models.
163
+
164
+ Parameters
165
+ ----------
166
+ features : list
167
+ List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
168
+ batch_size : int, default=2
169
+ Batch size for dummy input.
170
+ seq_length : int, default=10
171
+ Sequence length for SequenceFeature.
172
+ device : str, default='cpu'
173
+ Device to create tensors on.
174
+
175
+ Returns
176
+ -------
177
+ dict
178
+ Dict mapping feature names to tensors.
179
+
180
+ Examples
181
+ --------
182
+ >>> features = [SparseFeature("user_id", 1000)]
183
+ >>> dummy = generate_dummy_input_dict(features, batch_size=4)
184
+ >>> # Returns {"user_id": tensor[4]}
185
+ """
186
+ dummy_tuple = generate_dummy_input(features, batch_size, seq_length, device)
187
+ input_names = [f.name for f in features if hasattr(f, 'name')]
188
+ return {name: tensor for name, tensor in zip(input_names, dummy_tuple)}
189
+
190
+
191
+ def generate_dynamic_axes(input_names: List[str], output_names: Optional[List[str]] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: Optional[List[str]] = None) -> Dict[str, Dict[int, str]]:
192
+ """Generate dynamic axes configuration for ONNX export.
193
+
194
+ Parameters
195
+ ----------
196
+ input_names : list of str
197
+ List of input tensor names.
198
+ output_names : list of str, optional
199
+ List of output tensor names. Default is ["output"].
200
+ batch_dim : int, default=0
201
+ Dimension index for batch size.
202
+ include_seq_dim : bool, default=True
203
+ Whether to include sequence dimension as dynamic.
204
+ seq_features : list of str, optional
205
+ List of feature names that are sequences.
206
+
207
+ Returns
208
+ -------
209
+ dict
210
+ Dynamic axes dict for torch.onnx.export.
211
+
212
+ Examples
213
+ --------
214
+ >>> axes = generate_dynamic_axes(["user_id", "item_id"], seq_features=["hist"])
215
+ >>> # Returns {"user_id": {0: "batch_size"}, "item_id": {0: "batch_size"}, ...}
216
+ """
217
+ if output_names is None:
218
+ output_names = ["output"]
219
+
220
+ dynamic_axes = {}
221
+
222
+ # Input axes
223
+ for name in input_names:
224
+ dynamic_axes[name] = {batch_dim: "batch_size"}
225
+ # Add sequence dimension for sequence features
226
+ if include_seq_dim and seq_features and name in seq_features:
227
+ dynamic_axes[name][1] = "seq_length"
228
+
229
+ # Output axes
230
+ for name in output_names:
231
+ dynamic_axes[name] = {batch_dim: "batch_size"}
232
+
233
+ return dynamic_axes
@@ -62,142 +62,9 @@ class ONNXWrapper(nn.Module):
62
62
  self.model.mode = self._original_mode
63
63
 
64
64
 
65
- def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
66
- """Extract feature information from a model using reflection.
67
-
68
- This function inspects model attributes to find feature lists without
69
- modifying the model code. Supports various model architectures.
70
-
71
- Args:
72
- model: The recommendation model to inspect.
73
-
74
- Returns:
75
- Dict containing:
76
- - 'features': List of unique Feature objects
77
- - 'input_names': List of feature names in order
78
- - 'input_types': Dict mapping feature name to feature type
79
- - 'user_features': List of user-side features (for dual-tower models)
80
- - 'item_features': List of item-side features (for dual-tower models)
81
- """
82
- # Common feature attribute names across different model types
83
- feature_attrs = [
84
- 'features', # MMOE, DCN, etc.
85
- 'deep_features', # DeepFM, WideDeep
86
- 'fm_features', # DeepFM
87
- 'wide_features', # WideDeep
88
- 'linear_features', # DeepFFM
89
- 'cross_features', # DeepFFM
90
- 'user_features', # DSSM, YoutubeDNN, MIND
91
- 'item_features', # DSSM, YoutubeDNN, MIND
92
- 'history_features', # DIN, MIND
93
- 'target_features', # DIN
94
- 'neg_item_feature', # YoutubeDNN, MIND
95
- ]
96
-
97
- all_features = []
98
- user_features = []
99
- item_features = []
100
-
101
- for attr in feature_attrs:
102
- if hasattr(model, attr):
103
- feat_list = getattr(model, attr)
104
- if isinstance(feat_list, list) and len(feat_list) > 0:
105
- all_features.extend(feat_list)
106
- # Track user/item features for dual-tower models
107
- if 'user' in attr or 'history' in attr:
108
- user_features.extend(feat_list)
109
- elif 'item' in attr:
110
- item_features.extend(feat_list)
111
-
112
- # Deduplicate features by name while preserving order
113
- seen = set()
114
- unique_features = []
115
- for f in all_features:
116
- if hasattr(f, 'name') and f.name not in seen:
117
- seen.add(f.name)
118
- unique_features.append(f)
119
-
120
- # Deduplicate user/item features
121
- seen_user = set()
122
- unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
123
- seen_item = set()
124
- unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
125
-
126
- return {
127
- 'features': unique_features,
128
- 'input_names': [f.name for f in unique_features],
129
- 'input_types': {
130
- f.name: type(f).__name__ for f in unique_features
131
- },
132
- 'user_features': unique_user,
133
- 'item_features': unique_item,
134
- }
135
-
136
-
137
- def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
138
- """Generate dummy input tensors for ONNX export based on feature definitions.
139
-
140
- Args:
141
- features: List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
142
- batch_size: Batch size for dummy input (default: 2).
143
- seq_length: Sequence length for SequenceFeature (default: 10).
144
- device: Device to create tensors on (default: 'cpu').
145
-
146
- Returns:
147
- Tuple of tensors in the order of input features.
148
-
149
- Example:
150
- >>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
151
- >>> dummy = generate_dummy_input(features, batch_size=4)
152
- >>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
153
- """
154
- inputs = []
155
- for feat in features:
156
- if isinstance(feat, SequenceFeature):
157
- # Sequence features have shape [batch_size, seq_length]
158
- tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
159
- elif isinstance(feat, SparseFeature):
160
- # Sparse features have shape [batch_size]
161
- tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
162
- elif isinstance(feat, DenseFeature):
163
- # Dense features have shape [batch_size, embed_dim]
164
- tensor = torch.randn(batch_size, feat.embed_dim, device=device)
165
- else:
166
- raise TypeError(f"Unsupported feature type: {type(feat)}")
167
- inputs.append(tensor)
168
- return tuple(inputs)
169
-
170
-
171
- def generate_dynamic_axes(input_names: List[str], output_names: List[str] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: List[str] = None) -> Dict[str, Dict[int, str]]:
172
- """Generate dynamic axes configuration for ONNX export.
173
-
174
- Args:
175
- input_names: List of input tensor names.
176
- output_names: List of output tensor names (default: ["output"]).
177
- batch_dim: Dimension index for batch size (default: 0).
178
- include_seq_dim: Whether to include sequence dimension as dynamic (default: True).
179
- seq_features: List of feature names that are sequences (default: auto-detect).
180
-
181
- Returns:
182
- Dynamic axes dict for torch.onnx.export.
183
- """
184
- if output_names is None:
185
- output_names = ["output"]
186
-
187
- dynamic_axes = {}
188
-
189
- # Input axes
190
- for name in input_names:
191
- dynamic_axes[name] = {batch_dim: "batch_size"}
192
- # Add sequence dimension for sequence features
193
- if include_seq_dim and seq_features and name in seq_features:
194
- dynamic_axes[name][1] = "seq_length"
195
-
196
- # Output axes
197
- for name in output_names:
198
- dynamic_axes[name] = {batch_dim: "batch_size"}
199
-
200
- return dynamic_axes
65
+ # Re-export from model_utils for backward compatibility
66
+ # The actual implementations are now in model_utils.py
67
+ from .model_utils import extract_feature_info, generate_dummy_input, generate_dummy_input_dict, generate_dynamic_axes
201
68
 
202
69
 
203
70
  class ONNXExporter:
@@ -0,0 +1,271 @@
1
+ """
2
+ Model Visualization Utilities for Torch-RecHub.
3
+
4
+ This module provides model structure visualization using torchview library.
5
+ Requires optional dependencies: pip install torch-rechub[visualization]
6
+
7
+ Example:
8
+ >>> from torch_rechub.utils.visualization import visualize_model, display_graph
9
+ >>> graph = visualize_model(model, depth=4)
10
+ >>> display_graph(graph) # Display in Jupyter Notebook
11
+
12
+ >>> # Save to file
13
+ >>> visualize_model(model, save_path="model_arch.pdf")
14
+ """
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ # Check for optional dependencies
22
+ TORCHVIEW_AVAILABLE = False
23
+ TORCHVIEW_SKIP_REASON = "torchview not installed"
24
+
25
+ try:
26
+ from torchview import draw_graph
27
+ TORCHVIEW_AVAILABLE = True
28
+ except ImportError as e:
29
+ TORCHVIEW_SKIP_REASON = f"torchview not available: {e}"
30
+
31
+
32
+ def _is_jupyter_environment() -> bool:
33
+ """Check if running in Jupyter/IPython environment."""
34
+ try:
35
+ from IPython import get_ipython
36
+ shell = get_ipython()
37
+ if shell is None:
38
+ return False
39
+ # Check for Jupyter notebook or qtconsole
40
+ shell_class = shell.__class__.__name__
41
+ return shell_class in ('ZMQInteractiveShell', 'TerminalInteractiveShell')
42
+ except (ImportError, NameError):
43
+ return False
44
+
45
+
46
+ def display_graph(graph: Any, format: str = 'png') -> Any:
47
+ """Display a torchview ComputationGraph in Jupyter Notebook.
48
+
49
+ This function provides a reliable way to display visualization graphs
50
+ in Jupyter environments, especially VSCode Jupyter.
51
+
52
+ Parameters
53
+ ----------
54
+ graph : ComputationGraph
55
+ A torchview ComputationGraph object returned by visualize_model().
56
+ format : str, default='png'
57
+ Output format, 'png' recommended for VSCode.
58
+
59
+ Returns
60
+ -------
61
+ graphviz.Digraph or None
62
+ The displayed graph object, or None if display fails.
63
+
64
+ Examples
65
+ --------
66
+ >>> graph = visualize_model(model, depth=4)
67
+ >>> display_graph(graph) # Works in VSCode Jupyter
68
+ """
69
+ if not TORCHVIEW_AVAILABLE:
70
+ raise ImportError(f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
71
+ "Install with: pip install torch-rechub[visualization]")
72
+
73
+ try:
74
+ import graphviz
75
+
76
+ # Set format for VSCode compatibility
77
+ graphviz.set_jupyter_format(format)
78
+ except ImportError:
79
+ pass
80
+
81
+ # Get the visual_graph (graphviz.Digraph object)
82
+ visual = graph.visual_graph
83
+
84
+ # Try to use IPython display for explicit rendering
85
+ try:
86
+ from IPython.display import display
87
+ display(visual)
88
+ return visual
89
+ except ImportError:
90
+ # Not in IPython/Jupyter environment, return the graph
91
+ return visual
92
+
93
+
94
+ def visualize_model(
95
+ model: nn.Module,
96
+ input_data: Optional[Dict[str,
97
+ torch.Tensor]] = None,
98
+ batch_size: int = 2,
99
+ seq_length: int = 10,
100
+ depth: int = 3,
101
+ show_shapes: bool = True,
102
+ expand_nested: bool = True,
103
+ save_path: Optional[str] = None,
104
+ graph_name: str = "model",
105
+ device: str = "cpu",
106
+ dpi: int = 300,
107
+ **kwargs
108
+ ) -> Any:
109
+ """Visualize a Torch-RecHub model's computation graph.
110
+
111
+ This function generates a visual representation of the model architecture,
112
+ showing layer connections, tensor shapes, and nested module structures.
113
+ It automatically extracts feature information from the model to generate
114
+ appropriate dummy inputs.
115
+
116
+ Parameters
117
+ ----------
118
+ model : nn.Module
119
+ PyTorch model to visualize. Should be a Torch-RecHub model
120
+ with feature attributes (e.g., DeepFM, DSSM, MMOE).
121
+ input_data : dict, optional
122
+ Dict of example inputs {feature_name: tensor}.
123
+ If None, inputs are auto-generated based on model features.
124
+ batch_size : int, default=2
125
+ Batch size for auto-generated inputs.
126
+ seq_length : int, default=10
127
+ Sequence length for SequenceFeature inputs.
128
+ depth : int, default=3
129
+ Visualization depth - higher values show more detail.
130
+ Set to -1 to show all layers.
131
+ show_shapes : bool, default=True
132
+ Whether to display tensor shapes on edges.
133
+ expand_nested : bool, default=True
134
+ Whether to expand nested nn.Module with dashed borders.
135
+ save_path : str, optional
136
+ Path to save the graph image. Supports .pdf, .svg, .png formats.
137
+ If None, displays in Jupyter or opens system viewer.
138
+ graph_name : str, default="model"
139
+ Name for the computation graph.
140
+ device : str, default="cpu"
141
+ Device for model execution during tracing.
142
+ dpi : int, default=300
143
+ Resolution in dots per inch for output image.
144
+ Higher values produce sharper images suitable for papers.
145
+ **kwargs : dict
146
+ Additional arguments passed to torchview.draw_graph().
147
+
148
+ Returns
149
+ -------
150
+ ComputationGraph
151
+ A torchview ComputationGraph object.
152
+ - Use `.visual_graph` property to get the graphviz.Digraph
153
+ - Use `.resize_graph(scale=1.5)` to adjust graph size
154
+
155
+ Raises
156
+ ------
157
+ ImportError
158
+ If torchview or graphviz is not installed.
159
+ ValueError
160
+ If model has no recognizable feature attributes.
161
+
162
+ Notes
163
+ -----
164
+ Default Display Behavior:
165
+ When `save_path` is None (default):
166
+ - In Jupyter/IPython: automatically displays the graph inline
167
+ - In Python script: opens the graph with system default viewer
168
+
169
+ Requires graphviz system package: apt/brew/choco install graphviz.
170
+ For Jupyter display issues, try: graphviz.set_jupyter_format('png').
171
+
172
+ Examples
173
+ --------
174
+ >>> from torch_rechub.models.ranking import DeepFM
175
+ >>> from torch_rechub.utils.visualization import visualize_model
176
+ >>>
177
+ >>> # Auto-display in Jupyter or open in viewer
178
+ >>> visualize_model(model, depth=4) # No save_path needed
179
+ >>>
180
+ >>> # Save to high-DPI PNG for paper
181
+ >>> visualize_model(model, save_path="model.png", dpi=300)
182
+ """
183
+ if not TORCHVIEW_AVAILABLE:
184
+ raise ImportError(
185
+ f"Visualization requires torchview. {TORCHVIEW_SKIP_REASON}\n"
186
+ "Install with: pip install torch-rechub[visualization]\n"
187
+ "Also ensure graphviz is installed on your system:\n"
188
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
189
+ " - macOS: brew install graphviz\n"
190
+ " - Windows: choco install graphviz"
191
+ )
192
+
193
+ # Import feature extraction utilities from model_utils
194
+ from .model_utils import extract_feature_info, generate_dummy_input_dict
195
+
196
+ model.eval()
197
+ model.to(device)
198
+
199
+ # Auto-generate input data if not provided
200
+ if input_data is None:
201
+ feature_info = extract_feature_info(model)
202
+ features = feature_info['features']
203
+
204
+ if not features:
205
+ raise ValueError("Could not extract feature information from model. "
206
+ "Please provide input_data parameter manually.")
207
+
208
+ # Generate dummy input dict
209
+ input_data = generate_dummy_input_dict(features, batch_size=batch_size, seq_length=seq_length, device=device)
210
+ else:
211
+ # Ensure input tensors are on correct device
212
+ input_data = {k: v.to(device) for k, v in input_data.items()}
213
+
214
+ # IMPORTANT: Wrap input_data dict in a tuple to work around torchview's behavior
215
+ #
216
+ # torchview's forward_prop function checks the type of input_data:
217
+ # - If isinstance(x, (list, tuple)): model(*x)
218
+ # - If isinstance(x, Mapping): model(**x) <- This unpacks dict as kwargs!
219
+ # - Else: model(x)
220
+ #
221
+ # torch-rechub models expect forward(self, x) where x is a complete dict.
222
+ # By wrapping the dict in a tuple, torchview will call:
223
+ # model(*(input_dict,)) = model(input_dict)
224
+ # which is exactly what our models expect.
225
+ input_data_wrapped = (input_data,)
226
+
227
+ # Call torchview.draw_graph without saving (we'll save manually with DPI)
228
+ graph = draw_graph(
229
+ model,
230
+ input_data=input_data_wrapped,
231
+ graph_name=graph_name,
232
+ depth=depth,
233
+ device=device,
234
+ expand_nested=expand_nested,
235
+ show_shapes=show_shapes,
236
+ save_graph=False, # Don't save here, we'll save manually with DPI
237
+ **kwargs
238
+ )
239
+
240
+ # Set DPI for high-quality output (must be set BEFORE rendering/saving)
241
+ graph.visual_graph.graph_attr['dpi'] = str(dpi)
242
+
243
+ # Handle save_path: manually save with DPI applied
244
+ if save_path:
245
+ import os
246
+ directory = os.path.dirname(save_path) or "."
247
+ filename = os.path.splitext(os.path.basename(save_path))[0]
248
+ ext = os.path.splitext(save_path)[1].lstrip('.')
249
+ # Default to pdf if no extension
250
+ output_format = ext if ext else 'pdf'
251
+ # Create directory if it doesn't exist
252
+ if directory != "." and not os.path.exists(directory):
253
+ os.makedirs(directory, exist_ok=True)
254
+ # Render and save with DPI applied
255
+ graph.visual_graph.render(
256
+ filename=filename,
257
+ directory=directory,
258
+ format=output_format,
259
+ cleanup=True # Remove intermediate .gv file
260
+ )
261
+
262
+ # Handle default display behavior when save_path is None
263
+ if save_path is None:
264
+ if _is_jupyter_environment():
265
+ # In Jupyter: display inline
266
+ display_graph(graph)
267
+ else:
268
+ # In script: open with system viewer
269
+ graph.visual_graph.view(cleanup=True)
270
+
271
+ return graph