hypatorch 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hypatorch
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: HypaTorch: A library for abstract and visual model configuration
5
5
  Home-page: https://github.com/Altavo/hypatorch/
6
6
  Author: Altavo GmbH
@@ -5,4 +5,4 @@ from hypatorch.losses import MMAE_Loss
5
5
  from hypatorch.losses import MSE_Loss
6
6
  from hypatorch.losses import MMSE_Loss
7
7
 
8
- __version__ = '0.2.2'
8
+ __version__ = '0.2.3'
@@ -350,7 +350,21 @@ class Model( L.LightningModule ):
350
350
  )
351
351
 
352
352
  sm_out_dict = { key: value for key, value in zip( expected_outputs, submodule_out ) }
353
- x = { key_map: sm_out_dict[ key ] for key, key_map in output_key_map.items() }
353
+ #x = { key_map: sm_out_dict[ key ] for key, key_map in output_key_map.items() }
354
+ x = {}
355
+ for key, key_map in output_key_map.items():
356
+ if isinstance( key_map, str ):
357
+ x[ key_map ] = sm_out_dict[ key ]
358
+ elif isinstance( key_map, List ) or isinstance( key_map, ListConfig ):
359
+ for idx, km in enumerate( key_map ):
360
+ x[ km ] = sm_out_dict[ key ][ idx ]
361
+ else:
362
+ raise ValueError(
363
+ f"""
364
+ Error with output key mapping of {submodule_name}.
365
+ Expected a string or a list, but got {type(key_map)}.
366
+ """
367
+ )
354
368
 
355
369
  return x
356
370
 
@@ -408,6 +422,7 @@ class Model( L.LightningModule ):
408
422
  mode = 'train'
409
423
 
410
424
  input_dict = batch
425
+ output_dict = {}
411
426
 
412
427
  opts = self.optimizers()
413
428
 
@@ -419,15 +434,24 @@ class Model( L.LightningModule ):
419
434
  operation_name = list( self.operations.keys() )[ operation_idx ]
420
435
 
421
436
  # Forward Pass
422
- output_dict, loss = self._forward_pass(
423
- input_dict = input_dict,
437
+ operation_out, loss = self._forward_pass(
438
+ input_dict = shared_dict(
439
+ input_dict,
440
+ output_dict,
441
+ ),
424
442
  operation_name = operation_name,
425
443
  mode = mode,
426
444
  )
445
+
446
+ output_dict = self._handle_operation_output(
447
+ x = operation_out,
448
+ output_dict = output_dict,
449
+ operation_name = operation_name,
450
+ )
427
451
 
428
452
  # Backward Pass if self.losses is not empty list
429
453
  opt = opts[ operation_idx ]
430
- if self.losses:
454
+ if self.losses[ operation_name ]:
431
455
  #opt = opts[ operation_idx ]
432
456
  self._backward_pass(
433
457
  opt = opt,
@@ -454,16 +478,26 @@ class Model( L.LightningModule ):
454
478
  mode = 'val'
455
479
 
456
480
  input_dict = batch
481
+ output_dict = {}
457
482
 
458
483
  for operation_idx, _ in enumerate( self.operations ):
459
484
  operation_name = list( self.operations.keys() )[ operation_idx ]
460
485
  # Forward Pass
461
486
  with torch.no_grad():
462
- output_dict, loss = self._forward_pass(
463
- input_dict = input_dict,
487
+ operation_out, loss = self._forward_pass(
488
+ input_dict = shared_dict(
489
+ input_dict,
490
+ output_dict,
491
+ ),
464
492
  operation_name = operation_name,
465
493
  mode = mode,
466
494
  )
495
+
496
+ output_dict = self._handle_operation_output(
497
+ x = operation_out,
498
+ output_dict = output_dict,
499
+ operation_name = operation_name,
500
+ )
467
501
 
468
502
  # handle metrics
469
503
  self._handle_assessments(
@@ -586,22 +620,52 @@ class Model( L.LightningModule ):
586
620
  else:
587
621
  assessments_dict = None
588
622
  return assessments_dict
623
+
624
+ def _handle_operation_output(
625
+ self,
626
+ x,
627
+ output_dict,
628
+ operation_name,
629
+ ):
630
+ if not any( x in output_dict.keys() for x in x.keys() ):
631
+ output_dict.update( x )
632
+ else:
633
+ raise ValueError(
634
+ f"""
635
+ Error with output_dict of {operation_name}.
636
+ Operations are not allowed to overwrite existing keys.
637
+ However, {operation_name} has the following keys: {x.keys()}
638
+ and the output_dict has the following keys: {output_dict.keys()}.
639
+ """
640
+ )
641
+ return output_dict
589
642
 
590
643
  def predict_step(self, batch, batch_idx):
591
644
 
592
645
  mode = 'test'
593
646
 
594
647
  input_dict = batch
648
+ output_dict = {}
595
649
 
596
650
  for operation_idx, _ in enumerate( self.operations ):
597
651
  operation_name = list( self.operations.keys() )[ operation_idx ]
598
652
  # Forward Pass
599
653
  with torch.no_grad():
600
- output_dict, loss = self._forward_pass(
601
- input_dict = input_dict,
654
+ operation_out, loss = self._forward_pass(
655
+ input_dict = shared_dict(
656
+ input_dict,
657
+ output_dict,
658
+ ),
602
659
  operation_name = operation_name,
603
660
  mode = mode,
604
661
  )
662
+
663
+ output_dict = self._handle_operation_output(
664
+ x = operation_out,
665
+ output_dict = output_dict,
666
+ operation_name = operation_name,
667
+ )
668
+
605
669
 
606
670
  data_dict = shared_dict(
607
671
  input_dict,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: hypatorch
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: HypaTorch: A library for abstract and visual model configuration
5
5
  Home-page: https://github.com/Altavo/hypatorch/
6
6
  Author: Altavo GmbH
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes