torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Common model utility functions for Torch-RecHub.
|
|
2
|
+
|
|
3
|
+
This module provides shared utilities for model introspection and input generation,
|
|
4
|
+
used by both ONNX export and visualization features.
|
|
5
|
+
|
|
6
|
+
Examples
|
|
7
|
+
--------
|
|
8
|
+
>>> from torch_rechub.utils.model_utils import extract_feature_info, generate_dummy_input
|
|
9
|
+
>>> feature_info = extract_feature_info(model)
|
|
10
|
+
>>> dummy_input = generate_dummy_input(feature_info['features'], batch_size=2)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
# Import feature types for type checking
|
|
19
|
+
try:
|
|
20
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
21
|
+
except ImportError:
|
|
22
|
+
# Fallback for standalone usage
|
|
23
|
+
SparseFeature = None
|
|
24
|
+
DenseFeature = None
|
|
25
|
+
SequenceFeature = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
|
|
29
|
+
"""Extract feature information from a torch-rechub model using reflection.
|
|
30
|
+
|
|
31
|
+
This function inspects model attributes to find feature lists without
|
|
32
|
+
modifying the model code. Supports various model architectures.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
model : nn.Module
|
|
37
|
+
The recommendation model to inspect.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
dict
|
|
42
|
+
Dictionary containing:
|
|
43
|
+
- 'features': List of unique Feature objects
|
|
44
|
+
- 'input_names': List of feature names in order
|
|
45
|
+
- 'input_types': Dict mapping feature name to feature type
|
|
46
|
+
- 'user_features': List of user-side features (for dual-tower models)
|
|
47
|
+
- 'item_features': List of item-side features (for dual-tower models)
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from torch_rechub.models.ranking import DeepFM
|
|
52
|
+
>>> model = DeepFM(deep_features, fm_features, mlp_params)
|
|
53
|
+
>>> info = extract_feature_info(model)
|
|
54
|
+
>>> print(info['input_names']) # ['user_id', 'item_id', ...]
|
|
55
|
+
"""
|
|
56
|
+
# Common feature attribute names across different model types
|
|
57
|
+
feature_attrs = [
|
|
58
|
+
'features', # MMOE, DCN, etc.
|
|
59
|
+
'deep_features', # DeepFM, WideDeep
|
|
60
|
+
'fm_features', # DeepFM
|
|
61
|
+
'wide_features', # WideDeep
|
|
62
|
+
'linear_features', # DeepFFM
|
|
63
|
+
'cross_features', # DeepFFM
|
|
64
|
+
'user_features', # DSSM, YoutubeDNN, MIND
|
|
65
|
+
'item_features', # DSSM, YoutubeDNN, MIND
|
|
66
|
+
'history_features', # DIN, MIND
|
|
67
|
+
'target_features', # DIN
|
|
68
|
+
'neg_item_feature', # YoutubeDNN, MIND
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
all_features = []
|
|
72
|
+
user_features = []
|
|
73
|
+
item_features = []
|
|
74
|
+
|
|
75
|
+
for attr in feature_attrs:
|
|
76
|
+
if hasattr(model, attr):
|
|
77
|
+
feat_list = getattr(model, attr)
|
|
78
|
+
if isinstance(feat_list, list) and len(feat_list) > 0:
|
|
79
|
+
all_features.extend(feat_list)
|
|
80
|
+
# Track user/item features for dual-tower models
|
|
81
|
+
if 'user' in attr or 'history' in attr:
|
|
82
|
+
user_features.extend(feat_list)
|
|
83
|
+
elif 'item' in attr:
|
|
84
|
+
item_features.extend(feat_list)
|
|
85
|
+
|
|
86
|
+
# Deduplicate features by name while preserving order
|
|
87
|
+
seen = set()
|
|
88
|
+
unique_features = []
|
|
89
|
+
for f in all_features:
|
|
90
|
+
if hasattr(f, 'name') and f.name not in seen:
|
|
91
|
+
seen.add(f.name)
|
|
92
|
+
unique_features.append(f)
|
|
93
|
+
|
|
94
|
+
# Deduplicate user/item features
|
|
95
|
+
seen_user = set()
|
|
96
|
+
unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
|
|
97
|
+
seen_item = set()
|
|
98
|
+
unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
|
|
99
|
+
|
|
100
|
+
# Build input names and types
|
|
101
|
+
input_names = [f.name for f in unique_features if hasattr(f, 'name')]
|
|
102
|
+
input_types = {f.name: type(f).__name__ for f in unique_features if hasattr(f, 'name')}
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
'features': unique_features,
|
|
106
|
+
'input_names': input_names,
|
|
107
|
+
'input_types': input_types,
|
|
108
|
+
'user_features': unique_user,
|
|
109
|
+
'item_features': unique_item,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
|
|
114
|
+
"""Generate dummy input tensors based on feature definitions.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
features : list
|
|
119
|
+
List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
120
|
+
batch_size : int, default=2
|
|
121
|
+
Batch size for dummy input.
|
|
122
|
+
seq_length : int, default=10
|
|
123
|
+
Sequence length for SequenceFeature.
|
|
124
|
+
device : str, default='cpu'
|
|
125
|
+
Device to create tensors on.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
tuple of Tensor
|
|
130
|
+
Tuple of tensors in the order of input features.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
|
|
135
|
+
>>> dummy = generate_dummy_input(features, batch_size=4)
|
|
136
|
+
>>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
|
|
137
|
+
"""
|
|
138
|
+
# Dynamic import to handle feature types
|
|
139
|
+
from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
140
|
+
|
|
141
|
+
inputs = []
|
|
142
|
+
for feat in features:
|
|
143
|
+
if isinstance(feat, SequenceFeature):
|
|
144
|
+
# Sequence features have shape [batch_size, seq_length]
|
|
145
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
|
|
146
|
+
elif isinstance(feat, SparseFeature):
|
|
147
|
+
# Sparse features have shape [batch_size]
|
|
148
|
+
tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
|
|
149
|
+
elif isinstance(feat, DenseFeature):
|
|
150
|
+
# Dense features always have shape [batch_size, embed_dim]
|
|
151
|
+
tensor = torch.randn(batch_size, feat.embed_dim, device=device)
|
|
152
|
+
else:
|
|
153
|
+
raise TypeError(f"Unsupported feature type: {type(feat)}")
|
|
154
|
+
inputs.append(tensor)
|
|
155
|
+
return tuple(inputs)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def generate_dummy_input_dict(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Dict[str, torch.Tensor]:
|
|
159
|
+
"""Generate dummy input dict based on feature definitions.
|
|
160
|
+
|
|
161
|
+
Similar to generate_dummy_input but returns a dict mapping feature names
|
|
162
|
+
to tensors. This is the expected input format for torch-rechub models.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
features : list
|
|
167
|
+
List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
|
|
168
|
+
batch_size : int, default=2
|
|
169
|
+
Batch size for dummy input.
|
|
170
|
+
seq_length : int, default=10
|
|
171
|
+
Sequence length for SequenceFeature.
|
|
172
|
+
device : str, default='cpu'
|
|
173
|
+
Device to create tensors on.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
dict
|
|
178
|
+
Dict mapping feature names to tensors.
|
|
179
|
+
|
|
180
|
+
Examples
|
|
181
|
+
--------
|
|
182
|
+
>>> features = [SparseFeature("user_id", 1000)]
|
|
183
|
+
>>> dummy = generate_dummy_input_dict(features, batch_size=4)
|
|
184
|
+
>>> # Returns {"user_id": tensor[4]}
|
|
185
|
+
"""
|
|
186
|
+
dummy_tuple = generate_dummy_input(features, batch_size, seq_length, device)
|
|
187
|
+
input_names = [f.name for f in features if hasattr(f, 'name')]
|
|
188
|
+
return {name: tensor for name, tensor in zip(input_names, dummy_tuple)}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def generate_dynamic_axes(input_names: List[str], output_names: Optional[List[str]] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: Optional[List[str]] = None) -> Dict[str, Dict[int, str]]:
|
|
192
|
+
"""Generate dynamic axes configuration for ONNX export.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
input_names : list of str
|
|
197
|
+
List of input tensor names.
|
|
198
|
+
output_names : list of str, optional
|
|
199
|
+
List of output tensor names. Default is ["output"].
|
|
200
|
+
batch_dim : int, default=0
|
|
201
|
+
Dimension index for batch size.
|
|
202
|
+
include_seq_dim : bool, default=True
|
|
203
|
+
Whether to include sequence dimension as dynamic.
|
|
204
|
+
seq_features : list of str, optional
|
|
205
|
+
List of feature names that are sequences.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
dict
|
|
210
|
+
Dynamic axes dict for torch.onnx.export.
|
|
211
|
+
|
|
212
|
+
Examples
|
|
213
|
+
--------
|
|
214
|
+
>>> axes = generate_dynamic_axes(["user_id", "item_id"], seq_features=["hist"])
|
|
215
|
+
>>> # Returns {"user_id": {0: "batch_size"}, "item_id": {0: "batch_size"}, ...}
|
|
216
|
+
"""
|
|
217
|
+
if output_names is None:
|
|
218
|
+
output_names = ["output"]
|
|
219
|
+
|
|
220
|
+
dynamic_axes = {}
|
|
221
|
+
|
|
222
|
+
# Input axes
|
|
223
|
+
for name in input_names:
|
|
224
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
225
|
+
# Add sequence dimension for sequence features
|
|
226
|
+
if include_seq_dim and seq_features and name in seq_features:
|
|
227
|
+
dynamic_axes[name][1] = "seq_length"
|
|
228
|
+
|
|
229
|
+
# Output axes
|
|
230
|
+
for name in output_names:
|
|
231
|
+
dynamic_axes[name] = {batch_dim: "batch_size"}
|
|
232
|
+
|
|
233
|
+
return dynamic_axes
|
torch_rechub/utils/mtl.py
CHANGED
|
@@ -1,126 +1,136 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.optim.optimizer import Optimizer
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
task_layers
|
|
28
|
-
elif isinstance(model,
|
|
29
|
-
shared_layers += list(model.
|
|
30
|
-
task_layers = list(model.
|
|
31
|
-
model.
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
if gp.grad
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
if
|
|
91
|
-
gp.
|
|
92
|
-
gp.grad
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim.optimizer import Optimizer
|
|
3
|
+
|
|
4
|
+
from ..models.multi_task import AITM, MMOE, PLE, SharedBottom
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def shared_task_layers(model):
|
|
8
|
+
"""get shared layers and task layers in multi-task model
|
|
9
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
model (torch.nn.Module): only support `[MMOE, SharedBottom, PLE, AITM]`
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
list[torch.nn.parameter]: parameters split to shared list and task list.
|
|
16
|
+
"""
|
|
17
|
+
shared_layers = list(model.embedding.parameters())
|
|
18
|
+
task_layers = None
|
|
19
|
+
if isinstance(model, SharedBottom):
|
|
20
|
+
shared_layers += list(model.bottom_mlp.parameters())
|
|
21
|
+
task_layers = list(model.towers.parameters()) + \
|
|
22
|
+
list(model.predict_layers.parameters())
|
|
23
|
+
elif isinstance(model, MMOE):
|
|
24
|
+
shared_layers += list(model.experts.parameters())
|
|
25
|
+
task_layers = list(model.towers.parameters()) + \
|
|
26
|
+
list(model.predict_layers.parameters())
|
|
27
|
+
task_layers += list(model.gates.parameters())
|
|
28
|
+
elif isinstance(model, PLE):
|
|
29
|
+
shared_layers += list(model.cgc_layers.parameters())
|
|
30
|
+
task_layers = list(model.towers.parameters()) + \
|
|
31
|
+
list(model.predict_layers.parameters())
|
|
32
|
+
elif isinstance(model, AITM):
|
|
33
|
+
shared_layers += list(model.bottoms.parameters())
|
|
34
|
+
task_layers = list(model.info_gates.parameters()) + list(model.towers.parameters()) + list(model.aits.parameters())
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(f'this model {model} is not suitable for MetaBalance Optimizer')
|
|
37
|
+
return shared_layers, task_layers
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MetaBalance(Optimizer):
|
|
41
|
+
"""MetaBalance Optimizer
|
|
42
|
+
This method is used to scale the gradient and balance the gradient of each task.
|
|
43
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
parameters (list): the parameters of model
|
|
47
|
+
relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
|
|
48
|
+
beta (float, optional): the coefficient of moving average (default: 0.9)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, parameters, relax_factor=0.7, beta=0.9):
|
|
52
|
+
|
|
53
|
+
if relax_factor < 0. or relax_factor >= 1.:
|
|
54
|
+
raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
|
|
55
|
+
if beta < 0. or beta >= 1.:
|
|
56
|
+
raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
|
|
57
|
+
rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
|
|
58
|
+
super(MetaBalance, self).__init__(parameters, rel_beta_dict)
|
|
59
|
+
|
|
60
|
+
@torch.no_grad()
|
|
61
|
+
def step(self, losses):
|
|
62
|
+
for idx, loss in enumerate(losses):
|
|
63
|
+
loss.backward(retain_graph=True)
|
|
64
|
+
for group in self.param_groups:
|
|
65
|
+
for gp in group['params']:
|
|
66
|
+
if gp.grad is None:
|
|
67
|
+
# print('breaking')
|
|
68
|
+
break
|
|
69
|
+
if gp.grad.is_sparse:
|
|
70
|
+
raise RuntimeError('MetaBalance does not support sparse gradients')
|
|
71
|
+
# store the result of moving average
|
|
72
|
+
state = self.state[gp]
|
|
73
|
+
if len(state) == 0:
|
|
74
|
+
for i in range(len(losses)):
|
|
75
|
+
if i == 0:
|
|
76
|
+
gp.norms = [0]
|
|
77
|
+
else:
|
|
78
|
+
gp.norms.append(0)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# calculate the moving average
|
|
82
|
+
beta = group['beta']
|
|
83
|
+
gp.norms[idx] = gp.norms[idx] * beta + \
|
|
84
|
+
(1 - beta) * torch.norm(gp.grad)
|
|
85
|
+
# scale the auxiliary gradient
|
|
86
|
+
relax_factor = group['relax_factor']
|
|
87
|
+
gp.grad = gp.grad * \
|
|
88
|
+
gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
|
|
89
|
+
# store the gradient of each auxiliary task in state
|
|
90
|
+
if idx == 0:
|
|
91
|
+
state['sum_gradient'] = torch.zeros_like(gp.data)
|
|
92
|
+
state['sum_gradient'] += gp.grad
|
|
93
|
+
else:
|
|
94
|
+
state['sum_gradient'] += gp.grad
|
|
95
|
+
|
|
96
|
+
if gp.grad is not None:
|
|
97
|
+
gp.grad.detach_()
|
|
98
|
+
gp.grad.zero_()
|
|
99
|
+
if idx == len(losses) - 1:
|
|
100
|
+
gp.grad = state['sum_gradient']
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def gradnorm(loss_list, loss_weight, share_layer, initial_task_loss, alpha):
|
|
104
|
+
loss = 0
|
|
105
|
+
for loss_i, w_i in zip(loss_list, loss_weight):
|
|
106
|
+
loss += loss_i * w_i
|
|
107
|
+
loss.backward(retain_graph=True)
|
|
108
|
+
# set the gradients of w_i(t) to zero because these gradients have to be
|
|
109
|
+
# updated using the GradNorm loss
|
|
110
|
+
for w_i in loss_weight:
|
|
111
|
+
w_i.grad.data = w_i.grad.data * 0.0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# get the gradient norms for each of the tasks
|
|
115
|
+
# G^{(i)}_w(t)
|
|
116
|
+
norms, loss_ratio = [], []
|
|
117
|
+
for i in range(len(loss_list)):
|
|
118
|
+
# get the gradient of this task loss with respect to the shared
|
|
119
|
+
# parameters
|
|
120
|
+
gygw = torch.autograd.grad(loss_list[i], share_layer, retain_graph=True)
|
|
121
|
+
# compute the norm
|
|
122
|
+
norms.append(torch.norm(torch.mul(loss_weight[i], gygw[0])))
|
|
123
|
+
# compute the inverse training rate r_i(t)
|
|
124
|
+
loss_ratio.append(loss_list[i].item() / initial_task_loss[i])
|
|
125
|
+
norms = torch.stack(norms)
|
|
126
|
+
mean_norm = torch.mean(norms.detach())
|
|
127
|
+
mean_loss_ratio = sum(loss_ratio) / len(loss_ratio)
|
|
128
|
+
# compute the GradNorm loss
|
|
129
|
+
# this term has to remain constant
|
|
130
|
+
constant_term = mean_norm * (mean_loss_ratio**alpha)
|
|
131
|
+
grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
|
|
132
|
+
# print('GradNorm loss {}'.format(grad_norm_loss))
|
|
133
|
+
|
|
134
|
+
# compute the gradient for the weights
|
|
135
|
+
for w_i in loss_weight:
|
|
136
|
+
w_i.grad = torch.autograd.grad(grad_norm_loss, w_i, retain_graph=True)[0]
|