kaiko-eva 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (41) hide show
  1. eva/.DS_Store +0 -0
  2. eva/core/callbacks/__init__.py +2 -1
  3. eva/core/callbacks/config.py +143 -0
  4. eva/core/data/datasets/__init__.py +10 -2
  5. eva/core/data/datasets/embeddings/__init__.py +13 -0
  6. eva/core/data/datasets/{classification/embeddings.py → embeddings/base.py} +41 -43
  7. eva/core/data/datasets/embeddings/classification/__init__.py +10 -0
  8. eva/core/data/datasets/embeddings/classification/embeddings.py +66 -0
  9. eva/core/data/datasets/embeddings/classification/multi_embeddings.py +106 -0
  10. eva/core/data/transforms/__init__.py +3 -1
  11. eva/core/data/transforms/padding/__init__.py +5 -0
  12. eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
  13. eva/core/data/transforms/sampling/__init__.py +5 -0
  14. eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
  15. eva/core/loggers/__init__.py +7 -0
  16. eva/core/loggers/dummy.py +38 -0
  17. eva/core/loggers/experimental_loggers.py +8 -0
  18. eva/core/loggers/log/__init__.py +5 -0
  19. eva/core/loggers/log/parameters.py +64 -0
  20. eva/core/loggers/log/utils.py +13 -0
  21. eva/core/models/modules/head.py +6 -11
  22. eva/core/models/modules/module.py +25 -1
  23. eva/core/trainers/_recorder.py +69 -7
  24. eva/core/trainers/functional.py +22 -5
  25. eva/core/trainers/trainer.py +20 -6
  26. eva/vision/data/datasets/__init__.py +1 -8
  27. eva/vision/data/datasets/_utils.py +3 -3
  28. eva/vision/data/datasets/classification/__init__.py +1 -8
  29. eva/vision/data/datasets/segmentation/base.py +20 -35
  30. eva/vision/data/datasets/segmentation/total_segmentator.py +88 -69
  31. eva/vision/models/.DS_Store +0 -0
  32. eva/vision/models/networks/.DS_Store +0 -0
  33. eva/vision/utils/convert.py +24 -0
  34. eva/vision/utils/io/nifti.py +10 -6
  35. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/METADATA +51 -25
  36. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/RECORD +39 -22
  37. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/WHEEL +1 -1
  38. eva/core/data/datasets/classification/__init__.py +0 -5
  39. eva/vision/data/datasets/classification/total_segmentator.py +0 -213
  40. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/entry_points.txt +0 -0
  41. {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.0.2.dist-info}/licenses/LICENSE +0 -0
eva/.DS_Store ADDED
Binary file
@@ -1,5 +1,6 @@
1
1
  """Callbacks API."""
2
2
 
3
+ from eva.core.callbacks.config import ConfigurationLogger
3
4
  from eva.core.callbacks.writers import EmbeddingsWriter
4
5
 
5
- __all__ = ["EmbeddingsWriter"]
6
+ __all__ = ["ConfigurationLogger", "EmbeddingsWriter"]
@@ -0,0 +1,143 @@
1
+ """Configuration logger callback."""
2
+
3
+ import ast
4
+ import os
5
+ import sys
6
+ from types import BuiltinFunctionType
7
+ from typing import Any, Dict, List
8
+
9
+ import lightning.pytorch as pl
10
+ import yaml
11
+ from lightning_fabric.utilities import cloud_io
12
+ from loguru import logger as cli_logger
13
+ from omegaconf import OmegaConf
14
+ from typing_extensions import TypeGuard, override
15
+
16
+ from eva.core import loggers
17
+
18
+
19
+ class ConfigurationLogger(pl.Callback):
20
+ """Logs the submitted configuration to the experimental logger."""
21
+
22
+ _save_as: str = "config.yaml"
23
+
24
+ def __init__(self, verbose: bool = True) -> None:
25
+ """Initializes the callback.
26
+
27
+ Args:
28
+ verbose: Whether to print the configurations to print the
29
+ configuration to the terminal.
30
+ """
31
+ super().__init__()
32
+
33
+ self._verbose = verbose
34
+
35
+ @override
36
+ def setup(
37
+ self,
38
+ trainer: pl.Trainer,
39
+ pl_module: pl.LightningModule,
40
+ stage: str | None = None,
41
+ ) -> None:
42
+ log_dir = trainer.log_dir
43
+ if not _logdir_exists(log_dir):
44
+ return
45
+
46
+ configuration = _load_submitted_config()
47
+
48
+ if self._verbose:
49
+ config_as_text = yaml.dump(configuration, sort_keys=False)
50
+ print(f"Configuration:\033[94m\n---\n{config_as_text}\033[0m")
51
+
52
+ save_as = os.path.join(log_dir, self._save_as)
53
+ fs = cloud_io.get_filesystem(log_dir)
54
+ with fs.open(save_as, "w") as output_file:
55
+ yaml.dump(configuration, output_file, sort_keys=False)
56
+
57
+ loggers.log_parameters(trainer.loggers, tag="configuration", parameters=configuration)
58
+
59
+
60
+ def _logdir_exists(logdir: str | None, verbose: bool = True) -> TypeGuard[str]:
61
+ """Checks if the trainer has a log directory.
62
+
63
+ Args:
64
+ logdir: Trainer's logdir.
65
+ name: The name to log with.
66
+ verbose: Whether to log if it does not exist.
67
+
68
+ Returns:
69
+ A bool indicating if the log directory exists or not.
70
+ """
71
+ exists = isinstance(logdir, str)
72
+ if not exists and verbose:
73
+ print("\n")
74
+ cli_logger.warning("Log directory is `None`. Configuration file will not be logged.\n")
75
+ return exists
76
+
77
+
78
+ def _load_submitted_config() -> Dict[str, Any]:
79
+ """Retrieves and loads the submitted configuration.
80
+
81
+ Returns:
82
+ The path to the configuration file.
83
+ """
84
+ config_paths = _fetch_submitted_config_path()
85
+ return _load_yaml_files(config_paths)
86
+
87
+
88
+ def _fetch_submitted_config_path() -> List[str]:
89
+ """Fetches the config path from command line arguments.
90
+
91
+ Returns:
92
+ The path to the configuration file.
93
+ """
94
+ return list(filter(lambda f: f.endswith(".yaml"), sys.argv))
95
+
96
+
97
+ def _load_yaml_files(paths: List[str]) -> Dict[str, Any]:
98
+ """Loads yaml files and merge them from multiple paths.
99
+
100
+ Args:
101
+ paths: The paths to the yaml files.
102
+
103
+ Returns:
104
+ The merged configurations as a dictionary.
105
+ """
106
+ merged_config = {}
107
+ for config_path in paths:
108
+ fs = cloud_io.get_filesystem(config_path)
109
+ with fs.open(config_path, "r") as file:
110
+ omegaconf_file = OmegaConf.load(file) # type: ignore
111
+ config_dict = OmegaConf.to_object(omegaconf_file) # type: ignore
112
+ parsed_config = _type_resolver(config_dict) # type: ignore
113
+ merged_config.update(parsed_config)
114
+ return merged_config
115
+
116
+
117
+ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
118
+ """Parses the string values of a dictionary in-place.
119
+
120
+ Args:
121
+ mapping: A dictionary object.
122
+
123
+ Returns:
124
+ The mapping with the formatted values.
125
+ """
126
+ for key, value in mapping.items():
127
+ if isinstance(value, dict):
128
+ formatted_value = _type_resolver(value)
129
+ elif isinstance(value, list) and isinstance(value[0], dict):
130
+ formatted_value = [_type_resolver(subvalue) for subvalue in value]
131
+ else:
132
+ try:
133
+ parsed_value = ast.literal_eval(value) # type: ignore
134
+ formatted_value = (
135
+ value if isinstance(parsed_value, BuiltinFunctionType) else parsed_value
136
+ )
137
+
138
+ except Exception:
139
+ formatted_value = value
140
+
141
+ mapping[key] = formatted_value
142
+
143
+ return mapping
@@ -1,7 +1,15 @@
1
1
  """Datasets API."""
2
2
 
3
3
  from eva.core.data.datasets.base import Dataset
4
- from eva.core.data.datasets.classification import EmbeddingsClassificationDataset
5
4
  from eva.core.data.datasets.dataset import TorchDataset
5
+ from eva.core.data.datasets.embeddings import (
6
+ EmbeddingsClassificationDataset,
7
+ MultiEmbeddingsClassificationDataset,
8
+ )
6
9
 
7
- __all__ = ["Dataset", "EmbeddingsClassificationDataset", "TorchDataset"]
10
+ __all__ = [
11
+ "Dataset",
12
+ "EmbeddingsClassificationDataset",
13
+ "MultiEmbeddingsClassificationDataset",
14
+ "TorchDataset",
15
+ ]
@@ -0,0 +1,13 @@
1
+ """Datasets API."""
2
+
3
+ from eva.core.data.datasets.embeddings.base import EmbeddingsDataset
4
+ from eva.core.data.datasets.embeddings.classification import (
5
+ EmbeddingsClassificationDataset,
6
+ MultiEmbeddingsClassificationDataset,
7
+ )
8
+
9
+ __all__ = [
10
+ "EmbeddingsDataset",
11
+ "EmbeddingsClassificationDataset",
12
+ "MultiEmbeddingsClassificationDataset",
13
+ ]
@@ -1,7 +1,8 @@
1
- """Embeddings classification dataset."""
1
+ """Base dataset class for Embeddings."""
2
2
 
3
+ import abc
3
4
  import os
4
- from typing import Callable, Dict, Tuple
5
+ from typing import Callable, Dict, Literal, Tuple
5
6
 
6
7
  import numpy as np
7
8
  import pandas as pd
@@ -11,22 +12,23 @@ from typing_extensions import override
11
12
  from eva.core.data.datasets import base
12
13
  from eva.core.utils import io
13
14
 
15
+ default_column_mapping: Dict[str, str] = {
16
+ "path": "embeddings",
17
+ "target": "target",
18
+ "split": "split",
19
+ "multi_id": "slide_id",
20
+ }
21
+ """The default column mapping of the variables to the manifest columns."""
14
22
 
15
- class EmbeddingsClassificationDataset(base.Dataset):
16
- """Embeddings classification dataset."""
17
23
 
18
- default_column_mapping: Dict[str, str] = {
19
- "data": "embeddings",
20
- "target": "target",
21
- "split": "split",
22
- }
23
- """The default column mapping of the variables to the manifest columns."""
24
+ class EmbeddingsDataset(base.Dataset):
25
+ """Abstract base class for embedding datasets."""
24
26
 
25
27
  def __init__(
26
28
  self,
27
29
  root: str,
28
30
  manifest_file: str,
29
- split: str | None = None,
31
+ split: Literal["train", "val", "test"] | None = None,
30
32
  column_mapping: Dict[str, str] = default_column_mapping,
31
33
  embeddings_transforms: Callable | None = None,
32
34
  target_transforms: Callable | None = None,
@@ -54,12 +56,38 @@ class EmbeddingsClassificationDataset(base.Dataset):
54
56
  self._root = root
55
57
  self._manifest_file = manifest_file
56
58
  self._split = split
57
- self._column_mapping = self.default_column_mapping | column_mapping
59
+ self._column_mapping = default_column_mapping | column_mapping
58
60
  self._embeddings_transforms = embeddings_transforms
59
61
  self._target_transforms = target_transforms
60
62
 
61
63
  self._data: pd.DataFrame
62
64
 
65
+ @abc.abstractmethod
66
+ def _load_embeddings(self, index: int) -> torch.Tensor:
67
+ """Returns the `index`'th embedding sample.
68
+
69
+ Args:
70
+ index: The index of the data sample to load.
71
+
72
+ Returns:
73
+ The embedding sample as a tensor.
74
+ """
75
+
76
+ @abc.abstractmethod
77
+ def _load_target(self, index: int) -> np.ndarray:
78
+ """Returns the `index`'th target sample.
79
+
80
+ Args:
81
+ index: The index of the data sample to load.
82
+
83
+ Returns:
84
+ The sample target as an array.
85
+ """
86
+
87
+ @abc.abstractmethod
88
+ def __len__(self) -> int:
89
+ """Returns the total length of the data."""
90
+
63
91
  def filename(self, index: int) -> str:
64
92
  """Returns the filename of the `index`'th data sample.
65
93
 
@@ -71,7 +99,7 @@ class EmbeddingsClassificationDataset(base.Dataset):
71
99
  Returns:
72
100
  The filename of the `index`'th data sample.
73
101
  """
74
- return self._data.at[index, self._column_mapping["data"]]
102
+ return self._data.at[index, self._column_mapping["path"]]
75
103
 
76
104
  @override
77
105
  def setup(self):
@@ -90,36 +118,6 @@ class EmbeddingsClassificationDataset(base.Dataset):
90
118
  target = self._load_target(index)
91
119
  return self._apply_transforms(embeddings, target)
92
120
 
93
- def __len__(self) -> int:
94
- """Returns the total length of the data."""
95
- return len(self._data)
96
-
97
- def _load_embeddings(self, index: int) -> torch.Tensor:
98
- """Returns the `index`'th embedding sample.
99
-
100
- Args:
101
- index: The index of the data sample to load.
102
-
103
- Returns:
104
- The sample embedding as an array.
105
- """
106
- filename = self.filename(index)
107
- embeddings_path = os.path.join(self._root, filename)
108
- tensor = torch.load(embeddings_path, map_location="cpu")
109
- return tensor.squeeze(0)
110
-
111
- def _load_target(self, index: int) -> np.ndarray:
112
- """Returns the `index`'th target sample.
113
-
114
- Args:
115
- index: The index of the data sample to load.
116
-
117
- Returns:
118
- The sample target as an array.
119
- """
120
- target = self._data.at[index, self._column_mapping["target"]]
121
- return np.asarray(target, dtype=np.int64)
122
-
123
121
  def _load_manifest(self) -> pd.DataFrame:
124
122
  """Loads manifest file and filters the data based on the split column.
125
123
 
@@ -0,0 +1,10 @@
1
+ """Embedding cllassification datasets API."""
2
+
3
+ from eva.core.data.datasets.embeddings.classification.embeddings import (
4
+ EmbeddingsClassificationDataset,
5
+ )
6
+ from eva.core.data.datasets.embeddings.classification.multi_embeddings import (
7
+ MultiEmbeddingsClassificationDataset,
8
+ )
9
+
10
+ __all__ = ["EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset"]
@@ -0,0 +1,66 @@
1
+ """Embeddings classification dataset."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from typing_extensions import override
9
+
10
+ from eva.core.data.datasets.embeddings import base
11
+
12
+
13
+ class EmbeddingsClassificationDataset(base.EmbeddingsDataset):
14
+ """Embeddings dataset class for classification tasks."""
15
+
16
+ def __init__(
17
+ self,
18
+ root: str,
19
+ manifest_file: str,
20
+ split: Literal["train", "val", "test"] | None = None,
21
+ column_mapping: Dict[str, str] = base.default_column_mapping,
22
+ embeddings_transforms: Callable | None = None,
23
+ target_transforms: Callable | None = None,
24
+ ) -> None:
25
+ """Initialize dataset.
26
+
27
+ Expects a manifest file listing the paths of .pt files that contain
28
+ tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
29
+
30
+ Args:
31
+ root: Root directory of the dataset.
32
+ manifest_file: The path to the manifest file, which is relative to
33
+ the `root` argument.
34
+ split: The dataset split to use. The `split` column of the manifest
35
+ file will be splitted based on this value.
36
+ column_mapping: Defines the map between the variables and the manifest
37
+ columns. It will overwrite the `default_column_mapping` with
38
+ the provided values, so that `column_mapping` can contain only the
39
+ values which are altered or missing.
40
+ embeddings_transforms: A function/transform that transforms the embedding.
41
+ target_transforms: A function/transform that transforms the target.
42
+ """
43
+ super().__init__(
44
+ root=root,
45
+ manifest_file=manifest_file,
46
+ split=split,
47
+ column_mapping=column_mapping,
48
+ embeddings_transforms=embeddings_transforms,
49
+ target_transforms=target_transforms,
50
+ )
51
+
52
+ @override
53
+ def _load_embeddings(self, index: int) -> torch.Tensor:
54
+ filename = self.filename(index)
55
+ embeddings_path = os.path.join(self._root, filename)
56
+ tensor = torch.load(embeddings_path, map_location="cpu")
57
+ return tensor.squeeze(0)
58
+
59
+ @override
60
+ def _load_target(self, index: int) -> np.ndarray:
61
+ target = self._data.at[index, self._column_mapping["target"]]
62
+ return np.asarray(target, dtype=np.int64)
63
+
64
+ @override
65
+ def __len__(self) -> int:
66
+ return len(self._data)
@@ -0,0 +1,106 @@
1
+ """Dataset class for where a sample corresponds to multiple embeddings."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ from typing_extensions import override
9
+
10
+ from eva.core.data.datasets.embeddings import base
11
+
12
+
13
+ class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
14
+ """Dataset class for where a sample corresponds to multiple embeddings.
15
+
16
+ Example use case: Slide level dataset where each slide has multiple patch embeddings.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ root: str,
22
+ manifest_file: str,
23
+ split: Literal["train", "val", "test"],
24
+ column_mapping: Dict[str, str] = base.default_column_mapping,
25
+ embeddings_transforms: Callable | None = None,
26
+ target_transforms: Callable | None = None,
27
+ ):
28
+ """Initialize dataset.
29
+
30
+ Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
31
+
32
+ The manifest must have a `column_mapping["multi_id"]` column that contains the
33
+ unique identifier group of embeddings. For oncology datasets, this would be usually
34
+ the slide id. Each row in the manifest file points to a .pt file that can contain
35
+ one or multiple embeddings. There can also be multiple rows for the same `multi_id`,
36
+ in which case the embeddings from the different .pt files corresponding to that same
37
+ `multi_id` will be stacked along the first dimension.
38
+
39
+ Args:
40
+ root: Root directory of the dataset.
41
+ manifest_file: The path to the manifest file, which is relative to
42
+ the `root` argument.
43
+ split: The dataset split to use. The `split` column of the manifest
44
+ file will be splitted based on this value.
45
+ column_mapping: Defines the map between the variables and the manifest
46
+ columns. It will overwrite the `default_column_mapping` with
47
+ the provided values, so that `column_mapping` can contain only the
48
+ values which are altered or missing.
49
+ embeddings_transforms: A function/transform that transforms the embedding.
50
+ target_transforms: A function/transform that transforms the target.
51
+ """
52
+ super().__init__(
53
+ manifest_file=manifest_file,
54
+ root=root,
55
+ split=split,
56
+ column_mapping=column_mapping,
57
+ embeddings_transforms=embeddings_transforms,
58
+ target_transforms=target_transforms,
59
+ )
60
+
61
+ self._multi_ids: List[int]
62
+
63
+ @override
64
+ def setup(self):
65
+ super().setup()
66
+ self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
67
+
68
+ @override
69
+ def _load_embeddings(self, index: int) -> torch.Tensor:
70
+ """Loads and stacks all embedding corresponding to the `index`'th multi_id."""
71
+ # Get all embeddings for the given index (multi_id)
72
+ multi_id = self._multi_ids[index]
73
+ embedding_paths = self._data.loc[
74
+ self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
75
+ ].to_list()
76
+ embedding_paths = [os.path.join(self._root, path) for path in embedding_paths]
77
+
78
+ # Load embeddings and stack them accross the first dimension
79
+ embeddings = [torch.load(path, map_location="cpu") for path in embedding_paths]
80
+ embeddings = torch.cat(embeddings, dim=0)
81
+
82
+ if not embeddings.ndim == 2:
83
+ raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")
84
+
85
+ return embeddings
86
+
87
+ @override
88
+ def _load_target(self, index: int) -> np.ndarray:
89
+ """Returns the target corresponding to the `index`'th multi_id.
90
+
91
+ This method assumes that all the embeddings corresponding to the same `multi_id`
92
+ have the same target. If this is not the case, it will raise an error.
93
+ """
94
+ multi_id = self._multi_ids[index]
95
+ targets = self._data.loc[
96
+ self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
97
+ ]
98
+
99
+ if not targets.nunique() == 1:
100
+ raise ValueError(f"Multiple targets found for {multi_id}.")
101
+
102
+ return np.asarray(targets.iloc[0], dtype=np.int64)
103
+
104
+ @override
105
+ def __len__(self) -> int:
106
+ return len(self._data)
@@ -1,5 +1,7 @@
1
1
  """Core data transforms."""
2
2
 
3
3
  from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor
4
+ from eva.core.data.transforms.padding import Pad2DTensor
5
+ from eva.core.data.transforms.sampling import SampleFromAxis
4
6
 
5
- __all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
7
+ __all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"]
@@ -0,0 +1,5 @@
1
+ """Padding related transformations."""
2
+
3
+ from eva.core.data.transforms.padding.pad_2d_tensor import Pad2DTensor
4
+
5
+ __all__ = ["Pad2DTensor"]
@@ -0,0 +1,38 @@
1
+ """Padding transformation for 2D tensors."""
2
+
3
+ import torch
4
+ import torch.nn.functional
5
+
6
+
7
+ class Pad2DTensor:
8
+ """Pads a 2D tensor to a fixed dimension accross the first dimension."""
9
+
10
+ def __init__(self, pad_size: int, pad_value: int | float = float("-inf")):
11
+ """Initialize the transformation.
12
+
13
+ Args:
14
+ pad_size: The size to pad the tensor to. If the tensor is larger than this size,
15
+ no padding will be applied.
16
+ pad_value: The value to use for padding.
17
+ """
18
+ self._pad_size = pad_size
19
+ self._pad_value = pad_value
20
+
21
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
22
+ """Call method for the transformation.
23
+
24
+ Args:
25
+ tensor: The input tensor of shape [n, embedding_dim].
26
+
27
+ Returns:
28
+ A tensor of shape [max(n, pad_dim), embedding_dim].
29
+ """
30
+ n_pad_values = self._pad_size - tensor.size(0)
31
+ if n_pad_values > 0:
32
+ tensor = torch.nn.functional.pad(
33
+ tensor,
34
+ pad=(0, 0, 0, n_pad_values),
35
+ mode="constant",
36
+ value=self._pad_value,
37
+ )
38
+ return tensor
@@ -0,0 +1,5 @@
1
+ """Sampling related transformations."""
2
+
3
+ from eva.core.data.transforms.sampling.sample_from_axis import SampleFromAxis
4
+
5
+ __all__ = ["SampleFromAxis"]
@@ -0,0 +1,40 @@
1
+ """Sampling transformations."""
2
+
3
+ import torch
4
+
5
+
6
+ class SampleFromAxis:
7
+ """Samples n_samples entries from a tensor along a given axis."""
8
+
9
+ def __init__(self, n_samples: int, seed: int = 42, axis: int = 0):
10
+ """Initialize the transformation.
11
+
12
+ Args:
13
+ n_samples: The number of samples to draw.
14
+ seed: The seed to use for sampling.
15
+ axis: The axis along which to sample.
16
+ """
17
+ self._seed = seed
18
+ self._n_samples = n_samples
19
+ self._axis = axis
20
+ self._generator = self._get_generator()
21
+
22
+ def _get_generator(self):
23
+ """Return a torch random generator with fixed seed."""
24
+ generator = torch.Generator()
25
+ generator.manual_seed(self._seed)
26
+ return generator
27
+
28
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
29
+ """Call method for the transformation.
30
+
31
+ Args:
32
+ tensor: The input tensor of shape [n, embedding_dim].
33
+
34
+ Returns:
35
+ A tensor of shape [n_samples, embedding_dim].
36
+ """
37
+ indices = torch.randperm(tensor.size(self._axis), generator=self._generator)[
38
+ : self._n_samples
39
+ ]
40
+ return tensor.index_select(self._axis, indices)
@@ -0,0 +1,7 @@
1
+ """Experimental loggers API."""
2
+
3
+ from eva.core.loggers.dummy import DummyLogger
4
+ from eva.core.loggers.experimental_loggers import ExperimentalLoggers
5
+ from eva.core.loggers.log import log_parameters
6
+
7
+ __all__ = ["DummyLogger", "ExperimentalLoggers", "log_parameters"]
@@ -0,0 +1,38 @@
1
+ """Dummy logger class."""
2
+
3
+ import lightning.pytorch.loggers.logger
4
+
5
+
6
+ class DummyLogger(lightning.pytorch.loggers.logger.DummyLogger):
7
+ """Dummy logger class.
8
+
9
+ This logger is currently used as a placeholder when saving results
10
+ to remote storage, as common lightning loggers do not work
11
+ with azure blob storage:
12
+
13
+ <https://github.com/Lightning-AI/pytorch-lightning/issues/18861>
14
+ <https://github.com/Lightning-AI/pytorch-lightning/issues/19736>
15
+
16
+ Simply disabling the loggers when pointing to remote storage doesn't work
17
+ because callbacks such as LearningRateMonitor or ModelCheckpoint require a
18
+ logger to be present.
19
+ """
20
+
21
+ def __init__(self, save_dir: str) -> None:
22
+ """Initializes the logger.
23
+
24
+ Args:
25
+ save_dir: The save directory (this logger does not save anything,
26
+ but callbacks might use this path to save their outputs).
27
+ """
28
+ super().__init__()
29
+ self._save_dir = save_dir
30
+
31
+ @property
32
+ def save_dir(self) -> str:
33
+ """Returns the save directory."""
34
+ return self._save_dir
35
+
36
+ def __deepcopy__(self, memo=None):
37
+ """Override of the deepcopy method."""
38
+ return self
@@ -0,0 +1,8 @@
1
+ """Experiment loggers."""
2
+
3
+ from typing import Union
4
+
5
+ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
6
+
7
+ """Supported loggers."""
8
+ ExperimentalLoggers = Union[CSVLogger, TensorBoardLogger]
@@ -0,0 +1,5 @@
1
+ """Experiment loggers actions."""
2
+
3
+ from eva.core.loggers.log.parameters import log_parameters
4
+
5
+ __all__ = ["log_parameters"]