torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,128 +1,288 @@
1
- import os
2
- import torch
3
- import tqdm
4
- from sklearn.metrics import roc_auc_score
5
- from ..basic.callback import EarlyStopper
6
-
7
-
8
- class CTRTrainer(object):
9
- """A general trainer for single task learning.
10
-
11
- Args:
12
- model (nn.Module): any multi task learning model.
13
- optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
14
- optimizer_params (dict): parameters of optimizer_fn.
15
- scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
16
- scheduler_params (dict): parameters of optimizer scheduler_fn.
17
- n_epoch (int): epoch number of training.
18
- earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
19
- device (str): `"cpu"` or `"cuda:0"`
20
- gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
21
- loss_mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
22
- model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
23
- """
24
-
25
- def __init__(
26
- self,
27
- model,
28
- optimizer_fn=torch.optim.Adam,
29
- optimizer_params=None,
30
- scheduler_fn=None,
31
- scheduler_params=None,
32
- n_epoch=10,
33
- earlystop_patience=10,
34
- device="cpu",
35
- gpus=None,
36
- loss_mode=True,
37
- model_path="./",
38
- ):
39
- self.model = model # for uniform weights save method in one gpu or multi gpu
40
- if gpus is None:
41
- gpus = []
42
- self.gpus = gpus
43
- if len(gpus) > 1:
44
- print('parallel running on these gpus:', gpus)
45
- self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
46
- self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
- self.model.to(self.device)
48
- if optimizer_params is None:
49
- optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
50
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
51
- self.scheduler = None
52
- if scheduler_fn is not None:
53
- self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
54
- self.loss_mode = loss_mode
55
- self.criterion = torch.nn.BCELoss() #default loss cross_entropy
56
- self.evaluate_fn = roc_auc_score #default evaluate function
57
- self.n_epoch = n_epoch
58
- self.early_stopper = EarlyStopper(patience=earlystop_patience)
59
- self.model_path = model_path
60
-
61
- def train_one_epoch(self, data_loader, log_interval=10):
62
- self.model.train()
63
- total_loss = 0
64
- tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
65
- for i, (x_dict, y) in enumerate(tk0):
66
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
67
- y = y.to(self.device).float()
68
- if self.loss_mode:
69
- y_pred = self.model(x_dict)
70
- loss = self.criterion(y_pred, y)
71
- else:
72
- y_pred, other_loss = self.model(x_dict)
73
- loss = self.criterion(y_pred, y) + other_loss
74
- self.model.zero_grad()
75
- loss.backward()
76
- self.optimizer.step()
77
- total_loss += loss.item()
78
- if (i + 1) % log_interval == 0:
79
- tk0.set_postfix(loss=total_loss / log_interval)
80
- total_loss = 0
81
-
82
- def fit(self, train_dataloader, val_dataloader=None):
83
- for epoch_i in range(self.n_epoch):
84
- print('epoch:', epoch_i)
85
- self.train_one_epoch(train_dataloader)
86
- if self.scheduler is not None:
87
- if epoch_i % self.scheduler.step_size == 0:
88
- print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
89
- self.scheduler.step() #update lr in epoch level by scheduler
90
- if val_dataloader:
91
- auc = self.evaluate(self.model, val_dataloader)
92
- print('epoch:', epoch_i, 'validation: auc:', auc)
93
- if self.early_stopper.stop_training(auc, self.model.state_dict()):
94
- print(f'validation: best auc: {self.early_stopper.best_auc}')
95
- self.model.load_state_dict(self.early_stopper.best_weights)
96
- break
97
- torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) #save best auc model
98
-
99
- def evaluate(self, model, data_loader):
100
- model.eval()
101
- targets, predicts = list(), list()
102
- with torch.no_grad():
103
- tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
104
- for i, (x_dict, y) in enumerate(tk0):
105
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
106
- y = y.to(self.device).float().view(-1, 1) # 确保y是float类型且维度为[batch_size, 1]
107
- if self.loss_mode:
108
- y_pred = model(x_dict)
109
- else:
110
- y_pred, _ = model(x_dict)
111
- targets.extend(y.tolist())
112
- predicts.extend(y_pred.tolist())
113
- return self.evaluate_fn(targets, predicts)
114
-
115
- def predict(self, model, data_loader):
116
- model.eval()
117
- predicts = list()
118
- with torch.no_grad():
119
- tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
120
- for i, (x_dict, y) in enumerate(tk0):
121
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
122
- y = y.to(self.device)
123
- if self.loss_mode:
124
- y_pred = model(x_dict)
125
- else:
126
- y_pred, _ = model(x_dict)
127
- predicts.extend(y_pred.tolist())
128
- return predicts
1
+ import os
2
+
3
+ import torch
4
+ import tqdm
5
+ from sklearn.metrics import roc_auc_score
6
+
7
+ from ..basic.callback import EarlyStopper
8
+ from ..basic.loss_func import RegularizationLoss
9
+
10
+
11
+ class CTRTrainer(object):
12
+ """A general trainer for single task learning.
13
+
14
+ Args:
15
+ model (nn.Module): any multi task learning model.
16
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
17
+ optimizer_params (dict): parameters of optimizer_fn.
18
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
19
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
20
+ n_epoch (int): epoch number of training.
21
+ earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
22
+ device (str): `"cpu"` or `"cuda:0"`
23
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
24
+ loss_mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
25
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
26
+ embedding_l1 (float): L1 regularization coefficient for embedding parameters (default=0.0).
27
+ embedding_l2 (float): L2 regularization coefficient for embedding parameters (default=0.0).
28
+ dense_l1 (float): L1 regularization coefficient for dense parameters (default=0.0).
29
+ dense_l2 (float): L2 regularization coefficient for dense parameters (default=0.0).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ model,
35
+ optimizer_fn=torch.optim.Adam,
36
+ optimizer_params=None,
37
+ regularization_params=None,
38
+ scheduler_fn=None,
39
+ scheduler_params=None,
40
+ n_epoch=10,
41
+ earlystop_patience=10,
42
+ device="cpu",
43
+ gpus=None,
44
+ loss_mode=True,
45
+ model_path="./",
46
+ ):
47
+ self.model = model # for uniform weights save method in one gpu or multi gpu
48
+ if gpus is None:
49
+ gpus = []
50
+ self.gpus = gpus
51
+ if len(gpus) > 1:
52
+ print('parallel running on these gpus:', gpus)
53
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
54
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
55
+ self.device = torch.device(device)
56
+ self.model.to(self.device)
57
+ if optimizer_params is None:
58
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
59
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
60
+ if regularization_params is None:
61
+ regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
62
+ self.scheduler = None
63
+ if scheduler_fn is not None:
64
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
65
+ self.loss_mode = loss_mode
66
+ self.criterion = torch.nn.BCELoss() # default loss cross_entropy
67
+ self.evaluate_fn = roc_auc_score # default evaluate function
68
+ self.n_epoch = n_epoch
69
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
70
+ self.model_path = model_path
71
+ # Initialize regularization loss
72
+ self.reg_loss_fn = RegularizationLoss(**regularization_params)
73
+
74
+ def train_one_epoch(self, data_loader, log_interval=10):
75
+ self.model.train()
76
+ total_loss = 0
77
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
78
+ for i, (x_dict, y) in enumerate(tk0):
79
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
80
+ y = y.to(self.device).float()
81
+ if self.loss_mode:
82
+ y_pred = self.model(x_dict)
83
+ loss = self.criterion(y_pred, y)
84
+ else:
85
+ y_pred, other_loss = self.model(x_dict)
86
+ loss = self.criterion(y_pred, y) + other_loss
87
+
88
+ # Add regularization loss
89
+ reg_loss = self.reg_loss_fn(self.model)
90
+ loss = loss + reg_loss
91
+
92
+ self.model.zero_grad()
93
+ loss.backward()
94
+ self.optimizer.step()
95
+ total_loss += loss.item()
96
+ if (i + 1) % log_interval == 0:
97
+ tk0.set_postfix(loss=total_loss / log_interval)
98
+ total_loss = 0
99
+
100
+ def fit(self, train_dataloader, val_dataloader=None):
101
+ for epoch_i in range(self.n_epoch):
102
+ print('epoch:', epoch_i)
103
+ self.train_one_epoch(train_dataloader)
104
+ if self.scheduler is not None:
105
+ if epoch_i % self.scheduler.step_size == 0:
106
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
107
+ self.scheduler.step() # update lr in epoch level by scheduler
108
+ if val_dataloader:
109
+ auc = self.evaluate(self.model, val_dataloader)
110
+ print('epoch:', epoch_i, 'validation: auc:', auc)
111
+ if self.early_stopper.stop_training(auc, self.model.state_dict()):
112
+ print(f'validation: best auc: {self.early_stopper.best_auc}')
113
+ self.model.load_state_dict(self.early_stopper.best_weights)
114
+ break
115
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
116
+
117
+ def evaluate(self, model, data_loader):
118
+ model.eval()
119
+ targets, predicts = list(), list()
120
+ with torch.no_grad():
121
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
122
+ for i, (x_dict, y) in enumerate(tk0):
123
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
124
+ # 确保y是float类型且维度为[batch_size, 1]
125
+ y = y.to(self.device).float().view(-1, 1)
126
+ if self.loss_mode:
127
+ y_pred = model(x_dict)
128
+ else:
129
+ y_pred, _ = model(x_dict)
130
+ targets.extend(y.tolist())
131
+ predicts.extend(y_pred.tolist())
132
+ return self.evaluate_fn(targets, predicts)
133
+
134
+ def predict(self, model, data_loader):
135
+ model.eval()
136
+ predicts = list()
137
+ with torch.no_grad():
138
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
139
+ for i, (x_dict, y) in enumerate(tk0):
140
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
141
+ y = y.to(self.device)
142
+ if self.loss_mode:
143
+ y_pred = model(x_dict)
144
+ else:
145
+ y_pred, _ = model(x_dict)
146
+ predicts.extend(y_pred.tolist())
147
+ return predicts
148
+
149
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
150
+ """Export the trained model to ONNX format.
151
+
152
+ This method exports the ranking model (e.g., DeepFM, WideDeep, DCN) to ONNX format
153
+ for deployment. The export is non-invasive and does not modify the model code.
154
+
155
+ Args:
156
+ output_path (str): Path to save the ONNX model file.
157
+ dummy_input (dict, optional): Example input dict {feature_name: tensor}.
158
+ If not provided, dummy inputs will be generated automatically.
159
+ batch_size (int): Batch size for auto-generated dummy input (default: 2).
160
+ seq_length (int): Sequence length for SequenceFeature (default: 10).
161
+ opset_version (int): ONNX opset version (default: 14).
162
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
163
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
164
+ If None, defaults to 'cpu' for maximum compatibility.
165
+ verbose (bool): Print export details (default: False).
166
+
167
+ Returns:
168
+ bool: True if export succeeded, False otherwise.
169
+
170
+ Example:
171
+ >>> trainer = CTRTrainer(model, ...)
172
+ >>> trainer.fit(train_dl, val_dl)
173
+ >>> trainer.export_onnx("deepfm.onnx")
174
+
175
+ >>> # With custom dummy input
176
+ >>> dummy = {"user_id": torch.tensor([1, 2]), "item_id": torch.tensor([10, 20])}
177
+ >>> trainer.export_onnx("model.onnx", dummy_input=dummy)
178
+
179
+ >>> # Export on specific device
180
+ >>> trainer.export_onnx("model.onnx", device="cpu")
181
+ """
182
+ from ..utils.onnx_export import ONNXExporter
183
+
184
+ # Handle DataParallel wrapped model
185
+ model = self.model.module if hasattr(self.model, 'module') else self.model
186
+
187
+ # Use provided device or default to 'cpu'
188
+ export_device = device if device is not None else 'cpu'
189
+
190
+ exporter = ONNXExporter(model, device=export_device)
191
+ return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
192
+
193
+ def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
194
+ """Visualize the model's computation graph.
195
+
196
+ This method generates a visual representation of the model architecture,
197
+ showing layer connections, tensor shapes, and nested module structures.
198
+ It automatically extracts feature information from the model.
199
+
200
+ Parameters
201
+ ----------
202
+ input_data : dict, optional
203
+ Example input dict {feature_name: tensor}.
204
+ If not provided, dummy inputs will be generated automatically.
205
+ batch_size : int, default=2
206
+ Batch size for auto-generated dummy input.
207
+ seq_length : int, default=10
208
+ Sequence length for SequenceFeature.
209
+ depth : int, default=3
210
+ Visualization depth, higher values show more detail.
211
+ Set to -1 to show all layers.
212
+ show_shapes : bool, default=True
213
+ Whether to display tensor shapes.
214
+ expand_nested : bool, default=True
215
+ Whether to expand nested modules.
216
+ save_path : str, optional
217
+ Path to save the graph image (.pdf, .svg, .png).
218
+ If None, displays in Jupyter or opens system viewer.
219
+ graph_name : str, default="model"
220
+ Name for the graph.
221
+ device : str, optional
222
+ Device for model execution. If None, defaults to 'cpu'.
223
+ dpi : int, default=300
224
+ Resolution in dots per inch for output image.
225
+ Higher values produce sharper images suitable for papers.
226
+ **kwargs : dict
227
+ Additional arguments passed to torchview.draw_graph().
228
+
229
+ Returns
230
+ -------
231
+ ComputationGraph
232
+ A torchview ComputationGraph object.
233
+
234
+ Raises
235
+ ------
236
+ ImportError
237
+ If torchview or graphviz is not installed.
238
+
239
+ Notes
240
+ -----
241
+ Default Display Behavior:
242
+ When `save_path` is None (default):
243
+ - In Jupyter/IPython: automatically displays the graph inline
244
+ - In Python script: opens the graph with system default viewer
245
+
246
+ Examples
247
+ --------
248
+ >>> trainer = CTRTrainer(model, ...)
249
+ >>> trainer.fit(train_dl, val_dl)
250
+ >>>
251
+ >>> # Auto-display in Jupyter (no save_path needed)
252
+ >>> trainer.visualization(depth=4)
253
+ >>>
254
+ >>> # Save to high-DPI PNG for papers
255
+ >>> trainer.visualization(save_path="model.png", dpi=300)
256
+ """
257
+ from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
258
+
259
+ if not TORCHVIEW_AVAILABLE:
260
+ raise ImportError(
261
+ "Visualization requires torchview. "
262
+ "Install with: pip install torch-rechub[visualization]\n"
263
+ "Also ensure graphviz is installed on your system:\n"
264
+ " - Ubuntu/Debian: sudo apt-get install graphviz\n"
265
+ " - macOS: brew install graphviz\n"
266
+ " - Windows: choco install graphviz"
267
+ )
268
+
269
+ # Handle DataParallel wrapped model
270
+ model = self.model.module if hasattr(self.model, 'module') else self.model
271
+
272
+ # Use provided device or default to 'cpu'
273
+ viz_device = device if device is not None else 'cpu'
274
+
275
+ return visualize_model(
276
+ model,
277
+ input_data=input_data,
278
+ batch_size=batch_size,
279
+ seq_length=seq_length,
280
+ depth=depth,
281
+ show_shapes=show_shapes,
282
+ expand_nested=expand_nested,
283
+ save_path=save_path,
284
+ graph_name=graph_name,
285
+ device=viz_device,
286
+ dpi=dpi,
287
+ **kwargs
288
+ )