eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
eoml/torch/trainer.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from sklearn.metrics import f1_score
|
|
8
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clip_grad_norm(model, clip_grad_val=1):
|
|
13
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_val)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GradNormClipper:
|
|
17
|
+
|
|
18
|
+
def __init__(self, clip_val):
|
|
19
|
+
self.clip_val = clip_val
|
|
20
|
+
|
|
21
|
+
def __call__(self, model):
|
|
22
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_val)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def f1(output, labels):
|
|
27
|
+
pred_labels = torch.argmax(output, dim=1)
|
|
28
|
+
|
|
29
|
+
return f1_score(labels, pred_labels, labels=None, average='weighted', sample_weight=None,
|
|
30
|
+
zero_division='warn')
|
|
31
|
+
|
|
32
|
+
class Trainer:
|
|
33
|
+
"""TODO DO AGGRESSIVBE VERSION"""
|
|
34
|
+
def __init__(self, optimizer, model, loss_fn, grad_f=None, score_function=f1, score_name="f1", score_direction=1, scheduler=None):
|
|
35
|
+
self.optimizer = optimizer
|
|
36
|
+
self.model = model
|
|
37
|
+
self.loss_fn = loss_fn
|
|
38
|
+
self.grad_f = grad_f
|
|
39
|
+
|
|
40
|
+
self.score_direction = score_direction
|
|
41
|
+
|
|
42
|
+
self.writer = None
|
|
43
|
+
|
|
44
|
+
self.score_function = score_function
|
|
45
|
+
self.score_name = score_name
|
|
46
|
+
|
|
47
|
+
self.scheduler = scheduler
|
|
48
|
+
|
|
49
|
+
def _epoch(self, loader, epoch_index, report_frequency, device="cpu"):
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
:param loader:
|
|
53
|
+
:param epoch_index:
|
|
54
|
+
:param report_frequency:
|
|
55
|
+
:param device: device to move tensors to. None for do nothing
|
|
56
|
+
:return:
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Make sure gradient tracking is on, and do a pass over the data
|
|
60
|
+
self.model.train(True)
|
|
61
|
+
|
|
62
|
+
running_loss = 0.
|
|
63
|
+
last_loss = 0.
|
|
64
|
+
with tqdm(total=len(loader),desc="Batch") as pbar:
|
|
65
|
+
for i, data in enumerate(loader):
|
|
66
|
+
# Every data instance is an input + label pair
|
|
67
|
+
|
|
68
|
+
inputs, labels = data
|
|
69
|
+
if device is not None:
|
|
70
|
+
if isinstance(inputs, (list, tuple)):
|
|
71
|
+
inputs = map(lambda x: x.to(device, non_blocking=True), inputs)
|
|
72
|
+
else:
|
|
73
|
+
inputs = inputs.to(device, non_blocking=True)
|
|
74
|
+
|
|
75
|
+
labels = labels.to(device, non_blocking=True)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
#Zero your gradients for every batch
|
|
79
|
+
self.optimizer.zero_grad()
|
|
80
|
+
|
|
81
|
+
# Make predictions for this batch
|
|
82
|
+
outputs = self.model(*inputs)
|
|
83
|
+
|
|
84
|
+
# Compute the loss and its gradients
|
|
85
|
+
loss = self.loss_fn(outputs, labels)
|
|
86
|
+
loss.backward()
|
|
87
|
+
|
|
88
|
+
# clip the gradient
|
|
89
|
+
if self.grad_f is not None:
|
|
90
|
+
self.grad_f(self.model)
|
|
91
|
+
|
|
92
|
+
# Adjust learning weights
|
|
93
|
+
self.optimizer.step()
|
|
94
|
+
|
|
95
|
+
if self.scheduler is not None:
|
|
96
|
+
self.scheduler.step()
|
|
97
|
+
|
|
98
|
+
# Gather data and report
|
|
99
|
+
running_loss += loss.item()
|
|
100
|
+
if i % report_frequency == report_frequency - 1:
|
|
101
|
+
pbar.set_postfix({'Batch ': i + 1,
|
|
102
|
+
'Last loss': last_loss,
|
|
103
|
+
}, refresh=False)
|
|
104
|
+
pbar.update(report_frequency)
|
|
105
|
+
|
|
106
|
+
last_loss = running_loss / report_frequency # loss per item
|
|
107
|
+
#print(' batch {} loss: {}'.format(i + 1, last_loss))
|
|
108
|
+
tb_x = epoch_index * len(loader) + i + 1
|
|
109
|
+
self.writer.add_scalar('Loss/train', last_loss, tb_x)
|
|
110
|
+
running_loss = 0.
|
|
111
|
+
|
|
112
|
+
return last_loss
|
|
113
|
+
|
|
114
|
+
def _validate(self, validation_loader, device, sample_out=False):
|
|
115
|
+
|
|
116
|
+
self.model.train(False)
|
|
117
|
+
|
|
118
|
+
running_vloss = 0.0
|
|
119
|
+
running_score = 0.0
|
|
120
|
+
|
|
121
|
+
pred_validation=[]
|
|
122
|
+
label_validation=[]
|
|
123
|
+
|
|
124
|
+
for i, vdata in enumerate(validation_loader):
|
|
125
|
+
vinputs, vlabels = vdata
|
|
126
|
+
|
|
127
|
+
if device is not None:
|
|
128
|
+
if isinstance(vinputs, (list, tuple)):
|
|
129
|
+
vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
|
|
130
|
+
|
|
131
|
+
else:
|
|
132
|
+
vinputs = vinputs.to(device, non_blocking=True)
|
|
133
|
+
|
|
134
|
+
vlabels = vlabels.to(device, non_blocking=True)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
voutputs = self.model(*vinputs)
|
|
139
|
+
vloss = self.loss_fn(voutputs, vlabels)
|
|
140
|
+
running_vloss += vloss.item()
|
|
141
|
+
|
|
142
|
+
pred_validation.append(voutputs.cpu().detach().numpy())
|
|
143
|
+
label_validation.append(vlabels.cpu().detach().numpy())
|
|
144
|
+
|
|
145
|
+
vf1 = self.score_function(voutputs.cpu(), vlabels.cpu())
|
|
146
|
+
running_score += vf1
|
|
147
|
+
|
|
148
|
+
avg_vloss = running_vloss / (i + 1)
|
|
149
|
+
avg_score = running_score / (i + 1)
|
|
150
|
+
# print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
|
|
151
|
+
# print('Weighted avg f1 {}'.format(avg_f1))
|
|
152
|
+
if not sample_out:
|
|
153
|
+
return avg_vloss, avg_score
|
|
154
|
+
else:
|
|
155
|
+
return avg_vloss, avg_score, np.concatenate(label_validation,axis=0), np.concatenate(pred_validation,axis=0)
|
|
156
|
+
|
|
157
|
+
def train(self, epochs, training_loader, test_loader, validation_loader=None, report_per_epoch=10,
|
|
158
|
+
writer_base_path="runs", model_base_path=".", model_tag="model", device="cpu", validation_path=None):
|
|
159
|
+
# Initializing in a separate cell so we can easily add more epochs to the same run
|
|
160
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
161
|
+
self.writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
|
|
162
|
+
|
|
163
|
+
model_name = f"{model_tag}_{timestamp}"
|
|
164
|
+
base_dir = f"{model_base_path}/{model_name}"
|
|
165
|
+
os.mkdir(base_dir)
|
|
166
|
+
|
|
167
|
+
n_batch = len(training_loader)
|
|
168
|
+
|
|
169
|
+
report_frequency = math.ceil(n_batch / report_per_epoch)
|
|
170
|
+
|
|
171
|
+
best_score_epoch = 0.
|
|
172
|
+
if self.score_direction == -1:
|
|
173
|
+
best_score = 1_000_000
|
|
174
|
+
else:
|
|
175
|
+
best_score = 0
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
best_vloss = 1_000_000.
|
|
179
|
+
best_vloss_epoch = 0
|
|
180
|
+
|
|
181
|
+
best_vloss_val = 1_000_000.
|
|
182
|
+
|
|
183
|
+
best_epoch = 0
|
|
184
|
+
|
|
185
|
+
model_path = None
|
|
186
|
+
|
|
187
|
+
with tqdm(total=epochs, desc='Epoch') as pbar:
|
|
188
|
+
for epoch in range(epochs):
|
|
189
|
+
|
|
190
|
+
#print('EPOCH {}:'.format(epoch_number + 1))
|
|
191
|
+
avg_loss = self._epoch(training_loader, epoch, report_frequency, device)
|
|
192
|
+
|
|
193
|
+
# We don't need gradients on to do reporting
|
|
194
|
+
|
|
195
|
+
avg_vloss, avg_score = self._validate(test_loader, device)
|
|
196
|
+
|
|
197
|
+
if validation_loader is not None:
|
|
198
|
+
avg_vloss_val, avg_score, sample_label, sample_output = self._validate(validation_loader, device, sample_out=True)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# Log the running loss averaged per batch
|
|
202
|
+
# for both training and validation
|
|
203
|
+
self.writer.add_scalars('Training vs. Validation Loss',
|
|
204
|
+
{'Training': avg_loss, 'test': avg_vloss, "validation": avg_vloss_val},
|
|
205
|
+
epoch + 1)
|
|
206
|
+
|
|
207
|
+
self.writer.add_scalars(f'Weighted avg {self.score_name}',
|
|
208
|
+
{f'Weighted avg {self.score_name}': avg_score},
|
|
209
|
+
epoch + 1)
|
|
210
|
+
|
|
211
|
+
# todo f1 for all v batch at once
|
|
212
|
+
|
|
213
|
+
self.writer.flush()
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# Track the best performance, and save the model's state
|
|
217
|
+
if avg_vloss < best_vloss:
|
|
218
|
+
best_vloss = avg_vloss
|
|
219
|
+
best_vloss_epoch = epoch + 1
|
|
220
|
+
best_epoch = best_vloss_epoch
|
|
221
|
+
best_metric = "loss"
|
|
222
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}'
|
|
223
|
+
torch.save(self.model.state_dict(), model_path)
|
|
224
|
+
|
|
225
|
+
if avg_vloss_val < best_vloss_val:
|
|
226
|
+
best_vloss_val = avg_vloss_val
|
|
227
|
+
best_vloss_val_epoch = epoch + 1
|
|
228
|
+
|
|
229
|
+
if validation_path is not None:
|
|
230
|
+
np.savetxt(f"{validation_path}_label_path.csv", sample_label, delimiter=",")
|
|
231
|
+
np.savetxt(f"{validation_path}_output_path.csv", sample_output, delimiter=",")
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
if self.score_direction * avg_score > self.score_direction*best_score:
|
|
235
|
+
best_score = avg_score
|
|
236
|
+
best_score_epoch = epoch+1
|
|
237
|
+
best_epoch = best_score_epoch
|
|
238
|
+
best_metric = self.score_name
|
|
239
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_score_epoch}'
|
|
240
|
+
torch.save(self.model.state_dict(), model_path)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
pbar.set_postfix({f'best {self.score_name} epoch': best_score_epoch,
|
|
244
|
+
f'best {self.score_name}': best_score,
|
|
245
|
+
f'current {self.score_name}': avg_score,
|
|
246
|
+
'best avg loss epoch': best_vloss_epoch,
|
|
247
|
+
'best avg loss': best_vloss,
|
|
248
|
+
'current avg loss': avg_vloss,
|
|
249
|
+
'best val loos': best_vloss_val,
|
|
250
|
+
'best val epoch loos': best_vloss_val_epoch}, refresh=False)
|
|
251
|
+
pbar.update(1)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# load best model
|
|
255
|
+
self.model.load_state_dict(torch.load(model_path))
|
|
256
|
+
#switch off training
|
|
257
|
+
self.model.train(False)
|
|
258
|
+
# git model for inference
|
|
259
|
+
|
|
260
|
+
vinputs, _ = next(iter(test_loader))
|
|
261
|
+
|
|
262
|
+
if device is not None:
|
|
263
|
+
if isinstance(vinputs, (list, tuple)):
|
|
264
|
+
vinputs = list(map(lambda x: x.to(device), vinputs))
|
|
265
|
+
else:
|
|
266
|
+
vinputs = vinputs.to(device)
|
|
267
|
+
|
|
268
|
+
# switch off gradient
|
|
269
|
+
#todo update py torch
|
|
270
|
+
#torch.jit.enable_onednn_fusion(True)
|
|
271
|
+
with torch.inference_mode():
|
|
272
|
+
#model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
|
|
273
|
+
|
|
274
|
+
model_scripted = torch.jit.trace(self.model, example_inputs=vinputs)
|
|
275
|
+
model_scripted = torch.jit.freeze(model_scripted)
|
|
276
|
+
|
|
277
|
+
model_path = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
|
|
278
|
+
model_scripted.save(model_path) # Save
|
|
279
|
+
|
|
280
|
+
return base_dir, model_path, model_name
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def train_one_epoch(loader, optimizer, scheduler, model , loss_fn, epoch_index, tb_writer, report_frequency, grad_f=None,
|
|
285
|
+
device="cpu"):
|
|
286
|
+
"""
|
|
287
|
+
TODO DO AGGRESSIVBE VERSION
|
|
288
|
+
:param loader:
|
|
289
|
+
:param optimizer:
|
|
290
|
+
:param model:
|
|
291
|
+
:param loss_fn:
|
|
292
|
+
:param epoch_index:
|
|
293
|
+
:param tb_writer:
|
|
294
|
+
:param report_frequency:
|
|
295
|
+
:param grad_f:
|
|
296
|
+
:param device: device to move tensors to. None for do nothing
|
|
297
|
+
:return:
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
running_loss = 0.
|
|
301
|
+
last_loss = 0.
|
|
302
|
+
with tqdm(total=len(loader),desc="Batch") as pbar:
|
|
303
|
+
for i, data in enumerate(loader):
|
|
304
|
+
# Every data instance is an input + label pair
|
|
305
|
+
inputs, labels = data
|
|
306
|
+
|
|
307
|
+
if device is not None:
|
|
308
|
+
if isinstance(inputs, (list, tuple)):
|
|
309
|
+
inputs = map(lambda x: x.to(device, non_blocking=True), inputs)
|
|
310
|
+
else:
|
|
311
|
+
inputs = inputs.to(device, non_blocking=True) #create a tuple to match with list
|
|
312
|
+
|
|
313
|
+
labels = labels.to(device, non_blocking=True)
|
|
314
|
+
# Zero your gradients for every batch!
|
|
315
|
+
optimizer.zero_grad()
|
|
316
|
+
|
|
317
|
+
# Make predictions for this batch
|
|
318
|
+
outputs = model(*inputs)
|
|
319
|
+
|
|
320
|
+
del inputs
|
|
321
|
+
|
|
322
|
+
# Compute the loss and its gradients
|
|
323
|
+
loss = loss_fn(outputs, labels)
|
|
324
|
+
loss.backward()
|
|
325
|
+
|
|
326
|
+
# clip the gradient
|
|
327
|
+
if grad_f is not None:
|
|
328
|
+
grad_f(model)
|
|
329
|
+
|
|
330
|
+
# Adjust learning weights
|
|
331
|
+
optimizer.step()
|
|
332
|
+
|
|
333
|
+
if scheduler is not None:
|
|
334
|
+
scheduler.step()
|
|
335
|
+
# Gather data and report
|
|
336
|
+
running_loss += loss.item()
|
|
337
|
+
if i % report_frequency == report_frequency - 1:
|
|
338
|
+
pbar.set_postfix({'Batch ': i + 1,
|
|
339
|
+
'Last loss': last_loss,
|
|
340
|
+
}, refresh=False)
|
|
341
|
+
pbar.update(report_frequency)
|
|
342
|
+
|
|
343
|
+
last_loss = running_loss / report_frequency # loss per item
|
|
344
|
+
#print(' batch {} loss: {}'.format(i + 1, last_loss))
|
|
345
|
+
tb_x = epoch_index * len(loader) + i + 1
|
|
346
|
+
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
|
|
347
|
+
running_loss = 0.
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
return last_loss
|
|
353
|
+
|
|
354
|
+
# PyTorch TensorBoard support
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def train(epochs, model, optimizer, loss_fn, scheduler, training_loader, validation_loader, report_per_epoch=10,
|
|
358
|
+
writer_base_path="runs", model_base_path=".", model_tag="model", grad_f=None, device="cpu"):
|
|
359
|
+
# Initializing in a separate cell so we can easily add more epochs to the same run
|
|
360
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
361
|
+
writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
model_name = f"{model_tag}_{timestamp}"
|
|
365
|
+
base_dir = f"{model_base_path}/{model_name}"
|
|
366
|
+
os.mkdir(base_dir)
|
|
367
|
+
|
|
368
|
+
n_batch = len(training_loader)
|
|
369
|
+
|
|
370
|
+
report_frequency = math.ceil(n_batch / report_per_epoch)
|
|
371
|
+
|
|
372
|
+
best_vloss = 1_000_000.
|
|
373
|
+
model_path = None
|
|
374
|
+
|
|
375
|
+
for epoch in range(epochs):
|
|
376
|
+
#print('EPOCH {}:'.format(epoch + 1))
|
|
377
|
+
|
|
378
|
+
# Make sure gradient tracking is on, and do a pass over the data
|
|
379
|
+
model.train(True)
|
|
380
|
+
avg_loss = train_one_epoch(training_loader, optimizer, scheduler, model, loss_fn, epoch, writer, report_frequency, grad_f,device)
|
|
381
|
+
|
|
382
|
+
# We don't need gradients on to do reporting
|
|
383
|
+
model.train(False)
|
|
384
|
+
|
|
385
|
+
running_vloss = 0.0
|
|
386
|
+
i=0
|
|
387
|
+
for i, vdata in enumerate(validation_loader):
|
|
388
|
+
vinputs, vlabels = vdata
|
|
389
|
+
|
|
390
|
+
if device is not None:
|
|
391
|
+
if isinstance(vinputs, (list, tuple)):
|
|
392
|
+
vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
|
|
393
|
+
else:
|
|
394
|
+
vinputs = vinputs.to(device, non_blocking=True)
|
|
395
|
+
|
|
396
|
+
vlabels = vlabels.to(device, non_blocking=True)
|
|
397
|
+
|
|
398
|
+
voutputs = model(*vinputs)
|
|
399
|
+
vloss = loss_fn(voutputs, vlabels)
|
|
400
|
+
running_vloss += vloss
|
|
401
|
+
|
|
402
|
+
avg_vloss = running_vloss / (i + 1)
|
|
403
|
+
#print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
|
|
404
|
+
|
|
405
|
+
# Log the running loss averaged per batch
|
|
406
|
+
# for both training and validation
|
|
407
|
+
writer.add_scalars('Training vs. Validation Loss',
|
|
408
|
+
{'Training': avg_loss, 'Validation': avg_vloss},
|
|
409
|
+
epoch + 1)
|
|
410
|
+
writer.flush()
|
|
411
|
+
|
|
412
|
+
# Track the best performance, and save the model's state
|
|
413
|
+
|
|
414
|
+
if avg_vloss < best_vloss:
|
|
415
|
+
best_vloss = avg_vloss
|
|
416
|
+
best_metric = "loss"
|
|
417
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{epoch}'
|
|
418
|
+
torch.save(model.state_dict(), model_path)
|
|
419
|
+
|
|
420
|
+
epoch += 1
|
|
421
|
+
|
|
422
|
+
model.load_state_dict(torch.load(model_path))
|
|
423
|
+
model.train(False)
|
|
424
|
+
# git model for inference
|
|
425
|
+
model_scripted = torch.jit.script(
|
|
426
|
+
model) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
|
|
427
|
+
model_path = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{epoch}.pt'
|
|
428
|
+
model_scripted.save(model_path) # Save
|
|
429
|
+
|
|
430
|
+
return base_dir, model_path, model_name
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def train_labeling(epochs, model, optimizer, loss_fn, scheduler, training_loader, validation_loader, report_per_epoch=10,
|
|
435
|
+
writer_base_path="runs", model_base_path=".", model_tag="model", grad_f=None, device="cpu"):
|
|
436
|
+
# Initializing in a separate cell so we can easily add more epochs to the same run
|
|
437
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
438
|
+
writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
|
|
439
|
+
|
|
440
|
+
model_name = f"{model_tag}_{timestamp}"
|
|
441
|
+
base_dir = f"{model_base_path}/{model_name}"
|
|
442
|
+
os.makedirs(base_dir, exist_ok=True)
|
|
443
|
+
|
|
444
|
+
n_batch = len(training_loader)
|
|
445
|
+
|
|
446
|
+
report_frequency = math.ceil(n_batch / report_per_epoch)
|
|
447
|
+
|
|
448
|
+
best_f1_epoch=0.
|
|
449
|
+
best_f1 = 0.
|
|
450
|
+
|
|
451
|
+
best_vloss = 1_000_000.
|
|
452
|
+
best_vloss_epoch = 0
|
|
453
|
+
|
|
454
|
+
best_epoch = 0
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
model_path = None
|
|
459
|
+
|
|
460
|
+
with tqdm(total=epochs, desc='Epoch') as pbar:
|
|
461
|
+
for epoch in range(epochs):
|
|
462
|
+
|
|
463
|
+
#print('EPOCH {}:'.format(epoch_number + 1))
|
|
464
|
+
|
|
465
|
+
# Make sure gradient tracking is on, and do a pass over the data
|
|
466
|
+
model.train(True)
|
|
467
|
+
avg_loss = train_one_epoch(training_loader, optimizer, scheduler, model, loss_fn, epoch, writer, report_frequency, grad_f,device)
|
|
468
|
+
|
|
469
|
+
# We don't need gradients on to do reporting
|
|
470
|
+
model.train(False)
|
|
471
|
+
|
|
472
|
+
running_vloss = 0.0
|
|
473
|
+
running_vf1 = 0.0
|
|
474
|
+
for i, vdata in enumerate(validation_loader):
|
|
475
|
+
vinputs, vlabels = vdata
|
|
476
|
+
|
|
477
|
+
if device is not None:
|
|
478
|
+
if isinstance(vinputs, (list, tuple)):
|
|
479
|
+
vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
|
|
480
|
+
else:
|
|
481
|
+
vinputs = vinputs.to(device, non_blocking=True) #create a tuple to match with list
|
|
482
|
+
|
|
483
|
+
vlabels = vlabels.to(device, non_blocking=True)
|
|
484
|
+
|
|
485
|
+
voutputs = model(*vinputs)
|
|
486
|
+
|
|
487
|
+
del vinputs
|
|
488
|
+
|
|
489
|
+
vloss = loss_fn(voutputs, vlabels)
|
|
490
|
+
running_vloss += vloss.item()
|
|
491
|
+
|
|
492
|
+
vf1 = f1(voutputs.cpu(), vlabels.cpu())
|
|
493
|
+
running_vf1 += vf1
|
|
494
|
+
|
|
495
|
+
avg_vloss = running_vloss / (i + 1)
|
|
496
|
+
avg_f1 = running_vf1/ (i + 1)
|
|
497
|
+
#print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
|
|
498
|
+
#print('Weighted avg f1 {}'.format(avg_f1))
|
|
499
|
+
|
|
500
|
+
# Log the running loss averaged per batch
|
|
501
|
+
# for both training and validation
|
|
502
|
+
writer.add_scalars('Training vs. Validation Loss',
|
|
503
|
+
{'Training': avg_loss, 'Validation': avg_vloss},
|
|
504
|
+
epoch + 1)
|
|
505
|
+
writer.add_scalars('Weighted avg f1',
|
|
506
|
+
{'Weighted avg f1': avg_f1},
|
|
507
|
+
epoch + 1)
|
|
508
|
+
|
|
509
|
+
# todo f1 for all v batch at once
|
|
510
|
+
|
|
511
|
+
writer.flush()
|
|
512
|
+
|
|
513
|
+
# Track the best performance, and save the model's state
|
|
514
|
+
if avg_vloss < best_vloss:
|
|
515
|
+
best_vloss = avg_vloss
|
|
516
|
+
best_vloss_epoch = epoch + 1
|
|
517
|
+
best_epoch = best_vloss_epoch
|
|
518
|
+
best_metric = "loss"
|
|
519
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}_{int(1000*best_vloss)}'
|
|
520
|
+
torch.save(model.state_dict(), model_path)
|
|
521
|
+
|
|
522
|
+
if avg_f1 > best_f1:
|
|
523
|
+
best_f1 = avg_f1
|
|
524
|
+
best_f1_epoch = epoch+1
|
|
525
|
+
best_epoch = best_f1_epoch
|
|
526
|
+
best_metric = "f1"
|
|
527
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_f1_epoch}_{int(100*best_f1)}'
|
|
528
|
+
torch.save(model.state_dict(), model_path)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
pbar.set_postfix({'best f1 epoch': best_f1_epoch,
|
|
532
|
+
'best f1': best_f1,
|
|
533
|
+
'current f1': avg_f1,
|
|
534
|
+
'best avg loss epoch': best_vloss_epoch,
|
|
535
|
+
'best avg loss': best_vloss,
|
|
536
|
+
'current avg loss': avg_vloss}, refresh=False)
|
|
537
|
+
pbar.update(1)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
# load best model
|
|
541
|
+
model.load_state_dict(torch.load(model_path))
|
|
542
|
+
#switch off training
|
|
543
|
+
model.train(False)
|
|
544
|
+
# git model for inference
|
|
545
|
+
|
|
546
|
+
vinputs, _ = next(iter(validation_loader))
|
|
547
|
+
|
|
548
|
+
if device is not None:
|
|
549
|
+
if isinstance(vinputs, (list, tuple)):
|
|
550
|
+
vinputs = map(lambda x: x.to(device), vinputs)
|
|
551
|
+
else:
|
|
552
|
+
vinputs = vinputs.to(device) # create a tuple to match with list
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
# switch off gradient
|
|
556
|
+
#todo update py torch
|
|
557
|
+
#torch.jit.enable_onednn_fusion(True)
|
|
558
|
+
with torch.inference_mode():
|
|
559
|
+
#model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
|
|
560
|
+
model_scripted = torch.jit.trace(model, example_inputs=vinputs)
|
|
561
|
+
model_scripted = torch.jit.freeze(model_scripted)
|
|
562
|
+
|
|
563
|
+
model_path_jitted = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
|
|
564
|
+
model_scripted.save(model_path_jitted) # Save
|
|
565
|
+
|
|
566
|
+
return base_dir, model_path, model_path_jitted, model_name
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def agressive_train_labeling(epochs, model, optimizer, loss_fn, scheduler, training_loader, validation_loader, report_per_epoch=10,
|
|
570
|
+
writer_base_path="runs", model_base_path=".", model_tag="model", grad_f=None, device="cpu"):
|
|
571
|
+
"""
|
|
572
|
+
same as train bug start asynchrone loader more agreessively for more performance
|
|
573
|
+
"""
|
|
574
|
+
|
|
575
|
+
# Initializing in a separate cell so we can easily add more epochs to the same run
|
|
576
|
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
577
|
+
writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
|
|
578
|
+
|
|
579
|
+
model_name = f"{model_tag}_{timestamp}"
|
|
580
|
+
base_dir = f"{model_base_path}/{model_name}"
|
|
581
|
+
os.makedirs(base_dir, exist_ok=True)
|
|
582
|
+
|
|
583
|
+
n_batch = len(training_loader)
|
|
584
|
+
|
|
585
|
+
report_frequency = math.ceil(n_batch / report_per_epoch)
|
|
586
|
+
|
|
587
|
+
best_f1_epoch = 0.
|
|
588
|
+
best_f1 = 0.
|
|
589
|
+
|
|
590
|
+
best_vloss = 1_000_000.
|
|
591
|
+
best_vloss_epoch = 0
|
|
592
|
+
|
|
593
|
+
best_epoch = 0
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
model_path = None
|
|
598
|
+
|
|
599
|
+
# start the asynchronous loader
|
|
600
|
+
if epochs >0:
|
|
601
|
+
train_iter = iter(training_loader)
|
|
602
|
+
valid_iter = iter(validation_loader)
|
|
603
|
+
|
|
604
|
+
with tqdm(total=epochs, desc='Epoch') as pbar:
|
|
605
|
+
for epoch in range(epochs):
|
|
606
|
+
|
|
607
|
+
#print('EPOCH {}:'.format(epoch_number + 1))
|
|
608
|
+
|
|
609
|
+
# Make sure gradient tracking is on, and do a pass over the data
|
|
610
|
+
model.train(True)
|
|
611
|
+
avg_loss = agressive_train_one_epoch(train_iter, len(training_loader), optimizer, scheduler, model, loss_fn, epoch, writer, report_frequency, grad_f,device)
|
|
612
|
+
|
|
613
|
+
if epoch < epochs-1:
|
|
614
|
+
train_iter = iter(training_loader)
|
|
615
|
+
|
|
616
|
+
# We don't need gradients on to do reporting
|
|
617
|
+
model.train(True)
|
|
618
|
+
|
|
619
|
+
running_vloss = 0.0
|
|
620
|
+
running_vf1 = 0.0
|
|
621
|
+
for i, vdata in enumerate(valid_iter):
|
|
622
|
+
vinputs, vlabels = vdata
|
|
623
|
+
|
|
624
|
+
if device is not None:
|
|
625
|
+
if isinstance(vinputs, (list, tuple)):
|
|
626
|
+
vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
|
|
627
|
+
else:
|
|
628
|
+
vinputs = vinputs.to(device, non_blocking=True) # create a tuple to match with list
|
|
629
|
+
vlabels = vlabels.to(device, non_blocking=True)
|
|
630
|
+
|
|
631
|
+
voutputs = model(*vinputs)
|
|
632
|
+
vloss = loss_fn(voutputs, vlabels)
|
|
633
|
+
running_vloss += vloss.item()
|
|
634
|
+
|
|
635
|
+
del vinputs
|
|
636
|
+
|
|
637
|
+
vf1 = f1(voutputs.cpu(), vlabels.cpu())
|
|
638
|
+
running_vf1 += vf1
|
|
639
|
+
if epoch < epochs-1:
|
|
640
|
+
valid_iter = iter(validation_loader)
|
|
641
|
+
|
|
642
|
+
avg_vloss = running_vloss / (i + 1)
|
|
643
|
+
avg_f1 = running_vf1/ (i + 1)
|
|
644
|
+
#print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
|
|
645
|
+
#print('Weighted avg f1 {}'.format(avg_f1))
|
|
646
|
+
|
|
647
|
+
# Log the running loss averaged per batch
|
|
648
|
+
# for both training and validation
|
|
649
|
+
writer.add_scalars('Training vs. Validation Loss',
|
|
650
|
+
{'Training': avg_loss, 'Validation': avg_vloss},
|
|
651
|
+
epoch + 1)
|
|
652
|
+
writer.add_scalars('Weighted avg f1',
|
|
653
|
+
{'Weighted avg f1': avg_f1},
|
|
654
|
+
epoch + 1)
|
|
655
|
+
|
|
656
|
+
# todo f1 for all v batch at once
|
|
657
|
+
|
|
658
|
+
writer.flush()
|
|
659
|
+
|
|
660
|
+
# Track the best performance, and save the model's state
|
|
661
|
+
if avg_vloss < best_vloss:
|
|
662
|
+
best_vloss = avg_vloss
|
|
663
|
+
best_vloss_epoch = epoch + 1
|
|
664
|
+
best_epoch = best_vloss_epoch
|
|
665
|
+
best_metric = "loss"
|
|
666
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}_{int(1000*best_vloss)}'
|
|
667
|
+
torch.save(model.state_dict(), model_path)
|
|
668
|
+
|
|
669
|
+
if avg_f1 > best_f1:
|
|
670
|
+
best_f1 = avg_f1
|
|
671
|
+
best_f1_epoch = epoch+1
|
|
672
|
+
best_epoch = best_f1_epoch
|
|
673
|
+
best_metric = "f1"
|
|
674
|
+
model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_f1_epoch}_{int(100*best_f1)}'
|
|
675
|
+
torch.save(model.state_dict(), model_path)
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
pbar.set_postfix({'best f1 epoch': best_f1_epoch,
|
|
679
|
+
'best f1': best_f1,
|
|
680
|
+
'current f1': avg_f1,
|
|
681
|
+
'best avg loss epoch': best_vloss_epoch,
|
|
682
|
+
'best avg loss': best_vloss,
|
|
683
|
+
'current avg loss': avg_vloss}, refresh=False)
|
|
684
|
+
pbar.update(1)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
# load best model
|
|
688
|
+
model.load_state_dict(torch.load(model_path))
|
|
689
|
+
#switch off training
|
|
690
|
+
model.train(False)
|
|
691
|
+
# git model for inference
|
|
692
|
+
|
|
693
|
+
vinputs, _ = next(iter(validation_loader))
|
|
694
|
+
|
|
695
|
+
if device is not None:
|
|
696
|
+
if isinstance(vinputs, (list, tuple)):
|
|
697
|
+
vinputs = tuple(map(lambda x: x.to(device), vinputs)) #trace need tuple for input
|
|
698
|
+
else:
|
|
699
|
+
vinputs = vinputs.to(device)
|
|
700
|
+
|
|
701
|
+
# switch off gradient
|
|
702
|
+
#todo update py torch
|
|
703
|
+
#torch.jit.enable_onednn_fusion(True)
|
|
704
|
+
with torch.inference_mode():
|
|
705
|
+
#model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
|
|
706
|
+
model_scripted = torch.jit.trace(model, example_inputs=vinputs)
|
|
707
|
+
model_scripted = torch.jit.freeze(model_scripted)
|
|
708
|
+
|
|
709
|
+
model_path_jitted = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
|
|
710
|
+
model_scripted.save(model_path_jitted) # Save
|
|
711
|
+
|
|
712
|
+
return base_dir, model_path, model_path_jitted, model_name
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def agressive_train_one_epoch(loader_iter, loader_lenght, optimizer, scheduler, model, loss_fn, epoch_index, tb_writer, report_frequency, grad_f=None,
|
|
716
|
+
device="cpu"):
|
|
717
|
+
"""
|
|
718
|
+
|
|
719
|
+
:param loader:
|
|
720
|
+
:param optimizer:
|
|
721
|
+
:param model:
|
|
722
|
+
:param loss_fn:
|
|
723
|
+
:param epoch_index:
|
|
724
|
+
:param tb_writer:
|
|
725
|
+
:param report_frequency:
|
|
726
|
+
:param grad_f:
|
|
727
|
+
:param device: device to move tensors to. None for do nothing
|
|
728
|
+
:return:
|
|
729
|
+
"""
|
|
730
|
+
|
|
731
|
+
running_loss = 0.
|
|
732
|
+
last_loss = 0.
|
|
733
|
+
with tqdm(total=loader_lenght,desc="Batch") as pbar:
|
|
734
|
+
for i, data in enumerate(loader_iter):
|
|
735
|
+
# Every data instance is an input + label pair
|
|
736
|
+
inputs, labels = data
|
|
737
|
+
|
|
738
|
+
if device is not None:
|
|
739
|
+
if isinstance(inputs, (list, tuple)):
|
|
740
|
+
inputs = map(lambda x: x.to(device, non_blocking=True), inputs)
|
|
741
|
+
else:
|
|
742
|
+
inputs = inputs.to(device, non_blocking=True)
|
|
743
|
+
|
|
744
|
+
labels = labels.to(device, non_blocking=True)
|
|
745
|
+
# Zero your gradients for every batch!
|
|
746
|
+
optimizer.zero_grad()
|
|
747
|
+
|
|
748
|
+
# Make predictions for this batch
|
|
749
|
+
outputs = model(*inputs)
|
|
750
|
+
|
|
751
|
+
# Compute the loss and its gradients
|
|
752
|
+
loss = loss_fn(outputs, labels)
|
|
753
|
+
loss.backward()
|
|
754
|
+
|
|
755
|
+
# clip the gradient
|
|
756
|
+
if grad_f is not None:
|
|
757
|
+
grad_f(model)
|
|
758
|
+
|
|
759
|
+
# Adjust learning weights
|
|
760
|
+
optimizer.step()
|
|
761
|
+
if scheduler is not None:
|
|
762
|
+
scheduler.step()
|
|
763
|
+
|
|
764
|
+
# Gather data and report
|
|
765
|
+
running_loss += loss.item()
|
|
766
|
+
|
|
767
|
+
if i % report_frequency == report_frequency - 1:
|
|
768
|
+
pbar.set_postfix({'Batch ': i + 1,
|
|
769
|
+
'Last loss': last_loss,
|
|
770
|
+
}, refresh=False)
|
|
771
|
+
pbar.update(report_frequency)
|
|
772
|
+
|
|
773
|
+
last_loss = running_loss / report_frequency # loss per item
|
|
774
|
+
#print(' batch {} loss: {}'.format(i + 1, last_loss))
|
|
775
|
+
tb_x = epoch_index * loader_lenght + i + 1
|
|
776
|
+
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
|
|
777
|
+
running_loss = 0.
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
return last_loss
|