torch-rechub 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/tracking.py +198 -0
- torch_rechub/data/__init__.py +0 -0
- torch_rechub/data/convert.py +67 -0
- torch_rechub/data/dataset.py +120 -0
- torch_rechub/trainers/ctr_trainer.py +137 -1
- torch_rechub/trainers/match_trainer.py +136 -1
- torch_rechub/trainers/mtl_trainer.py +146 -1
- torch_rechub/trainers/seq_trainer.py +193 -2
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/onnx_export.py +3 -136
- torch_rechub/utils/visualization.py +271 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/METADATA +68 -49
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/RECORD +15 -9
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.4.dist-info → torch_rechub-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -39,6 +39,7 @@ class MatchTrainer(object):
|
|
|
39
39
|
device="cpu",
|
|
40
40
|
gpus=None,
|
|
41
41
|
model_path="./",
|
|
42
|
+
model_logger=None,
|
|
42
43
|
):
|
|
43
44
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
44
45
|
if gpus is None:
|
|
@@ -73,10 +74,13 @@ class MatchTrainer(object):
|
|
|
73
74
|
self.model_path = model_path
|
|
74
75
|
# Initialize regularization loss
|
|
75
76
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
77
|
+
self.model_logger = model_logger
|
|
76
78
|
|
|
77
79
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
78
80
|
self.model.train()
|
|
79
81
|
total_loss = 0
|
|
82
|
+
epoch_loss = 0
|
|
83
|
+
batch_count = 0
|
|
80
84
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
81
85
|
for i, (x_dict, y) in enumerate(tk0):
|
|
82
86
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
@@ -114,14 +118,26 @@ class MatchTrainer(object):
|
|
|
114
118
|
loss.backward()
|
|
115
119
|
self.optimizer.step()
|
|
116
120
|
total_loss += loss.item()
|
|
121
|
+
epoch_loss += loss.item()
|
|
122
|
+
batch_count += 1
|
|
117
123
|
if (i + 1) % log_interval == 0:
|
|
118
124
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
119
125
|
total_loss = 0
|
|
120
126
|
|
|
127
|
+
# Return average epoch loss
|
|
128
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
129
|
+
|
|
121
130
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
131
|
+
for logger in self._iter_loggers():
|
|
132
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_mode': self.mode})
|
|
133
|
+
|
|
122
134
|
for epoch_i in range(self.n_epoch):
|
|
123
135
|
print('epoch:', epoch_i)
|
|
124
|
-
self.train_one_epoch(train_dataloader)
|
|
136
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
137
|
+
|
|
138
|
+
for logger in self._iter_loggers():
|
|
139
|
+
logger.log_metrics({'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}, step=epoch_i)
|
|
140
|
+
|
|
125
141
|
if self.scheduler is not None:
|
|
126
142
|
if epoch_i % self.scheduler.step_size == 0:
|
|
127
143
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
@@ -130,12 +146,34 @@ class MatchTrainer(object):
|
|
|
130
146
|
if val_dataloader:
|
|
131
147
|
auc = self.evaluate(self.model, val_dataloader)
|
|
132
148
|
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
149
|
+
|
|
150
|
+
for logger in self._iter_loggers():
|
|
151
|
+
logger.log_metrics({'val/auc': auc}, step=epoch_i)
|
|
152
|
+
|
|
133
153
|
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
134
154
|
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
135
155
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
136
156
|
break
|
|
157
|
+
|
|
137
158
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
138
159
|
|
|
160
|
+
for logger in self._iter_loggers():
|
|
161
|
+
logger.finish()
|
|
162
|
+
|
|
163
|
+
def _iter_loggers(self):
|
|
164
|
+
"""Return logger instances as a list.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
list
|
|
169
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
170
|
+
"""
|
|
171
|
+
if self.model_logger is None:
|
|
172
|
+
return []
|
|
173
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
174
|
+
return list(self.model_logger)
|
|
175
|
+
return [self.model_logger]
|
|
176
|
+
|
|
139
177
|
def evaluate(self, model, data_loader):
|
|
140
178
|
model.eval()
|
|
141
179
|
targets, predicts = list(), list()
|
|
@@ -237,3 +275,100 @@ class MatchTrainer(object):
|
|
|
237
275
|
# Restore original mode
|
|
238
276
|
if hasattr(model, 'mode'):
|
|
239
277
|
model.mode = original_mode
|
|
278
|
+
|
|
279
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
280
|
+
"""Visualize the model's computation graph.
|
|
281
|
+
|
|
282
|
+
This method generates a visual representation of the model architecture,
|
|
283
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
284
|
+
It automatically extracts feature information from the model.
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
input_data : dict, optional
|
|
289
|
+
Example input dict {feature_name: tensor}.
|
|
290
|
+
If not provided, dummy inputs will be generated automatically.
|
|
291
|
+
batch_size : int, default=2
|
|
292
|
+
Batch size for auto-generated dummy input.
|
|
293
|
+
seq_length : int, default=10
|
|
294
|
+
Sequence length for SequenceFeature.
|
|
295
|
+
depth : int, default=3
|
|
296
|
+
Visualization depth, higher values show more detail.
|
|
297
|
+
Set to -1 to show all layers.
|
|
298
|
+
show_shapes : bool, default=True
|
|
299
|
+
Whether to display tensor shapes.
|
|
300
|
+
expand_nested : bool, default=True
|
|
301
|
+
Whether to expand nested modules.
|
|
302
|
+
save_path : str, optional
|
|
303
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
304
|
+
If None, displays in Jupyter or opens system viewer.
|
|
305
|
+
graph_name : str, default="model"
|
|
306
|
+
Name for the graph.
|
|
307
|
+
device : str, optional
|
|
308
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
309
|
+
dpi : int, default=300
|
|
310
|
+
Resolution in dots per inch for output image.
|
|
311
|
+
Higher values produce sharper images suitable for papers.
|
|
312
|
+
**kwargs : dict
|
|
313
|
+
Additional arguments passed to torchview.draw_graph().
|
|
314
|
+
|
|
315
|
+
Returns
|
|
316
|
+
-------
|
|
317
|
+
ComputationGraph
|
|
318
|
+
A torchview ComputationGraph object.
|
|
319
|
+
|
|
320
|
+
Raises
|
|
321
|
+
------
|
|
322
|
+
ImportError
|
|
323
|
+
If torchview or graphviz is not installed.
|
|
324
|
+
|
|
325
|
+
Notes
|
|
326
|
+
-----
|
|
327
|
+
Default Display Behavior:
|
|
328
|
+
When `save_path` is None (default):
|
|
329
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
330
|
+
- In Python script: opens the graph with system default viewer
|
|
331
|
+
|
|
332
|
+
Examples
|
|
333
|
+
--------
|
|
334
|
+
>>> trainer = MatchTrainer(model, ...)
|
|
335
|
+
>>> trainer.fit(train_dl)
|
|
336
|
+
>>>
|
|
337
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
338
|
+
>>> trainer.visualization(depth=4)
|
|
339
|
+
>>>
|
|
340
|
+
>>> # Save to high-DPI PNG for papers
|
|
341
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
342
|
+
"""
|
|
343
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
344
|
+
|
|
345
|
+
if not TORCHVIEW_AVAILABLE:
|
|
346
|
+
raise ImportError(
|
|
347
|
+
"Visualization requires torchview. "
|
|
348
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
349
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
350
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
351
|
+
" - macOS: brew install graphviz\n"
|
|
352
|
+
" - Windows: choco install graphviz"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Handle DataParallel wrapped model
|
|
356
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
357
|
+
|
|
358
|
+
# Use provided device or default to 'cpu'
|
|
359
|
+
viz_device = device if device is not None else 'cpu'
|
|
360
|
+
|
|
361
|
+
return visualize_model(
|
|
362
|
+
model,
|
|
363
|
+
input_data=input_data,
|
|
364
|
+
batch_size=batch_size,
|
|
365
|
+
seq_length=seq_length,
|
|
366
|
+
depth=depth,
|
|
367
|
+
show_shapes=show_shapes,
|
|
368
|
+
expand_nested=expand_nested,
|
|
369
|
+
save_path=save_path,
|
|
370
|
+
graph_name=graph_name,
|
|
371
|
+
device=viz_device,
|
|
372
|
+
dpi=dpi,
|
|
373
|
+
**kwargs
|
|
374
|
+
)
|
|
@@ -47,6 +47,7 @@ class MTLTrainer(object):
|
|
|
47
47
|
device="cpu",
|
|
48
48
|
gpus=None,
|
|
49
49
|
model_path="./",
|
|
50
|
+
model_logger=None,
|
|
50
51
|
):
|
|
51
52
|
self.model = model
|
|
52
53
|
if gpus is None:
|
|
@@ -104,6 +105,7 @@ class MTLTrainer(object):
|
|
|
104
105
|
self.model_path = model_path
|
|
105
106
|
# Initialize regularization loss
|
|
106
107
|
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
108
|
+
self.model_logger = model_logger
|
|
107
109
|
|
|
108
110
|
def train_one_epoch(self, data_loader):
|
|
109
111
|
self.model.train()
|
|
@@ -163,21 +165,42 @@ class MTLTrainer(object):
|
|
|
163
165
|
def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
|
|
164
166
|
total_log = []
|
|
165
167
|
|
|
168
|
+
# Log hyperparameters once
|
|
169
|
+
for logger in self._iter_loggers():
|
|
170
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self._current_lr(), 'adaptive_method': self.adaptive_method})
|
|
171
|
+
|
|
166
172
|
for epoch_i in range(self.n_epoch):
|
|
167
173
|
_log_per_epoch = self.train_one_epoch(train_dataloader)
|
|
168
174
|
|
|
175
|
+
# Collect metrics
|
|
176
|
+
logs = {f'train/task_{task_id}_loss': loss_val for task_id, loss_val in enumerate(_log_per_epoch)}
|
|
177
|
+
lr_value = self._current_lr()
|
|
178
|
+
if lr_value is not None:
|
|
179
|
+
logs['learning_rate'] = lr_value
|
|
180
|
+
|
|
169
181
|
if self.scheduler is not None:
|
|
170
182
|
if epoch_i % self.scheduler.step_size == 0:
|
|
171
183
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
172
184
|
self.scheduler.step() # update lr in epoch level by scheduler
|
|
185
|
+
|
|
173
186
|
scores = self.evaluate(self.model, val_dataloader)
|
|
174
187
|
print('epoch:', epoch_i, 'validation scores: ', scores)
|
|
175
188
|
|
|
176
|
-
for score in scores:
|
|
189
|
+
for task_id, score in enumerate(scores):
|
|
190
|
+
logs[f'val/task_{task_id}_score'] = score
|
|
177
191
|
_log_per_epoch.append(score)
|
|
192
|
+
logs['auc'] = scores[self.earlystop_taskid]
|
|
193
|
+
|
|
194
|
+
if self.loss_weight:
|
|
195
|
+
for task_id, weight in enumerate(self.loss_weight):
|
|
196
|
+
logs[f'loss_weight/task_{task_id}'] = weight.item()
|
|
178
197
|
|
|
179
198
|
total_log.append(_log_per_epoch)
|
|
180
199
|
|
|
200
|
+
# Log metrics once per epoch
|
|
201
|
+
for logger in self._iter_loggers():
|
|
202
|
+
logger.log_metrics(logs, step=epoch_i)
|
|
203
|
+
|
|
181
204
|
if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
|
|
182
205
|
print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
|
|
183
206
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
@@ -185,8 +208,33 @@ class MTLTrainer(object):
|
|
|
185
208
|
|
|
186
209
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
|
|
187
210
|
|
|
211
|
+
for logger in self._iter_loggers():
|
|
212
|
+
logger.finish()
|
|
213
|
+
|
|
188
214
|
return total_log
|
|
189
215
|
|
|
216
|
+
def _iter_loggers(self):
|
|
217
|
+
"""Return logger instances as a list.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
list
|
|
222
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
223
|
+
"""
|
|
224
|
+
if self.model_logger is None:
|
|
225
|
+
return []
|
|
226
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
227
|
+
return list(self.model_logger)
|
|
228
|
+
return [self.model_logger]
|
|
229
|
+
|
|
230
|
+
def _current_lr(self):
|
|
231
|
+
"""Fetch current learning rate regardless of adaptive method."""
|
|
232
|
+
if self.adaptive_method == "metabalance":
|
|
233
|
+
return self.share_optimizer.param_groups[0]['lr'] if hasattr(self, 'share_optimizer') else None
|
|
234
|
+
if hasattr(self, 'optimizer'):
|
|
235
|
+
return self.optimizer.param_groups[0]['lr']
|
|
236
|
+
return None
|
|
237
|
+
|
|
190
238
|
def evaluate(self, model, data_loader):
|
|
191
239
|
model.eval()
|
|
192
240
|
targets, predicts = list(), list()
|
|
@@ -257,3 +305,100 @@ class MTLTrainer(object):
|
|
|
257
305
|
|
|
258
306
|
exporter = ONNXExporter(model, device=export_device)
|
|
259
307
|
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|
|
308
|
+
|
|
309
|
+
def visualization(self, input_data=None, batch_size=2, seq_length=10, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
310
|
+
"""Visualize the model's computation graph.
|
|
311
|
+
|
|
312
|
+
This method generates a visual representation of the model architecture,
|
|
313
|
+
showing layer connections, tensor shapes, and nested module structures.
|
|
314
|
+
It automatically extracts feature information from the model.
|
|
315
|
+
|
|
316
|
+
Parameters
|
|
317
|
+
----------
|
|
318
|
+
input_data : dict, optional
|
|
319
|
+
Example input dict {feature_name: tensor}.
|
|
320
|
+
If not provided, dummy inputs will be generated automatically.
|
|
321
|
+
batch_size : int, default=2
|
|
322
|
+
Batch size for auto-generated dummy input.
|
|
323
|
+
seq_length : int, default=10
|
|
324
|
+
Sequence length for SequenceFeature.
|
|
325
|
+
depth : int, default=3
|
|
326
|
+
Visualization depth, higher values show more detail.
|
|
327
|
+
Set to -1 to show all layers.
|
|
328
|
+
show_shapes : bool, default=True
|
|
329
|
+
Whether to display tensor shapes.
|
|
330
|
+
expand_nested : bool, default=True
|
|
331
|
+
Whether to expand nested modules.
|
|
332
|
+
save_path : str, optional
|
|
333
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
334
|
+
If None, displays in Jupyter or opens system viewer.
|
|
335
|
+
graph_name : str, default="model"
|
|
336
|
+
Name for the graph.
|
|
337
|
+
device : str, optional
|
|
338
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
339
|
+
dpi : int, default=300
|
|
340
|
+
Resolution in dots per inch for output image.
|
|
341
|
+
Higher values produce sharper images suitable for papers.
|
|
342
|
+
**kwargs : dict
|
|
343
|
+
Additional arguments passed to torchview.draw_graph().
|
|
344
|
+
|
|
345
|
+
Returns
|
|
346
|
+
-------
|
|
347
|
+
ComputationGraph
|
|
348
|
+
A torchview ComputationGraph object.
|
|
349
|
+
|
|
350
|
+
Raises
|
|
351
|
+
------
|
|
352
|
+
ImportError
|
|
353
|
+
If torchview or graphviz is not installed.
|
|
354
|
+
|
|
355
|
+
Notes
|
|
356
|
+
-----
|
|
357
|
+
Default Display Behavior:
|
|
358
|
+
When `save_path` is None (default):
|
|
359
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
360
|
+
- In Python script: opens the graph with system default viewer
|
|
361
|
+
|
|
362
|
+
Examples
|
|
363
|
+
--------
|
|
364
|
+
>>> trainer = MTLTrainer(model, task_types=["classification", "classification"])
|
|
365
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
366
|
+
>>>
|
|
367
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
368
|
+
>>> trainer.visualization(depth=4)
|
|
369
|
+
>>>
|
|
370
|
+
>>> # Save to high-DPI PNG for papers
|
|
371
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
372
|
+
"""
|
|
373
|
+
from ..utils.visualization import TORCHVIEW_AVAILABLE, visualize_model
|
|
374
|
+
|
|
375
|
+
if not TORCHVIEW_AVAILABLE:
|
|
376
|
+
raise ImportError(
|
|
377
|
+
"Visualization requires torchview. "
|
|
378
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
379
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
380
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
381
|
+
" - macOS: brew install graphviz\n"
|
|
382
|
+
" - Windows: choco install graphviz"
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Handle DataParallel wrapped model
|
|
386
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
387
|
+
|
|
388
|
+
# Use provided device or default to 'cpu'
|
|
389
|
+
viz_device = device if device is not None else 'cpu'
|
|
390
|
+
|
|
391
|
+
return visualize_model(
|
|
392
|
+
model,
|
|
393
|
+
input_data=input_data,
|
|
394
|
+
batch_size=batch_size,
|
|
395
|
+
seq_length=seq_length,
|
|
396
|
+
depth=depth,
|
|
397
|
+
show_shapes=show_shapes,
|
|
398
|
+
expand_nested=expand_nested,
|
|
399
|
+
save_path=save_path,
|
|
400
|
+
graph_name=graph_name,
|
|
401
|
+
device=viz_device,
|
|
402
|
+
dpi=dpi,
|
|
403
|
+
**kwargs
|
|
404
|
+
)
|
|
@@ -46,7 +46,22 @@ class SeqTrainer(object):
|
|
|
46
46
|
... )
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
|
-
def __init__(
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model,
|
|
52
|
+
optimizer_fn=torch.optim.Adam,
|
|
53
|
+
optimizer_params=None,
|
|
54
|
+
scheduler_fn=None,
|
|
55
|
+
scheduler_params=None,
|
|
56
|
+
n_epoch=10,
|
|
57
|
+
earlystop_patience=10,
|
|
58
|
+
device='cpu',
|
|
59
|
+
gpus=None,
|
|
60
|
+
model_path='./',
|
|
61
|
+
loss_type='cross_entropy',
|
|
62
|
+
loss_params=None,
|
|
63
|
+
model_logger=None
|
|
64
|
+
):
|
|
50
65
|
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
51
66
|
if gpus is None:
|
|
52
67
|
gpus = []
|
|
@@ -74,9 +89,11 @@ class SeqTrainer(object):
|
|
|
74
89
|
loss_params = {"ignore_index": 0}
|
|
75
90
|
self.loss_fn = nn.CrossEntropyLoss(**loss_params)
|
|
76
91
|
|
|
92
|
+
self.loss_type = loss_type
|
|
77
93
|
self.n_epoch = n_epoch
|
|
78
94
|
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
79
95
|
self.model_path = model_path
|
|
96
|
+
self.model_logger = model_logger
|
|
80
97
|
|
|
81
98
|
def fit(self, train_dataloader, val_dataloader=None):
|
|
82
99
|
"""训练模型.
|
|
@@ -90,10 +107,18 @@ class SeqTrainer(object):
|
|
|
90
107
|
"""
|
|
91
108
|
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
|
|
92
109
|
|
|
110
|
+
for logger in self._iter_loggers():
|
|
111
|
+
logger.log_hyperparams({'n_epoch': self.n_epoch, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'loss_type': self.loss_type})
|
|
112
|
+
|
|
93
113
|
for epoch_i in range(self.n_epoch):
|
|
94
114
|
print('epoch:', epoch_i)
|
|
95
115
|
# 训练阶段
|
|
96
|
-
self.train_one_epoch(train_dataloader)
|
|
116
|
+
train_loss = self.train_one_epoch(train_dataloader)
|
|
117
|
+
history['train_loss'].append(train_loss)
|
|
118
|
+
|
|
119
|
+
# Collect metrics
|
|
120
|
+
logs = {'train/loss': train_loss, 'learning_rate': self.optimizer.param_groups[0]['lr']}
|
|
121
|
+
|
|
97
122
|
if self.scheduler is not None:
|
|
98
123
|
if epoch_i % self.scheduler.step_size == 0:
|
|
99
124
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
@@ -105,6 +130,10 @@ class SeqTrainer(object):
|
|
|
105
130
|
history['val_loss'].append(val_loss)
|
|
106
131
|
history['val_accuracy'].append(val_accuracy)
|
|
107
132
|
|
|
133
|
+
logs['val/loss'] = val_loss
|
|
134
|
+
logs['val/accuracy'] = val_accuracy
|
|
135
|
+
logs['auc'] = val_accuracy # For compatibility with EarlyStopper
|
|
136
|
+
|
|
108
137
|
print(f"epoch: {epoch_i}, validation: loss: {val_loss:.4f}, accuracy: {val_accuracy:.4f}")
|
|
109
138
|
|
|
110
139
|
# 早停
|
|
@@ -113,9 +142,30 @@ class SeqTrainer(object):
|
|
|
113
142
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
143
|
break
|
|
115
144
|
|
|
145
|
+
for logger in self._iter_loggers():
|
|
146
|
+
logger.log_metrics(logs, step=epoch_i)
|
|
147
|
+
|
|
116
148
|
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best model
|
|
149
|
+
|
|
150
|
+
for logger in self._iter_loggers():
|
|
151
|
+
logger.finish()
|
|
152
|
+
|
|
117
153
|
return history
|
|
118
154
|
|
|
155
|
+
def _iter_loggers(self):
|
|
156
|
+
"""Return logger instances as a list.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
list
|
|
161
|
+
Active logger instances. Empty when ``model_logger`` is ``None``.
|
|
162
|
+
"""
|
|
163
|
+
if self.model_logger is None:
|
|
164
|
+
return []
|
|
165
|
+
if isinstance(self.model_logger, (list, tuple)):
|
|
166
|
+
return list(self.model_logger)
|
|
167
|
+
return [self.model_logger]
|
|
168
|
+
|
|
119
169
|
def train_one_epoch(self, data_loader, log_interval=10):
|
|
120
170
|
"""Train the model for a single epoch.
|
|
121
171
|
|
|
@@ -128,6 +178,8 @@ class SeqTrainer(object):
|
|
|
128
178
|
"""
|
|
129
179
|
self.model.train()
|
|
130
180
|
total_loss = 0
|
|
181
|
+
epoch_loss = 0
|
|
182
|
+
batch_count = 0
|
|
131
183
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
132
184
|
for i, (seq_tokens, seq_positions, seq_time_diffs, targets) in enumerate(tk0):
|
|
133
185
|
# Move tensors to the target device
|
|
@@ -152,10 +204,15 @@ class SeqTrainer(object):
|
|
|
152
204
|
self.optimizer.step()
|
|
153
205
|
|
|
154
206
|
total_loss += loss.item()
|
|
207
|
+
epoch_loss += loss.item()
|
|
208
|
+
batch_count += 1
|
|
155
209
|
if (i + 1) % log_interval == 0:
|
|
156
210
|
tk0.set_postfix(loss=total_loss / log_interval)
|
|
157
211
|
total_loss = 0
|
|
158
212
|
|
|
213
|
+
# Return average epoch loss
|
|
214
|
+
return epoch_loss / batch_count if batch_count > 0 else 0
|
|
215
|
+
|
|
159
216
|
def evaluate(self, data_loader):
|
|
160
217
|
"""Evaluate the model on a validation/test data loader.
|
|
161
218
|
|
|
@@ -291,3 +348,137 @@ class SeqTrainer(object):
|
|
|
291
348
|
except Exception as e:
|
|
292
349
|
warnings.warn(f"ONNX export failed: {str(e)}")
|
|
293
350
|
raise RuntimeError(f"Failed to export ONNX model: {str(e)}") from e
|
|
351
|
+
|
|
352
|
+
def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, show_shapes=True, expand_nested=True, save_path=None, graph_name="model", device=None, dpi=300, **kwargs):
|
|
353
|
+
"""Visualize the model's computation graph.
|
|
354
|
+
|
|
355
|
+
This method generates a visual representation of the sequence model
|
|
356
|
+
architecture, showing layer connections, tensor shapes, and nested
|
|
357
|
+
module structures.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
seq_length : int, default=50
|
|
362
|
+
Sequence length for dummy input.
|
|
363
|
+
vocab_size : int, optional
|
|
364
|
+
Vocabulary size for generating dummy tokens.
|
|
365
|
+
If None, will try to get from model.vocab_size or model.item_num.
|
|
366
|
+
batch_size : int, default=2
|
|
367
|
+
Batch size for dummy input.
|
|
368
|
+
depth : int, default=3
|
|
369
|
+
Visualization depth, higher values show more detail.
|
|
370
|
+
Set to -1 to show all layers.
|
|
371
|
+
show_shapes : bool, default=True
|
|
372
|
+
Whether to display tensor shapes.
|
|
373
|
+
expand_nested : bool, default=True
|
|
374
|
+
Whether to expand nested modules.
|
|
375
|
+
save_path : str, optional
|
|
376
|
+
Path to save the graph image (.pdf, .svg, .png).
|
|
377
|
+
If None, displays in Jupyter or opens system viewer.
|
|
378
|
+
graph_name : str, default="model"
|
|
379
|
+
Name for the graph.
|
|
380
|
+
device : str, optional
|
|
381
|
+
Device for model execution. If None, defaults to 'cpu'.
|
|
382
|
+
dpi : int, default=300
|
|
383
|
+
Resolution in dots per inch for output image.
|
|
384
|
+
Higher values produce sharper images suitable for papers.
|
|
385
|
+
**kwargs : dict
|
|
386
|
+
Additional arguments passed to torchview.draw_graph().
|
|
387
|
+
|
|
388
|
+
Returns
|
|
389
|
+
-------
|
|
390
|
+
ComputationGraph
|
|
391
|
+
A torchview ComputationGraph object.
|
|
392
|
+
|
|
393
|
+
Raises
|
|
394
|
+
------
|
|
395
|
+
ImportError
|
|
396
|
+
If torchview or graphviz is not installed.
|
|
397
|
+
ValueError
|
|
398
|
+
If vocab_size is not provided and cannot be inferred from model.
|
|
399
|
+
|
|
400
|
+
Notes
|
|
401
|
+
-----
|
|
402
|
+
Default Display Behavior:
|
|
403
|
+
When `save_path` is None (default):
|
|
404
|
+
- In Jupyter/IPython: automatically displays the graph inline
|
|
405
|
+
- In Python script: opens the graph with system default viewer
|
|
406
|
+
|
|
407
|
+
Examples
|
|
408
|
+
--------
|
|
409
|
+
>>> trainer = SeqTrainer(hstu_model, ...)
|
|
410
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
411
|
+
>>>
|
|
412
|
+
>>> # Auto-display in Jupyter (no save_path needed)
|
|
413
|
+
>>> trainer.visualization(depth=4, vocab_size=10000)
|
|
414
|
+
>>>
|
|
415
|
+
>>> # Save to high-DPI PNG for papers
|
|
416
|
+
>>> trainer.visualization(save_path="model.png", dpi=300)
|
|
417
|
+
"""
|
|
418
|
+
try:
|
|
419
|
+
from torchview import draw_graph
|
|
420
|
+
TORCHVIEW_AVAILABLE = True
|
|
421
|
+
except ImportError:
|
|
422
|
+
TORCHVIEW_AVAILABLE = False
|
|
423
|
+
|
|
424
|
+
if not TORCHVIEW_AVAILABLE:
|
|
425
|
+
raise ImportError(
|
|
426
|
+
"Visualization requires torchview. "
|
|
427
|
+
"Install with: pip install torch-rechub[visualization]\n"
|
|
428
|
+
"Also ensure graphviz is installed on your system:\n"
|
|
429
|
+
" - Ubuntu/Debian: sudo apt-get install graphviz\n"
|
|
430
|
+
" - macOS: brew install graphviz\n"
|
|
431
|
+
" - Windows: choco install graphviz"
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
from ..utils.visualization import _is_jupyter_environment, display_graph
|
|
435
|
+
|
|
436
|
+
# Handle DataParallel wrapped model
|
|
437
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
438
|
+
|
|
439
|
+
# Use provided device or default to 'cpu'
|
|
440
|
+
viz_device = device if device is not None else 'cpu'
|
|
441
|
+
|
|
442
|
+
# Get vocab_size from model if not provided
|
|
443
|
+
if vocab_size is None:
|
|
444
|
+
if hasattr(model, 'vocab_size'):
|
|
445
|
+
vocab_size = model.vocab_size
|
|
446
|
+
elif hasattr(model, 'item_num'):
|
|
447
|
+
vocab_size = model.item_num
|
|
448
|
+
else:
|
|
449
|
+
raise ValueError("vocab_size must be provided or model must have "
|
|
450
|
+
"'vocab_size' or 'item_num' attribute")
|
|
451
|
+
|
|
452
|
+
# Generate dummy inputs for sequence model
|
|
453
|
+
dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)
|
|
454
|
+
dummy_seq_time_diffs = torch.zeros(batch_size, seq_length, dtype=torch.float32, device=viz_device)
|
|
455
|
+
|
|
456
|
+
# Move model to device
|
|
457
|
+
model = model.to(viz_device)
|
|
458
|
+
model.eval()
|
|
459
|
+
|
|
460
|
+
# Call torchview.draw_graph
|
|
461
|
+
graph = draw_graph(model, input_data=(dummy_seq_tokens, dummy_seq_time_diffs), graph_name=graph_name, depth=depth, device=viz_device, expand_nested=expand_nested, show_shapes=show_shapes, save_graph=False, **kwargs)
|
|
462
|
+
|
|
463
|
+
# Set DPI for high-quality output
|
|
464
|
+
graph.visual_graph.graph_attr['dpi'] = str(dpi)
|
|
465
|
+
|
|
466
|
+
# Handle save_path: manually save with DPI applied
|
|
467
|
+
if save_path:
|
|
468
|
+
import os
|
|
469
|
+
directory = os.path.dirname(save_path) or "."
|
|
470
|
+
filename = os.path.splitext(os.path.basename(save_path))[0]
|
|
471
|
+
ext = os.path.splitext(save_path)[1].lstrip('.')
|
|
472
|
+
output_format = ext if ext else 'pdf'
|
|
473
|
+
if directory != "." and not os.path.exists(directory):
|
|
474
|
+
os.makedirs(directory, exist_ok=True)
|
|
475
|
+
graph.visual_graph.render(filename=filename, directory=directory, format=output_format, cleanup=True)
|
|
476
|
+
|
|
477
|
+
# Handle default display behavior when save_path is None
|
|
478
|
+
if save_path is None:
|
|
479
|
+
if _is_jupyter_environment():
|
|
480
|
+
display_graph(graph)
|
|
481
|
+
else:
|
|
482
|
+
graph.visual_graph.view(cleanup=True)
|
|
483
|
+
|
|
484
|
+
return graph
|