hypatorch 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hypatorch-0.2.2 → hypatorch-0.2.3}/PKG-INFO +1 -1
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch/__init__.py +1 -1
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch/core.py +72 -8
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch.egg-info/PKG-INFO +1 -1
- {hypatorch-0.2.2 → hypatorch-0.2.3}/LICENSE +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/README.md +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch/assessments.py +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch/losses.py +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch/utils.py +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch.egg-info/SOURCES.txt +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch.egg-info/dependency_links.txt +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch.egg-info/requires.txt +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/hypatorch.egg-info/top_level.txt +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/setup.cfg +0 -0
- {hypatorch-0.2.2 → hypatorch-0.2.3}/setup.py +0 -0
|
@@ -350,7 +350,21 @@ class Model( L.LightningModule ):
|
|
|
350
350
|
)
|
|
351
351
|
|
|
352
352
|
sm_out_dict = { key: value for key, value in zip( expected_outputs, submodule_out ) }
|
|
353
|
-
x = { key_map: sm_out_dict[ key ] for key, key_map in output_key_map.items() }
|
|
353
|
+
#x = { key_map: sm_out_dict[ key ] for key, key_map in output_key_map.items() }
|
|
354
|
+
x = {}
|
|
355
|
+
for key, key_map in output_key_map.items():
|
|
356
|
+
if isinstance( key_map, str ):
|
|
357
|
+
x[ key_map ] = sm_out_dict[ key ]
|
|
358
|
+
elif isinstance( key_map, List ) or isinstance( key_map, ListConfig ):
|
|
359
|
+
for idx, km in enumerate( key_map ):
|
|
360
|
+
x[ km ] = sm_out_dict[ key ][ idx ]
|
|
361
|
+
else:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"""
|
|
364
|
+
Error with output key mapping of {submodule_name}.
|
|
365
|
+
Expected a string or a list, but got {type(key_map)}.
|
|
366
|
+
"""
|
|
367
|
+
)
|
|
354
368
|
|
|
355
369
|
return x
|
|
356
370
|
|
|
@@ -408,6 +422,7 @@ class Model( L.LightningModule ):
|
|
|
408
422
|
mode = 'train'
|
|
409
423
|
|
|
410
424
|
input_dict = batch
|
|
425
|
+
output_dict = {}
|
|
411
426
|
|
|
412
427
|
opts = self.optimizers()
|
|
413
428
|
|
|
@@ -419,15 +434,24 @@ class Model( L.LightningModule ):
|
|
|
419
434
|
operation_name = list( self.operations.keys() )[ operation_idx ]
|
|
420
435
|
|
|
421
436
|
# Forward Pass
|
|
422
|
-
|
|
423
|
-
input_dict =
|
|
437
|
+
operation_out, loss = self._forward_pass(
|
|
438
|
+
input_dict = shared_dict(
|
|
439
|
+
input_dict,
|
|
440
|
+
output_dict,
|
|
441
|
+
),
|
|
424
442
|
operation_name = operation_name,
|
|
425
443
|
mode = mode,
|
|
426
444
|
)
|
|
445
|
+
|
|
446
|
+
output_dict = self._handle_operation_output(
|
|
447
|
+
x = operation_out,
|
|
448
|
+
output_dict = output_dict,
|
|
449
|
+
operation_name = operation_name,
|
|
450
|
+
)
|
|
427
451
|
|
|
428
452
|
# Backward Pass if self.losses is not empty list
|
|
429
453
|
opt = opts[ operation_idx ]
|
|
430
|
-
if self.losses:
|
|
454
|
+
if self.losses[ operation_name ]:
|
|
431
455
|
#opt = opts[ operation_idx ]
|
|
432
456
|
self._backward_pass(
|
|
433
457
|
opt = opt,
|
|
@@ -454,16 +478,26 @@ class Model( L.LightningModule ):
|
|
|
454
478
|
mode = 'val'
|
|
455
479
|
|
|
456
480
|
input_dict = batch
|
|
481
|
+
output_dict = {}
|
|
457
482
|
|
|
458
483
|
for operation_idx, _ in enumerate( self.operations ):
|
|
459
484
|
operation_name = list( self.operations.keys() )[ operation_idx ]
|
|
460
485
|
# Forward Pass
|
|
461
486
|
with torch.no_grad():
|
|
462
|
-
|
|
463
|
-
input_dict =
|
|
487
|
+
operation_out, loss = self._forward_pass(
|
|
488
|
+
input_dict = shared_dict(
|
|
489
|
+
input_dict,
|
|
490
|
+
output_dict,
|
|
491
|
+
),
|
|
464
492
|
operation_name = operation_name,
|
|
465
493
|
mode = mode,
|
|
466
494
|
)
|
|
495
|
+
|
|
496
|
+
output_dict = self._handle_operation_output(
|
|
497
|
+
x = operation_out,
|
|
498
|
+
output_dict = output_dict,
|
|
499
|
+
operation_name = operation_name,
|
|
500
|
+
)
|
|
467
501
|
|
|
468
502
|
# handle metrics
|
|
469
503
|
self._handle_assessments(
|
|
@@ -586,22 +620,52 @@ class Model( L.LightningModule ):
|
|
|
586
620
|
else:
|
|
587
621
|
assessments_dict = None
|
|
588
622
|
return assessments_dict
|
|
623
|
+
|
|
624
|
+
def _handle_operation_output(
|
|
625
|
+
self,
|
|
626
|
+
x,
|
|
627
|
+
output_dict,
|
|
628
|
+
operation_name,
|
|
629
|
+
):
|
|
630
|
+
if not any( x in output_dict.keys() for x in x.keys() ):
|
|
631
|
+
output_dict.update( x )
|
|
632
|
+
else:
|
|
633
|
+
raise ValueError(
|
|
634
|
+
f"""
|
|
635
|
+
Error with output_dict of {operation_name}.
|
|
636
|
+
Operations are not allowed to overwrite existing keys.
|
|
637
|
+
However, {operation_name} has the following keys: {x.keys()}
|
|
638
|
+
and the output_dict has the following keys: {output_dict.keys()}.
|
|
639
|
+
"""
|
|
640
|
+
)
|
|
641
|
+
return output_dict
|
|
589
642
|
|
|
590
643
|
def predict_step(self, batch, batch_idx):
|
|
591
644
|
|
|
592
645
|
mode = 'test'
|
|
593
646
|
|
|
594
647
|
input_dict = batch
|
|
648
|
+
output_dict = {}
|
|
595
649
|
|
|
596
650
|
for operation_idx, _ in enumerate( self.operations ):
|
|
597
651
|
operation_name = list( self.operations.keys() )[ operation_idx ]
|
|
598
652
|
# Forward Pass
|
|
599
653
|
with torch.no_grad():
|
|
600
|
-
|
|
601
|
-
input_dict =
|
|
654
|
+
operation_out, loss = self._forward_pass(
|
|
655
|
+
input_dict = shared_dict(
|
|
656
|
+
input_dict,
|
|
657
|
+
output_dict,
|
|
658
|
+
),
|
|
602
659
|
operation_name = operation_name,
|
|
603
660
|
mode = mode,
|
|
604
661
|
)
|
|
662
|
+
|
|
663
|
+
output_dict = self._handle_operation_output(
|
|
664
|
+
x = operation_out,
|
|
665
|
+
output_dict = output_dict,
|
|
666
|
+
operation_name = operation_name,
|
|
667
|
+
)
|
|
668
|
+
|
|
605
669
|
|
|
606
670
|
data_dict = shared_dict(
|
|
607
671
|
input_dict,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|