eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
eoml/torch/trainer.py ADDED
@@ -0,0 +1,782 @@
1
+ import math
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import numpy as np
6
+ import torch
7
+ from sklearn.metrics import f1_score
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from tqdm import tqdm
10
+
11
+
12
+ def clip_grad_norm(model, clip_grad_val=1):
13
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_val)
14
+
15
+
16
+ class GradNormClipper:
17
+
18
+ def __init__(self, clip_val):
19
+ self.clip_val = clip_val
20
+
21
+ def __call__(self, model):
22
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_val)
23
+
24
+
25
+
26
+ def f1(output, labels):
27
+ pred_labels = torch.argmax(output, dim=1)
28
+
29
+ return f1_score(labels, pred_labels, labels=None, average='weighted', sample_weight=None,
30
+ zero_division='warn')
31
+
32
+ class Trainer:
33
+ """TODO DO AGGRESSIVBE VERSION"""
34
+ def __init__(self, optimizer, model, loss_fn, grad_f=None, score_function=f1, score_name="f1", score_direction=1, scheduler=None):
35
+ self.optimizer = optimizer
36
+ self.model = model
37
+ self.loss_fn = loss_fn
38
+ self.grad_f = grad_f
39
+
40
+ self.score_direction = score_direction
41
+
42
+ self.writer = None
43
+
44
+ self.score_function = score_function
45
+ self.score_name = score_name
46
+
47
+ self.scheduler = scheduler
48
+
49
+ def _epoch(self, loader, epoch_index, report_frequency, device="cpu"):
50
+
51
+ """
52
+ :param loader:
53
+ :param epoch_index:
54
+ :param report_frequency:
55
+ :param device: device to move tensors to. None for do nothing
56
+ :return:
57
+ """
58
+
59
+ # Make sure gradient tracking is on, and do a pass over the data
60
+ self.model.train(True)
61
+
62
+ running_loss = 0.
63
+ last_loss = 0.
64
+ with tqdm(total=len(loader),desc="Batch") as pbar:
65
+ for i, data in enumerate(loader):
66
+ # Every data instance is an input + label pair
67
+
68
+ inputs, labels = data
69
+ if device is not None:
70
+ if isinstance(inputs, (list, tuple)):
71
+ inputs = map(lambda x: x.to(device, non_blocking=True), inputs)
72
+ else:
73
+ inputs = inputs.to(device, non_blocking=True)
74
+
75
+ labels = labels.to(device, non_blocking=True)
76
+
77
+
78
+ #Zero your gradients for every batch
79
+ self.optimizer.zero_grad()
80
+
81
+ # Make predictions for this batch
82
+ outputs = self.model(*inputs)
83
+
84
+ # Compute the loss and its gradients
85
+ loss = self.loss_fn(outputs, labels)
86
+ loss.backward()
87
+
88
+ # clip the gradient
89
+ if self.grad_f is not None:
90
+ self.grad_f(self.model)
91
+
92
+ # Adjust learning weights
93
+ self.optimizer.step()
94
+
95
+ if self.scheduler is not None:
96
+ self.scheduler.step()
97
+
98
+ # Gather data and report
99
+ running_loss += loss.item()
100
+ if i % report_frequency == report_frequency - 1:
101
+ pbar.set_postfix({'Batch ': i + 1,
102
+ 'Last loss': last_loss,
103
+ }, refresh=False)
104
+ pbar.update(report_frequency)
105
+
106
+ last_loss = running_loss / report_frequency # loss per item
107
+ #print(' batch {} loss: {}'.format(i + 1, last_loss))
108
+ tb_x = epoch_index * len(loader) + i + 1
109
+ self.writer.add_scalar('Loss/train', last_loss, tb_x)
110
+ running_loss = 0.
111
+
112
+ return last_loss
113
+
114
+ def _validate(self, validation_loader, device, sample_out=False):
115
+
116
+ self.model.train(False)
117
+
118
+ running_vloss = 0.0
119
+ running_score = 0.0
120
+
121
+ pred_validation=[]
122
+ label_validation=[]
123
+
124
+ for i, vdata in enumerate(validation_loader):
125
+ vinputs, vlabels = vdata
126
+
127
+ if device is not None:
128
+ if isinstance(vinputs, (list, tuple)):
129
+ vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
130
+
131
+ else:
132
+ vinputs = vinputs.to(device, non_blocking=True)
133
+
134
+ vlabels = vlabels.to(device, non_blocking=True)
135
+
136
+
137
+
138
+ voutputs = self.model(*vinputs)
139
+ vloss = self.loss_fn(voutputs, vlabels)
140
+ running_vloss += vloss.item()
141
+
142
+ pred_validation.append(voutputs.cpu().detach().numpy())
143
+ label_validation.append(vlabels.cpu().detach().numpy())
144
+
145
+ vf1 = self.score_function(voutputs.cpu(), vlabels.cpu())
146
+ running_score += vf1
147
+
148
+ avg_vloss = running_vloss / (i + 1)
149
+ avg_score = running_score / (i + 1)
150
+ # print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
151
+ # print('Weighted avg f1 {}'.format(avg_f1))
152
+ if not sample_out:
153
+ return avg_vloss, avg_score
154
+ else:
155
+ return avg_vloss, avg_score, np.concatenate(label_validation,axis=0), np.concatenate(pred_validation,axis=0)
156
+
157
+ def train(self, epochs, training_loader, test_loader, validation_loader=None, report_per_epoch=10,
158
+ writer_base_path="runs", model_base_path=".", model_tag="model", device="cpu", validation_path=None):
159
+ # Initializing in a separate cell so we can easily add more epochs to the same run
160
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
161
+ self.writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
162
+
163
+ model_name = f"{model_tag}_{timestamp}"
164
+ base_dir = f"{model_base_path}/{model_name}"
165
+ os.mkdir(base_dir)
166
+
167
+ n_batch = len(training_loader)
168
+
169
+ report_frequency = math.ceil(n_batch / report_per_epoch)
170
+
171
+ best_score_epoch = 0.
172
+ if self.score_direction == -1:
173
+ best_score = 1_000_000
174
+ else:
175
+ best_score = 0
176
+
177
+
178
+ best_vloss = 1_000_000.
179
+ best_vloss_epoch = 0
180
+
181
+ best_vloss_val = 1_000_000.
182
+
183
+ best_epoch = 0
184
+
185
+ model_path = None
186
+
187
+ with tqdm(total=epochs, desc='Epoch') as pbar:
188
+ for epoch in range(epochs):
189
+
190
+ #print('EPOCH {}:'.format(epoch_number + 1))
191
+ avg_loss = self._epoch(training_loader, epoch, report_frequency, device)
192
+
193
+ # We don't need gradients on to do reporting
194
+
195
+ avg_vloss, avg_score = self._validate(test_loader, device)
196
+
197
+ if validation_loader is not None:
198
+ avg_vloss_val, avg_score, sample_label, sample_output = self._validate(validation_loader, device, sample_out=True)
199
+
200
+
201
+ # Log the running loss averaged per batch
202
+ # for both training and validation
203
+ self.writer.add_scalars('Training vs. Validation Loss',
204
+ {'Training': avg_loss, 'test': avg_vloss, "validation": avg_vloss_val},
205
+ epoch + 1)
206
+
207
+ self.writer.add_scalars(f'Weighted avg {self.score_name}',
208
+ {f'Weighted avg {self.score_name}': avg_score},
209
+ epoch + 1)
210
+
211
+ # todo f1 for all v batch at once
212
+
213
+ self.writer.flush()
214
+
215
+
216
+ # Track the best performance, and save the model's state
217
+ if avg_vloss < best_vloss:
218
+ best_vloss = avg_vloss
219
+ best_vloss_epoch = epoch + 1
220
+ best_epoch = best_vloss_epoch
221
+ best_metric = "loss"
222
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}'
223
+ torch.save(self.model.state_dict(), model_path)
224
+
225
+ if avg_vloss_val < best_vloss_val:
226
+ best_vloss_val = avg_vloss_val
227
+ best_vloss_val_epoch = epoch + 1
228
+
229
+ if validation_path is not None:
230
+ np.savetxt(f"{validation_path}_label_path.csv", sample_label, delimiter=",")
231
+ np.savetxt(f"{validation_path}_output_path.csv", sample_output, delimiter=",")
232
+
233
+
234
+ if self.score_direction * avg_score > self.score_direction*best_score:
235
+ best_score = avg_score
236
+ best_score_epoch = epoch+1
237
+ best_epoch = best_score_epoch
238
+ best_metric = self.score_name
239
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_score_epoch}'
240
+ torch.save(self.model.state_dict(), model_path)
241
+
242
+
243
+ pbar.set_postfix({f'best {self.score_name} epoch': best_score_epoch,
244
+ f'best {self.score_name}': best_score,
245
+ f'current {self.score_name}': avg_score,
246
+ 'best avg loss epoch': best_vloss_epoch,
247
+ 'best avg loss': best_vloss,
248
+ 'current avg loss': avg_vloss,
249
+ 'best val loos': best_vloss_val,
250
+ 'best val epoch loos': best_vloss_val_epoch}, refresh=False)
251
+ pbar.update(1)
252
+
253
+
254
+ # load best model
255
+ self.model.load_state_dict(torch.load(model_path))
256
+ #switch off training
257
+ self.model.train(False)
258
+ # git model for inference
259
+
260
+ vinputs, _ = next(iter(test_loader))
261
+
262
+ if device is not None:
263
+ if isinstance(vinputs, (list, tuple)):
264
+ vinputs = list(map(lambda x: x.to(device), vinputs))
265
+ else:
266
+ vinputs = vinputs.to(device)
267
+
268
+ # switch off gradient
269
+ #todo update py torch
270
+ #torch.jit.enable_onednn_fusion(True)
271
+ with torch.inference_mode():
272
+ #model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
273
+
274
+ model_scripted = torch.jit.trace(self.model, example_inputs=vinputs)
275
+ model_scripted = torch.jit.freeze(model_scripted)
276
+
277
+ model_path = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
278
+ model_scripted.save(model_path) # Save
279
+
280
+ return base_dir, model_path, model_name
281
+
282
+
283
+
284
+ def train_one_epoch(loader, optimizer, scheduler, model , loss_fn, epoch_index, tb_writer, report_frequency, grad_f=None,
285
+ device="cpu"):
286
+ """
287
+ TODO DO AGGRESSIVBE VERSION
288
+ :param loader:
289
+ :param optimizer:
290
+ :param model:
291
+ :param loss_fn:
292
+ :param epoch_index:
293
+ :param tb_writer:
294
+ :param report_frequency:
295
+ :param grad_f:
296
+ :param device: device to move tensors to. None for do nothing
297
+ :return:
298
+ """
299
+
300
+ running_loss = 0.
301
+ last_loss = 0.
302
+ with tqdm(total=len(loader),desc="Batch") as pbar:
303
+ for i, data in enumerate(loader):
304
+ # Every data instance is an input + label pair
305
+ inputs, labels = data
306
+
307
+ if device is not None:
308
+ if isinstance(inputs, (list, tuple)):
309
+ inputs = map(lambda x: x.to(device, non_blocking=True), inputs)
310
+ else:
311
+ inputs = inputs.to(device, non_blocking=True) #create a tuple to match with list
312
+
313
+ labels = labels.to(device, non_blocking=True)
314
+ # Zero your gradients for every batch!
315
+ optimizer.zero_grad()
316
+
317
+ # Make predictions for this batch
318
+ outputs = model(*inputs)
319
+
320
+ del inputs
321
+
322
+ # Compute the loss and its gradients
323
+ loss = loss_fn(outputs, labels)
324
+ loss.backward()
325
+
326
+ # clip the gradient
327
+ if grad_f is not None:
328
+ grad_f(model)
329
+
330
+ # Adjust learning weights
331
+ optimizer.step()
332
+
333
+ if scheduler is not None:
334
+ scheduler.step()
335
+ # Gather data and report
336
+ running_loss += loss.item()
337
+ if i % report_frequency == report_frequency - 1:
338
+ pbar.set_postfix({'Batch ': i + 1,
339
+ 'Last loss': last_loss,
340
+ }, refresh=False)
341
+ pbar.update(report_frequency)
342
+
343
+ last_loss = running_loss / report_frequency # loss per item
344
+ #print(' batch {} loss: {}'.format(i + 1, last_loss))
345
+ tb_x = epoch_index * len(loader) + i + 1
346
+ tb_writer.add_scalar('Loss/train', last_loss, tb_x)
347
+ running_loss = 0.
348
+
349
+
350
+
351
+
352
+ return last_loss
353
+
354
+ # PyTorch TensorBoard support
355
+
356
+
357
+ def train(epochs, model, optimizer, loss_fn, scheduler, training_loader, validation_loader, report_per_epoch=10,
358
+ writer_base_path="runs", model_base_path=".", model_tag="model", grad_f=None, device="cpu"):
359
+ # Initializing in a separate cell so we can easily add more epochs to the same run
360
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
361
+ writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
362
+
363
+
364
+ model_name = f"{model_tag}_{timestamp}"
365
+ base_dir = f"{model_base_path}/{model_name}"
366
+ os.mkdir(base_dir)
367
+
368
+ n_batch = len(training_loader)
369
+
370
+ report_frequency = math.ceil(n_batch / report_per_epoch)
371
+
372
+ best_vloss = 1_000_000.
373
+ model_path = None
374
+
375
+ for epoch in range(epochs):
376
+ #print('EPOCH {}:'.format(epoch + 1))
377
+
378
+ # Make sure gradient tracking is on, and do a pass over the data
379
+ model.train(True)
380
+ avg_loss = train_one_epoch(training_loader, optimizer, scheduler, model, loss_fn, epoch, writer, report_frequency, grad_f,device)
381
+
382
+ # We don't need gradients on to do reporting
383
+ model.train(False)
384
+
385
+ running_vloss = 0.0
386
+ i=0
387
+ for i, vdata in enumerate(validation_loader):
388
+ vinputs, vlabels = vdata
389
+
390
+ if device is not None:
391
+ if isinstance(vinputs, (list, tuple)):
392
+ vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
393
+ else:
394
+ vinputs = vinputs.to(device, non_blocking=True)
395
+
396
+ vlabels = vlabels.to(device, non_blocking=True)
397
+
398
+ voutputs = model(*vinputs)
399
+ vloss = loss_fn(voutputs, vlabels)
400
+ running_vloss += vloss
401
+
402
+ avg_vloss = running_vloss / (i + 1)
403
+ #print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
404
+
405
+ # Log the running loss averaged per batch
406
+ # for both training and validation
407
+ writer.add_scalars('Training vs. Validation Loss',
408
+ {'Training': avg_loss, 'Validation': avg_vloss},
409
+ epoch + 1)
410
+ writer.flush()
411
+
412
+ # Track the best performance, and save the model's state
413
+
414
+ if avg_vloss < best_vloss:
415
+ best_vloss = avg_vloss
416
+ best_metric = "loss"
417
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{epoch}'
418
+ torch.save(model.state_dict(), model_path)
419
+
420
+ epoch += 1
421
+
422
+ model.load_state_dict(torch.load(model_path))
423
+ model.train(False)
424
+ # git model for inference
425
+ model_scripted = torch.jit.script(
426
+ model) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
427
+ model_path = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{epoch}.pt'
428
+ model_scripted.save(model_path) # Save
429
+
430
+ return base_dir, model_path, model_name
431
+
432
+
433
+
434
+ def train_labeling(epochs, model, optimizer, loss_fn, scheduler, training_loader, validation_loader, report_per_epoch=10,
435
+ writer_base_path="runs", model_base_path=".", model_tag="model", grad_f=None, device="cpu"):
436
+ # Initializing in a separate cell so we can easily add more epochs to the same run
437
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
438
+ writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
439
+
440
+ model_name = f"{model_tag}_{timestamp}"
441
+ base_dir = f"{model_base_path}/{model_name}"
442
+ os.makedirs(base_dir, exist_ok=True)
443
+
444
+ n_batch = len(training_loader)
445
+
446
+ report_frequency = math.ceil(n_batch / report_per_epoch)
447
+
448
+ best_f1_epoch=0.
449
+ best_f1 = 0.
450
+
451
+ best_vloss = 1_000_000.
452
+ best_vloss_epoch = 0
453
+
454
+ best_epoch = 0
455
+
456
+
457
+
458
+ model_path = None
459
+
460
+ with tqdm(total=epochs, desc='Epoch') as pbar:
461
+ for epoch in range(epochs):
462
+
463
+ #print('EPOCH {}:'.format(epoch_number + 1))
464
+
465
+ # Make sure gradient tracking is on, and do a pass over the data
466
+ model.train(True)
467
+ avg_loss = train_one_epoch(training_loader, optimizer, scheduler, model, loss_fn, epoch, writer, report_frequency, grad_f,device)
468
+
469
+ # We don't need gradients on to do reporting
470
+ model.train(False)
471
+
472
+ running_vloss = 0.0
473
+ running_vf1 = 0.0
474
+ for i, vdata in enumerate(validation_loader):
475
+ vinputs, vlabels = vdata
476
+
477
+ if device is not None:
478
+ if isinstance(vinputs, (list, tuple)):
479
+ vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
480
+ else:
481
+ vinputs = vinputs.to(device, non_blocking=True) #create a tuple to match with list
482
+
483
+ vlabels = vlabels.to(device, non_blocking=True)
484
+
485
+ voutputs = model(*vinputs)
486
+
487
+ del vinputs
488
+
489
+ vloss = loss_fn(voutputs, vlabels)
490
+ running_vloss += vloss.item()
491
+
492
+ vf1 = f1(voutputs.cpu(), vlabels.cpu())
493
+ running_vf1 += vf1
494
+
495
+ avg_vloss = running_vloss / (i + 1)
496
+ avg_f1 = running_vf1/ (i + 1)
497
+ #print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
498
+ #print('Weighted avg f1 {}'.format(avg_f1))
499
+
500
+ # Log the running loss averaged per batch
501
+ # for both training and validation
502
+ writer.add_scalars('Training vs. Validation Loss',
503
+ {'Training': avg_loss, 'Validation': avg_vloss},
504
+ epoch + 1)
505
+ writer.add_scalars('Weighted avg f1',
506
+ {'Weighted avg f1': avg_f1},
507
+ epoch + 1)
508
+
509
+ # todo f1 for all v batch at once
510
+
511
+ writer.flush()
512
+
513
+ # Track the best performance, and save the model's state
514
+ if avg_vloss < best_vloss:
515
+ best_vloss = avg_vloss
516
+ best_vloss_epoch = epoch + 1
517
+ best_epoch = best_vloss_epoch
518
+ best_metric = "loss"
519
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}_{int(1000*best_vloss)}'
520
+ torch.save(model.state_dict(), model_path)
521
+
522
+ if avg_f1 > best_f1:
523
+ best_f1 = avg_f1
524
+ best_f1_epoch = epoch+1
525
+ best_epoch = best_f1_epoch
526
+ best_metric = "f1"
527
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_f1_epoch}_{int(100*best_f1)}'
528
+ torch.save(model.state_dict(), model_path)
529
+
530
+
531
+ pbar.set_postfix({'best f1 epoch': best_f1_epoch,
532
+ 'best f1': best_f1,
533
+ 'current f1': avg_f1,
534
+ 'best avg loss epoch': best_vloss_epoch,
535
+ 'best avg loss': best_vloss,
536
+ 'current avg loss': avg_vloss}, refresh=False)
537
+ pbar.update(1)
538
+
539
+
540
+ # load best model
541
+ model.load_state_dict(torch.load(model_path))
542
+ #switch off training
543
+ model.train(False)
544
+ # git model for inference
545
+
546
+ vinputs, _ = next(iter(validation_loader))
547
+
548
+ if device is not None:
549
+ if isinstance(vinputs, (list, tuple)):
550
+ vinputs = map(lambda x: x.to(device), vinputs)
551
+ else:
552
+ vinputs = vinputs.to(device) # create a tuple to match with list
553
+
554
+
555
+ # switch off gradient
556
+ #todo update py torch
557
+ #torch.jit.enable_onednn_fusion(True)
558
+ with torch.inference_mode():
559
+ #model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
560
+ model_scripted = torch.jit.trace(model, example_inputs=vinputs)
561
+ model_scripted = torch.jit.freeze(model_scripted)
562
+
563
+ model_path_jitted = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
564
+ model_scripted.save(model_path_jitted) # Save
565
+
566
+ return base_dir, model_path, model_path_jitted, model_name
567
+
568
+
569
+ def agressive_train_labeling(epochs, model, optimizer, loss_fn, scheduler, training_loader, validation_loader, report_per_epoch=10,
570
+ writer_base_path="runs", model_base_path=".", model_tag="model", grad_f=None, device="cpu"):
571
+ """
572
+ same as train bug start asynchrone loader more agreessively for more performance
573
+ """
574
+
575
+ # Initializing in a separate cell so we can easily add more epochs to the same run
576
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
577
+ writer = SummaryWriter(f"{writer_base_path}/{model_tag}_{timestamp}")
578
+
579
+ model_name = f"{model_tag}_{timestamp}"
580
+ base_dir = f"{model_base_path}/{model_name}"
581
+ os.makedirs(base_dir, exist_ok=True)
582
+
583
+ n_batch = len(training_loader)
584
+
585
+ report_frequency = math.ceil(n_batch / report_per_epoch)
586
+
587
+ best_f1_epoch = 0.
588
+ best_f1 = 0.
589
+
590
+ best_vloss = 1_000_000.
591
+ best_vloss_epoch = 0
592
+
593
+ best_epoch = 0
594
+
595
+
596
+
597
+ model_path = None
598
+
599
+ # start the asynchronous loader
600
+ if epochs >0:
601
+ train_iter = iter(training_loader)
602
+ valid_iter = iter(validation_loader)
603
+
604
+ with tqdm(total=epochs, desc='Epoch') as pbar:
605
+ for epoch in range(epochs):
606
+
607
+ #print('EPOCH {}:'.format(epoch_number + 1))
608
+
609
+ # Make sure gradient tracking is on, and do a pass over the data
610
+ model.train(True)
611
+ avg_loss = agressive_train_one_epoch(train_iter, len(training_loader), optimizer, scheduler, model, loss_fn, epoch, writer, report_frequency, grad_f,device)
612
+
613
+ if epoch < epochs-1:
614
+ train_iter = iter(training_loader)
615
+
616
+ # We don't need gradients on to do reporting
617
+ model.train(True)
618
+
619
+ running_vloss = 0.0
620
+ running_vf1 = 0.0
621
+ for i, vdata in enumerate(valid_iter):
622
+ vinputs, vlabels = vdata
623
+
624
+ if device is not None:
625
+ if isinstance(vinputs, (list, tuple)):
626
+ vinputs = map(lambda x: x.to(device, non_blocking=True), vinputs)
627
+ else:
628
+ vinputs = vinputs.to(device, non_blocking=True) # create a tuple to match with list
629
+ vlabels = vlabels.to(device, non_blocking=True)
630
+
631
+ voutputs = model(*vinputs)
632
+ vloss = loss_fn(voutputs, vlabels)
633
+ running_vloss += vloss.item()
634
+
635
+ del vinputs
636
+
637
+ vf1 = f1(voutputs.cpu(), vlabels.cpu())
638
+ running_vf1 += vf1
639
+ if epoch < epochs-1:
640
+ valid_iter = iter(validation_loader)
641
+
642
+ avg_vloss = running_vloss / (i + 1)
643
+ avg_f1 = running_vf1/ (i + 1)
644
+ #print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
645
+ #print('Weighted avg f1 {}'.format(avg_f1))
646
+
647
+ # Log the running loss averaged per batch
648
+ # for both training and validation
649
+ writer.add_scalars('Training vs. Validation Loss',
650
+ {'Training': avg_loss, 'Validation': avg_vloss},
651
+ epoch + 1)
652
+ writer.add_scalars('Weighted avg f1',
653
+ {'Weighted avg f1': avg_f1},
654
+ epoch + 1)
655
+
656
+ # todo f1 for all v batch at once
657
+
658
+ writer.flush()
659
+
660
+ # Track the best performance, and save the model's state
661
+ if avg_vloss < best_vloss:
662
+ best_vloss = avg_vloss
663
+ best_vloss_epoch = epoch + 1
664
+ best_epoch = best_vloss_epoch
665
+ best_metric = "loss"
666
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_vloss_epoch}_{int(1000*best_vloss)}'
667
+ torch.save(model.state_dict(), model_path)
668
+
669
+ if avg_f1 > best_f1:
670
+ best_f1 = avg_f1
671
+ best_f1_epoch = epoch+1
672
+ best_epoch = best_f1_epoch
673
+ best_metric = "f1"
674
+ model_path = f'{base_dir}/{best_metric}_{model_tag}_{timestamp}_{best_f1_epoch}_{int(100*best_f1)}'
675
+ torch.save(model.state_dict(), model_path)
676
+
677
+
678
+ pbar.set_postfix({'best f1 epoch': best_f1_epoch,
679
+ 'best f1': best_f1,
680
+ 'current f1': avg_f1,
681
+ 'best avg loss epoch': best_vloss_epoch,
682
+ 'best avg loss': best_vloss,
683
+ 'current avg loss': avg_vloss}, refresh=False)
684
+ pbar.update(1)
685
+
686
+
687
+ # load best model
688
+ model.load_state_dict(torch.load(model_path))
689
+ #switch off training
690
+ model.train(False)
691
+ # git model for inference
692
+
693
+ vinputs, _ = next(iter(validation_loader))
694
+
695
+ if device is not None:
696
+ if isinstance(vinputs, (list, tuple)):
697
+ vinputs = tuple(map(lambda x: x.to(device), vinputs)) #trace need tuple for input
698
+ else:
699
+ vinputs = vinputs.to(device)
700
+
701
+ # switch off gradient
702
+ #todo update py torch
703
+ #torch.jit.enable_onednn_fusion(True)
704
+ with torch.inference_mode():
705
+ #model_scripted = torch.jit.script(model, example_inputs=vinputs) # Export to TorchScript, from the doc: TorchScript is actually the recommended model format for scaled inference and deployment.
706
+ model_scripted = torch.jit.trace(model, example_inputs=vinputs)
707
+ model_scripted = torch.jit.freeze(model_scripted)
708
+
709
+ model_path_jitted = f'{base_dir}/jited_{best_metric}_{model_tag}_{timestamp}_{best_epoch}.pt'
710
+ model_scripted.save(model_path_jitted) # Save
711
+
712
+ return base_dir, model_path, model_path_jitted, model_name
713
+
714
+
715
+ def agressive_train_one_epoch(loader_iter, loader_lenght, optimizer, scheduler, model, loss_fn, epoch_index, tb_writer, report_frequency, grad_f=None,
716
+ device="cpu"):
717
+ """
718
+
719
+ :param loader:
720
+ :param optimizer:
721
+ :param model:
722
+ :param loss_fn:
723
+ :param epoch_index:
724
+ :param tb_writer:
725
+ :param report_frequency:
726
+ :param grad_f:
727
+ :param device: device to move tensors to. None for do nothing
728
+ :return:
729
+ """
730
+
731
+ running_loss = 0.
732
+ last_loss = 0.
733
+ with tqdm(total=loader_lenght,desc="Batch") as pbar:
734
+ for i, data in enumerate(loader_iter):
735
+ # Every data instance is an input + label pair
736
+ inputs, labels = data
737
+
738
+ if device is not None:
739
+ if isinstance(inputs, (list, tuple)):
740
+ inputs = map(lambda x: x.to(device, non_blocking=True), inputs)
741
+ else:
742
+ inputs = inputs.to(device, non_blocking=True)
743
+
744
+ labels = labels.to(device, non_blocking=True)
745
+ # Zero your gradients for every batch!
746
+ optimizer.zero_grad()
747
+
748
+ # Make predictions for this batch
749
+ outputs = model(*inputs)
750
+
751
+ # Compute the loss and its gradients
752
+ loss = loss_fn(outputs, labels)
753
+ loss.backward()
754
+
755
+ # clip the gradient
756
+ if grad_f is not None:
757
+ grad_f(model)
758
+
759
+ # Adjust learning weights
760
+ optimizer.step()
761
+ if scheduler is not None:
762
+ scheduler.step()
763
+
764
+ # Gather data and report
765
+ running_loss += loss.item()
766
+
767
+ if i % report_frequency == report_frequency - 1:
768
+ pbar.set_postfix({'Batch ': i + 1,
769
+ 'Last loss': last_loss,
770
+ }, refresh=False)
771
+ pbar.update(report_frequency)
772
+
773
+ last_loss = running_loss / report_frequency # loss per item
774
+ #print(' batch {} loss: {}'.format(i + 1, last_loss))
775
+ tb_x = epoch_index * loader_lenght + i + 1
776
+ tb_writer.add_scalar('Loss/train', last_loss, tb_x)
777
+ running_loss = 0.
778
+
779
+
780
+
781
+
782
+ return last_loss