torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
torch_rechub/basic/loss_func.py
CHANGED
|
@@ -1,34 +1,223 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.functional as F
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
1
|
+
import torch
|
|
2
|
+
import torch.functional as F
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RegularizationLoss(nn.Module):
|
|
7
|
+
"""Unified L1/L2 Regularization Loss for embedding and dense parameters.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
>>> reg_loss_fn = RegularizationLoss(embedding_l2=1e-5, dense_l2=1e-5)
|
|
11
|
+
>>> # In model's forward or trainer
|
|
12
|
+
>>> reg_loss = reg_loss_fn(model)
|
|
13
|
+
>>> total_loss = task_loss + reg_loss
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, embedding_l1=0.0, embedding_l2=0.0, dense_l1=0.0, dense_l2=0.0):
|
|
17
|
+
super(RegularizationLoss, self).__init__()
|
|
18
|
+
self.embedding_l1 = embedding_l1
|
|
19
|
+
self.embedding_l2 = embedding_l2
|
|
20
|
+
self.dense_l1 = dense_l1
|
|
21
|
+
self.dense_l2 = dense_l2
|
|
22
|
+
|
|
23
|
+
def forward(self, model):
|
|
24
|
+
reg_loss = 0.0
|
|
25
|
+
|
|
26
|
+
# Register normalization layers
|
|
27
|
+
norm_params = set()
|
|
28
|
+
for module in model.modules():
|
|
29
|
+
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
|
30
|
+
for param in module.parameters():
|
|
31
|
+
norm_params.add(id(param))
|
|
32
|
+
|
|
33
|
+
# Register embedding layers
|
|
34
|
+
embedding_params = set()
|
|
35
|
+
for module in model.modules():
|
|
36
|
+
if isinstance(module, (nn.Embedding, nn.EmbeddingBag)):
|
|
37
|
+
for param in module.parameters():
|
|
38
|
+
embedding_params.add(id(param))
|
|
39
|
+
|
|
40
|
+
for param in model.parameters():
|
|
41
|
+
if param.requires_grad:
|
|
42
|
+
# Skip normalization layer parameters
|
|
43
|
+
if id(param) in norm_params:
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
if id(param) in embedding_params:
|
|
47
|
+
if self.embedding_l1 > 0:
|
|
48
|
+
reg_loss += self.embedding_l1 * torch.sum(torch.abs(param))
|
|
49
|
+
if self.embedding_l2 > 0:
|
|
50
|
+
reg_loss += self.embedding_l2 * torch.sum(param**2)
|
|
51
|
+
else:
|
|
52
|
+
if self.dense_l1 > 0:
|
|
53
|
+
reg_loss += self.dense_l1 * torch.sum(torch.abs(param))
|
|
54
|
+
if self.dense_l2 > 0:
|
|
55
|
+
reg_loss += self.dense_l2 * torch.sum(param**2)
|
|
56
|
+
|
|
57
|
+
return reg_loss
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class HingeLoss(torch.nn.Module):
|
|
61
|
+
"""Hinge Loss for pairwise learning.
|
|
62
|
+
reference: https://github.com/ustcml/RecStudio/blob/main/recstudio/model/loss_func.py
|
|
63
|
+
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, margin=2, num_items=None):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.margin = margin
|
|
69
|
+
self.n_items = num_items
|
|
70
|
+
|
|
71
|
+
def forward(self, pos_score, neg_score):
|
|
72
|
+
loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
|
|
73
|
+
if self.n_items is not None:
|
|
74
|
+
impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
|
|
75
|
+
rank = torch.mean(impostors, -1) * self.n_items
|
|
76
|
+
return torch.mean(loss * torch.log(rank + 1))
|
|
77
|
+
else:
|
|
78
|
+
return torch.mean(loss)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class BPRLoss(torch.nn.Module):
|
|
82
|
+
|
|
83
|
+
def __init__(self):
|
|
84
|
+
super().__init__()
|
|
85
|
+
|
|
86
|
+
def forward(self, pos_score, neg_score):
|
|
87
|
+
loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
|
|
88
|
+
return loss
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class NCELoss(torch.nn.Module):
|
|
92
|
+
"""Noise Contrastive Estimation (NCE) Loss for recommendation systems.
|
|
93
|
+
|
|
94
|
+
NCE Loss is more efficient than CrossEntropyLoss for large-scale recommendation
|
|
95
|
+
scenarios. It uses in-batch negatives to reduce computational complexity.
|
|
96
|
+
|
|
97
|
+
Reference:
|
|
98
|
+
- Noise-contrastive estimation: A new estimation principle for unnormalized
|
|
99
|
+
statistical models (Gutmann & Hyvärinen, 2010)
|
|
100
|
+
- HLLM: Hierarchical Large Language Model for Recommendation
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
temperature (float): Temperature parameter for scaling logits. Default: 1.0
|
|
104
|
+
ignore_index (int): Index to ignore in loss computation. Default: 0
|
|
105
|
+
reduction (str): Specifies the reduction to apply to the output.
|
|
106
|
+
Options: 'mean', 'sum', 'none'. Default: 'mean'
|
|
107
|
+
|
|
108
|
+
Example:
|
|
109
|
+
>>> nce_loss = NCELoss(temperature=0.1)
|
|
110
|
+
>>> logits = torch.randn(32, 1000) # (batch_size, vocab_size)
|
|
111
|
+
>>> targets = torch.randint(0, 1000, (32,))
|
|
112
|
+
>>> loss = nce_loss(logits, targets)
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, temperature=1.0, ignore_index=0, reduction='mean'):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.temperature = temperature
|
|
118
|
+
self.ignore_index = ignore_index
|
|
119
|
+
self.reduction = reduction
|
|
120
|
+
|
|
121
|
+
def forward(self, logits, targets):
|
|
122
|
+
"""Compute NCE loss.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
logits (torch.Tensor): Model output logits of shape (batch_size, vocab_size)
|
|
126
|
+
targets (torch.Tensor): Target indices of shape (batch_size,)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
torch.Tensor: NCE loss value
|
|
130
|
+
"""
|
|
131
|
+
# Scale logits by temperature
|
|
132
|
+
logits = logits / self.temperature
|
|
133
|
+
|
|
134
|
+
# Compute log softmax
|
|
135
|
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
136
|
+
|
|
137
|
+
# Get log probability of target class
|
|
138
|
+
batch_size = targets.shape[0]
|
|
139
|
+
target_log_probs = log_probs[torch.arange(batch_size), targets]
|
|
140
|
+
|
|
141
|
+
# Create mask for ignore_index
|
|
142
|
+
mask = targets != self.ignore_index
|
|
143
|
+
|
|
144
|
+
# Compute loss
|
|
145
|
+
loss = -target_log_probs
|
|
146
|
+
|
|
147
|
+
# Apply mask
|
|
148
|
+
if mask.any():
|
|
149
|
+
loss = loss[mask]
|
|
150
|
+
|
|
151
|
+
# Apply reduction
|
|
152
|
+
if self.reduction == 'mean':
|
|
153
|
+
return loss.mean()
|
|
154
|
+
elif self.reduction == 'sum':
|
|
155
|
+
return loss.sum()
|
|
156
|
+
else: # 'none'
|
|
157
|
+
return loss
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class InBatchNCELoss(torch.nn.Module):
|
|
161
|
+
"""In-Batch NCE Loss with explicit negative sampling.
|
|
162
|
+
|
|
163
|
+
This loss function uses other samples in the batch as negative samples,
|
|
164
|
+
which is more efficient than sampling random negatives.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
temperature (float): Temperature parameter for scaling logits. Default: 0.1
|
|
168
|
+
ignore_index (int): Index to ignore in loss computation. Default: 0
|
|
169
|
+
reduction (str): Specifies the reduction to apply to the output.
|
|
170
|
+
Options: 'mean', 'sum', 'none'. Default: 'mean'
|
|
171
|
+
|
|
172
|
+
Example:
|
|
173
|
+
>>> loss_fn = InBatchNCELoss(temperature=0.1)
|
|
174
|
+
>>> embeddings = torch.randn(32, 256) # (batch_size, embedding_dim)
|
|
175
|
+
>>> item_embeddings = torch.randn(1000, 256) # (vocab_size, embedding_dim)
|
|
176
|
+
>>> targets = torch.randint(0, 1000, (32,))
|
|
177
|
+
>>> loss = loss_fn(embeddings, item_embeddings, targets)
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(self, temperature=0.1, ignore_index=0, reduction='mean'):
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.temperature = temperature
|
|
183
|
+
self.ignore_index = ignore_index
|
|
184
|
+
self.reduction = reduction
|
|
185
|
+
|
|
186
|
+
def forward(self, embeddings, item_embeddings, targets):
|
|
187
|
+
"""Compute in-batch NCE loss.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
embeddings (torch.Tensor): User/query embeddings of shape (batch_size, embedding_dim)
|
|
191
|
+
item_embeddings (torch.Tensor): Item embeddings of shape (vocab_size, embedding_dim)
|
|
192
|
+
targets (torch.Tensor): Target item indices of shape (batch_size,)
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
torch.Tensor: In-batch NCE loss value
|
|
196
|
+
"""
|
|
197
|
+
# Compute logits: (batch_size, vocab_size)
|
|
198
|
+
logits = torch.matmul(embeddings, item_embeddings.t()) / self.temperature
|
|
199
|
+
|
|
200
|
+
# Compute log softmax
|
|
201
|
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
202
|
+
|
|
203
|
+
# Get log probability of target class
|
|
204
|
+
batch_size = targets.shape[0]
|
|
205
|
+
target_log_probs = log_probs[torch.arange(batch_size), targets]
|
|
206
|
+
|
|
207
|
+
# Create mask for ignore_index
|
|
208
|
+
mask = targets != self.ignore_index
|
|
209
|
+
|
|
210
|
+
# Compute loss
|
|
211
|
+
loss = -target_log_probs
|
|
212
|
+
|
|
213
|
+
# Apply mask
|
|
214
|
+
if mask.any():
|
|
215
|
+
loss = loss[mask]
|
|
216
|
+
|
|
217
|
+
# Apply reduction
|
|
218
|
+
if self.reduction == 'mean':
|
|
219
|
+
return loss.mean()
|
|
220
|
+
elif self.reduction == 'sum':
|
|
221
|
+
return loss.sum()
|
|
222
|
+
else: # 'none'
|
|
223
|
+
return loss
|
|
@@ -1,72 +1,76 @@
|
|
|
1
|
-
"""The metaoptimizer module, it provides a class MetaBalance
|
|
2
|
-
MetaBalance is used to scale the gradient and balance the gradient of each task
|
|
3
|
-
Authors: Qida Dong, dongjidan@126.com
|
|
4
|
-
"""
|
|
5
|
-
import torch
|
|
6
|
-
from torch.optim.optimizer import Optimizer
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class MetaBalance(Optimizer):
|
|
10
|
-
"""MetaBalance Optimizer
|
|
11
|
-
This method is used to scale the gradient and balance the gradient of each task
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
parameters (list): the parameters of model
|
|
15
|
-
relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
|
|
16
|
-
beta (float, optional): the coefficient of moving average (default: 0.9)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def __init__(self, parameters, relax_factor=0.7, beta=0.9):
|
|
20
|
-
|
|
21
|
-
if relax_factor < 0. or relax_factor >= 1.:
|
|
22
|
-
raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
|
|
23
|
-
if beta < 0. or beta >= 1.:
|
|
24
|
-
raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
|
|
25
|
-
rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
|
|
26
|
-
super(MetaBalance, self).__init__(parameters, rel_beta_dict)
|
|
27
|
-
|
|
28
|
-
@torch.no_grad()
|
|
29
|
-
def step(self, losses):
|
|
30
|
-
"""_summary_
|
|
31
|
-
Args:
|
|
32
|
-
losses (_type_): _description_
|
|
33
|
-
|
|
34
|
-
Raises:
|
|
35
|
-
RuntimeError: _description_
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
for idx, loss in enumerate(losses):
|
|
39
|
-
loss.backward(retain_graph=True)
|
|
40
|
-
for group in self.param_groups:
|
|
41
|
-
for gp in group['params']:
|
|
42
|
-
if gp.grad is None:
|
|
43
|
-
# print('breaking')
|
|
44
|
-
break
|
|
45
|
-
if gp.grad.is_sparse:
|
|
46
|
-
raise RuntimeError('MetaBalance does not support sparse gradients')
|
|
47
|
-
|
|
48
|
-
state = self.state[gp]
|
|
49
|
-
if len(state) == 0:
|
|
50
|
-
for i in range(len(losses)):
|
|
51
|
-
if i == 0:
|
|
52
|
-
gp.norms = [0]
|
|
53
|
-
else:
|
|
54
|
-
gp.norms.append(0)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
#
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
gp.grad
|
|
71
|
-
|
|
72
|
-
|
|
1
|
+
"""The metaoptimizer module, it provides a class MetaBalance
|
|
2
|
+
MetaBalance is used to scale the gradient and balance the gradient of each task
|
|
3
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
4
|
+
"""
|
|
5
|
+
import torch
|
|
6
|
+
from torch.optim.optimizer import Optimizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MetaBalance(Optimizer):
|
|
10
|
+
"""MetaBalance Optimizer
|
|
11
|
+
This method is used to scale the gradient and balance the gradient of each task
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
parameters (list): the parameters of model
|
|
15
|
+
relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
|
|
16
|
+
beta (float, optional): the coefficient of moving average (default: 0.9)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, parameters, relax_factor=0.7, beta=0.9):
|
|
20
|
+
|
|
21
|
+
if relax_factor < 0. or relax_factor >= 1.:
|
|
22
|
+
raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
|
|
23
|
+
if beta < 0. or beta >= 1.:
|
|
24
|
+
raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
|
|
25
|
+
rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
|
|
26
|
+
super(MetaBalance, self).__init__(parameters, rel_beta_dict)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad()
|
|
29
|
+
def step(self, losses):
|
|
30
|
+
"""_summary_
|
|
31
|
+
Args:
|
|
32
|
+
losses (_type_): _description_
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
RuntimeError: _description_
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
for idx, loss in enumerate(losses):
|
|
39
|
+
loss.backward(retain_graph=True)
|
|
40
|
+
for group in self.param_groups:
|
|
41
|
+
for gp in group['params']:
|
|
42
|
+
if gp.grad is None:
|
|
43
|
+
# print('breaking')
|
|
44
|
+
break
|
|
45
|
+
if gp.grad.is_sparse:
|
|
46
|
+
raise RuntimeError('MetaBalance does not support sparse gradients')
|
|
47
|
+
# store the result of moving average
|
|
48
|
+
state = self.state[gp]
|
|
49
|
+
if len(state) == 0:
|
|
50
|
+
for i in range(len(losses)):
|
|
51
|
+
if i == 0:
|
|
52
|
+
gp.norms = [0]
|
|
53
|
+
else:
|
|
54
|
+
gp.norms.append(0)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# calculate the moving average
|
|
58
|
+
beta = group['beta']
|
|
59
|
+
gp.norms[idx] = gp.norms[idx] * beta + \
|
|
60
|
+
(1 - beta) * torch.norm(gp.grad)
|
|
61
|
+
# scale the auxiliary gradient
|
|
62
|
+
relax_factor = group['relax_factor']
|
|
63
|
+
gp.grad = gp.grad * \
|
|
64
|
+
gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
|
|
65
|
+
# store the gradient of each auxiliary task in state
|
|
66
|
+
if idx == 0:
|
|
67
|
+
state['sum_gradient'] = torch.zeros_like(gp.data)
|
|
68
|
+
state['sum_gradient'] += gp.grad
|
|
69
|
+
else:
|
|
70
|
+
state['sum_gradient'] += gp.grad
|
|
71
|
+
|
|
72
|
+
if gp.grad is not None:
|
|
73
|
+
gp.grad.detach_()
|
|
74
|
+
gp.grad.zero_()
|
|
75
|
+
if idx == len(losses) - 1:
|
|
76
|
+
gp.grad = state['sum_gradient']
|