torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,457 @@
1
+ import copy
2
+ import random
3
+ from collections import Counter, OrderedDict
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import tqdm
8
+
9
+ from .data import df_to_dict, pad_sequences
10
+
11
+ # Optional imports with fallbacks
12
+ try:
13
+ from annoy import AnnoyIndex
14
+ ANNOY_AVAILABLE = True
15
+ except ImportError:
16
+ ANNOY_AVAILABLE = False
17
+
18
+ try:
19
+ import torch
20
+ from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
21
+ MILVUS_AVAILABLE = True
22
+ except ImportError:
23
+ MILVUS_AVAILABLE = False
24
+
25
+ try:
26
+ import faiss
27
+ FAISS_AVAILABLE = True
28
+ except ImportError:
29
+ FAISS_AVAILABLE = False
30
+
31
+
32
+ def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
33
+ """Merge user_profile and item_profile to df, pad and truncate history sequence feature.
34
+
35
+ Args:
36
+ df (pd.DataFrame): data with history sequence feature
37
+ user_profile (pd.DataFrame): user data
38
+ user_col (str): user column name
39
+ item_profile (pd.DataFrame): item data
40
+ item_col (str): item column name
41
+ seq_max_len (int): sequence length of every data
42
+ padding (str, optional): padding style, {'pre', 'post'}. Defaults to 'pre'.
43
+ truncating (str, optional): truncate style, {'pre', 'post'}. Defaults to 'pre'.
44
+
45
+ Returns:
46
+ dict: The converted dict, which can be used directly into the input network
47
+ """
48
+ df = pd.merge(df, user_profile, on=user_col, how='left') # how=left to keep samples order same as the input
49
+ df = pd.merge(df, item_profile, on=item_col, how='left')
50
+ for col in df.columns.to_list():
51
+ if col.startswith("hist_"):
52
+ df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
53
+ for col in df.columns.to_list():
54
+ if col.startswith("tag_"):
55
+ df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
56
+
57
+ input_dict = df_to_dict(df)
58
+ return input_dict
59
+
60
+
61
+ def negative_sample(items_cnt_order, ratio, method_id=0):
62
+ """Negative Sample method for matching model.
63
+
64
+ Reference: https://github.com/wangzhegeek/DSSM-Lookalike/blob/master/utils.py
65
+ Updated with more methods and redesigned this function.
66
+
67
+ Args:
68
+ items_cnt_order (dict): the item count dict, the keys(item) sorted by value(count) in reverse order.
69
+ ratio (int): negative sample ratio, >= 1
70
+ method_id (int, optional):
71
+ `{
72
+ 0: "random sampling",
73
+ 1: "popularity sampling method used in word2vec",
74
+ 2: "popularity sampling method by `log(count+1)+1e-6`",
75
+ 3: "tencent RALM sampling"}`.
76
+ Defaults to 0.
77
+
78
+ Returns:
79
+ list: sampled negative item list
80
+ """
81
+ items_set = [item for item, count in items_cnt_order.items()]
82
+ if method_id == 0:
83
+ neg_items = np.random.choice(items_set, size=ratio, replace=True)
84
+ elif method_id == 1:
85
+ # items_cnt_freq = {item: count/len(items_cnt) for item, count in items_cnt_order.items()}
86
+ # p_sel = {item: np.sqrt(1e-5/items_cnt_freq[item]) for item in items_cnt_order}
87
+ # The most popular paramter is item_cnt**0.75:
88
+ p_sel = {item: count**0.75 for item, count in items_cnt_order.items()}
89
+ p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
90
+ neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
91
+ elif method_id == 2:
92
+ p_sel = {item: np.log(count + 1) + 1e-6 for item, count in items_cnt_order.items()}
93
+ p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
94
+ neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
95
+ elif method_id == 3:
96
+ p_sel = {item: (np.log(k + 2) - np.log(k + 1)) / np.log(len(items_cnt_order) + 1) for item, k in items_cnt_order.items()}
97
+ p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
98
+ neg_items = np.random.choice(items_set, size=ratio, replace=False, p=p_value)
99
+ else:
100
+ raise ValueError("method id should in (0,1,2,3)")
101
+ return neg_items
102
+
103
+
104
+ def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
105
+ """Generate sequence feature and negative sample for match.
106
+
107
+ Args:
108
+ data (pd.DataFrame): the raw data.
109
+ user_col (str): the col name of user_id
110
+ item_col (str): the col name of item_id
111
+ time_col (str): the col name of timestamp
112
+ item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
113
+ sample_method (int, optional): the negative sample method `{
114
+ 0: "random sampling",
115
+ 1: "popularity sampling method used in word2vec",
116
+ 2: "popularity sampling method by `log(count+1)+1e-6`",
117
+ 3: "tencent RALM sampling"}`.
118
+ Defaults to 0.
119
+ mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
120
+ neg_ratio (int, optional): negative sample ratio, >= 1. Defaults to 0.
121
+ min_item (int, optional): the min item each user must have. Defaults to 0.
122
+
123
+ Returns:
124
+ pd.DataFrame: split train and test data with sequence features.
125
+ """
126
+ if item_attribute_cols is None:
127
+ item_attribute_cols = []
128
+ if mode == 2: # list wise learning
129
+ assert neg_ratio > 0, 'neg_ratio must be greater than 0 when list-wise learning'
130
+ elif mode == 1: # pair wise learning
131
+ neg_ratio = 1
132
+ print("preprocess data")
133
+ data.sort_values(time_col, inplace=True) # sort by time from old to new
134
+ train_set, test_set = [], []
135
+ n_cold_user = 0
136
+
137
+ items_cnt = Counter(data[item_col].tolist())
138
+ items_cnt_order = OrderedDict(sorted((items_cnt.items()), key=lambda x: x[1], reverse=True)) # item_id:item count
139
+ neg_list = negative_sample(items_cnt_order, ratio=data.shape[0] * neg_ratio, method_id=sample_method)
140
+ neg_idx = 0
141
+ for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
142
+ pos_list = hist[item_col].tolist()
143
+ if len(pos_list) < min_item: # drop this user when his pos items < min_item
144
+ n_cold_user += 1
145
+ continue
146
+
147
+ for i in range(1, len(pos_list)):
148
+ hist_item = pos_list[:i]
149
+ sample = [uid, pos_list[i], hist_item, len(hist_item)]
150
+ if len(item_attribute_cols) > 0:
151
+ for attr_col in item_attribute_cols: # the history of item attribute features
152
+ sample.append(hist[attr_col].tolist()[:i])
153
+ if i != len(pos_list) - 1:
154
+ if mode == 0: # point-wise, the last col is label_col, include label 0 and 1
155
+ last_col = "label"
156
+ train_set.append(sample + [1])
157
+ for _ in range(neg_ratio):
158
+ sample[1] = neg_list[neg_idx]
159
+ neg_idx += 1
160
+ train_set.append(sample + [0])
161
+ elif mode == 1: # pair-wise, the last col is neg_col, include one negative item
162
+ last_col = "neg_items"
163
+ for _ in range(neg_ratio):
164
+ sample_copy = copy.deepcopy(sample)
165
+ sample_copy.append(neg_list[neg_idx])
166
+ neg_idx += 1
167
+ train_set.append(sample_copy)
168
+ elif mode == 2: # list-wise, the last col is neg_col, include neg_ratio negative items
169
+ last_col = "neg_items"
170
+ sample.append(neg_list[neg_idx:neg_idx + neg_ratio])
171
+ neg_idx += neg_ratio
172
+ train_set.append(sample)
173
+ else:
174
+ raise ValueError("mode should in (0,1,2)")
175
+ else:
176
+ # Note: if mode=1 or 2, the label col is useless.
177
+ test_set.append(sample + [1])
178
+
179
+ random.shuffle(train_set)
180
+ random.shuffle(test_set)
181
+
182
+ print("n_train: %d, n_test: %d" % (len(train_set), len(test_set)))
183
+ print("%d cold start user dropped " % n_cold_user)
184
+
185
+ attr_hist_col = ["hist_" + col for col in item_attribute_cols]
186
+ df_train = pd.DataFrame(train_set, columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
187
+ df_test = pd.DataFrame(test_set, columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
188
+
189
+ return df_train, df_test
190
+
191
+
192
+ class Annoy(object):
193
+ """A vector matching engine using Annoy library"""
194
+
195
+ def __init__(self, metric='angular', n_trees=10, search_k=-1):
196
+ if not ANNOY_AVAILABLE:
197
+ raise ImportError("Annoy is not available. To use Annoy engine, please install it first:\n"
198
+ "pip install annoy\n"
199
+ "Or use other available engines like Faiss or Milvus")
200
+ self._n_trees = n_trees
201
+ self._search_k = search_k
202
+ self._metric = metric
203
+
204
+ def fit(self, X):
205
+ """Build the Annoy index from input vectors.
206
+
207
+ Args:
208
+ X (np.ndarray): input vectors with shape (n_samples, n_features)
209
+ """
210
+ self._annoy = AnnoyIndex(X.shape[1], metric=self._metric)
211
+ for i, x in enumerate(X):
212
+ self._annoy.add_item(i, x.tolist())
213
+ self._annoy.build(self._n_trees)
214
+
215
+ def set_query_arguments(self, search_k):
216
+ """Set query parameters for searching.
217
+
218
+ Args:
219
+ search_k (int): number of nodes to inspect during searching
220
+ """
221
+ self._search_k = search_k
222
+
223
+ def query(self, v, n):
224
+ """Find the n nearest neighbors to vector v.
225
+
226
+ Args:
227
+ v (np.ndarray): query vector
228
+ n (int): number of nearest neighbors to return
229
+
230
+ Returns:
231
+ tuple: (indices, distances) - lists of nearest neighbor indices and their distances
232
+ """
233
+ return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k, include_distances=True)
234
+
235
+ def __str__(self):
236
+ return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, self._search_k)
237
+
238
+
239
+ class Milvus(object):
240
+ """A vector matching engine using Milvus database"""
241
+
242
+ def __init__(self, dim=64, host="localhost", port="19530"):
243
+ if not MILVUS_AVAILABLE:
244
+ raise ImportError("Milvus is not available. To use Milvus engine, please install it first:\n"
245
+ "pip install pymilvus\n"
246
+ "Or use other available engines like Annoy or Faiss")
247
+ self.dim = dim
248
+ has = utility.has_collection("rechub")
249
+ if has:
250
+ utility.drop_collection("rechub")
251
+
252
+
253
+ # Create collection with schema definition
254
+ fields = [
255
+ FieldSchema(name="id",
256
+ dtype=DataType.INT64,
257
+ is_primary=True),
258
+ FieldSchema(name="embeddings",
259
+ dtype=DataType.FLOAT_VECTOR,
260
+ dim=dim),
261
+ ]
262
+ schema = CollectionSchema(fields=fields)
263
+ self.milvus = Collection("rechub", schema=schema)
264
+
265
+ def fit(self, X):
266
+ """Insert vectors into Milvus collection and build index.
267
+
268
+ Args:
269
+ X (np.ndarray or torch.Tensor): input vectors with shape (n_samples, n_features)
270
+ """
271
+ if hasattr(X, 'cpu'): # Handle PyTorch tensor
272
+ X = X.cpu().numpy()
273
+ self.milvus.release()
274
+ entities = [[i for i in range(len(X))], X]
275
+ self.milvus.insert(entities)
276
+ print(f"Number of entities in Milvus: {self.milvus.num_entities}")
277
+
278
+ # Create IVF_FLAT index for efficient search
279
+ index = {
280
+ "index_type": "IVF_FLAT",
281
+ "metric_type": "L2",
282
+ "params": {
283
+ "nlist": 128
284
+ },
285
+ }
286
+ self.milvus.create_index("embeddings", index)
287
+
288
+ @staticmethod
289
+ def process_result(results):
290
+ """Process Milvus search results into standard format.
291
+
292
+ Args:
293
+ results: raw search results from Milvus
294
+
295
+ Returns:
296
+ tuple: (indices_list, distances_list) - processed results
297
+ """
298
+ idx_list = []
299
+ score_list = []
300
+ for r in results:
301
+ temp_idx_list = []
302
+ temp_score_list = []
303
+ for i in range(len(r)):
304
+ temp_idx_list.append(r[i].id)
305
+ temp_score_list.append(r[i].distance)
306
+ idx_list.append(temp_idx_list)
307
+ score_list.append(temp_score_list)
308
+ return idx_list, score_list
309
+
310
+ def query(self, v, n):
311
+ """Query Milvus for the n nearest neighbors to vector v.
312
+
313
+ Args:
314
+ v (np.ndarray or torch.Tensor): query vector
315
+ n (int): number of nearest neighbors to return
316
+
317
+ Returns:
318
+ tuple: (indices, distances) - lists of nearest neighbor indices and their distances
319
+ """
320
+ if torch.is_tensor(v):
321
+ v = v.cpu().numpy()
322
+ self.milvus.load()
323
+ search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
324
+ results = self.milvus.search(v, "embeddings", search_params, n)
325
+ return self.process_result(results)
326
+
327
+
328
+ class Faiss(object):
329
+ """A vector matching engine using Faiss library"""
330
+
331
+ def __init__(self, dim, index_type='flat', nlist=100, m=32, metric='l2'):
332
+ self.dim = dim
333
+ self.index_type = index_type.lower()
334
+ self.nlist = nlist
335
+ self.m = m
336
+ self.metric = metric.lower()
337
+ self.index = None
338
+ self.is_trained = False
339
+
340
+ # Create index based on different index types and metrics
341
+ if self.metric == 'l2':
342
+ if self.index_type == 'flat':
343
+ self.index = faiss.IndexFlatL2(dim)
344
+ elif self.index_type == 'ivf':
345
+ quantizer = faiss.IndexFlatL2(dim)
346
+ self.index = faiss.IndexIVFFlat(quantizer, dim, nlist)
347
+ elif self.index_type == 'hnsw':
348
+ self.index = faiss.IndexHNSWFlat(dim, m)
349
+ else:
350
+ raise ValueError(f"Unsupported index type: {index_type}")
351
+ elif self.metric == 'ip':
352
+ if self.index_type == 'flat':
353
+ self.index = faiss.IndexFlatIP(dim)
354
+ elif self.index_type == 'ivf':
355
+ quantizer = faiss.IndexFlatIP(dim)
356
+ self.index = faiss.IndexIVFFlat(quantizer, dim, nlist)
357
+ elif self.index_type == 'hnsw':
358
+ self.index = faiss.IndexHNSWFlat(dim, m)
359
+ # HNSW defaults to L2, need to change to inner product
360
+ self.index.metric_type = faiss.METRIC_INNER_PRODUCT
361
+ else:
362
+ raise ValueError(f"Unsupported index type: {index_type}")
363
+ else:
364
+ raise ValueError(f"Unsupported metric: {metric}")
365
+
366
+ def fit(self, X):
367
+ """Train and build the index from input vectors.
368
+
369
+ Args:
370
+ X (np.ndarray): input vectors with shape (n_samples, dim)
371
+ """
372
+
373
+ # For index types that require training (like IVF), train first
374
+ if self.index_type == 'ivf' and not self.is_trained:
375
+ print(f"Training {self.index_type.upper()} index with {X.shape[0]} vectors...")
376
+ self.index.train(X)
377
+ self.is_trained = True
378
+
379
+ # Add vectors to the index
380
+ print(f"Adding {X.shape[0]} vectors to index...")
381
+ self.index.add(X)
382
+ print(f"Index built successfully. Total vectors: {self.index.ntotal}")
383
+
384
+ def query(self, v, n):
385
+ """Query the nearest neighbors for given vector.
386
+
387
+ Args:
388
+ v (np.ndarray or torch.Tensor): query vector
389
+ n (int): number of nearest neighbors to return
390
+
391
+ Returns:
392
+ tuple: (indices, distances) - lists of nearest neighbor indices and distances
393
+ """
394
+ if hasattr(v, 'cpu'): # Handle PyTorch tensor
395
+ v = v.cpu().numpy()
396
+
397
+ # Ensure query vector has correct shape
398
+ if v.ndim == 1:
399
+ v = v.reshape(1, -1)
400
+
401
+ v = v.astype(np.float32)
402
+
403
+ # Set search parameters for IVF index
404
+ if self.index_type == 'ivf':
405
+ # Set number of clusters to search
406
+ nprobe = min(self.nlist, max(1, self.nlist // 4))
407
+ self.index.nprobe = nprobe
408
+
409
+
410
+ # Execute search
411
+ distances, indices = self.index.search(v, n)
412
+
413
+ return indices.tolist(), distances.tolist()
414
+
415
+ def set_query_arguments(self, nprobe=None, efSearch=None):
416
+ """Set query parameters for search.
417
+
418
+ Args:
419
+ nprobe (int): number of clusters to search for IVF index
420
+ efSearch (int): search parameter for HNSW index
421
+ """
422
+ if self.index_type == 'ivf' and nprobe is not None:
423
+ self.index.nprobe = min(nprobe, self.nlist)
424
+ elif self.index_type == 'hnsw' and efSearch is not None:
425
+ self.index.hnsw.efSearch = efSearch
426
+
427
+ def save_index(self, filepath):
428
+ """Save index to file for later use."""
429
+ faiss.write_index(self.index, filepath)
430
+
431
+ def load_index(self, filepath):
432
+ """Load index from file."""
433
+ self.index = faiss.read_index(filepath)
434
+ self.is_trained = True
435
+
436
+ def __str__(self):
437
+ return f'Faiss(index_type={self.index_type}, dim={self.dim}, metric={self.metric}, ntotal={self.index.ntotal if self.index else 0})'
438
+
439
+ if __name__ == '__main__':
440
+ # Generate random item embeddings (100 items, each with 64 dimensions)
441
+ item_embeddings = np.random.rand(100, 64).astype(np.float32)
442
+
443
+ # Generate random user embedding (1 user, 64 dimensions)
444
+ user_embedding = np.random.rand(1, 64).astype(np.float32)
445
+
446
+ # Create FAISS index
447
+ faiss_index = Faiss(dim=64, index_type='ivf', nlist=100, metric='l2')
448
+
449
+ # Train and build the index
450
+ faiss_index.fit(item_embeddings)
451
+
452
+ # Query nearest neighbors
453
+ indices, distances = faiss_index.query(user_embedding, n=10)
454
+
455
+ print("Top 10 nearest neighbors:")
456
+ print(indices) # Output indices of nearest neighbors
457
+ print(distances) # Output distances of nearest neighbors
@@ -0,0 +1,136 @@
1
+ import torch
2
+ from torch.optim.optimizer import Optimizer
3
+
4
+ from ..models.multi_task import AITM, MMOE, PLE, SharedBottom
5
+
6
+
7
+ def shared_task_layers(model):
8
+ """get shared layers and task layers in multi-task model
9
+ Authors: Qida Dong, dongjidan@126.com
10
+
11
+ Args:
12
+ model (torch.nn.Module): only support `[MMOE, SharedBottom, PLE, AITM]`
13
+
14
+ Returns:
15
+ list[torch.nn.parameter]: parameters split to shared list and task list.
16
+ """
17
+ shared_layers = list(model.embedding.parameters())
18
+ task_layers = None
19
+ if isinstance(model, SharedBottom):
20
+ shared_layers += list(model.bottom_mlp.parameters())
21
+ task_layers = list(model.towers.parameters()) + \
22
+ list(model.predict_layers.parameters())
23
+ elif isinstance(model, MMOE):
24
+ shared_layers += list(model.experts.parameters())
25
+ task_layers = list(model.towers.parameters()) + \
26
+ list(model.predict_layers.parameters())
27
+ task_layers += list(model.gates.parameters())
28
+ elif isinstance(model, PLE):
29
+ shared_layers += list(model.cgc_layers.parameters())
30
+ task_layers = list(model.towers.parameters()) + \
31
+ list(model.predict_layers.parameters())
32
+ elif isinstance(model, AITM):
33
+ shared_layers += list(model.bottoms.parameters())
34
+ task_layers = list(model.info_gates.parameters()) + list(model.towers.parameters()) + list(model.aits.parameters())
35
+ else:
36
+ raise ValueError(f'this model {model} is not suitable for MetaBalance Optimizer')
37
+ return shared_layers, task_layers
38
+
39
+
40
+ class MetaBalance(Optimizer):
41
+ """MetaBalance Optimizer
42
+ This method is used to scale the gradient and balance the gradient of each task.
43
+ Authors: Qida Dong, dongjidan@126.com
44
+
45
+ Args:
46
+ parameters (list): the parameters of model
47
+ relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
48
+ beta (float, optional): the coefficient of moving average (default: 0.9)
49
+ """
50
+
51
+ def __init__(self, parameters, relax_factor=0.7, beta=0.9):
52
+
53
+ if relax_factor < 0. or relax_factor >= 1.:
54
+ raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
55
+ if beta < 0. or beta >= 1.:
56
+ raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
57
+ rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
58
+ super(MetaBalance, self).__init__(parameters, rel_beta_dict)
59
+
60
+ @torch.no_grad()
61
+ def step(self, losses):
62
+ for idx, loss in enumerate(losses):
63
+ loss.backward(retain_graph=True)
64
+ for group in self.param_groups:
65
+ for gp in group['params']:
66
+ if gp.grad is None:
67
+ # print('breaking')
68
+ break
69
+ if gp.grad.is_sparse:
70
+ raise RuntimeError('MetaBalance does not support sparse gradients')
71
+ # store the result of moving average
72
+ state = self.state[gp]
73
+ if len(state) == 0:
74
+ for i in range(len(losses)):
75
+ if i == 0:
76
+ gp.norms = [0]
77
+ else:
78
+ gp.norms.append(0)
79
+
80
+
81
+ # calculate the moving average
82
+ beta = group['beta']
83
+ gp.norms[idx] = gp.norms[idx] * beta + \
84
+ (1 - beta) * torch.norm(gp.grad)
85
+ # scale the auxiliary gradient
86
+ relax_factor = group['relax_factor']
87
+ gp.grad = gp.grad * \
88
+ gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. - relax_factor)
89
+ # store the gradient of each auxiliary task in state
90
+ if idx == 0:
91
+ state['sum_gradient'] = torch.zeros_like(gp.data)
92
+ state['sum_gradient'] += gp.grad
93
+ else:
94
+ state['sum_gradient'] += gp.grad
95
+
96
+ if gp.grad is not None:
97
+ gp.grad.detach_()
98
+ gp.grad.zero_()
99
+ if idx == len(losses) - 1:
100
+ gp.grad = state['sum_gradient']
101
+
102
+
103
+ def gradnorm(loss_list, loss_weight, share_layer, initial_task_loss, alpha):
104
+ loss = 0
105
+ for loss_i, w_i in zip(loss_list, loss_weight):
106
+ loss += loss_i * w_i
107
+ loss.backward(retain_graph=True)
108
+ # set the gradients of w_i(t) to zero because these gradients have to be
109
+ # updated using the GradNorm loss
110
+ for w_i in loss_weight:
111
+ w_i.grad.data = w_i.grad.data * 0.0
112
+
113
+
114
+ # get the gradient norms for each of the tasks
115
+ # G^{(i)}_w(t)
116
+ norms, loss_ratio = [], []
117
+ for i in range(len(loss_list)):
118
+ # get the gradient of this task loss with respect to the shared
119
+ # parameters
120
+ gygw = torch.autograd.grad(loss_list[i], share_layer, retain_graph=True)
121
+ # compute the norm
122
+ norms.append(torch.norm(torch.mul(loss_weight[i], gygw[0])))
123
+ # compute the inverse training rate r_i(t)
124
+ loss_ratio.append(loss_list[i].item() / initial_task_loss[i])
125
+ norms = torch.stack(norms)
126
+ mean_norm = torch.mean(norms.detach())
127
+ mean_loss_ratio = sum(loss_ratio) / len(loss_ratio)
128
+ # compute the GradNorm loss
129
+ # this term has to remain constant
130
+ constant_term = mean_norm * (mean_loss_ratio**alpha)
131
+ grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
132
+ # print('GradNorm loss {}'.format(grad_norm_loss))
133
+
134
+ # compute the gradient for the weights
135
+ for w_i in loss_weight:
136
+ w_i.grad = torch.autograd.grad(grad_norm_loss, w_i, retain_graph=True)[0]