torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -1,207 +1,356 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
import torch.nn as nn
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from ..
|
|
9
|
-
from ..
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
self.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
self.
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
self.
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
self.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
self.
|
|
92
|
-
|
|
93
|
-
self.
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
self.model.
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
self.
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
for
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
8
|
+
from ..basic.callback import EarlyStopper
|
|
9
|
+
from ..basic.loss_func import RegularizationLoss
|
|
10
|
+
from ..models.multi_task import ESMM
|
|
11
|
+
from ..utils.data import get_loss_func, get_metric_func
|
|
12
|
+
from ..utils.mtl import MetaBalance, gradnorm, shared_task_layers
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MTLTrainer(object):
|
|
16
|
+
"""A trainer for multi task learning.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model (nn.Module): any multi task learning model.
|
|
20
|
+
task_types (list): types of tasks, only support ["classfication", "regression"].
|
|
21
|
+
optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
|
|
22
|
+
optimizer_params (dict): parameters of optimizer_fn.
|
|
23
|
+
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
24
|
+
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
25
|
+
adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
|
|
26
|
+
n_epoch (int): epoch number of training.
|
|
27
|
+
earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
|
|
28
|
+
earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
|
|
29
|
+
device (str): `"cpu"` or `"cuda:0"`
|
|
30
|
+
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
31
|
+
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model,
|
|
37
|
+
task_types,
|
|
38
|
+
optimizer_fn=torch.optim.Adam,
|
|
39
|
+
optimizer_params=None,
|
|
40
|
+
regularization_params=None,
|
|
41
|
+
scheduler_fn=None,
|
|
42
|
+
scheduler_params=None,
|
|
43
|
+
adaptive_params=None,
|
|
44
|
+
n_epoch=10,
|
|
45
|
+
earlystop_taskid=0,
|
|
46
|
+
earlystop_patience=10,
|
|
47
|
+
device="cpu",
|
|
48
|
+
gpus=None,
|
|
49
|
+
model_path="./",
|
|
50
|
+
):
|
|
51
|
+
self.model = model
|
|
52
|
+
if gpus is None:
|
|
53
|
+
gpus = []
|
|
54
|
+
if optimizer_params is None:
|
|
55
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
56
|
+
if regularization_params is None:
|
|
57
|
+
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
|
|
58
|
+
self.task_types = task_types
|
|
59
|
+
self.n_task = len(task_types)
|
|
60
|
+
self.loss_weight = None
|
|
61
|
+
self.adaptive_method = None
|
|
62
|
+
if adaptive_params is not None:
|
|
63
|
+
if adaptive_params["method"] == "uwl":
|
|
64
|
+
self.adaptive_method = "uwl"
|
|
65
|
+
self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
|
|
66
|
+
self.model.add_module("loss weight", self.loss_weight)
|
|
67
|
+
elif adaptive_params["method"] == "metabalance":
|
|
68
|
+
self.adaptive_method = "metabalance"
|
|
69
|
+
share_layers, task_layers = shared_task_layers(self.model)
|
|
70
|
+
self.meta_optimizer = MetaBalance(share_layers)
|
|
71
|
+
self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
|
|
72
|
+
self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
|
|
73
|
+
elif adaptive_params["method"] == "gradnorm":
|
|
74
|
+
self.adaptive_method = "gradnorm"
|
|
75
|
+
self.alpha = adaptive_params.get("alpha", 0.16)
|
|
76
|
+
share_layers = shared_task_layers(self.model)[0]
|
|
77
|
+
# gradnorm calculate the gradients of each loss on the last
|
|
78
|
+
# fully connected shared layer weight(dimension is 2)
|
|
79
|
+
for i in range(len(share_layers)):
|
|
80
|
+
if share_layers[-i].ndim == 2:
|
|
81
|
+
self.last_share_layer = share_layers[-i]
|
|
82
|
+
break
|
|
83
|
+
self.initial_task_loss = None
|
|
84
|
+
self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
|
|
85
|
+
self.model.add_module("loss weight", self.loss_weight)
|
|
86
|
+
if self.adaptive_method != "metabalance":
|
|
87
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default Adam optimizer
|
|
88
|
+
self.scheduler = None
|
|
89
|
+
if scheduler_fn is not None:
|
|
90
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
91
|
+
self.loss_fns = [get_loss_func(task_type) for task_type in task_types]
|
|
92
|
+
self.evaluate_fns = [get_metric_func(task_type) for task_type in task_types]
|
|
93
|
+
self.n_epoch = n_epoch
|
|
94
|
+
self.earlystop_taskid = earlystop_taskid
|
|
95
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
96
|
+
|
|
97
|
+
self.gpus = gpus
|
|
98
|
+
if len(gpus) > 1:
|
|
99
|
+
print('parallel running on these gpus:', gpus)
|
|
100
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
101
|
+
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
102
|
+
self.device = torch.device(device)
|
|
103
|
+
self.model.to(self.device)
|
|
104
|
+
self.model_path = model_path
|
|
105
|
+
# Initialize regularization loss
|
|
106
|
+
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
107
|
+
|
|
108
|
+
def train_one_epoch(self, data_loader):
|
|
109
|
+
self.model.train()
|
|
110
|
+
total_loss = np.zeros(self.n_task)
|
|
111
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
112
|
+
for iter_i, (x_dict, ys) in enumerate(tk0):
|
|
113
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
114
|
+
ys = ys.to(self.device)
|
|
115
|
+
y_preds = self.model(x_dict)
|
|
116
|
+
loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
|
|
117
|
+
if isinstance(self.model, ESMM):
|
|
118
|
+
# ESSM only compute loss for ctr and ctcvr task
|
|
119
|
+
loss = sum(loss_list[1:])
|
|
120
|
+
else:
|
|
121
|
+
if self.adaptive_method is not None:
|
|
122
|
+
if self.adaptive_method == "uwl":
|
|
123
|
+
loss = 0
|
|
124
|
+
for loss_i, w_i in zip(loss_list, self.loss_weight):
|
|
125
|
+
w_i = torch.clamp(w_i, min=0)
|
|
126
|
+
loss += 2 * loss_i * torch.exp(-w_i) + w_i
|
|
127
|
+
else:
|
|
128
|
+
loss = sum(loss_list) / self.n_task
|
|
129
|
+
|
|
130
|
+
# Add regularization loss
|
|
131
|
+
reg_loss = self.reg_loss_fn(self.model)
|
|
132
|
+
loss = loss + reg_loss
|
|
133
|
+
if self.adaptive_method == 'metabalance':
|
|
134
|
+
self.share_optimizer.zero_grad()
|
|
135
|
+
self.task_optimizer.zero_grad()
|
|
136
|
+
self.meta_optimizer.step(loss_list)
|
|
137
|
+
self.share_optimizer.step()
|
|
138
|
+
self.task_optimizer.step()
|
|
139
|
+
elif self.adaptive_method == "gradnorm":
|
|
140
|
+
self.optimizer.zero_grad()
|
|
141
|
+
if self.initial_task_loss is None:
|
|
142
|
+
self.initial_task_loss = [l.item() for l in loss_list]
|
|
143
|
+
gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
|
|
144
|
+
self.optimizer.step()
|
|
145
|
+
# renormalize
|
|
146
|
+
loss_weight_sum = sum([w.item() for w in self.loss_weight])
|
|
147
|
+
normalize_coeff = len(self.loss_weight) / loss_weight_sum
|
|
148
|
+
for w in self.loss_weight:
|
|
149
|
+
w.data = w.data * normalize_coeff
|
|
150
|
+
else:
|
|
151
|
+
self.model.zero_grad()
|
|
152
|
+
loss.backward()
|
|
153
|
+
self.optimizer.step()
|
|
154
|
+
total_loss += np.array([l.item() for l in loss_list])
|
|
155
|
+
log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
|
|
156
|
+
loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
|
|
157
|
+
print("train loss: ", log_dict)
|
|
158
|
+
if self.loss_weight:
|
|
159
|
+
print("loss weight: ", [w.item() for w in self.loss_weight])
|
|
160
|
+
|
|
161
|
+
return loss_list
|
|
162
|
+
|
|
163
|
+
def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
|
|
164
|
+
total_log = []
|
|
165
|
+
|
|
166
|
+
for epoch_i in range(self.n_epoch):
|
|
167
|
+
_log_per_epoch = self.train_one_epoch(train_dataloader)
|
|
168
|
+
|
|
169
|
+
if self.scheduler is not None:
|
|
170
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
171
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
172
|
+
self.scheduler.step() # update lr in epoch level by scheduler
|
|
173
|
+
scores = self.evaluate(self.model, val_dataloader)
|
|
174
|
+
print('epoch:', epoch_i, 'validation scores: ', scores)
|
|
175
|
+
|
|
176
|
+
for score in scores:
|
|
177
|
+
_log_per_epoch.append(score)
|
|
178
|
+
|
|
179
|
+
total_log.append(_log_per_epoch)
|
|
180
|
+
|
|
181
|
+
if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
|
|
182
|
+
print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
|
|
183
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
184
|
+
break
|
|
185
|
+
|
|
186
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
|
|
187
|
+
|
|
188
|
+
return total_log
|
|
189
|
+
|
|
190
|
+
def evaluate(self, model, data_loader):
|
|
191
|
+
model.eval()
|
|
192
|
+
targets, predicts = list(), list()
|
|
193
|
+
with torch.no_grad():
|
|
194
|
+
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
195
|
+
for i, (x_dict, ys) in enumerate(tk0):
|
|
196
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
197
|
+
ys = ys.to(self.device)
|
|
198
|
+
y_preds = self.model(x_dict)
|
|
199
|
+
targets.extend(ys.tolist())
|
|
200
|
+
predicts.extend(y_preds.tolist())
|
|
201
|
+
targets, predicts = np.array(targets), np.array(predicts)
|
|
202
|
+
scores = [self.evaluate_fns[i](targets[:, i], predicts[:, i]) for i in range(self.n_task)]
|
|
203
|
+
return scores
|
|
204
|
+
|
|
205
|
+
def predict(self, model, data_loader):
|
|
206
|
+
model.eval()
|
|
207
|
+
predicts = list()
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
|
|
210
|
+
for i, x_dict in enumerate(tk0):
|
|
211
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
212
|
+
y_preds = model(x_dict)
|
|
213
|
+
predicts.extend(y_preds.tolist())
|
|
214
|
+
return predicts
|
|
215
|
+
|
|
216
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
217
|
+
"""Export the trained multi-task model to ONNX format.
|
|
218
|
+
|
|
219
|
+
This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
|
|
220
|
+
to ONNX format for deployment. The exported model will have multiple outputs
|
|
221
|
+
corresponding to each task.
|
|
222
|
+
|
|
223
|
+
Note:
|
|
224
|
+
The ONNX model will output a tensor of shape [batch_size, n_task] where
|
|
225
|
+
n_task is the number of tasks in the multi-task model.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
output_path (str): Path to save the ONNX model file.
|
|
229
|
+
dummy_input (dict, optional): Example input dict {feature_name: tensor}.
|
|
230
|
+
If not provided, dummy inputs will be generated automatically.
|
|
231
|
+
batch_size (int): Batch size for auto-generated dummy input (default: 2).
|
|
232
|
+
seq_length (int): Sequence length for SequenceFeature (default: 10).
|
|
233
|
+
opset_version (int): ONNX opset version (default: 14).
|
|
234
|
+
dynamic_batch (bool): Enable dynamic batch size (default: True).
|
|
235
|
+
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
236
|
+
If None, defaults to 'cpu' for maximum compatibility.
|
|
237
|
+
verbose (bool): Print export details (default: False).
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
bool: True if export succeeded, False otherwise.
|
|
241
|
+
|
|
242
|
+
Example:
|
|
243
|
+
>>> trainer = MTLTrainer(mmoe_model, task_types=["classification", "classification"], ...)
|
|
244
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
245
|
+
>>> trainer.export_onnx("mmoe.onnx")
|
|
246
|
+
|
|
247
|
+
>>> # Export on specific device
|
|
248
|
+
>>> trainer.export_onnx("mmoe.onnx", device="cpu")
|
|
249
|
+
"""
|
|
250
|
+
from ..utils.onnx_export import ONNXExporter
|
|
251
|
+
|
|
252
|
+
# Handle DataParallel wrapped model
|
|
253
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
254
|
+
|
|
255
|
+
# Use provided device or default to 'cpu'
|
|
256
|
+
export_device = device if device is not None else 'cpu'
|
|
257
|
+
|
|
258
|
+
exporter = ONNXExporter(model, device=export_device)
|
|
259
|
+
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|
|
260
|
+
|
|
261
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
262
|
+
"""Visualize the model's computation graph.
|
|
263
|
+
|
|
264
|
+
This method generates a visual representation of the model architecture,
|
|
265
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
266
|
+
It automatically extracts feature information from the model.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
input_data : dict, optional
|
|
271
|
+
Example input dict {feature_name: tensor}.
|
|
272
|
+
If not provided, dummy inputs will be generated automatically.
|
|
273
|
+
batch_size : int, default=2
|
|
274
|
+
Batch size for auto-generated dummy input.
|
|
275
|
+
seq_length : int, default=10
|
|
276
|
+
Sequence length for SequenceFeature.
|
|
277
|
+
depth : int, default=3
|
|
278
|
+
Visualization depth, higher values show more detail.
|
|
279
|
+
Set to -1 to show all layers.
|
|
280
|
+
show_shapes : bool, default=True
|
|
281
|
+
Whether to display tensor shapes.
|
|
282
|
+
expand_nested : bool, default=True
|
|
283
|
+
Whether to expand nested modules.
|
|
284
|
+
save_path : str, optional
|
|
285
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
286
|
+
If None, displays in Jupyter or opens system viewer.
|
|
287
|
+
graph_name : str, default="model"
|
|
288
|
+
Name for the graph.
|
|
289
|
+
device : str, optional
|
|
290
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
291
|
+
dpi : int, default=300
|
|
292
|
+
Resolution in dots per inch for output image.
|
|
293
|
+
Higher values produce sharper images suitable for papers.
|
|
294
|
+
**kwargs : dict
|
|
295
|
+
Additional arguments passed to torchview.draw_graph().
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
ComputationGraph
|
|
300
|
+
A torchview ComputationGraph object.
|
|
301
|
+
|
|
302
|
+
Raises
|
|
303
|
+
------
|
|
304
|
+
ImportError
|
|
305
|
+
If torchview or graphviz is not installed.
|
|
306
|
+
|
|
307
|
+
Notes
|
|
308
|
+
-----
|
|
309
|
+
Default Display Behavior:
|
|
310
|
+
When `save_path` is None (default):
|
|
311
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
312
|
+
- In Python script: opens the graph with system default viewer
|
|
313
|
+
|
|
314
|
+
Examples
|
|
315
|
+
--------
|
|
316
|
+
>>> trainer = MTLTrainer(model, task_types=["classification", "classification"])
|
|
317
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
318
|
+
>>>
|
|
319
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
320
|
+
>>> trainer.visualization(depth=4)
|
|
321
|
+
>>>
|
|
322
|
+
>>> # Save to high-DPI PNG for papers
|
|
323
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
324
|
+
"""
|
|
325
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
326
|
+
|
|
327
|
+
if not TORCHVIEW_AVAILABLE:
|
|
328
|
+
raise ImportError(
|
|
329
|
+
"Visualization requires torchview. "
|
|
330
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
331
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
332
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
333
|
+
" - macOS: brew install graphviz\n"
|
|
334
|
+
" - Windows: choco install graphviz"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Handle DataParallel wrapped model
|
|
338
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
339
|
+
|
|
340
|
+
# Use provided device or default to 'cpu'
|
|
341
|
+
viz_device = device if device is not None else 'cpu'
|
|
342
|
+
|
|
343
|
+
return visualize_model(
|
|
344
|
+
model,
|
|
345
|
+
input_data=input_data,
|
|
346
|
+
batch_size=batch_size,
|
|
347
|
+
seq_length=seq_length,
|
|
348
|
+
depth=depth,
|
|
349
|
+
show_shapes=show_shapes,
|
|
350
|
+
expand_nested=expand_nested,
|
|
351
|
+
save_path=save_path,
|
|
352
|
+
graph_name=graph_name,
|
|
353
|
+
device=viz_device,
|
|
354
|
+
dpi=dpi,
|
|
355
|
+
**kwargs
|
|
356
|
+
)
|