kaiko-eva 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/_recorder.py +4 -1
  17. eva/core/trainers/functional.py +40 -43
  18. eva/core/utils/factory.py +66 -0
  19. eva/core/utils/registry.py +42 -0
  20. eva/core/utils/requirements.py +26 -0
  21. eva/language/__init__.py +13 -0
  22. eva/language/data/__init__.py +5 -0
  23. eva/language/data/datasets/__init__.py +9 -0
  24. eva/language/data/datasets/classification/__init__.py +7 -0
  25. eva/language/data/datasets/classification/base.py +63 -0
  26. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  27. eva/language/data/datasets/language.py +13 -0
  28. eva/language/models/__init__.py +25 -0
  29. eva/language/models/modules/__init__.py +5 -0
  30. eva/language/models/modules/text.py +85 -0
  31. eva/language/models/modules/typings.py +16 -0
  32. eva/language/models/wrappers/__init__.py +11 -0
  33. eva/language/models/wrappers/huggingface.py +69 -0
  34. eva/language/models/wrappers/litellm.py +77 -0
  35. eva/language/models/wrappers/vllm.py +149 -0
  36. eva/language/utils/__init__.py +5 -0
  37. eva/language/utils/str_to_int_tensor.py +95 -0
  38. eva/vision/data/dataloaders/__init__.py +2 -1
  39. eva/vision/data/dataloaders/worker_init.py +35 -0
  40. eva/vision/data/datasets/__init__.py +5 -5
  41. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  42. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  43. eva/vision/data/datasets/segmentation/consep.py +5 -4
  44. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  45. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  46. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  47. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  48. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  49. eva/vision/data/transforms/__init__.py +11 -2
  50. eva/vision/data/transforms/base/__init__.py +5 -0
  51. eva/vision/data/transforms/base/monai.py +27 -0
  52. eva/vision/data/transforms/common/__init__.py +2 -1
  53. eva/vision/data/transforms/common/squeeze.py +24 -0
  54. eva/vision/data/transforms/croppad/__init__.py +4 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  56. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  57. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  58. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  59. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  60. eva/vision/models/modules/semantic_segmentation.py +27 -11
  61. eva/vision/models/networks/backbones/__init__.py +2 -3
  62. eva/vision/models/networks/backbones/_utils.py +1 -1
  63. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  64. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  65. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  66. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  67. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  68. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  71. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  72. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  73. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  74. eva/vision/models/networks/backbones/registry.py +2 -44
  75. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  76. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  77. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  78. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  80. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  81. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  82. eva/vision/models/wrappers/from_registry.py +14 -9
  83. eva/vision/models/wrappers/from_timm.py +6 -5
  84. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/METADATA +22 -12
  85. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/RECORD +89 -58
  86. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/WHEEL +1 -1
  87. eva/vision/data/datasets/segmentation/lits.py +0 -199
  88. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  89. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  90. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/entry_points.txt +0 -0
  91. {kaiko_eva-0.2.1.dist-info → kaiko_eva-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
1
- """"Neural Network Semantic Segmentation Module."""
1
+ """Neural Network Semantic Segmentation Module."""
2
2
 
3
- from typing import Any, Callable, Dict, Iterable, List
3
+ import functools
4
+ from typing import Any, Callable, Dict, Iterable, List, Tuple
4
5
 
5
6
  import torch
6
7
  from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
@@ -11,7 +12,7 @@ from torch.optim import lr_scheduler
11
12
  from typing_extensions import override
12
13
 
13
14
  from eva.core.metrics import structs as metrics_lib
14
- from eva.core.models.modules import module
15
+ from eva.core.models.modules import SchedulerConfiguration, module
15
16
  from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
16
17
  from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
17
18
  from eva.core.utils import parser
@@ -31,10 +32,11 @@ class SemanticSegmentationModule(module.ModelModule):
31
32
  lr_multiplier_encoder: float = 0.0,
32
33
  inferer: Inferer | None = None,
33
34
  optimizer: OptimizerCallable = optim.AdamW,
34
- lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
35
+ lr_scheduler: LRSchedulerCallable | SchedulerConfiguration = lr_scheduler.ConstantLR,
35
36
  metrics: metrics_lib.MetricsSchema | None = None,
36
37
  postprocess: batch_postprocess.BatchPostProcess | None = None,
37
38
  save_decoder_only: bool = True,
39
+ spatial_dims: int = 2,
38
40
  ) -> None:
39
41
  """Initializes the neural net head module.
40
42
 
@@ -57,6 +59,8 @@ class SemanticSegmentationModule(module.ModelModule):
57
59
  predictions and targets.
58
60
  save_decoder_only: Whether to save only the decoder during checkpointing. If False,
59
61
  will also save the encoder (not recommended when frozen).
62
+ spatial_dims: The number of spatial dimensions, 2 for 2D
63
+ and 3 for 3D segmentation.
60
64
  """
61
65
  super().__init__(metrics=metrics, postprocess=postprocess)
62
66
 
@@ -68,6 +72,7 @@ class SemanticSegmentationModule(module.ModelModule):
68
72
  self.lr_scheduler = lr_scheduler
69
73
  self.save_decoder_only = save_decoder_only
70
74
  self.inferer = inferer
75
+ self.spatial_dims = spatial_dims
71
76
 
72
77
  @override
73
78
  def configure_model(self) -> None:
@@ -111,13 +116,14 @@ class SemanticSegmentationModule(module.ModelModule):
111
116
  def forward(
112
117
  self,
113
118
  tensor: torch.Tensor,
119
+ to_size: Tuple[int, ...],
114
120
  *args: Any,
115
121
  **kwargs: Any,
116
122
  ) -> torch.Tensor:
117
123
  return (
118
- self.inferer(tensor, network=self._forward_networks)
124
+ self.inferer(tensor, network=functools.partial(self._forward_networks, to_size=to_size))
119
125
  if self.inferer is not None and not self.training
120
- else self._forward_networks(tensor)
126
+ else self._forward_networks(tensor, to_size=to_size)
121
127
  )
122
128
 
123
129
  @override
@@ -168,7 +174,8 @@ class SemanticSegmentationModule(module.ModelModule):
168
174
  The batch step output.
169
175
  """
170
176
  data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
171
- predictions = self(data)
177
+ to_size = targets.shape[-self.spatial_dims :] if self.inferer is None else None
178
+ predictions = self(data, to_size=to_size)
172
179
  loss = self.criterion(predictions, targets)
173
180
  return {
174
181
  "loss": loss,
@@ -177,12 +184,21 @@ class SemanticSegmentationModule(module.ModelModule):
177
184
  "metadata": metadata,
178
185
  }
179
186
 
180
- def _forward_networks(self, tensor: torch.Tensor) -> torch.Tensor:
187
+ def _forward_networks(
188
+ self, tensor: torch.Tensor, to_size: Tuple[int, ...] | None = None
189
+ ) -> torch.Tensor:
181
190
  """Passes the input tensor through the encoder and decoder."""
182
- features = self.encoder(tensor) if self.encoder else tensor
191
+ if self.encoder:
192
+ to_size = to_size or tuple(tensor.shape[-self.spatial_dims :])
193
+ features = self.encoder(tensor)
194
+ else:
195
+ if to_size is None:
196
+ raise ValueError("`to_size` must be provided when no encoder is used.")
197
+ features = tensor
198
+
183
199
  if isinstance(self.decoder, segmentation.Decoder):
184
200
  if not isinstance(features, list):
185
201
  raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
186
- image_size = (tensor.shape[-2], tensor.shape[-1])
187
- return self.decoder(DecoderInputs(features, image_size, tensor))
202
+ return self.decoder(DecoderInputs(features, to_size, tensor))
203
+
188
204
  return self.decoder(features)
@@ -1,13 +1,12 @@
1
1
  """Vision Model Backbones API."""
2
2
 
3
3
  from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
4
- from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
4
+ from eva.vision.models.networks.backbones.registry import backbone_registry
5
5
 
6
6
  __all__ = [
7
7
  "radiology",
8
8
  "pathology",
9
9
  "timm",
10
10
  "universal",
11
- "BackboneModelRegistry",
12
- "register_model",
11
+ "backbone_registry",
13
12
  ]
@@ -36,7 +36,7 @@ def load_hugingface_model(
36
36
 
37
37
  return models.HuggingFaceModel(
38
38
  model_name_or_path=model_name,
39
- tensor_transforms=tensor_transforms,
39
+ transforms=tensor_transforms,
40
40
  model_kwargs=model_kwargs,
41
41
  )
42
42
 
@@ -8,10 +8,10 @@ from torch import nn
8
8
  from eva.core.models import transforms
9
9
  from eva.vision.models import wrappers
10
10
  from eva.vision.models.networks.backbones import _utils
11
- from eva.vision.models.networks.backbones.registry import register_model
11
+ from eva.vision.models.networks.backbones.registry import backbone_registry
12
12
 
13
13
 
14
- @register_model("pathology/bioptimus_h_optimus_0")
14
+ @backbone_registry.register("pathology/bioptimus_h_optimus_0")
15
15
  def bioptimus_h_optimus_0(
16
16
  dynamic_img_size: bool = True,
17
17
  out_indices: int | Tuple[int, ...] | None = None,
@@ -39,7 +39,7 @@ def bioptimus_h_optimus_0(
39
39
  )
40
40
 
41
41
 
42
- @register_model("pathology/bioptimus_h0_mini")
42
+ @backbone_registry.register("pathology/bioptimus_h0_mini")
43
43
  def bioptimus_h0_mini(
44
44
  dynamic_img_size: bool = True,
45
45
  out_indices: int | Tuple[int, ...] | None = None,
@@ -72,7 +72,7 @@ def bioptimus_h0_mini(
72
72
  "mlp_layer": timm.layers.SwiGLUPacked,
73
73
  "act_layer": nn.SiLU,
74
74
  },
75
- tensor_transforms=(
75
+ transforms=(
76
76
  transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
77
77
  if out_indices is None
78
78
  else None
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("pathology/prov_gigapath")
11
+ @backbone_registry.register("pathology/prov_gigapath")
12
12
  def prov_gigapath(
13
13
  dynamic_img_size: bool = True,
14
14
  out_indices: int | Tuple[int, ...] | None = None,
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  from torch import nn
6
6
 
7
7
  from eva.vision.models.networks.backbones import _utils
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("pathology/histai_hibou_b")
11
+ @backbone_registry.register("pathology/histai_hibou_b")
12
12
  def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
13
  """Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
14
14
 
@@ -30,7 +30,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
30
30
  )
31
31
 
32
32
 
33
- @register_model("pathology/histai_hibou_l")
33
+ @backbone_registry.register("pathology/histai_hibou_l")
34
34
  def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
35
35
  """Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
36
36
 
@@ -7,10 +7,10 @@ import timm
7
7
  from torch import nn
8
8
 
9
9
  from eva.core.models.wrappers import _utils
10
- from eva.vision.models.networks.backbones.registry import register_model
10
+ from eva.vision.models.networks.backbones.registry import backbone_registry
11
11
 
12
12
 
13
- @register_model("pathology/hkust_gpfm")
13
+ @backbone_registry.register("pathology/hkust_gpfm")
14
14
  def hkust_gpfm(
15
15
  dynamic_img_size: bool = True,
16
16
  out_indices: int | Tuple[int, ...] | None = None,
@@ -6,10 +6,10 @@ import torch
6
6
  from torch import nn
7
7
 
8
8
  from eva.vision.models.networks.backbones import _utils
9
- from eva.vision.models.networks.backbones.registry import register_model
9
+ from eva.vision.models.networks.backbones.registry import backbone_registry
10
10
 
11
11
 
12
- @register_model("pathology/kaiko_midnight_12k")
12
+ @backbone_registry.register("pathology/kaiko_midnight_12k")
13
13
  def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
14
14
  """Initializes the Midnight-12k pathology FM by kaiko.ai.
15
15
 
@@ -26,7 +26,7 @@ def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.M
26
26
  )
27
27
 
28
28
 
29
- @register_model("pathology/kaiko_vits16")
29
+ @backbone_registry.register("pathology/kaiko_vits16")
30
30
  def kaiko_vits16(
31
31
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
32
32
  ) -> nn.Module:
@@ -49,7 +49,7 @@ def kaiko_vits16(
49
49
  )
50
50
 
51
51
 
52
- @register_model("pathology/kaiko_vits8")
52
+ @backbone_registry.register("pathology/kaiko_vits8")
53
53
  def kaiko_vits8(
54
54
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
55
55
  ) -> nn.Module:
@@ -72,7 +72,7 @@ def kaiko_vits8(
72
72
  )
73
73
 
74
74
 
75
- @register_model("pathology/kaiko_vitb16")
75
+ @backbone_registry.register("pathology/kaiko_vitb16")
76
76
  def kaiko_vitb16(
77
77
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
78
78
  ) -> nn.Module:
@@ -95,7 +95,7 @@ def kaiko_vitb16(
95
95
  )
96
96
 
97
97
 
98
- @register_model("pathology/kaiko_vitb8")
98
+ @backbone_registry.register("pathology/kaiko_vitb8")
99
99
  def kaiko_vitb8(
100
100
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
101
101
  ) -> nn.Module:
@@ -118,7 +118,7 @@ def kaiko_vitb8(
118
118
  )
119
119
 
120
120
 
121
- @register_model("pathology/kaiko_vitl14")
121
+ @backbone_registry.register("pathology/kaiko_vitl14")
122
122
  def kaiko_vitl14(
123
123
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
124
124
  ) -> nn.Module:
@@ -13,14 +13,14 @@ from typing import Tuple
13
13
  from torch import nn
14
14
 
15
15
  from eva.vision.models import wrappers
16
- from eva.vision.models.networks.backbones.registry import register_model
16
+ from eva.vision.models.networks.backbones.registry import backbone_registry
17
17
 
18
18
  VITS_URL_PREFIX = (
19
19
  "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
20
20
  )
21
21
 
22
22
 
23
- @register_model("pathology/lunit_vits16")
23
+ @backbone_registry.register("pathology/lunit_vits16")
24
24
  def lunit_vits16(
25
25
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
26
26
  ) -> nn.Module:
@@ -44,7 +44,7 @@ def lunit_vits16(
44
44
  )
45
45
 
46
46
 
47
- @register_model("pathology/lunit_vits8")
47
+ @backbone_registry.register("pathology/lunit_vits8")
48
48
  def lunit_vits8(
49
49
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
50
50
  ) -> nn.Module:
@@ -8,10 +8,10 @@ from torch import nn
8
8
 
9
9
  from eva.vision.models import wrappers
10
10
  from eva.vision.models.networks.backbones import _utils
11
- from eva.vision.models.networks.backbones.registry import register_model
11
+ from eva.vision.models.networks.backbones.registry import backbone_registry
12
12
 
13
13
 
14
- @register_model("pathology/mahmood_uni")
14
+ @backbone_registry.register("pathology/mahmood_uni")
15
15
  def mahmood_uni(
16
16
  dynamic_img_size: bool = True,
17
17
  out_indices: int | Tuple[int, ...] | None = None,
@@ -41,7 +41,7 @@ def mahmood_uni(
41
41
  )
42
42
 
43
43
 
44
- @register_model("pathology/mahmood_uni2_h")
44
+ @backbone_registry.register("pathology/mahmood_uni2_h")
45
45
  def mahmood_uni2_h(
46
46
  dynamic_img_size: bool = True,
47
47
  out_indices: int | Tuple[int, ...] | None = None,
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  from torch import nn
6
6
 
7
7
  from eva.vision.models.networks.backbones import _utils
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("pathology/owkin_phikon")
11
+ @backbone_registry.register("pathology/owkin_phikon")
12
12
  def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
13
  """Initializes the phikon pathology FM by owkin (https://huggingface.co/owkin/phikon).
14
14
 
@@ -22,7 +22,7 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
22
22
  return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
23
23
 
24
24
 
25
- @register_model("pathology/owkin_phikon_v2")
25
+ @backbone_registry.register("pathology/owkin_phikon_v2")
26
26
  def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
27
27
  """Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
28
28
 
@@ -11,10 +11,10 @@ import torch.nn as nn
11
11
  from eva.core.models import transforms
12
12
  from eva.vision.models import wrappers
13
13
  from eva.vision.models.networks.backbones import _utils
14
- from eva.vision.models.networks.backbones.registry import register_model
14
+ from eva.vision.models.networks.backbones.registry import backbone_registry
15
15
 
16
16
 
17
- @register_model("pathology/paige_virchow2")
17
+ @backbone_registry.register("pathology/paige_virchow2")
18
18
  def paige_virchow2(
19
19
  dynamic_img_size: bool = True,
20
20
  out_indices: int | Tuple[int, ...] | None = None,
@@ -43,7 +43,7 @@ def paige_virchow2(
43
43
  "mlp_layer": timm.layers.SwiGLUPacked,
44
44
  "act_layer": nn.SiLU,
45
45
  },
46
- tensor_transforms=(
46
+ transforms=(
47
47
  transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
48
48
  if out_indices is None
49
49
  else None
@@ -9,10 +9,10 @@ from monai.networks.nets import swin_unetr
9
9
  from monai.utils import misc
10
10
  from torch import nn
11
11
 
12
- from eva.vision.models.networks.backbones.registry import register_model
12
+ from eva.vision.models.networks.backbones.registry import backbone_registry
13
13
 
14
14
 
15
- @register_model("radiology/swin_unetr_encoder")
15
+ @backbone_registry.register("radiology/swin_unetr_encoder")
16
16
  class SwinUNETREncoder(nn.Module):
17
17
  """Swin transformer encoder based on UNETR [0].
18
18
 
@@ -4,10 +4,10 @@ from typing_extensions import override
4
4
 
5
5
  from eva.core.models.wrappers import _utils
6
6
  from eva.vision.models.networks.backbones.radiology import swin_unetr
7
- from eva.vision.models.networks.backbones.registry import register_model
7
+ from eva.vision.models.networks.backbones.registry import backbone_registry
8
8
 
9
9
 
10
- class _VoCo(swin_unetr.SwinUNETREncoder):
10
+ class _VoCo(swin_unetr.SwinUNETREncoder): # type: ignore
11
11
  """Base class for the VoCo self-supervised encoders."""
12
12
 
13
13
  _checkpoint: str
@@ -39,7 +39,7 @@ class _VoCo(swin_unetr.SwinUNETREncoder):
39
39
  self.load_state_dict(state_dict)
40
40
 
41
41
 
42
- @register_model("radiology/voco_b")
42
+ @backbone_registry.register("radiology/voco_b")
43
43
  class VoCoB(_VoCo):
44
44
  """VoCo Self-supervised pre-trained B model."""
45
45
 
@@ -51,7 +51,7 @@ class VoCoB(_VoCo):
51
51
  super().__init__(feature_size=48, out_indices=out_indices)
52
52
 
53
53
 
54
- @register_model("radiology/voco_l")
54
+ @backbone_registry.register("radiology/voco_l")
55
55
  class VoCoL(_VoCo):
56
56
  """VoCo Self-supervised pre-trained L model."""
57
57
 
@@ -63,7 +63,7 @@ class VoCoL(_VoCo):
63
63
  super().__init__(feature_size=96, out_indices=out_indices)
64
64
 
65
65
 
66
- @register_model("radiology/voco_h")
66
+ @backbone_registry.register("radiology/voco_h")
67
67
  class VoCoH(_VoCo):
68
68
  """VoCo Self-supervised pre-trained H model."""
69
69
 
@@ -1,47 +1,5 @@
1
1
  """Backbone Model Registry."""
2
2
 
3
- from typing import Any, Callable, Dict, List
3
+ from eva.core.utils.registry import Registry
4
4
 
5
- import torch.nn as nn
6
-
7
-
8
- class BackboneModelRegistry:
9
- """A model registry for accessing backbone models by name."""
10
-
11
- _registry: Dict[str, Callable[..., nn.Module]] = {}
12
-
13
- @classmethod
14
- def register(cls, name: str) -> Callable:
15
- """Decorator to register a new model."""
16
-
17
- def decorator(model_fn: Callable[..., nn.Module]) -> Callable[..., nn.Module]:
18
- if name in cls._registry:
19
- raise ValueError(f"Model {name} is already registered.")
20
- cls._registry[name] = model_fn
21
- return model_fn
22
-
23
- return decorator
24
-
25
- @classmethod
26
- def get(cls, model_name: str) -> Callable[..., nn.Module]:
27
- """Gets a model function from the registry."""
28
- if model_name not in cls._registry:
29
- raise ValueError(f"Model {model_name} not found in the registry.")
30
- return cls._registry[model_name]
31
-
32
- @classmethod
33
- def load_model(cls, model_name: str, model_kwargs: Dict[str, Any] | None = None) -> nn.Module:
34
- """Loads & initializes a model class from the registry."""
35
- model_fn = cls.get(model_name)
36
- return model_fn(**(model_kwargs or {}))
37
-
38
- @classmethod
39
- def list_models(cls) -> List[str]:
40
- """List all models in the registry."""
41
- register_models = [name for name in cls._registry.keys() if not name.startswith("timm")]
42
- return register_models + ["timm/<model_name>"]
43
-
44
-
45
- def register_model(name: str) -> Callable:
46
- """Simple decorator to register a model."""
47
- return BackboneModelRegistry.register(name)
5
+ backbone_registry = Registry()
@@ -8,7 +8,7 @@ from loguru import logger
8
8
  from torch import nn
9
9
 
10
10
  from eva.vision.models import wrappers
11
- from eva.vision.models.networks.backbones.registry import BackboneModelRegistry
11
+ from eva.vision.models.networks.backbones.registry import backbone_registry
12
12
 
13
13
 
14
14
  def timm_model(
@@ -46,7 +46,7 @@ def timm_model(
46
46
  )
47
47
 
48
48
 
49
- BackboneModelRegistry._registry.update(
49
+ backbone_registry._registry.update(
50
50
  {
51
51
  f"timm/{model_name}": functools.partial(timm_model, model_name=model_name)
52
52
  for model_name in timm.list_models()
@@ -1,8 +1,15 @@
1
1
  """Universal Vision Model Backbones API."""
2
2
 
3
3
  from eva.vision.models.networks.backbones.universal.vit import (
4
+ vit_base_patch16_224_dino_1chan,
4
5
  vit_small_patch16_224_dino,
6
+ vit_small_patch16_224_dino_1chan,
5
7
  vit_small_patch16_224_random,
6
8
  )
7
9
 
8
- __all__ = ["vit_small_patch16_224_dino", "vit_small_patch16_224_random"]
10
+ __all__ = [
11
+ "vit_small_patch16_224_dino",
12
+ "vit_small_patch16_224_random",
13
+ "vit_small_patch16_224_dino_1chan",
14
+ "vit_base_patch16_224_dino_1chan",
15
+ ]
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("universal/vit_small_patch16_224_random")
11
+ @backbone_registry.register("universal/vit_small_patch16_224_random")
12
12
  def vit_small_patch16_224_random(
13
13
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
14
14
  ) -> nn.Module:
@@ -31,7 +31,7 @@ def vit_small_patch16_224_random(
31
31
  )
32
32
 
33
33
 
34
- @register_model("universal/vit_small_patch16_224_dino")
34
+ @backbone_registry.register("universal/vit_small_patch16_224_dino")
35
35
  def vit_small_patch16_224_dino(
36
36
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
37
37
  ) -> nn.Module:
@@ -52,3 +52,53 @@ def vit_small_patch16_224_dino(
52
52
  out_indices=out_indices,
53
53
  dynamic_img_size=dynamic_img_size,
54
54
  )
55
+
56
+
57
+ @backbone_registry.register("universal/vit_small_patch16_224_dino_1chan")
58
+ def vit_small_patch16_224_dino_1chan(
59
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
60
+ ) -> nn.Module:
61
+ """Initializes a ViTS-16 baseline model pretrained w/ DINO for single-channel images.
62
+
63
+ Args:
64
+ dynamic_img_size: Support different input image sizes by allowing to change
65
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
66
+ out_indices: Whether and which multi-level patch embeddings to return.
67
+
68
+ Returns:
69
+ The torch ViTS-16 based foundation model.
70
+ """
71
+ return timm.create_model(
72
+ model_name="vit_small_patch16_224.dino",
73
+ in_chans=1,
74
+ num_classes=0,
75
+ pretrained=True,
76
+ features_only=out_indices is not None,
77
+ out_indices=out_indices,
78
+ dynamic_img_size=dynamic_img_size,
79
+ )
80
+
81
+
82
+ @backbone_registry.register("universal/vit_base_patch16_224_dino_1chan")
83
+ def vit_base_patch16_224_dino_1chan(
84
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
85
+ ) -> nn.Module:
86
+ """Initializes a ViTB-16 baseline model pretrained w/ DINO for single-channel images.
87
+
88
+ Args:
89
+ dynamic_img_size: Support different input image sizes by allowing to change
90
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
91
+ out_indices: Whether and which multi-level patch embeddings to return.
92
+
93
+ Returns:
94
+ The torch ViTB-16 based foundation model.
95
+ """
96
+ return timm.create_model(
97
+ model_name="vit_base_patch16_224.dino",
98
+ in_chans=1,
99
+ num_classes=0,
100
+ pretrained=True,
101
+ features_only=out_indices is not None,
102
+ out_indices=out_indices,
103
+ dynamic_img_size=dynamic_img_size,
104
+ )
@@ -115,4 +115,4 @@ class Decoder2D(base.Decoder):
115
115
  if self._combine_features:
116
116
  features = self._forward_features(features)
117
117
  logits = self._forward_head(features)
118
- return self._upscale(logits, image_size)
118
+ return self._upscale(logits, image_size) # type: ignore
@@ -117,4 +117,4 @@ class LinearDecoder(base.Decoder):
117
117
  """
118
118
  patch_embeddings = self._forward_features(decoder_inputs.features)
119
119
  logits = self._forward_head(patch_embeddings)
120
- return self._cls_seg(logits, decoder_inputs.image_size)
120
+ return self._cls_seg(logits, decoder_inputs.image_size) # type: ignore
@@ -1,7 +1,7 @@
1
1
  """Common semantic segmentation decoders.
2
2
 
3
- This module contains implementations of different types of decoder models
4
- used in semantic segmentation. These decoders convert the high-level features
3
+ This module contains implementations of different types of decoder models
4
+ used in semantic segmentation. These decoders convert the high-level features
5
5
  output by an encoder into pixel-wise predictions for segmentation tasks.
6
6
  """
7
7
 
@@ -11,7 +11,7 @@ class DecoderInputs(NamedTuple):
11
11
  features: List[torch.Tensor]
12
12
  """List of image features generated by the encoder from the original images."""
13
13
 
14
- image_size: Tuple[int, int]
14
+ image_size: Tuple[int, ...]
15
15
  """Size of the original input images to be used for upsampling."""
16
16
 
17
17
  images: torch.Tensor | None = None