torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,170 +1,336 @@
1
- import os
2
- import torch
3
- import tqdm
4
- from sklearn.metrics import roc_auc_score
5
- from ..basic.callback import EarlyStopper
6
- from ..basic.loss_func import BPRLoss
7
-
8
-
9
- class MatchTrainer(object):
10
- """A general trainer for Matching/Retrieval
11
-
12
- Args:
13
- model (nn.Module): any matching model.
14
- mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
15
- optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
16
- optimizer_params (dict): parameters of optimizer_fn.
17
- scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
18
- scheduler_params (dict): parameters of optimizer scheduler_fn.
19
- n_epoch (int): epoch number of training.
20
- earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
21
- device (str): `"cpu"` or `"cuda:0"`
22
- gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
23
- model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
24
- """
25
-
26
- def __init__(
27
- self,
28
- model,
29
- mode=0,
30
- optimizer_fn=torch.optim.Adam,
31
- optimizer_params=None,
32
- scheduler_fn=None,
33
- scheduler_params=None,
34
- n_epoch=10,
35
- earlystop_patience=10,
36
- device="cpu",
37
- gpus=None,
38
- model_path="./",
39
- ):
40
- self.model = model # for uniform weights save method in one gpu or multi gpu
41
- if gpus is None:
42
- gpus = []
43
- self.gpus = gpus
44
- if len(gpus) > 1:
45
- print('parallel running on these gpus:', gpus)
46
- self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
47
- self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
- self.model.to(self.device)
49
- if optimizer_params is None:
50
- optimizer_params = {
51
- "lr": 1e-3,
52
- "weight_decay": 1e-5
53
- }
54
- self.mode = mode
55
- if mode == 0: #point-wise loss, binary cross_entropy
56
- self.criterion = torch.nn.BCELoss() #default loss binary cross_entropy
57
- elif mode == 1: #pair-wise loss
58
- self.criterion = BPRLoss()
59
- elif mode == 2: #list-wise loss, softmax
60
- self.criterion = torch.nn.CrossEntropyLoss()
61
- else:
62
- raise ValueError("mode only contain value in %s, but got %s" % ([0, 1, 2], mode))
63
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
64
- self.scheduler = None
65
- if scheduler_fn is not None:
66
- self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
67
- self.evaluate_fn = roc_auc_score #default evaluate function
68
- self.n_epoch = n_epoch
69
- self.early_stopper = EarlyStopper(patience=earlystop_patience)
70
- self.model_path = model_path
71
-
72
- def train_one_epoch(self, data_loader, log_interval=10):
73
- self.model.train()
74
- total_loss = 0
75
- tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
76
- for i, (x_dict, y) in enumerate(tk0):
77
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
78
- y = y.to(self.device)
79
- if self.mode == 0:
80
- y = y.float() #torch._C._nn.binary_cross_entropy expected Float
81
- else:
82
- y = y.long() #
83
- if self.mode == 1: #pair_wise
84
- pos_score, neg_score = self.model(x_dict)
85
- loss = self.criterion(pos_score, neg_score)
86
- else:
87
- y_pred = self.model(x_dict)
88
- loss = self.criterion(y_pred, y)
89
- # used for debug
90
- # if i == 0:
91
- # print()
92
- # if self.mode == 0:
93
- # print('pred: ', [f'{float(each):5.2g}' for each in y_pred.detach().cpu().tolist()])
94
- # print('truth:', [f'{float(each):5.2g}' for each in y.detach().cpu().tolist()])
95
- # elif self.mode == 2:
96
- # pred = y_pred.detach().cpu().mean(0)
97
- # pred = torch.softmax(pred, dim=0).tolist()
98
- # print('pred: ', [f'{float(each):4.2g}' for each in pred])
99
- # elif self.mode == 1:
100
- # print('pos:', [f'{float(each):5.2g}' for each in pos_score.detach().cpu().tolist()])
101
- # print('neg: ', [f'{float(each):5.2g}' for each in neg_score.detach().cpu().tolist()])
102
-
103
- self.model.zero_grad()
104
- loss.backward()
105
- self.optimizer.step()
106
- total_loss += loss.item()
107
- if (i + 1) % log_interval == 0:
108
- tk0.set_postfix(loss=total_loss / log_interval)
109
- total_loss = 0
110
-
111
- def fit(self, train_dataloader, val_dataloader=None):
112
- for epoch_i in range(self.n_epoch):
113
- print('epoch:', epoch_i)
114
- self.train_one_epoch(train_dataloader)
115
- if self.scheduler is not None:
116
- if epoch_i % self.scheduler.step_size == 0:
117
- print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
118
- self.scheduler.step() #update lr in epoch level by scheduler
119
-
120
- if val_dataloader:
121
- auc = self.evaluate(self.model, val_dataloader)
122
- print('epoch:', epoch_i, 'validation: auc:', auc)
123
- if self.early_stopper.stop_training(auc, self.model.state_dict()):
124
- print(f'validation: best auc: {self.early_stopper.best_auc}')
125
- self.model.load_state_dict(self.early_stopper.best_weights)
126
- break
127
- torch.save(self.model.state_dict(), os.path.join(self.model_path,
128
- "model.pth")) #save best auc model
129
-
130
-
131
- def evaluate(self, model, data_loader):
132
- model.eval()
133
- targets, predicts = list(), list()
134
- with torch.no_grad():
135
- tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
136
- for i, (x_dict, y) in enumerate(tk0):
137
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
138
- y = y.to(self.device)
139
- y_pred = model(x_dict)
140
- targets.extend(y.tolist())
141
- predicts.extend(y_pred.tolist())
142
- return self.evaluate_fn(targets, predicts)
143
-
144
- def predict(self, model, data_loader):
145
- model.eval()
146
- predicts = list()
147
- with torch.no_grad():
148
- tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
149
- for i, (x_dict, y) in enumerate(tk0):
150
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
151
- y = y.to(self.device)
152
- y_pred = model(x_dict)
153
- predicts.extend(y_pred.tolist())
154
- return predicts
155
-
156
- def inference_embedding(self, model, mode, data_loader, model_path):
157
- #inference
158
- assert mode in ["user", "item"], "Invalid mode={}.".format(mode)
159
- model.mode = mode
160
- model.load_state_dict(torch.load(os.path.join(model_path, "model.pth")))
161
- model = model.to(self.device)
162
- model.eval()
163
- predicts = []
164
- with torch.no_grad():
165
- tk0 = tqdm.tqdm(data_loader, desc="%s inference" % (mode), smoothing=0, mininterval=1.0)
166
- for i, x_dict in enumerate(tk0):
167
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
168
- y_pred = model(x_dict)
169
- predicts.append(y_pred.data)
170
- return torch.cat(predicts, dim=0)
1
+ import os
2
+
3
+ import torch
4
+ import tqdm
5
+ from sklearn.metrics import roc_auc_score
6
+
7
+ from ..basic.callback import EarlyStopper
8
+ from ..basic.loss_func import BPRLoss, RegularizationLoss
9
+
10
+
11
+ class MatchTrainer(object):
12
+ """A general trainer for Matching/Retrieval
13
+
14
+ Args:
15
+ model (nn.Module): any matching model.
16
+ mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
17
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
18
+ optimizer_params (dict): parameters of optimizer_fn.
19
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
20
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
21
+ n_epoch (int): epoch number of training.
22
+ earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
23
+ device (str): `"cpu"` or `"cuda:0"`
24
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
25
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model,
31
+ mode=0,
32
+ optimizer_fn=torch.optim.Adam,
33
+ optimizer_params=None,
34
+ regularization_params=None,
35
+ scheduler_fn=None,
36
+ scheduler_params=None,
37
+ n_epoch=10,
38
+ earlystop_patience=10,
39
+ device="cpu",
40
+ gpus=None,
41
+ model_path="./",
42
+ ):
43
+ self.model = model # for uniform weights save method in one gpu or multi gpu
44
+ if gpus is None:
45
+ gpus = []
46
+ self.gpus = gpus
47
+ if len(gpus) > 1:
48
+ print('parallel running on these gpus:', gpus)
49
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
50
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
+ self.device = torch.device(device)
52
+ self.model.to(self.device)
53
+ if optimizer_params is None:
54
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
55
+ if regularization_params is None:
56
+ regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
57
+ self.mode = mode
58
+ if mode == 0: # point-wise loss, binary cross_entropy
59
+ self.criterion = torch.nn.BCELoss() # default loss binary cross_entropy
60
+ elif mode == 1: # pair-wise loss
61
+ self.criterion = BPRLoss()
62
+ elif mode == 2: # list-wise loss, softmax
63
+ self.criterion = torch.nn.CrossEntropyLoss()
64
+ else:
65
+ raise ValueError("mode only contain value in %s, but got %s" % ([0, 1, 2], mode))
66
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
67
+ self.scheduler = None
68
+ if scheduler_fn is not None:
69
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
70
+ self.evaluate_fn = roc_auc_score # default evaluate function
71
+ self.n_epoch = n_epoch
72
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
73
+ self.model_path = model_path
74
+ # Initialize regularization loss
75
+ self.reg_loss_fn = RegularizationLoss(**regularization_params)
76
+
77
+ def train_one_epoch(self, data_loader, log_interval=10):
78
+ self.model.train()
79
+ total_loss = 0
80
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
81
+ for i, (x_dict, y) in enumerate(tk0):
82
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
83
+ y = y.to(self.device)
84
+ if self.mode == 0:
85
+ y = y.float() # torch._C._nn.binary_cross_entropy expected Float
86
+ else:
87
+ y = y.long() #
88
+ if self.mode == 1: # pair_wise
89
+ pos_score, neg_score = self.model(x_dict)
90
+ loss = self.criterion(pos_score, neg_score)
91
+ else:
92
+ y_pred = self.model(x_dict)
93
+ loss = self.criterion(y_pred, y)
94
+
95
+ # Add regularization loss
96
+ reg_loss = self.reg_loss_fn(self.model)
97
+ loss = loss + reg_loss
98
+
99
+ # used for debug
100
+ # if i == 0:
101
+ # print()
102
+ # if self.mode == 0:
103
+ # print('pred: ', [f'{float(each):5.2g}' for each in y_pred.detach().cpu().tolist()])
104
+ # print('truth:', [f'{float(each):5.2g}' for each in y.detach().cpu().tolist()])
105
+ # elif self.mode == 2:
106
+ # pred = y_pred.detach().cpu().mean(0)
107
+ # pred = torch.softmax(pred, dim=0).tolist()
108
+ # print('pred: ', [f'{float(each):4.2g}' for each in pred])
109
+ # elif self.mode == 1:
110
+ # print('pos:', [f'{float(each):5.2g}' for each in pos_score.detach().cpu().tolist()])
111
+ # print('neg: ', [f'{float(each):5.2g}' for each in neg_score.detach().cpu().tolist()])
112
+
113
+ self.model.zero_grad()
114
+ loss.backward()
115
+ self.optimizer.step()
116
+ total_loss += loss.item()
117
+ if (i + 1) % log_interval == 0:
118
+ tk0.set_postfix(loss=total_loss / log_interval)
119
+ total_loss = 0
120
+
121
+ def fit(self, train_dataloader, val_dataloader=None):
122
+ for epoch_i in range(self.n_epoch):
123
+ print('epoch:', epoch_i)
124
+ self.train_one_epoch(train_dataloader)
125
+ if self.scheduler is not None:
126
+ if epoch_i % self.scheduler.step_size == 0:
127
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
128
+ self.scheduler.step() # update lr in epoch level by scheduler
129
+
130
+ if val_dataloader:
131
+ auc = self.evaluate(self.model, val_dataloader)
132
+ print('epoch:', epoch_i, 'validation: auc:', auc)
133
+ if self.early_stopper.stop_training(auc, self.model.state_dict()):
134
+ print(f'validation: best auc: {self.early_stopper.best_auc}')
135
+ self.model.load_state_dict(self.early_stopper.best_weights)
136
+ break
137
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
138
+
139
+ def evaluate(self, model, data_loader):
140
+ model.eval()
141
+ targets, predicts = list(), list()
142
+ with torch.no_grad():
143
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
144
+ for i, (x_dict, y) in enumerate(tk0):
145
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
146
+ y = y.to(self.device)
147
+ y_pred = model(x_dict)
148
+ targets.extend(y.tolist())
149
+ predicts.extend(y_pred.tolist())
150
+ return self.evaluate_fn(targets, predicts)
151
+
152
+ def predict(self, model, data_loader):
153
+ model.eval()
154
+ predicts = list()
155
+ with torch.no_grad():
156
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
157
+ for i, (x_dict, y) in enumerate(tk0):
158
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
159
+ y = y.to(self.device)
160
+ y_pred = model(x_dict)
161
+ predicts.extend(y_pred.tolist())
162
+ return predicts
163
+
164
+ def inference_embedding(self, model, mode, data_loader, model_path):
165
+ # inference
166
+ assert mode in ["user", "item"], "Invalid mode={}.".format(mode)
167
+ model.mode = mode
168
+ model.load_state_dict(torch.load(os.path.join(model_path, "model.pth"), map_location=self.device, weights_only=True))
169
+ model = model.to(self.device)
170
+ model.eval()
171
+ predicts = []
172
+ with torch.no_grad():
173
+ tk0 = tqdm.tqdm(data_loader, desc="%s inference" % (mode), smoothing=0, mininterval=1.0)
174
+ for i, x_dict in enumerate(tk0):
175
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
176
+ y_pred = model(x_dict)
177
+ predicts.append(y_pred.data)
178
+ return torch.cat(predicts, dim=0)
179
+
180
+ def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
181
+ """Export the trained matching model to ONNX format.
182
+
183
+ This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
184
+ to ONNX format. For dual-tower models, you can export user tower and item
185
+ tower separately for efficient online serving.
186
+
187
+ Args:
188
+ output_path (str): Path to save the ONNX model file.
189
+ mode (str, optional): Export mode for dual-tower models:
190
+ - "user": Export only the user tower (for user embedding inference)
191
+ - "item": Export only the item tower (for item embedding inference)
192
+ - None: Export the full model (default)
193
+ dummy_input (dict, optional): Example input dict {feature_name: tensor}.
194
+ If not provided, dummy inputs will be generated automatically.
195
+ batch_size (int): Batch size for auto-generated dummy input (default: 2).
196
+ seq_length (int): Sequence length for SequenceFeature (default: 10).
197
+ opset_version (int): ONNX opset version (default: 14).
198
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
199
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
200
+ If None, defaults to 'cpu' for maximum compatibility.
201
+ verbose (bool): Print export details (default: False).
202
+
203
+ Returns:
204
+ bool: True if export succeeded, False otherwise.
205
+
206
+ Example:
207
+ >>> trainer = MatchTrainer(dssm_model, mode=0, ...)
208
+ >>> trainer.fit(train_dl)
209
+
210
+ >>> # Export user tower for user embedding inference
211
+ >>> trainer.export_onnx("user_tower.onnx", mode="user")
212
+
213
+ >>> # Export item tower for item embedding inference
214
+ >>> trainer.export_onnx("item_tower.onnx", mode="item")
215
+
216
+ >>> # Export full model (for online similarity computation)
217
+ >>> trainer.export_onnx("full_model.onnx")
218
+
219
+ >>> # Export on specific device
220
+ >>> trainer.export_onnx("user_tower.onnx", mode="user", device="cpu")
221
+ """
222
+ from ..utils.onnx_export import ONNXExporter
223
+
224
+ # Handle DataParallel wrapped model
225
+ model = self.model.module if hasattr(self.model, 'module') else self.model
226
+
227
+ # Store original mode
228
+ original_mode = getattr(model, 'mode', None)
229
+
230
+ # Use provided device or default to 'cpu'
231
+ export_device = device if device is not None else 'cpu'
232
+
233
+ try:
234
+ exporter = ONNXExporter(model, device=export_device)
235
+ return exporter.export(output_path=output_path, mode=mode, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
236
+ finally:
237
+ # Restore original mode
238
+ if hasattr(model, 'mode'):
239
+ model.mode = original_mode
240
+
241
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
242
+ """Visualize the model's computation graph.
243
+
244
+ This method generates a visual representation of the model architecture,
245
+ showing layer connections, tensor shapes, and nested module structures.
246
+ It automatically extracts feature information from the model.
247
+
248
+ Parameters
249
+ ----------
250
+ input_data : dict, optional
251
+ Example input dict {feature_name: tensor}.
252
+ If not provided, dummy inputs will be generated automatically.
253
+ batch_size : int, default=2
254
+ Batch size for auto-generated dummy input.
255
+ seq_length : int, default=10
256
+ Sequence length for SequenceFeature.
257
+ depth : int, default=3
258
+ Visualization depth, higher values show more detail.
259
+ Set to -1 to show all layers.
260
+ show_shapes : bool, default=True
261
+ Whether to display tensor shapes.
262
+ expand_nested : bool, default=True
263
+ Whether to expand nested modules.
264
+ save_path : str, optional
265
+ Path to save the graph image (.pdf, .svg, .png).
266
+ If None, displays in Jupyter or opens system viewer.
267
+ graph_name : str, default="model"
268
+ Name for the graph.
269
+ device : str, optional
270
+ Device for model execution. If None, defaults to 'cpu'.
271
+ dpi : int, default=300
272
+ Resolution in dots per inch for output image.
273
+ Higher values produce sharper images suitable for papers.
274
+ **kwargs : dict
275
+ Additional arguments passed to torchview.draw_graph().
276
+
277
+ Returns
278
+ -------
279
+ ComputationGraph
280
+ A torchview ComputationGraph object.
281
+
282
+ Raises
283
+ ------
284
+ ImportError
285
+ If torchview or graphviz is not installed.
286
+
287
+ Notes
288
+ -----
289
+ Default Display Behavior:
290
+ When `save_path` is None (default):
291
+ - In Jupyter/IPython: automatically displays the graph inline
292
+ - In Python script: opens the graph with system default viewer
293
+
294
+ Examples
295
+ --------
296
+ >>> trainer = MatchTrainer(model, ...)
297
+ >>> trainer.fit(train_dl)
298
+ >>>
299
+ >>> # Auto-display in Jupyter (no save_path needed)
300
+ >>> trainer.visualization(depth=4)
301
+ >>>
302
+ >>> # Save to high-DPI PNG for papers
303
+ >>> trainer.visualization(save_path="model.png", dpi=300)
304
+ """
305
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
306
+
307
+ if not TORCHVIEW_AVAILABLE:
308
+ raise ImportError(
309
+ "Visualization requires torchview. "
310
+ "Install with: pip install torch-rechub[visualization]\n"
311
+ "Also ensure graphviz is installed on your system:\n"
312
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
313
+ " - macOS: brew install graphviz\n"
314
+ " - Windows: choco install graphviz"
315
+ )
316
+
317
+ # Handle DataParallel wrapped model
318
+ model = self.model.module if hasattr(self.model, 'module') else self.model
319
+
320
+ # Use provided device or default to 'cpu'
321
+ viz_device = device if device is not None else 'cpu'
322
+
323
+ return visualize_model(
324
+ model,
325
+ input_data=input_data,
326
+ batch_size=batch_size,
327
+ seq_length=seq_length,
328
+ depth=depth,
329
+ show_shapes=show_shapes,
330
+ expand_nested=expand_nested,
331
+ save_path=save_path,
332
+ graph_name=graph_name,
333
+ device=viz_device,
334
+ dpi=dpi,
335
+ **kwargs
336
+ )
@@ -0,0 +1,3 @@
1
+ # Matching
2
+
3
+ 召回使用文档