kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/data/datasets/classification/embeddings.py +2 -2
  3. eva/core/data/datasets/classification/multi_embeddings.py +2 -2
  4. eva/core/data/datasets/embeddings.py +4 -4
  5. eva/core/data/samplers/classification/balanced.py +19 -18
  6. eva/core/loggers/utils/wandb.py +33 -0
  7. eva/core/models/modules/head.py +5 -3
  8. eva/core/models/modules/typings.py +2 -2
  9. eva/core/models/transforms/__init__.py +2 -1
  10. eva/core/models/transforms/as_discrete.py +57 -0
  11. eva/core/models/wrappers/_utils.py +121 -1
  12. eva/core/trainers/functional.py +8 -5
  13. eva/core/trainers/trainer.py +32 -17
  14. eva/core/utils/suppress_logs.py +28 -0
  15. eva/vision/data/__init__.py +2 -2
  16. eva/vision/data/dataloaders/__init__.py +5 -0
  17. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  18. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  19. eva/vision/data/datasets/__init__.py +10 -2
  20. eva/vision/data/datasets/classification/__init__.py +9 -0
  21. eva/vision/data/datasets/classification/bach.py +3 -4
  22. eva/vision/data/datasets/classification/bracs.py +111 -0
  23. eva/vision/data/datasets/classification/breakhis.py +209 -0
  24. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  25. eva/vision/data/datasets/classification/crc.py +3 -4
  26. eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
  27. eva/vision/data/datasets/classification/mhist.py +3 -4
  28. eva/vision/data/datasets/classification/panda.py +4 -5
  29. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  30. eva/vision/data/datasets/classification/unitopatho.py +158 -0
  31. eva/vision/data/datasets/classification/wsi.py +6 -5
  32. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  33. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  34. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  35. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  36. eva/vision/data/datasets/segmentation/consep.py +6 -7
  37. eva/vision/data/datasets/segmentation/embeddings.py +2 -2
  38. eva/vision/data/datasets/segmentation/lits.py +9 -8
  39. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  40. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  41. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  42. eva/vision/data/datasets/vision.py +95 -4
  43. eva/vision/data/datasets/wsi.py +5 -5
  44. eva/vision/data/transforms/__init__.py +22 -3
  45. eva/vision/data/transforms/common/__init__.py +1 -2
  46. eva/vision/data/transforms/croppad/__init__.py +11 -0
  47. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  48. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  49. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  50. eva/vision/data/transforms/intensity/__init__.py +11 -0
  51. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  52. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  53. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  54. eva/vision/data/transforms/spatial/__init__.py +7 -0
  55. eva/vision/data/transforms/spatial/flip.py +72 -0
  56. eva/vision/data/transforms/spatial/rotate.py +53 -0
  57. eva/vision/data/transforms/spatial/spacing.py +69 -0
  58. eva/vision/data/transforms/utility/__init__.py +5 -0
  59. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  60. eva/vision/data/tv_tensors/__init__.py +5 -0
  61. eva/vision/data/tv_tensors/volume.py +61 -0
  62. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  63. eva/vision/models/modules/semantic_segmentation.py +28 -20
  64. eva/vision/models/networks/backbones/__init__.py +9 -2
  65. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  66. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  67. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  68. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
  70. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  72. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  73. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  74. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  75. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  76. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  77. eva/vision/utils/io/__init__.py +2 -0
  78. eva/vision/utils/io/nifti.py +91 -11
  79. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  80. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
  81. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  82. eva/vision/data/datasets/classification/base.py +0 -96
  83. eva/vision/data/datasets/segmentation/base.py +0 -96
  84. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  85. eva/vision/data/transforms/normalization/__init__.py +0 -6
  86. eva/vision/data/transforms/normalization/clamp.py +0 -43
  87. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  88. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  89. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  90. eva/vision/metrics/segmentation/BUILD +0 -1
  91. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  92. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  93. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  94. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,96 +0,0 @@
1
- """Base for image classification datasets."""
2
-
3
- import abc
4
- from typing import Any, Callable, Dict, List, Tuple
5
-
6
- import torch
7
- from torchvision import tv_tensors
8
- from typing_extensions import override
9
-
10
- from eva.vision.data.datasets import vision
11
-
12
-
13
- class ImageClassification(vision.VisionDataset[Tuple[tv_tensors.Image, torch.Tensor]], abc.ABC):
14
- """Image classification abstract dataset."""
15
-
16
- def __init__(
17
- self,
18
- transforms: Callable | None = None,
19
- ) -> None:
20
- """Initializes the image classification dataset.
21
-
22
- Args:
23
- transforms: A function/transform which returns a transformed
24
- version of the raw data samples.
25
- """
26
- super().__init__()
27
-
28
- self._transforms = transforms
29
-
30
- @property
31
- def classes(self) -> List[str] | None:
32
- """Returns the list with names of the dataset names."""
33
-
34
- @property
35
- def class_to_idx(self) -> Dict[str, int] | None:
36
- """Returns a mapping of the class name to its target index."""
37
-
38
- def load_metadata(self, index: int) -> Dict[str, Any] | None:
39
- """Returns the dataset metadata.
40
-
41
- Args:
42
- index: The index of the data sample to return the metadata of.
43
-
44
- Returns:
45
- The sample metadata.
46
- """
47
-
48
- @abc.abstractmethod
49
- def load_image(self, index: int) -> tv_tensors.Image:
50
- """Returns the `index`'th image sample.
51
-
52
- Args:
53
- index: The index of the data sample to load.
54
-
55
- Returns:
56
- The image as a numpy array.
57
- """
58
-
59
- @abc.abstractmethod
60
- def load_target(self, index: int) -> torch.Tensor:
61
- """Returns the `index`'th target sample.
62
-
63
- Args:
64
- index: The index of the data sample to load.
65
-
66
- Returns:
67
- The sample target as an array.
68
- """
69
-
70
- @abc.abstractmethod
71
- @override
72
- def __len__(self) -> int:
73
- raise NotImplementedError
74
-
75
- @override
76
- def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
77
- image = self.load_image(index)
78
- target = self.load_target(index)
79
- image, target = self._apply_transforms(image, target)
80
- return image, target, self.load_metadata(index) or {}
81
-
82
- def _apply_transforms(
83
- self, image: tv_tensors.Image, target: torch.Tensor
84
- ) -> Tuple[tv_tensors.Image, torch.Tensor]:
85
- """Applies the transforms to the provided data and returns them.
86
-
87
- Args:
88
- image: The desired image.
89
- target: The target of the image.
90
-
91
- Returns:
92
- A tuple with the image and the target transformed.
93
- """
94
- if self._transforms is not None:
95
- image, target = self._transforms(image, target)
96
- return image, target
@@ -1,96 +0,0 @@
1
- """Base for image segmentation datasets."""
2
-
3
- import abc
4
- from typing import Any, Callable, Dict, List, Tuple
5
-
6
- from torchvision import tv_tensors
7
- from typing_extensions import override
8
-
9
- from eva.vision.data.datasets import vision
10
-
11
-
12
- class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
13
- """Image segmentation abstract dataset."""
14
-
15
- def __init__(self, transforms: Callable | None = None) -> None:
16
- """Initializes the image segmentation base class.
17
-
18
- Args:
19
- transforms: A function/transforms that takes in an
20
- image and a label and returns the transformed versions of both.
21
- """
22
- super().__init__()
23
-
24
- self._transforms = transforms
25
-
26
- @property
27
- def classes(self) -> List[str] | None:
28
- """Returns the list with names of the dataset names."""
29
-
30
- @property
31
- def class_to_idx(self) -> Dict[str, int] | None:
32
- """Returns a mapping of the class name to its target index."""
33
-
34
- @abc.abstractmethod
35
- def load_image(self, index: int) -> tv_tensors.Image:
36
- """Loads and returns the `index`'th image sample.
37
-
38
- Args:
39
- index: The index of the data sample to load.
40
-
41
- Returns:
42
- An image torchvision tensor (channels, height, width).
43
- """
44
-
45
- @abc.abstractmethod
46
- def load_mask(self, index: int) -> tv_tensors.Mask:
47
- """Returns the `index`'th target masks sample.
48
-
49
- Args:
50
- index: The index of the data sample target masks to load.
51
-
52
- Returns:
53
- The semantic mask as a (H x W) shaped tensor with integer
54
- values which represent the pixel class id.
55
- """
56
-
57
- def load_metadata(self, index: int) -> Dict[str, Any] | None:
58
- """Returns the dataset metadata.
59
-
60
- Args:
61
- index: The index of the data sample to return the metadata of.
62
- If `None`, it will return the metadata of the current dataset.
63
-
64
- Returns:
65
- The sample metadata.
66
- """
67
-
68
- @abc.abstractmethod
69
- @override
70
- def __len__(self) -> int:
71
- raise NotImplementedError
72
-
73
- @override
74
- def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
75
- image = self.load_image(index)
76
- mask = self.load_mask(index)
77
- metadata = self.load_metadata(index) or {}
78
- image_tensor, mask_tensor = self._apply_transforms(image, mask)
79
- return image_tensor, mask_tensor, metadata
80
-
81
- def _apply_transforms(
82
- self, image: tv_tensors.Image, mask: tv_tensors.Mask
83
- ) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
84
- """Applies the transforms to the provided data and returns them.
85
-
86
- Args:
87
- image: The desired image.
88
- mask: The target segmentation mask.
89
-
90
- Returns:
91
- A tuple with the image and the masks transformed.
92
- """
93
- if self._transforms is not None:
94
- image, mask = self._transforms(image, mask)
95
-
96
- return image, mask
@@ -1,51 +0,0 @@
1
- """Specialized transforms for resizing, clamping and range normalizing."""
2
-
3
- from typing import Callable, Sequence, Tuple
4
-
5
- from torchvision.transforms import v2
6
-
7
- from eva.vision.data.transforms import normalization
8
-
9
-
10
- class ResizeAndClamp(v2.Compose):
11
- """Resizes, crops, clamps and normalizes an input image."""
12
-
13
- def __init__(
14
- self,
15
- size: int | Sequence[int] = 224,
16
- clamp_range: Tuple[int, int] = (-1024, 1024),
17
- mean: Sequence[float] = (0.0, 0.0, 0.0),
18
- std: Sequence[float] = (1.0, 1.0, 1.0),
19
- ) -> None:
20
- """Initializes the transform object.
21
-
22
- Args:
23
- size: Desired output size of the crop. If size is an `int` instead
24
- of sequence like (h, w), a square crop (size, size) is made.
25
- clamp_range: The lower and upper bound to clamp the pixel values.
26
- mean: Sequence of means for each image channel.
27
- std: Sequence of standard deviations for each image channel.
28
- """
29
- self._size = size
30
- self._clamp_range = clamp_range
31
- self._mean = mean
32
- self._std = std
33
-
34
- super().__init__(transforms=self._build_transforms())
35
-
36
- def _build_transforms(self) -> Sequence[Callable]:
37
- """Builds and returns the list of transforms."""
38
- transforms = [
39
- v2.Resize(size=self._size),
40
- v2.CenterCrop(size=self._size),
41
- normalization.Clamp(out_range=self._clamp_range),
42
- normalization.RescaleIntensity(
43
- in_range=self._clamp_range,
44
- out_range=(0.0, 1.0),
45
- ),
46
- v2.Normalize(
47
- mean=self._mean,
48
- std=self._std,
49
- ),
50
- ]
51
- return transforms
@@ -1,6 +0,0 @@
1
- """Normalization related transformations."""
2
-
3
- from eva.vision.data.transforms.normalization.clamp import Clamp
4
- from eva.vision.data.transforms.normalization.rescale_intensity import RescaleIntensity
5
-
6
- __all__ = ["Clamp", "RescaleIntensity"]
@@ -1,43 +0,0 @@
1
- """Image clamp transform."""
2
-
3
- import functools
4
- from typing import Any, Dict, Tuple
5
-
6
- import torch
7
- import torchvision.transforms.v2 as torch_transforms
8
- from torchvision import tv_tensors
9
- from typing_extensions import override
10
-
11
-
12
- class Clamp(torch_transforms.Transform):
13
- """Clamps all elements in input into a specific range."""
14
-
15
- def __init__(self, out_range: Tuple[int, int]) -> None:
16
- """Initializes the transform.
17
-
18
- Args:
19
- out_range: The lower and upper bound of the range to
20
- be clamped to.
21
- """
22
- super().__init__()
23
-
24
- self._out_range = out_range
25
-
26
- @functools.singledispatchmethod
27
- @override
28
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
29
- return inpt
30
-
31
- @_transform.register(torch.Tensor)
32
- def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
33
- return torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
34
-
35
- @_transform.register(tv_tensors.Image)
36
- def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
37
- inpt_clamp = torch.clamp(inpt, min=self._out_range[0], max=self._out_range[1])
38
- return tv_tensors.wrap(inpt_clamp, like=inpt)
39
-
40
- @_transform.register(tv_tensors.BoundingBoxes)
41
- @_transform.register(tv_tensors.Mask)
42
- def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
43
- return inpt
@@ -1,5 +0,0 @@
1
- """Functional normalization related transformations API."""
2
-
3
- from eva.vision.data.transforms.normalization.functional.rescale_intensity import rescale_intensity
4
-
5
- __all__ = ["rescale_intensity"]
@@ -1,28 +0,0 @@
1
- """Intensity level functions."""
2
-
3
- import sys
4
- from typing import Tuple
5
-
6
- import torch
7
-
8
-
9
- def rescale_intensity(
10
- image: torch.Tensor,
11
- in_range: Tuple[float, float] | None = None,
12
- out_range: Tuple[float, float] = (0.0, 1.0),
13
- ) -> torch.Tensor:
14
- """Stretches or shrinks the image intensity levels.
15
-
16
- Args:
17
- image: The image tensor as float-type.
18
- in_range: The input data range. If `None`, it will
19
- fetch the min and max of the input image.
20
- out_range: The desired intensity range of the output.
21
-
22
- Returns:
23
- The image tensor after stretching or shrinking its intensity levels.
24
- """
25
- imin, imax = in_range or (image.min(), image.max())
26
- omin, omax = out_range
27
- image_scaled = (image - imin) / (imax - imin + sys.float_info.epsilon)
28
- return image_scaled * (omax - omin) + omin
@@ -1,53 +0,0 @@
1
- """Intensity level scaling transform."""
2
-
3
- import functools
4
- from typing import Any, Dict, Tuple
5
-
6
- import torch
7
- import torchvision.transforms.v2 as torch_transforms
8
- from torchvision import tv_tensors
9
- from typing_extensions import override
10
-
11
- from eva.vision.data.transforms.normalization import functional
12
-
13
-
14
- class RescaleIntensity(torch_transforms.Transform):
15
- """Stretches or shrinks the image intensity levels."""
16
-
17
- def __init__(
18
- self,
19
- in_range: Tuple[float, float] | None = None,
20
- out_range: Tuple[float, float] = (0.0, 1.0),
21
- ) -> None:
22
- """Initializes the transform.
23
-
24
- Args:
25
- in_range: The input data range. If `None`, it will
26
- fetch the min and max of the input image.
27
- out_range: The desired intensity range of the output.
28
- """
29
- super().__init__()
30
-
31
- self._in_range = in_range
32
- self._out_range = out_range
33
-
34
- @functools.singledispatchmethod
35
- @override
36
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
37
- return inpt
38
-
39
- @_transform.register(torch.Tensor)
40
- def _(self, inpt: torch.Tensor, params: Dict[str, Any]) -> Any:
41
- return functional.rescale_intensity(
42
- inpt, in_range=self._in_range, out_range=self._out_range
43
- )
44
-
45
- @_transform.register(tv_tensors.Image)
46
- def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
47
- scaled_inpt = functional.rescale_intensity(inpt, out_range=self._out_range)
48
- return tv_tensors.wrap(scaled_inpt, like=inpt)
49
-
50
- @_transform.register(tv_tensors.BoundingBoxes)
51
- @_transform.register(tv_tensors.Mask)
52
- def _(self, inpt: tv_tensors.BoundingBoxes | tv_tensors.Mask, params: Dict[str, Any]) -> Any:
53
- return inpt
@@ -1 +0,0 @@
1
- python_sources()
@@ -1,5 +0,0 @@
1
- """torch.hub backbones API."""
2
-
3
- from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model
4
-
5
- __all__ = ["torch_hub_model"]
@@ -1,61 +0,0 @@
1
- """torch.hub backbones."""
2
-
3
- import functools
4
- from typing import Tuple
5
-
6
- import torch
7
- from loguru import logger
8
- from torch import nn
9
-
10
- from eva.core.models import wrappers
11
- from eva.vision.models.networks.backbones.registry import BackboneModelRegistry
12
-
13
- HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"]
14
- """List of torch.hub repositories for which to add the models to the registry."""
15
-
16
-
17
- def torch_hub_model(
18
- model_name: str,
19
- repo_or_dir: str,
20
- checkpoint_path: str | None = None,
21
- pretrained: bool = False,
22
- out_indices: int | Tuple[int, ...] | None = None,
23
- **kwargs,
24
- ) -> nn.Module:
25
- """Initializes any ViT model from torch.hub with weights from a specified checkpoint.
26
-
27
- Args:
28
- model_name: The name of the model to load.
29
- repo_or_dir: The torch.hub repository or local directory to load the model from.
30
- checkpoint_path: The path to the checkpoint file.
31
- pretrained: If set to `True`, load pretrained model weights if available.
32
- out_indices: Whether and which multi-level patch embeddings to return.
33
- **kwargs: Additional arguments to pass to the model
34
-
35
- Returns:
36
- The VIT model instance.
37
- """
38
- logger.info(
39
- f"Loading torch.hub model {model_name} from {repo_or_dir}"
40
- + (f"using checkpoint {checkpoint_path}" if checkpoint_path else "")
41
- )
42
-
43
- return wrappers.TorchHubModel(
44
- model_name=model_name,
45
- repo_or_dir=repo_or_dir,
46
- pretrained=pretrained,
47
- checkpoint_path=checkpoint_path or "",
48
- out_indices=out_indices,
49
- model_kwargs=kwargs,
50
- )
51
-
52
-
53
- BackboneModelRegistry._registry.update(
54
- {
55
- f"torchhub/{repo}:{model_name}": functools.partial(
56
- torch_hub_model, model_name=model_name, repo_or_dir=repo
57
- )
58
- for repo in HUB_REPOS
59
- for model_name in torch.hub.list(repo, verbose=False)
60
- }
61
- )