dynamic-network-architectures 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dynamic_network_architectures/__init__.py +0 -0
- dynamic_network_architectures/architectures/__init__.py +0 -0
- dynamic_network_architectures/architectures/abstract_arch.py +40 -0
- dynamic_network_architectures/architectures/dinov2_eva.py +782 -0
- dynamic_network_architectures/architectures/primus.py +1053 -0
- dynamic_network_architectures/architectures/resnet.py +554 -0
- dynamic_network_architectures/architectures/unet.py +365 -0
- dynamic_network_architectures/architectures/vgg.py +119 -0
- dynamic_network_architectures/building_blocks/__init__.py +0 -0
- dynamic_network_architectures/building_blocks/eva.py +213 -0
- dynamic_network_architectures/building_blocks/helper.py +242 -0
- dynamic_network_architectures/building_blocks/patch_encode_decode.py +399 -0
- dynamic_network_architectures/building_blocks/plain_conv_encoder.py +105 -0
- dynamic_network_architectures/building_blocks/regularization.py +86 -0
- dynamic_network_architectures/building_blocks/residual.py +371 -0
- dynamic_network_architectures/building_blocks/residual_encoders.py +172 -0
- dynamic_network_architectures/building_blocks/simple_conv_blocks.py +167 -0
- dynamic_network_architectures/building_blocks/unet_decoder.py +154 -0
- dynamic_network_architectures/building_blocks/unet_residual_decoder.py +155 -0
- dynamic_network_architectures/initialization/__init__.py +0 -0
- dynamic_network_architectures/initialization/weight_init.py +34 -0
- dynamic_network_architectures-0.4.4.dist-info/METADATA +214 -0
- dynamic_network_architectures-0.4.4.dist-info/RECORD +26 -0
- dynamic_network_architectures-0.4.4.dist-info/WHEEL +5 -0
- dynamic_network_architectures-0.4.4.dist-info/licenses/LICENSE +201 -0
- dynamic_network_architectures-0.4.4.dist-info/top_level.txt +1 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import Literal, Sequence
|
|
3
|
+
from torch import nn
|
|
4
|
+
import torch
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AbstractDynamicNetworkArchitectures(nn.Module):
|
|
9
|
+
|
|
10
|
+
def __init__(self):
|
|
11
|
+
super(AbstractDynamicNetworkArchitectures, self).__init__()
|
|
12
|
+
# Key to the position holding all the encoder weights
|
|
13
|
+
self.key_to_encoder: str
|
|
14
|
+
# Key to the full stem -- Can be located within or outside the encoder
|
|
15
|
+
self.key_to_stem: str
|
|
16
|
+
# Not sure yet if we need anything but this -- but minor redundancy is okay I suppose
|
|
17
|
+
# Key to the weights that are dependent on the input channels.
|
|
18
|
+
# Can hold multiple weights (e.g. for bad weight mappings like in this repo >.<' )
|
|
19
|
+
self.keys_to_in_proj: Sequence[str]
|
|
20
|
+
self.key_to_lpe: str | None = None # LPE == Learnable Positional Embedding
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_submodules_loadable(model: AbstractDynamicNetworkArchitectures):
|
|
24
|
+
encoder_key = model.key_to_encoder
|
|
25
|
+
stem_key = model.key_to_stem
|
|
26
|
+
stem_weights_key = model.keys_to_in_proj
|
|
27
|
+
# Check if the encoder submodule is loadable
|
|
28
|
+
# Throws an error otherwise.
|
|
29
|
+
_ = model.get_submodule(encoder_key)
|
|
30
|
+
_ = model.get_submodule(stem_key)
|
|
31
|
+
prev_shape = None
|
|
32
|
+
for swk in stem_weights_key:
|
|
33
|
+
stem_weights_submodule = model.get_submodule(swk).weight
|
|
34
|
+
if prev_shape is None:
|
|
35
|
+
prev_shape = stem_weights_submodule.shape
|
|
36
|
+
else:
|
|
37
|
+
assert stem_weights_submodule.shape == prev_shape, f"Stem weights submodule {swk} has different shape"
|
|
38
|
+
prev_shape = stem_weights_submodule.shape
|
|
39
|
+
assert stem_weights_submodule is not None, f"Stem weights submodule {swk} is not loadable"
|
|
40
|
+
return
|