torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -0,0 +1,233 @@
1
+ """Common model utility functions for Torch-RecHub.
2
+
3
+ This module provides shared utilities for model introspection and input generation,
4
+ used by both ONNX export and visualization features.
5
+
6
+ Examples
7
+ --------
8
+ >>> from torch_rechub.utils.model_utils import extract_feature_info, generate_dummy_input
9
+ >>> feature_info = extract_feature_info(model)
10
+ >>> dummy_input = generate_dummy_input(feature_info['features'], batch_size=2)
11
+ """
12
+
13
+ from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ # Import feature types for type checking
19
+ try:
20
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
21
+ except ImportError:
22
+ # Fallback for standalone usage
23
+ SparseFeature = None
24
+ DenseFeature = None
25
+ SequenceFeature = None
26
+
27
+
28
+ def extract_feature_info(model: nn.Module) -> Dict[str, Any]:
29
+ """Extract feature information from a torch-rechub model using reflection.
30
+
31
+ This function inspects model attributes to find feature lists without
32
+ modifying the model code. Supports various model architectures.
33
+
34
+ Parameters
35
+ ----------
36
+ model : nn.Module
37
+ The recommendation model to inspect.
38
+
39
+ Returns
40
+ -------
41
+ dict
42
+ Dictionary containing:
43
+ - 'features': List of unique Feature objects
44
+ - 'input_names': List of feature names in order
45
+ - 'input_types': Dict mapping feature name to feature type
46
+ - 'user_features': List of user-side features (for dual-tower models)
47
+ - 'item_features': List of item-side features (for dual-tower models)
48
+
49
+ Examples
50
+ --------
51
+ >>> from torch_rechub.models.ranking import DeepFM
52
+ >>> model = DeepFM(deep_features, fm_features, mlp_params)
53
+ >>> info = extract_feature_info(model)
54
+ >>> print(info['input_names']) # ['user_id', 'item_id', ...]
55
+ """
56
+ # Common feature attribute names across different model types
57
+ feature_attrs = [
58
+ 'features', # MMOE, DCN, etc.
59
+ 'deep_features', # DeepFM, WideDeep
60
+ 'fm_features', # DeepFM
61
+ 'wide_features', # WideDeep
62
+ 'linear_features', # DeepFFM
63
+ 'cross_features', # DeepFFM
64
+ 'user_features', # DSSM, YoutubeDNN, MIND
65
+ 'item_features', # DSSM, YoutubeDNN, MIND
66
+ 'history_features', # DIN, MIND
67
+ 'target_features', # DIN
68
+ 'neg_item_feature', # YoutubeDNN, MIND
69
+ ]
70
+
71
+ all_features = []
72
+ user_features = []
73
+ item_features = []
74
+
75
+ for attr in feature_attrs:
76
+ if hasattr(model, attr):
77
+ feat_list = getattr(model, attr)
78
+ if isinstance(feat_list, list) and len(feat_list) > 0:
79
+ all_features.extend(feat_list)
80
+ # Track user/item features for dual-tower models
81
+ if 'user' in attr or 'history' in attr:
82
+ user_features.extend(feat_list)
83
+ elif 'item' in attr:
84
+ item_features.extend(feat_list)
85
+
86
+ # Deduplicate features by name while preserving order
87
+ seen = set()
88
+ unique_features = []
89
+ for f in all_features:
90
+ if hasattr(f, 'name') and f.name not in seen:
91
+ seen.add(f.name)
92
+ unique_features.append(f)
93
+
94
+ # Deduplicate user/item features
95
+ seen_user = set()
96
+ unique_user = [f for f in user_features if hasattr(f, 'name') and f.name not in seen_user and not seen_user.add(f.name)]
97
+ seen_item = set()
98
+ unique_item = [f for f in item_features if hasattr(f, 'name') and f.name not in seen_item and not seen_item.add(f.name)]
99
+
100
+ # Build input names and types
101
+ input_names = [f.name for f in unique_features if hasattr(f, 'name')]
102
+ input_types = {f.name: type(f).__name__ for f in unique_features if hasattr(f, 'name')}
103
+
104
+ return {
105
+ 'features': unique_features,
106
+ 'input_names': input_names,
107
+ 'input_types': input_types,
108
+ 'user_features': unique_user,
109
+ 'item_features': unique_item,
110
+ }
111
+
112
+
113
+ def generate_dummy_input(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Tuple[torch.Tensor, ...]:
114
+ """Generate dummy input tensors based on feature definitions.
115
+
116
+ Parameters
117
+ ----------
118
+ features : list
119
+ List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
120
+ batch_size : int, default=2
121
+ Batch size for dummy input.
122
+ seq_length : int, default=10
123
+ Sequence length for SequenceFeature.
124
+ device : str, default='cpu'
125
+ Device to create tensors on.
126
+
127
+ Returns
128
+ -------
129
+ tuple of Tensor
130
+ Tuple of tensors in the order of input features.
131
+
132
+ Examples
133
+ --------
134
+ >>> features = [SparseFeature("user_id", 1000), SequenceFeature("hist", 500)]
135
+ >>> dummy = generate_dummy_input(features, batch_size=4)
136
+ >>> # Returns (user_id_tensor[4], hist_tensor[4, 10])
137
+ """
138
+ # Dynamic import to handle feature types
139
+ from ..basic.features import DenseFeature, SequenceFeature, SparseFeature
140
+
141
+ inputs = []
142
+ for feat in features:
143
+ if isinstance(feat, SequenceFeature):
144
+ # Sequence features have shape [batch_size, seq_length]
145
+ tensor = torch.randint(0, feat.vocab_size, (batch_size, seq_length), device=device)
146
+ elif isinstance(feat, SparseFeature):
147
+ # Sparse features have shape [batch_size]
148
+ tensor = torch.randint(0, feat.vocab_size, (batch_size,), device=device)
149
+ elif isinstance(feat, DenseFeature):
150
+ # Dense features always have shape [batch_size, embed_dim]
151
+ tensor = torch.randn(batch_size, feat.embed_dim, device=device)
152
+ else:
153
+ raise TypeError(f"Unsupported feature type: {type(feat)}")
154
+ inputs.append(tensor)
155
+ return tuple(inputs)
156
+
157
+
158
+ def generate_dummy_input_dict(features: List[Any], batch_size: int = 2, seq_length: int = 10, device: str = 'cpu') -> Dict[str, torch.Tensor]:
159
+ """Generate dummy input dict based on feature definitions.
160
+
161
+ Similar to generate_dummy_input but returns a dict mapping feature names
162
+ to tensors. This is the expected input format for torch-rechub models.
163
+
164
+ Parameters
165
+ ----------
166
+ features : list
167
+ List of Feature objects (SparseFeature, DenseFeature, SequenceFeature).
168
+ batch_size : int, default=2
169
+ Batch size for dummy input.
170
+ seq_length : int, default=10
171
+ Sequence length for SequenceFeature.
172
+ device : str, default='cpu'
173
+ Device to create tensors on.
174
+
175
+ Returns
176
+ -------
177
+ dict
178
+ Dict mapping feature names to tensors.
179
+
180
+ Examples
181
+ --------
182
+ >>> features = [SparseFeature("user_id", 1000)]
183
+ >>> dummy = generate_dummy_input_dict(features, batch_size=4)
184
+ >>> # Returns {"user_id": tensor[4]}
185
+ """
186
+ dummy_tuple = generate_dummy_input(features, batch_size, seq_length, device)
187
+ input_names = [f.name for f in features if hasattr(f, 'name')]
188
+ return {name: tensor for name, tensor in zip(input_names, dummy_tuple)}
189
+
190
+
191
+ def generate_dynamic_axes(input_names: List[str], output_names: Optional[List[str]] = None, batch_dim: int = 0, include_seq_dim: bool = True, seq_features: Optional[List[str]] = None) -> Dict[str, Dict[int, str]]:
192
+ """Generate dynamic axes configuration for ONNX export.
193
+
194
+ Parameters
195
+ ----------
196
+ input_names : list of str
197
+ List of input tensor names.
198
+ output_names : list of str, optional
199
+ List of output tensor names. Default is ["output"].
200
+ batch_dim : int, default=0
201
+ Dimension index for batch size.
202
+ include_seq_dim : bool, default=True
203
+ Whether to include sequence dimension as dynamic.
204
+ seq_features : list of str, optional
205
+ List of feature names that are sequences.
206
+
207
+ Returns
208
+ -------
209
+ dict
210
+ Dynamic axes dict for torch.onnx.export.
211
+
212
+ Examples
213
+ --------
214
+ >>> axes = generate_dynamic_axes(["user_id", "item_id"], seq_features=["hist"])
215
+ >>> # Returns {"user_id": {0: "batch_size"}, "item_id": {0: "batch_size"}, ...}
216
+ """
217
+ if output_names is None:
218
+ output_names = ["output"]
219
+
220
+ dynamic_axes = {}
221
+
222
+ # Input axes
223
+ for name in input_names:
224
+ dynamic_axes[name] = {batch_dim: "batch_size"}
225
+ # Add sequence dimension for sequence features
226
+ if include_seq_dim and seq_features and name in seq_features:
227
+ dynamic_axes[name][1] = "seq_length"
228
+
229
+ # Output axes
230
+ for name in output_names:
231
+ dynamic_axes[name] = {batch_dim: "batch_size"}
232
+
233
+ return dynamic_axes
torch_rechub/utils/mtl.py CHANGED
@@ -1,126 +1,136 @@
1
- import torch
2
- from torch.optim.optimizer import Optimizer
3
- from ..models.multi_task import MMOE, SharedBottom, PLE, AITM
4
-
5
-
6
- def shared_task_layers(model):
7
- """get shared layers and task layers in multi-task model
8
- Authors: Qida Dong, dongjidan@126.com
9
-
10
- Args:
11
- model (torch.nn.Module): only support `[MMOE, SharedBottom, PLE, AITM]`
12
-
13
- Returns:
14
- list[torch.nn.parameter]: parameters split to shared list and task list.
15
- """
16
- shared_layers = list(model.embedding.parameters())
17
- task_layers = None
18
- if isinstance(model, SharedBottom):
19
- shared_layers += list(model.bottom_mlp.parameters())
20
- task_layers = list(model.towers.parameters()) + list(model.predict_layers.parameters())
21
- elif isinstance(model, MMOE):
22
- shared_layers += list(model.experts.parameters())
23
- task_layers = list(model.towers.parameters()) + list(model.predict_layers.parameters())
24
- task_layers += list(model.gates.parameters())
25
- elif isinstance(model, PLE):
26
- shared_layers += list(model.cgc_layers.parameters())
27
- task_layers = list(model.towers.parameters()) + list(model.predict_layers.parameters())
28
- elif isinstance(model, AITM):
29
- shared_layers += list(model.bottoms.parameters())
30
- task_layers = list(model.info_gates.parameters()) + list(model.towers.parameters()) + list(
31
- model.aits.parameters())
32
- else:
33
- raise ValueError(f'this model {model} is not suitable for MetaBalance Optimizer')
34
- return shared_layers, task_layers
35
-
36
-
37
- class MetaBalance(Optimizer):
38
- """MetaBalance Optimizer
39
- This method is used to scale the gradient and balance the gradient of each task.
40
- Authors: Qida Dong, dongjidan@126.com
41
-
42
- Args:
43
- parameters (list): the parameters of model
44
- relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
45
- beta (float, optional): the coefficient of moving average (default: 0.9)
46
- """
47
-
48
- def __init__(self, parameters, relax_factor=0.7, beta=0.9):
49
-
50
- if relax_factor < 0. or relax_factor >= 1.:
51
- raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
52
- if beta < 0. or beta >= 1.:
53
- raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
54
- rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
55
- super(MetaBalance, self).__init__(parameters, rel_beta_dict)
56
-
57
- @torch.no_grad()
58
- def step(self, losses):
59
- for idx, loss in enumerate(losses):
60
- loss.backward(retain_graph=True)
61
- for group in self.param_groups:
62
- for gp in group['params']:
63
- if gp.grad is None:
64
- # print('breaking')
65
- break
66
- if gp.grad.is_sparse:
67
- raise RuntimeError('MetaBalance does not support sparse gradients')
68
- # store the result of moving average
69
- state = self.state[gp]
70
- if len(state) == 0:
71
- for i in range(len(losses)):
72
- if i == 0:
73
- gp.norms = [0]
74
- else:
75
- gp.norms.append(0)
76
- # calculate the moving average
77
- beta = group['beta']
78
- gp.norms[idx] = gp.norms[idx] * beta + (1 - beta) * torch.norm(gp.grad)
79
- # scale the auxiliary gradient
80
- relax_factor = group['relax_factor']
81
- gp.grad = gp.grad * gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. -
82
- relax_factor)
83
- # store the gradient of each auxiliary task in state
84
- if idx == 0:
85
- state['sum_gradient'] = torch.zeros_like(gp.data)
86
- state['sum_gradient'] += gp.grad
87
- else:
88
- state['sum_gradient'] += gp.grad
89
-
90
- if gp.grad is not None:
91
- gp.grad.detach_()
92
- gp.grad.zero_()
93
- if idx == len(losses) - 1:
94
- gp.grad = state['sum_gradient']
95
-
96
-
97
- def gradnorm(loss_list, loss_weight, share_layer, initial_task_loss, alpha):
98
- loss = 0
99
- for loss_i, w_i in zip(loss_list, loss_weight):
100
- loss += loss_i * w_i
101
- loss.backward(retain_graph=True)
102
- # set the gradients of w_i(t) to zero because these gradients have to be updated using the GradNorm loss
103
- for w_i in loss_weight:
104
- w_i.grad.data = w_i.grad.data * 0.0
105
- # get the gradient norms for each of the tasks
106
- # G^{(i)}_w(t)
107
- norms, loss_ratio = [], []
108
- for i in range(len(loss_list)):
109
- # get the gradient of this task loss with respect to the shared parameters
110
- gygw = torch.autograd.grad(loss_list[i], share_layer, retain_graph=True)
111
- # compute the norm
112
- norms.append(torch.norm(torch.mul(loss_weight[i], gygw[0])))
113
- # compute the inverse training rate r_i(t)
114
- loss_ratio.append(loss_list[i].item() / initial_task_loss[i])
115
- norms = torch.stack(norms)
116
- mean_norm = torch.mean(norms.detach())
117
- mean_loss_ratio = sum(loss_ratio) / len(loss_ratio)
118
- # compute the GradNorm loss
119
- # this term has to remain constant
120
- constant_term = mean_norm * (mean_loss_ratio**alpha)
121
- grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
122
- #print('GradNorm loss {}'.format(grad_norm_loss))
123
-
124
- # compute the gradient for the weights
125
- for w_i in loss_weight:
126
- w_i.grad = torch.autograd.grad(grad_norm_loss, w_i, retain_graph=True)[0]
1
+ import torch
2
+ from torch.optim.optimizer import Optimizer
3
+
4
+ from ..models.multi_task import AITM, MMOE, PLE, SharedBottom
5
+
6
+
7
+ def shared_task_layers(model):
8
+ """get shared layers and task layers in multi-task model
9
+ Authors: Qida Dong, dongjidan@126.com
10
+
11
+ Args:
12
+ model (torch.nn.Module): only support `[MMOE, SharedBottom, PLE, AITM]`
13
+
14
+ Returns:
15
+ list[torch.nn.parameter]: parameters split to shared list and task list.
16
+ """
17
+ shared_layers = list(model.embedding.parameters())
18
+ task_layers = None
19
+ if isinstance(model, SharedBottom):
20
+ shared_layers += list(model.bottom_mlp.parameters())
21
+ task_layers = list(model.towers.parameters()) + \
22
+ list(model.predict_layers.parameters())
23
+ elif isinstance(model, MMOE):
24
+ shared_layers += list(model.experts.parameters())
25
+ task_layers = list(model.towers.parameters()) + \
26
+ list(model.predict_layers.parameters())
27
+ task_layers += list(model.gates.parameters())
28
+ elif isinstance(model, PLE):
29
+ shared_layers += list(model.cgc_layers.parameters())
30
+ task_layers = list(model.towers.parameters()) + \
31
+ list(model.predict_layers.parameters())
32
+ elif isinstance(model, AITM):
33
+ shared_layers += list(model.bottoms.parameters())
34
+ task_layers = list(model.info_gates.parameters()) + list(model.towers.parameters()) + list(model.aits.parameters())
35
+ else:
36
+ raise ValueError(f'this model {model} is not suitable for MetaBalance Optimizer')
37
+ return shared_layers, task_layers
38
+
39
+
40
+ class MetaBalance(Optimizer):
41
+ """MetaBalance Optimizer
42
+ This method is used to scale the gradient and balance the gradient of each task.
43
+ Authors: Qida Dong, dongjidan@126.com
44
+
45
+ Args:
46
+ parameters (list): the parameters of model
47
+ relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
48
+ beta (float, optional): the coefficient of moving average (default: 0.9)
49
+ """
50
+
51
+ def __init__(self, parameters, relax_factor=0.7, beta=0.9):
52
+
53
+ if relax_factor < 0. or relax_factor >= 1.:
54
+ raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
55
+ if beta < 0. or beta >= 1.:
56
+ raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
57
+ rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
58
+ super(MetaBalance, self).__init__(parameters, rel_beta_dict)
59
+
60
+ @torch.no_grad()
61
+ def step(self, losses):
62
+ for idx, loss in enumerate(losses):
63
+ loss.backward(retain_graph=True)
64
+ for group in self.param_groups:
65
+ for gp in group['params']:
66
+ if gp.grad is None:
67
+ # print('breaking')
68
+ break
69
+ if gp.grad.is_sparse:
70
+ raise RuntimeError('MetaBalance does not support sparse gradients')
71
+ # store the result of moving average
72
+ state = self.state[gp]
73
+ if len(state) == 0:
74
+ for i in range(len(losses)):
75
+ if i == 0:
76
+ gp.norms = [0]
77
+ else:
78
+ gp.norms.append(0)
79
+
80
+
81
+ # calculate the moving average
82
+ beta = group['beta']
83
+ gp.norms[idx] = gp.norms[idx] * beta + \
84
+ (1 - beta) * torch.norm(gp.grad)
85
+ # scale the auxiliary gradient
86
+ relax_factor = group['relax_factor']
87
+ gp.grad = gp.grad * \
88
+ gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
89
+ # store the gradient of each auxiliary task in state
90
+ if idx == 0:
91
+ state['sum_gradient'] = torch.zeros_like(gp.data)
92
+ state['sum_gradient'] += gp.grad
93
+ else:
94
+ state['sum_gradient'] += gp.grad
95
+
96
+ if gp.grad is not None:
97
+ gp.grad.detach_()
98
+ gp.grad.zero_()
99
+ if idx == len(losses) - 1:
100
+ gp.grad = state['sum_gradient']
101
+
102
+
103
+ def gradnorm(loss_list, loss_weight, share_layer, initial_task_loss, alpha):
104
+ loss = 0
105
+ for loss_i, w_i in zip(loss_list, loss_weight):
106
+ loss += loss_i * w_i
107
+ loss.backward(retain_graph=True)
108
+ # set the gradients of w_i(t) to zero because these gradients have to be
109
+ # updated using the GradNorm loss
110
+ for w_i in loss_weight:
111
+ w_i.grad.data = w_i.grad.data * 0.0
112
+
113
+
114
+ # get the gradient norms for each of the tasks
115
+ # G^{(i)}_w(t)
116
+ norms, loss_ratio = [], []
117
+ for i in range(len(loss_list)):
118
+ # get the gradient of this task loss with respect to the shared
119
+ # parameters
120
+ gygw = torch.autograd.grad(loss_list[i], share_layer, retain_graph=True)
121
+ # compute the norm
122
+ norms.append(torch.norm(torch.mul(loss_weight[i], gygw[0])))
123
+ # compute the inverse training rate r_i(t)
124
+ loss_ratio.append(loss_list[i].item() / initial_task_loss[i])
125
+ norms = torch.stack(norms)
126
+ mean_norm = torch.mean(norms.detach())
127
+ mean_loss_ratio = sum(loss_ratio) / len(loss_ratio)
128
+ # compute the GradNorm loss
129
+ # this term has to remain constant
130
+ constant_term = mean_norm * (mean_loss_ratio**alpha)
131
+ grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
132
+ # print('GradNorm loss {}'.format(grad_norm_loss))
133
+
134
+ # compute the gradient for the weights
135
+ for w_i in loss_weight:
136
+ w_i.grad = torch.autograd.grad(grad_norm_loss, w_i, retain_graph=True)[0]