kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (90) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/functional.py +40 -43
  17. eva/core/utils/factory.py +66 -0
  18. eva/core/utils/registry.py +42 -0
  19. eva/core/utils/requirements.py +26 -0
  20. eva/language/__init__.py +13 -0
  21. eva/language/data/__init__.py +5 -0
  22. eva/language/data/datasets/__init__.py +9 -0
  23. eva/language/data/datasets/classification/__init__.py +7 -0
  24. eva/language/data/datasets/classification/base.py +63 -0
  25. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  26. eva/language/data/datasets/language.py +13 -0
  27. eva/language/models/__init__.py +25 -0
  28. eva/language/models/modules/__init__.py +5 -0
  29. eva/language/models/modules/text.py +85 -0
  30. eva/language/models/modules/typings.py +16 -0
  31. eva/language/models/wrappers/__init__.py +11 -0
  32. eva/language/models/wrappers/huggingface.py +69 -0
  33. eva/language/models/wrappers/litellm.py +77 -0
  34. eva/language/models/wrappers/vllm.py +149 -0
  35. eva/language/utils/__init__.py +5 -0
  36. eva/language/utils/str_to_int_tensor.py +95 -0
  37. eva/vision/data/dataloaders/__init__.py +2 -1
  38. eva/vision/data/dataloaders/worker_init.py +35 -0
  39. eva/vision/data/datasets/__init__.py +5 -5
  40. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  41. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  42. eva/vision/data/datasets/segmentation/consep.py +5 -4
  43. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  44. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  45. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  46. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  47. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  48. eva/vision/data/transforms/__init__.py +11 -2
  49. eva/vision/data/transforms/base/__init__.py +5 -0
  50. eva/vision/data/transforms/base/monai.py +27 -0
  51. eva/vision/data/transforms/common/__init__.py +2 -1
  52. eva/vision/data/transforms/common/squeeze.py +24 -0
  53. eva/vision/data/transforms/croppad/__init__.py +4 -0
  54. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  56. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  57. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  58. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  59. eva/vision/models/modules/semantic_segmentation.py +18 -7
  60. eva/vision/models/networks/backbones/__init__.py +2 -3
  61. eva/vision/models/networks/backbones/_utils.py +1 -1
  62. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  63. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  64. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  65. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  66. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  67. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  68. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  72. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  73. eva/vision/models/networks/backbones/registry.py +2 -44
  74. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  75. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  76. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  77. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  78. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  80. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  81. eva/vision/models/wrappers/from_registry.py +14 -9
  82. eva/vision/models/wrappers/from_timm.py +6 -5
  83. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
  84. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
  85. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
  86. eva/vision/data/datasets/segmentation/lits.py +0 -199
  87. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  88. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  89. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
  90. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- """"Neural Network Semantic Segmentation Module."""
1
+ """Neural Network Semantic Segmentation Module."""
2
2
 
3
3
  import functools
4
4
  from typing import Any, Callable, Dict, Iterable, List, Tuple
@@ -12,7 +12,7 @@ from torch.optim import lr_scheduler
12
12
  from typing_extensions import override
13
13
 
14
14
  from eva.core.metrics import structs as metrics_lib
15
- from eva.core.models.modules import module
15
+ from eva.core.models.modules import SchedulerConfiguration, module
16
16
  from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
17
17
  from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
18
18
  from eva.core.utils import parser
@@ -32,7 +32,7 @@ class SemanticSegmentationModule(module.ModelModule):
32
32
  lr_multiplier_encoder: float = 0.0,
33
33
  inferer: Inferer | None = None,
34
34
  optimizer: OptimizerCallable = optim.AdamW,
35
- lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
35
+ lr_scheduler: LRSchedulerCallable | SchedulerConfiguration = lr_scheduler.ConstantLR,
36
36
  metrics: metrics_lib.MetricsSchema | None = None,
37
37
  postprocess: batch_postprocess.BatchPostProcess | None = None,
38
38
  save_decoder_only: bool = True,
@@ -116,7 +116,7 @@ class SemanticSegmentationModule(module.ModelModule):
116
116
  def forward(
117
117
  self,
118
118
  tensor: torch.Tensor,
119
- to_size: Tuple[int, int],
119
+ to_size: Tuple[int, ...],
120
120
  *args: Any,
121
121
  **kwargs: Any,
122
122
  ) -> torch.Tensor:
@@ -174,7 +174,8 @@ class SemanticSegmentationModule(module.ModelModule):
174
174
  The batch step output.
175
175
  """
176
176
  data, targets, metadata = INPUT_TENSOR_BATCH(*batch)
177
- predictions = self(data, to_size=targets.shape[-self.spatial_dims :])
177
+ to_size = targets.shape[-self.spatial_dims :] if self.inferer is None else None
178
+ predictions = self(data, to_size=to_size)
178
179
  loss = self.criterion(predictions, targets)
179
180
  return {
180
181
  "loss": loss,
@@ -183,11 +184,21 @@ class SemanticSegmentationModule(module.ModelModule):
183
184
  "metadata": metadata,
184
185
  }
185
186
 
186
- def _forward_networks(self, tensor: torch.Tensor, to_size: Tuple[int, int]) -> torch.Tensor:
187
+ def _forward_networks(
188
+ self, tensor: torch.Tensor, to_size: Tuple[int, ...] | None = None
189
+ ) -> torch.Tensor:
187
190
  """Passes the input tensor through the encoder and decoder."""
188
- features = self.encoder(tensor) if self.encoder else tensor
191
+ if self.encoder:
192
+ to_size = to_size or tuple(tensor.shape[-self.spatial_dims :])
193
+ features = self.encoder(tensor)
194
+ else:
195
+ if to_size is None:
196
+ raise ValueError("`to_size` must be provided when no encoder is used.")
197
+ features = tensor
198
+
189
199
  if isinstance(self.decoder, segmentation.Decoder):
190
200
  if not isinstance(features, list):
191
201
  raise ValueError(f"Expected a list of feature map tensors, got {type(features)}.")
192
202
  return self.decoder(DecoderInputs(features, to_size, tensor))
203
+
193
204
  return self.decoder(features)
@@ -1,13 +1,12 @@
1
1
  """Vision Model Backbones API."""
2
2
 
3
3
  from eva.vision.models.networks.backbones import pathology, radiology, timm, universal
4
- from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model
4
+ from eva.vision.models.networks.backbones.registry import backbone_registry
5
5
 
6
6
  __all__ = [
7
7
  "radiology",
8
8
  "pathology",
9
9
  "timm",
10
10
  "universal",
11
- "BackboneModelRegistry",
12
- "register_model",
11
+ "backbone_registry",
13
12
  ]
@@ -36,7 +36,7 @@ def load_hugingface_model(
36
36
 
37
37
  return models.HuggingFaceModel(
38
38
  model_name_or_path=model_name,
39
- tensor_transforms=tensor_transforms,
39
+ transforms=tensor_transforms,
40
40
  model_kwargs=model_kwargs,
41
41
  )
42
42
 
@@ -8,10 +8,10 @@ from torch import nn
8
8
  from eva.core.models import transforms
9
9
  from eva.vision.models import wrappers
10
10
  from eva.vision.models.networks.backbones import _utils
11
- from eva.vision.models.networks.backbones.registry import register_model
11
+ from eva.vision.models.networks.backbones.registry import backbone_registry
12
12
 
13
13
 
14
- @register_model("pathology/bioptimus_h_optimus_0")
14
+ @backbone_registry.register("pathology/bioptimus_h_optimus_0")
15
15
  def bioptimus_h_optimus_0(
16
16
  dynamic_img_size: bool = True,
17
17
  out_indices: int | Tuple[int, ...] | None = None,
@@ -39,7 +39,7 @@ def bioptimus_h_optimus_0(
39
39
  )
40
40
 
41
41
 
42
- @register_model("pathology/bioptimus_h0_mini")
42
+ @backbone_registry.register("pathology/bioptimus_h0_mini")
43
43
  def bioptimus_h0_mini(
44
44
  dynamic_img_size: bool = True,
45
45
  out_indices: int | Tuple[int, ...] | None = None,
@@ -72,7 +72,7 @@ def bioptimus_h0_mini(
72
72
  "mlp_layer": timm.layers.SwiGLUPacked,
73
73
  "act_layer": nn.SiLU,
74
74
  },
75
- tensor_transforms=(
75
+ transforms=(
76
76
  transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
77
77
  if out_indices is None
78
78
  else None
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("pathology/prov_gigapath")
11
+ @backbone_registry.register("pathology/prov_gigapath")
12
12
  def prov_gigapath(
13
13
  dynamic_img_size: bool = True,
14
14
  out_indices: int | Tuple[int, ...] | None = None,
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  from torch import nn
6
6
 
7
7
  from eva.vision.models.networks.backbones import _utils
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("pathology/histai_hibou_b")
11
+ @backbone_registry.register("pathology/histai_hibou_b")
12
12
  def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
13
  """Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
14
14
 
@@ -30,7 +30,7 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
30
30
  )
31
31
 
32
32
 
33
- @register_model("pathology/histai_hibou_l")
33
+ @backbone_registry.register("pathology/histai_hibou_l")
34
34
  def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
35
35
  """Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
36
36
 
@@ -7,10 +7,10 @@ import timm
7
7
  from torch import nn
8
8
 
9
9
  from eva.core.models.wrappers import _utils
10
- from eva.vision.models.networks.backbones.registry import register_model
10
+ from eva.vision.models.networks.backbones.registry import backbone_registry
11
11
 
12
12
 
13
- @register_model("pathology/hkust_gpfm")
13
+ @backbone_registry.register("pathology/hkust_gpfm")
14
14
  def hkust_gpfm(
15
15
  dynamic_img_size: bool = True,
16
16
  out_indices: int | Tuple[int, ...] | None = None,
@@ -6,10 +6,10 @@ import torch
6
6
  from torch import nn
7
7
 
8
8
  from eva.vision.models.networks.backbones import _utils
9
- from eva.vision.models.networks.backbones.registry import register_model
9
+ from eva.vision.models.networks.backbones.registry import backbone_registry
10
10
 
11
11
 
12
- @register_model("pathology/kaiko_midnight_12k")
12
+ @backbone_registry.register("pathology/kaiko_midnight_12k")
13
13
  def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
14
14
  """Initializes the Midnight-12k pathology FM by kaiko.ai.
15
15
 
@@ -26,7 +26,7 @@ def kaiko_midnight_12k(out_indices: int | Tuple[int, ...] | None = None) -> nn.M
26
26
  )
27
27
 
28
28
 
29
- @register_model("pathology/kaiko_vits16")
29
+ @backbone_registry.register("pathology/kaiko_vits16")
30
30
  def kaiko_vits16(
31
31
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
32
32
  ) -> nn.Module:
@@ -49,7 +49,7 @@ def kaiko_vits16(
49
49
  )
50
50
 
51
51
 
52
- @register_model("pathology/kaiko_vits8")
52
+ @backbone_registry.register("pathology/kaiko_vits8")
53
53
  def kaiko_vits8(
54
54
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
55
55
  ) -> nn.Module:
@@ -72,7 +72,7 @@ def kaiko_vits8(
72
72
  )
73
73
 
74
74
 
75
- @register_model("pathology/kaiko_vitb16")
75
+ @backbone_registry.register("pathology/kaiko_vitb16")
76
76
  def kaiko_vitb16(
77
77
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
78
78
  ) -> nn.Module:
@@ -95,7 +95,7 @@ def kaiko_vitb16(
95
95
  )
96
96
 
97
97
 
98
- @register_model("pathology/kaiko_vitb8")
98
+ @backbone_registry.register("pathology/kaiko_vitb8")
99
99
  def kaiko_vitb8(
100
100
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
101
101
  ) -> nn.Module:
@@ -118,7 +118,7 @@ def kaiko_vitb8(
118
118
  )
119
119
 
120
120
 
121
- @register_model("pathology/kaiko_vitl14")
121
+ @backbone_registry.register("pathology/kaiko_vitl14")
122
122
  def kaiko_vitl14(
123
123
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
124
124
  ) -> nn.Module:
@@ -13,14 +13,14 @@ from typing import Tuple
13
13
  from torch import nn
14
14
 
15
15
  from eva.vision.models import wrappers
16
- from eva.vision.models.networks.backbones.registry import register_model
16
+ from eva.vision.models.networks.backbones.registry import backbone_registry
17
17
 
18
18
  VITS_URL_PREFIX = (
19
19
  "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
20
20
  )
21
21
 
22
22
 
23
- @register_model("pathology/lunit_vits16")
23
+ @backbone_registry.register("pathology/lunit_vits16")
24
24
  def lunit_vits16(
25
25
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
26
26
  ) -> nn.Module:
@@ -44,7 +44,7 @@ def lunit_vits16(
44
44
  )
45
45
 
46
46
 
47
- @register_model("pathology/lunit_vits8")
47
+ @backbone_registry.register("pathology/lunit_vits8")
48
48
  def lunit_vits8(
49
49
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
50
50
  ) -> nn.Module:
@@ -8,10 +8,10 @@ from torch import nn
8
8
 
9
9
  from eva.vision.models import wrappers
10
10
  from eva.vision.models.networks.backbones import _utils
11
- from eva.vision.models.networks.backbones.registry import register_model
11
+ from eva.vision.models.networks.backbones.registry import backbone_registry
12
12
 
13
13
 
14
- @register_model("pathology/mahmood_uni")
14
+ @backbone_registry.register("pathology/mahmood_uni")
15
15
  def mahmood_uni(
16
16
  dynamic_img_size: bool = True,
17
17
  out_indices: int | Tuple[int, ...] | None = None,
@@ -41,7 +41,7 @@ def mahmood_uni(
41
41
  )
42
42
 
43
43
 
44
- @register_model("pathology/mahmood_uni2_h")
44
+ @backbone_registry.register("pathology/mahmood_uni2_h")
45
45
  def mahmood_uni2_h(
46
46
  dynamic_img_size: bool = True,
47
47
  out_indices: int | Tuple[int, ...] | None = None,
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  from torch import nn
6
6
 
7
7
  from eva.vision.models.networks.backbones import _utils
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("pathology/owkin_phikon")
11
+ @backbone_registry.register("pathology/owkin_phikon")
12
12
  def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
13
13
  """Initializes the phikon pathology FM by owkin (https://huggingface.co/owkin/phikon).
14
14
 
@@ -22,7 +22,7 @@ def owkin_phikon(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
22
22
  return _utils.load_hugingface_model(model_name="owkin/phikon", out_indices=out_indices)
23
23
 
24
24
 
25
- @register_model("pathology/owkin_phikon_v2")
25
+ @backbone_registry.register("pathology/owkin_phikon_v2")
26
26
  def owkin_phikon_v2(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
27
27
  """Initializes the phikon-v2 pathology FM by owkin (https://huggingface.co/owkin/phikon-v2).
28
28
 
@@ -11,10 +11,10 @@ import torch.nn as nn
11
11
  from eva.core.models import transforms
12
12
  from eva.vision.models import wrappers
13
13
  from eva.vision.models.networks.backbones import _utils
14
- from eva.vision.models.networks.backbones.registry import register_model
14
+ from eva.vision.models.networks.backbones.registry import backbone_registry
15
15
 
16
16
 
17
- @register_model("pathology/paige_virchow2")
17
+ @backbone_registry.register("pathology/paige_virchow2")
18
18
  def paige_virchow2(
19
19
  dynamic_img_size: bool = True,
20
20
  out_indices: int | Tuple[int, ...] | None = None,
@@ -43,7 +43,7 @@ def paige_virchow2(
43
43
  "mlp_layer": timm.layers.SwiGLUPacked,
44
44
  "act_layer": nn.SiLU,
45
45
  },
46
- tensor_transforms=(
46
+ transforms=(
47
47
  transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
48
48
  if out_indices is None
49
49
  else None
@@ -9,10 +9,10 @@ from monai.networks.nets import swin_unetr
9
9
  from monai.utils import misc
10
10
  from torch import nn
11
11
 
12
- from eva.vision.models.networks.backbones.registry import register_model
12
+ from eva.vision.models.networks.backbones.registry import backbone_registry
13
13
 
14
14
 
15
- @register_model("radiology/swin_unetr_encoder")
15
+ @backbone_registry.register("radiology/swin_unetr_encoder")
16
16
  class SwinUNETREncoder(nn.Module):
17
17
  """Swin transformer encoder based on UNETR [0].
18
18
 
@@ -4,10 +4,10 @@ from typing_extensions import override
4
4
 
5
5
  from eva.core.models.wrappers import _utils
6
6
  from eva.vision.models.networks.backbones.radiology import swin_unetr
7
- from eva.vision.models.networks.backbones.registry import register_model
7
+ from eva.vision.models.networks.backbones.registry import backbone_registry
8
8
 
9
9
 
10
- class _VoCo(swin_unetr.SwinUNETREncoder):
10
+ class _VoCo(swin_unetr.SwinUNETREncoder): # type: ignore
11
11
  """Base class for the VoCo self-supervised encoders."""
12
12
 
13
13
  _checkpoint: str
@@ -39,7 +39,7 @@ class _VoCo(swin_unetr.SwinUNETREncoder):
39
39
  self.load_state_dict(state_dict)
40
40
 
41
41
 
42
- @register_model("radiology/voco_b")
42
+ @backbone_registry.register("radiology/voco_b")
43
43
  class VoCoB(_VoCo):
44
44
  """VoCo Self-supervised pre-trained B model."""
45
45
 
@@ -51,7 +51,7 @@ class VoCoB(_VoCo):
51
51
  super().__init__(feature_size=48, out_indices=out_indices)
52
52
 
53
53
 
54
- @register_model("radiology/voco_l")
54
+ @backbone_registry.register("radiology/voco_l")
55
55
  class VoCoL(_VoCo):
56
56
  """VoCo Self-supervised pre-trained L model."""
57
57
 
@@ -63,7 +63,7 @@ class VoCoL(_VoCo):
63
63
  super().__init__(feature_size=96, out_indices=out_indices)
64
64
 
65
65
 
66
- @register_model("radiology/voco_h")
66
+ @backbone_registry.register("radiology/voco_h")
67
67
  class VoCoH(_VoCo):
68
68
  """VoCo Self-supervised pre-trained H model."""
69
69
 
@@ -1,47 +1,5 @@
1
1
  """Backbone Model Registry."""
2
2
 
3
- from typing import Any, Callable, Dict, List
3
+ from eva.core.utils.registry import Registry
4
4
 
5
- import torch.nn as nn
6
-
7
-
8
- class BackboneModelRegistry:
9
- """A model registry for accessing backbone models by name."""
10
-
11
- _registry: Dict[str, Callable[..., nn.Module]] = {}
12
-
13
- @classmethod
14
- def register(cls, name: str) -> Callable:
15
- """Decorator to register a new model."""
16
-
17
- def decorator(model_fn: Callable[..., nn.Module]) -> Callable[..., nn.Module]:
18
- if name in cls._registry:
19
- raise ValueError(f"Model {name} is already registered.")
20
- cls._registry[name] = model_fn
21
- return model_fn
22
-
23
- return decorator
24
-
25
- @classmethod
26
- def get(cls, model_name: str) -> Callable[..., nn.Module]:
27
- """Gets a model function from the registry."""
28
- if model_name not in cls._registry:
29
- raise ValueError(f"Model {model_name} not found in the registry.")
30
- return cls._registry[model_name]
31
-
32
- @classmethod
33
- def load_model(cls, model_name: str, model_kwargs: Dict[str, Any] | None = None) -> nn.Module:
34
- """Loads & initializes a model class from the registry."""
35
- model_fn = cls.get(model_name)
36
- return model_fn(**(model_kwargs or {}))
37
-
38
- @classmethod
39
- def list_models(cls) -> List[str]:
40
- """List all models in the registry."""
41
- register_models = [name for name in cls._registry.keys() if not name.startswith("timm")]
42
- return register_models + ["timm/<model_name>"]
43
-
44
-
45
- def register_model(name: str) -> Callable:
46
- """Simple decorator to register a model."""
47
- return BackboneModelRegistry.register(name)
5
+ backbone_registry = Registry()
@@ -8,7 +8,7 @@ from loguru import logger
8
8
  from torch import nn
9
9
 
10
10
  from eva.vision.models import wrappers
11
- from eva.vision.models.networks.backbones.registry import BackboneModelRegistry
11
+ from eva.vision.models.networks.backbones.registry import backbone_registry
12
12
 
13
13
 
14
14
  def timm_model(
@@ -46,7 +46,7 @@ def timm_model(
46
46
  )
47
47
 
48
48
 
49
- BackboneModelRegistry._registry.update(
49
+ backbone_registry._registry.update(
50
50
  {
51
51
  f"timm/{model_name}": functools.partial(timm_model, model_name=model_name)
52
52
  for model_name in timm.list_models()
@@ -1,8 +1,15 @@
1
1
  """Universal Vision Model Backbones API."""
2
2
 
3
3
  from eva.vision.models.networks.backbones.universal.vit import (
4
+ vit_base_patch16_224_dino_1chan,
4
5
  vit_small_patch16_224_dino,
6
+ vit_small_patch16_224_dino_1chan,
5
7
  vit_small_patch16_224_random,
6
8
  )
7
9
 
8
- __all__ = ["vit_small_patch16_224_dino", "vit_small_patch16_224_random"]
10
+ __all__ = [
11
+ "vit_small_patch16_224_dino",
12
+ "vit_small_patch16_224_random",
13
+ "vit_small_patch16_224_dino_1chan",
14
+ "vit_base_patch16_224_dino_1chan",
15
+ ]
@@ -5,10 +5,10 @@ from typing import Tuple
5
5
  import timm
6
6
  from torch import nn
7
7
 
8
- from eva.vision.models.networks.backbones.registry import register_model
8
+ from eva.vision.models.networks.backbones.registry import backbone_registry
9
9
 
10
10
 
11
- @register_model("universal/vit_small_patch16_224_random")
11
+ @backbone_registry.register("universal/vit_small_patch16_224_random")
12
12
  def vit_small_patch16_224_random(
13
13
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
14
14
  ) -> nn.Module:
@@ -31,7 +31,7 @@ def vit_small_patch16_224_random(
31
31
  )
32
32
 
33
33
 
34
- @register_model("universal/vit_small_patch16_224_dino")
34
+ @backbone_registry.register("universal/vit_small_patch16_224_dino")
35
35
  def vit_small_patch16_224_dino(
36
36
  dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
37
37
  ) -> nn.Module:
@@ -52,3 +52,53 @@ def vit_small_patch16_224_dino(
52
52
  out_indices=out_indices,
53
53
  dynamic_img_size=dynamic_img_size,
54
54
  )
55
+
56
+
57
+ @backbone_registry.register("universal/vit_small_patch16_224_dino_1chan")
58
+ def vit_small_patch16_224_dino_1chan(
59
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
60
+ ) -> nn.Module:
61
+ """Initializes a ViTS-16 baseline model pretrained w/ DINO for single-channel images.
62
+
63
+ Args:
64
+ dynamic_img_size: Support different input image sizes by allowing to change
65
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
66
+ out_indices: Whether and which multi-level patch embeddings to return.
67
+
68
+ Returns:
69
+ The torch ViTS-16 based foundation model.
70
+ """
71
+ return timm.create_model(
72
+ model_name="vit_small_patch16_224.dino",
73
+ in_chans=1,
74
+ num_classes=0,
75
+ pretrained=True,
76
+ features_only=out_indices is not None,
77
+ out_indices=out_indices,
78
+ dynamic_img_size=dynamic_img_size,
79
+ )
80
+
81
+
82
+ @backbone_registry.register("universal/vit_base_patch16_224_dino_1chan")
83
+ def vit_base_patch16_224_dino_1chan(
84
+ dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None
85
+ ) -> nn.Module:
86
+ """Initializes a ViTB-16 baseline model pretrained w/ DINO for single-channel images.
87
+
88
+ Args:
89
+ dynamic_img_size: Support different input image sizes by allowing to change
90
+ the grid size (interpolate abs and/or ROPE pos) in the forward pass.
91
+ out_indices: Whether and which multi-level patch embeddings to return.
92
+
93
+ Returns:
94
+ The torch ViTB-16 based foundation model.
95
+ """
96
+ return timm.create_model(
97
+ model_name="vit_base_patch16_224.dino",
98
+ in_chans=1,
99
+ num_classes=0,
100
+ pretrained=True,
101
+ features_only=out_indices is not None,
102
+ out_indices=out_indices,
103
+ dynamic_img_size=dynamic_img_size,
104
+ )
@@ -115,4 +115,4 @@ class Decoder2D(base.Decoder):
115
115
  if self._combine_features:
116
116
  features = self._forward_features(features)
117
117
  logits = self._forward_head(features)
118
- return self._upscale(logits, image_size)
118
+ return self._upscale(logits, image_size) # type: ignore
@@ -117,4 +117,4 @@ class LinearDecoder(base.Decoder):
117
117
  """
118
118
  patch_embeddings = self._forward_features(decoder_inputs.features)
119
119
  logits = self._forward_head(patch_embeddings)
120
- return self._cls_seg(logits, decoder_inputs.image_size)
120
+ return self._cls_seg(logits, decoder_inputs.image_size) # type: ignore
@@ -1,7 +1,7 @@
1
1
  """Common semantic segmentation decoders.
2
2
 
3
- This module contains implementations of different types of decoder models
4
- used in semantic segmentation. These decoders convert the high-level features
3
+ This module contains implementations of different types of decoder models
4
+ used in semantic segmentation. These decoders convert the high-level features
5
5
  output by an encoder into pixel-wise predictions for segmentation tasks.
6
6
  """
7
7
 
@@ -11,7 +11,7 @@ class DecoderInputs(NamedTuple):
11
11
  features: List[torch.Tensor]
12
12
  """List of image features generated by the encoder from the original images."""
13
13
 
14
- image_size: Tuple[int, int]
14
+ image_size: Tuple[int, ...]
15
15
  """Size of the original input images to be used for upsampling."""
16
16
 
17
17
  images: torch.Tensor | None = None
@@ -2,18 +2,20 @@
2
2
 
3
3
  from typing import Any, Callable, Dict
4
4
 
5
+ import torch
5
6
  from typing_extensions import override
6
7
 
7
- from eva.core.models import wrappers
8
- from eva.vision.models.networks.backbones import BackboneModelRegistry
8
+ from eva.core.models.wrappers import base
9
+ from eva.core.utils import factory
10
+ from eva.vision.models.networks.backbones import backbone_registry
9
11
 
10
12
 
11
- class ModelFromRegistry(wrappers.BaseModel):
13
+ class ModelFromRegistry(base.BaseModel[torch.Tensor, torch.Tensor]):
12
14
  """Wrapper class for vision backbone models.
13
15
 
14
16
  This class can be used by load backbones available in eva's
15
17
  model registry by name. New backbones can be registered by using
16
- the `@register_model(model_name)` decorator.
18
+ the `@backbone_registry.register(model_name)` decorator.
17
19
  """
18
20
 
19
21
  def __init__(
@@ -21,7 +23,7 @@ class ModelFromRegistry(wrappers.BaseModel):
21
23
  model_name: str,
22
24
  model_kwargs: Dict[str, Any] | None = None,
23
25
  model_extra_kwargs: Dict[str, Any] | None = None,
24
- tensor_transforms: Callable | None = None,
26
+ transforms: Callable | None = None,
25
27
  ) -> None:
26
28
  """Initializes the model.
27
29
 
@@ -29,10 +31,10 @@ class ModelFromRegistry(wrappers.BaseModel):
29
31
  model_name: The name of the model to load.
30
32
  model_kwargs: The arguments used for instantiating the model.
31
33
  model_extra_kwargs: Extra arguments used for instantiating the model.
32
- tensor_transforms: The transforms to apply to the output tensor
34
+ transforms: The transforms to apply to the output tensor
33
35
  produced by the model.
34
36
  """
35
- super().__init__(tensor_transforms=tensor_transforms)
37
+ super().__init__(transforms=transforms)
36
38
 
37
39
  self._model_name = model_name
38
40
  self._model_kwargs = model_kwargs or {}
@@ -42,7 +44,10 @@ class ModelFromRegistry(wrappers.BaseModel):
42
44
 
43
45
  @override
44
46
  def load_model(self) -> None:
45
- self._model = BackboneModelRegistry.load_model(
46
- self._model_name, self._model_kwargs | self._model_extra_kwargs
47
+ self._model = factory.ModuleFactory(
48
+ registry=backbone_registry,
49
+ name=self._model_name,
50
+ init_args=self._model_kwargs | self._model_extra_kwargs,
47
51
  )
52
+
48
53
  ModelFromRegistry.__name__ = self._model_name