kaiko-eva 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (131) hide show
  1. eva/core/callbacks/config.py +15 -6
  2. eva/core/callbacks/writers/embeddings/base.py +44 -10
  3. eva/core/cli/setup.py +1 -1
  4. eva/core/data/dataloaders/__init__.py +1 -2
  5. eva/core/data/samplers/classification/balanced.py +24 -12
  6. eva/core/data/samplers/random.py +17 -10
  7. eva/core/interface/interface.py +21 -0
  8. eva/core/loggers/utils/wandb.py +4 -1
  9. eva/core/models/modules/module.py +2 -2
  10. eva/core/models/wrappers/base.py +2 -2
  11. eva/core/models/wrappers/from_function.py +3 -3
  12. eva/core/models/wrappers/from_torchhub.py +9 -7
  13. eva/core/models/wrappers/huggingface.py +4 -5
  14. eva/core/models/wrappers/onnx.py +5 -5
  15. eva/core/trainers/trainer.py +13 -1
  16. eva/core/utils/__init__.py +2 -1
  17. eva/core/utils/distributed.py +12 -0
  18. eva/core/utils/paths.py +14 -0
  19. eva/core/utils/requirements.py +52 -6
  20. eva/language/__init__.py +2 -1
  21. eva/language/callbacks/__init__.py +5 -0
  22. eva/language/callbacks/writers/__init__.py +5 -0
  23. eva/language/callbacks/writers/prediction.py +201 -0
  24. eva/language/data/dataloaders/__init__.py +5 -0
  25. eva/language/data/dataloaders/collate_fn/__init__.py +5 -0
  26. eva/language/data/dataloaders/collate_fn/text.py +57 -0
  27. eva/language/data/datasets/__init__.py +3 -1
  28. eva/language/data/datasets/{language.py → base.py} +1 -1
  29. eva/language/data/datasets/classification/base.py +3 -43
  30. eva/language/data/datasets/classification/pubmedqa.py +36 -4
  31. eva/language/data/datasets/prediction.py +151 -0
  32. eva/language/data/datasets/schemas.py +18 -0
  33. eva/language/data/datasets/text.py +92 -0
  34. eva/language/data/datasets/typings.py +39 -0
  35. eva/language/data/messages.py +60 -0
  36. eva/language/models/__init__.py +15 -11
  37. eva/language/models/modules/__init__.py +2 -2
  38. eva/language/models/modules/language.py +94 -0
  39. eva/language/models/networks/__init__.py +12 -0
  40. eva/language/models/networks/alibaba.py +26 -0
  41. eva/language/models/networks/api/__init__.py +11 -0
  42. eva/language/models/networks/api/anthropic.py +34 -0
  43. eva/language/models/networks/registry.py +5 -0
  44. eva/language/models/typings.py +56 -0
  45. eva/language/models/wrappers/__init__.py +13 -5
  46. eva/language/models/wrappers/base.py +47 -0
  47. eva/language/models/wrappers/from_registry.py +54 -0
  48. eva/language/models/wrappers/huggingface.py +57 -11
  49. eva/language/models/wrappers/litellm.py +91 -46
  50. eva/language/models/wrappers/vllm.py +37 -13
  51. eva/language/utils/__init__.py +2 -1
  52. eva/language/utils/str_to_int_tensor.py +20 -12
  53. eva/language/utils/text/__init__.py +5 -0
  54. eva/language/utils/text/messages.py +113 -0
  55. eva/multimodal/__init__.py +6 -0
  56. eva/multimodal/callbacks/__init__.py +5 -0
  57. eva/multimodal/callbacks/writers/__init__.py +5 -0
  58. eva/multimodal/callbacks/writers/prediction.py +39 -0
  59. eva/multimodal/data/__init__.py +5 -0
  60. eva/multimodal/data/dataloaders/__init__.py +5 -0
  61. eva/multimodal/data/dataloaders/collate_fn/__init__.py +5 -0
  62. eva/multimodal/data/dataloaders/collate_fn/text_image.py +28 -0
  63. eva/multimodal/data/datasets/__init__.py +6 -0
  64. eva/multimodal/data/datasets/base.py +13 -0
  65. eva/multimodal/data/datasets/multiple_choice/__init__.py +5 -0
  66. eva/multimodal/data/datasets/multiple_choice/patch_camelyon.py +80 -0
  67. eva/multimodal/data/datasets/schemas.py +14 -0
  68. eva/multimodal/data/datasets/text_image.py +77 -0
  69. eva/multimodal/data/datasets/typings.py +27 -0
  70. eva/multimodal/models/__init__.py +8 -0
  71. eva/multimodal/models/modules/__init__.py +5 -0
  72. eva/multimodal/models/modules/vision_language.py +56 -0
  73. eva/multimodal/models/networks/__init__.py +14 -0
  74. eva/multimodal/models/networks/alibaba.py +40 -0
  75. eva/multimodal/models/networks/api/__init__.py +11 -0
  76. eva/multimodal/models/networks/api/anthropic.py +34 -0
  77. eva/multimodal/models/networks/others.py +48 -0
  78. eva/multimodal/models/networks/registry.py +5 -0
  79. eva/multimodal/models/typings.py +27 -0
  80. eva/multimodal/models/wrappers/__init__.py +13 -0
  81. eva/multimodal/models/wrappers/base.py +48 -0
  82. eva/multimodal/models/wrappers/from_registry.py +54 -0
  83. eva/multimodal/models/wrappers/huggingface.py +193 -0
  84. eva/multimodal/models/wrappers/litellm.py +58 -0
  85. eva/multimodal/utils/__init__.py +1 -0
  86. eva/multimodal/utils/batch/__init__.py +5 -0
  87. eva/multimodal/utils/batch/unpack.py +11 -0
  88. eva/multimodal/utils/image/__init__.py +5 -0
  89. eva/multimodal/utils/image/encode.py +28 -0
  90. eva/multimodal/utils/text/__init__.py +1 -0
  91. eva/multimodal/utils/text/messages.py +79 -0
  92. eva/vision/data/datasets/classification/breakhis.py +5 -8
  93. eva/vision/data/datasets/classification/panda.py +12 -5
  94. eva/vision/data/datasets/classification/patch_camelyon.py +8 -6
  95. eva/vision/data/datasets/segmentation/btcv.py +1 -1
  96. eva/vision/data/datasets/segmentation/consep.py +1 -1
  97. eva/vision/data/datasets/segmentation/lits17.py +1 -1
  98. eva/vision/data/datasets/segmentation/monusac.py +15 -6
  99. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +1 -1
  100. eva/vision/data/transforms/__init__.py +2 -1
  101. eva/vision/data/transforms/base/__init__.py +2 -1
  102. eva/vision/data/transforms/base/monai.py +2 -2
  103. eva/vision/data/transforms/base/torchvision.py +33 -0
  104. eva/vision/data/transforms/common/squeeze.py +6 -3
  105. eva/vision/data/transforms/croppad/crop_foreground.py +8 -7
  106. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +6 -5
  107. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -5
  108. eva/vision/data/transforms/croppad/rand_spatial_crop.py +8 -7
  109. eva/vision/data/transforms/croppad/spatial_pad.py +6 -6
  110. eva/vision/data/transforms/intensity/rand_scale_intensity.py +3 -3
  111. eva/vision/data/transforms/intensity/rand_shift_intensity.py +3 -3
  112. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +5 -5
  113. eva/vision/data/transforms/spatial/__init__.py +2 -1
  114. eva/vision/data/transforms/spatial/flip.py +8 -7
  115. eva/vision/data/transforms/spatial/functional/__init__.py +5 -0
  116. eva/vision/data/transforms/spatial/functional/resize.py +26 -0
  117. eva/vision/data/transforms/spatial/resize.py +63 -0
  118. eva/vision/data/transforms/spatial/rotate.py +8 -7
  119. eva/vision/data/transforms/spatial/spacing.py +7 -6
  120. eva/vision/data/transforms/utility/ensure_channel_first.py +6 -6
  121. eva/vision/models/networks/backbones/universal/vit.py +24 -0
  122. eva/vision/models/wrappers/from_registry.py +6 -5
  123. eva/vision/models/wrappers/from_timm.py +6 -4
  124. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/METADATA +17 -3
  125. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/RECORD +128 -66
  126. eva/core/data/dataloaders/collate_fn/__init__.py +0 -5
  127. eva/core/data/dataloaders/collate_fn/collate.py +0 -24
  128. eva/language/models/modules/text.py +0 -85
  129. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/WHEEL +0 -0
  130. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/entry_points.txt +0 -0
  131. {kaiko_eva-0.3.3.dist-info → kaiko_eva-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -9,11 +9,13 @@ from typing import Any, Dict, List
9
9
  import lightning.pytorch as pl
10
10
  import yaml
11
11
  from lightning_fabric.utilities import cloud_io
12
+ from loguru import logger
12
13
  from loguru import logger as cli_logger
13
14
  from omegaconf import OmegaConf
14
15
  from typing_extensions import TypeGuard, override
15
16
 
16
17
  from eva.core import loggers
18
+ from eva.core.utils import distributed as dist_utils
17
19
 
18
20
 
19
21
  class ConfigurationLogger(pl.Callback):
@@ -39,8 +41,14 @@ class ConfigurationLogger(pl.Callback):
39
41
  pl_module: pl.LightningModule,
40
42
  stage: str | None = None,
41
43
  ) -> None:
42
- log_dir = trainer.log_dir
43
- if not _logdir_exists(log_dir):
44
+ if dist_utils.is_distributed():
45
+ logger.info("ConfigurationLogger skipped as not supported in distributed mode.")
46
+ # TODO: Enabling leads to deadlocks in DDP mode, but I could not yet figure out why.
47
+ return
48
+
49
+ if not trainer.is_global_zero or not _logdir_exists(
50
+ log_dir := trainer.log_dir, self._verbose
51
+ ):
44
52
  return
45
53
 
46
54
  configuration = _load_submitted_config()
@@ -51,6 +59,10 @@ class ConfigurationLogger(pl.Callback):
51
59
 
52
60
  save_as = os.path.join(log_dir, self._save_as)
53
61
  fs = cloud_io.get_filesystem(log_dir)
62
+
63
+ if not fs.exists(log_dir):
64
+ fs.makedirs(log_dir)
65
+
54
66
  with fs.open(save_as, "w") as output_file:
55
67
  yaml.dump(configuration, output_file, sort_keys=False)
56
68
 
@@ -126,7 +138,7 @@ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
126
138
  for key, value in mapping.items():
127
139
  if isinstance(value, dict):
128
140
  formatted_value = _type_resolver(value)
129
- elif isinstance(value, list) and isinstance(value[0], dict):
141
+ elif isinstance(value, list) and value and isinstance(value[0], dict):
130
142
  formatted_value = [_type_resolver(subvalue) for subvalue in value]
131
143
  else:
132
144
  try:
@@ -134,10 +146,7 @@ def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
134
146
  formatted_value = (
135
147
  value if isinstance(parsed_value, BuiltinFunctionType) else parsed_value
136
148
  )
137
-
138
149
  except Exception:
139
150
  formatted_value = value
140
-
141
151
  mapping[key] = formatted_value
142
-
143
152
  return mapping
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Sequence
7
7
 
8
8
  import lightning.pytorch as pl
9
9
  import torch
10
+ import torch.distributed as dist
10
11
  from lightning.pytorch import callbacks
11
12
  from loguru import logger
12
13
  from torch import multiprocessing, nn
@@ -15,6 +16,7 @@ from typing_extensions import override
15
16
  from eva.core import utils
16
17
  from eva.core.callbacks.writers.embeddings.typings import QUEUE_ITEM
17
18
  from eva.core.models.modules.typings import INPUT_BATCH
19
+ from eva.core.utils import distributed as dist_utils
18
20
  from eva.core.utils import multiprocessing as eva_multiprocessing
19
21
 
20
22
 
@@ -58,8 +60,9 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
58
60
  self._save_every_n = save_every_n
59
61
  self._metadata_keys = metadata_keys or []
60
62
 
61
- self._write_queue: multiprocessing.Queue
62
- self._write_process: eva_multiprocessing.Process
63
+ self._write_queue: multiprocessing.Queue | None = None
64
+ self._write_process: eva_multiprocessing.Process | None = None
65
+ self._is_rank_zero: bool = False
63
66
 
64
67
  @staticmethod
65
68
  @abc.abstractmethod
@@ -78,9 +81,13 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
78
81
 
79
82
  @override
80
83
  def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
81
- self._check_if_exists()
82
- self._initialize_write_process()
83
- self._write_process.start()
84
+ self._is_rank_zero = trainer.is_global_zero
85
+ if self._is_rank_zero:
86
+ self._check_if_exists()
87
+ self._initialize_write_process()
88
+ if self._write_process is None or self._write_queue is None:
89
+ raise RuntimeError("Failed to initialize embedding writer process.")
90
+ self._write_process.start()
84
91
 
85
92
  if self._backbone is not None:
86
93
  self._backbone = self._backbone.to(pl_module.device)
@@ -106,6 +113,7 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
106
113
  with torch.no_grad():
107
114
  embeddings = self._get_embeddings(prediction)
108
115
 
116
+ queue_items: List[QUEUE_ITEM] = []
109
117
  for local_idx, global_idx in enumerate(batch_indices[: len(embeddings)]):
110
118
  data_name = dataset.filename(global_idx)
111
119
  save_name = os.path.splitext(data_name)[0] + ".pt"
@@ -121,15 +129,41 @@ class EmbeddingsWriter(callbacks.BasePredictionWriter, abc.ABC):
121
129
  split=split,
122
130
  metadata=item_metadata,
123
131
  )
124
- self._write_queue.put(item)
132
+ queue_items.append(item)
125
133
 
126
- self._write_process.check_exceptions()
134
+ gathered_items = self._gather_queue_items(queue_items)
135
+ if self._is_rank_zero:
136
+ for item in gathered_items:
137
+ self._write_queue.put(item) # type: ignore
138
+ self._write_process.check_exceptions() # type: ignore
127
139
 
128
140
  @override
129
141
  def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
130
- self._write_queue.put(None)
131
- self._write_process.join()
132
- logger.info(f"Predictions and manifest saved to {self._output_dir}")
142
+ if dist_utils.is_distributed():
143
+ dist.barrier()
144
+
145
+ if self._is_rank_zero and self._write_queue is not None:
146
+ self._write_queue.put(None)
147
+ if self._write_process is not None:
148
+ self._write_process.join()
149
+ logger.info(f"Predictions and manifest saved to {self._output_dir}")
150
+
151
+ def _gather_queue_items(self, items: List[QUEUE_ITEM]) -> List[QUEUE_ITEM]:
152
+ """Gather queue items across distributed ranks, returning only on rank zero."""
153
+ if not dist_utils.is_distributed():
154
+ return items
155
+
156
+ world_size = dist.get_world_size()
157
+ object_list: List[List[QUEUE_ITEM]] = [[] for _ in range(world_size)]
158
+ dist.all_gather_object(object_list, items)
159
+
160
+ if self._is_rank_zero:
161
+ gathered: List[QUEUE_ITEM] = []
162
+ for rank_items in object_list:
163
+ gathered.extend(rank_items)
164
+ return gathered
165
+
166
+ return []
133
167
 
134
168
  def _initialize_write_process(self) -> None:
135
169
  self._write_queue = multiprocessing.Queue()
eva/core/cli/setup.py CHANGED
@@ -59,7 +59,7 @@ def _initialize_logger() -> None:
59
59
  " :: <bold><level>{level}</level></bold>"
60
60
  " :: {message}",
61
61
  colorize=True,
62
- level="INFO",
62
+ level=os.getenv("LOGURU_LEVEL", "INFO"),
63
63
  )
64
64
 
65
65
 
@@ -1,6 +1,5 @@
1
1
  """Dataloaders API."""
2
2
 
3
- from eva.core.data.dataloaders.collate_fn import text_collate_fn
4
3
  from eva.core.data.dataloaders.dataloader import DataLoader
5
4
 
6
- __all__ = ["text_collate_fn", "DataLoader"]
5
+ __all__ = ["DataLoader"]
@@ -1,9 +1,10 @@
1
1
  """Random class sampler for data loading."""
2
2
 
3
3
  from collections import defaultdict
4
- from typing import Dict, Iterator, List
4
+ from typing import Dict, Iterator, List, Union
5
5
 
6
6
  import numpy as np
7
+ import torch
7
8
  from loguru import logger
8
9
  from typing_extensions import override
9
10
 
@@ -32,7 +33,7 @@ class BalancedSampler(SamplerWithDataSource[int]):
32
33
  """
33
34
  self._num_samples = num_samples
34
35
  self._replacement = replacement
35
- self._class_indices: Dict[int, List[int]] = defaultdict(list)
36
+ self._class_indices: Dict[Union[int, str], List[int]] = defaultdict(list)
36
37
  self._random_generator = np.random.default_rng(seed)
37
38
  self._indices: List[int] = []
38
39
 
@@ -62,20 +63,31 @@ class BalancedSampler(SamplerWithDataSource[int]):
62
63
  super().set_dataset(data_source)
63
64
  self._make_indices()
64
65
 
66
+ def _get_class_idx(self, idx):
67
+ """Load and validate the class index for a given sample index."""
68
+ if hasattr(self.data_source, "load_target"):
69
+ target = self.data_source.load_target(idx) # type: ignore
70
+ else:
71
+ _, target, _ = DataSample(*self.data_source[idx])
72
+
73
+ if target is None:
74
+ raise ValueError("The dataset must return non-empty targets.")
75
+
76
+ if isinstance(target, str):
77
+ return target
78
+
79
+ if isinstance(target, torch.Tensor):
80
+ if target.numel() != 1:
81
+ raise ValueError("The dataset must return a single & scalar target.")
82
+ return int(target.item())
83
+
84
+ raise ValueError("Unsupported target type. Expected str or tensor-like object.")
85
+
65
86
  def _make_indices(self):
66
87
  """Samples the indices for each class in the dataset."""
67
88
  self._class_indices.clear()
68
89
  for idx in tqdm(range(len(self.data_source)), desc="Fetching class indices for sampler"):
69
- if hasattr(self.data_source, "load_target"):
70
- target = self.data_source.load_target(idx) # type: ignore
71
- else:
72
- _, target, _ = DataSample(*self.data_source[idx])
73
- if target is None:
74
- raise ValueError("The dataset must return non-empty targets.")
75
- if target.numel() != 1:
76
- raise ValueError("The dataset must return a single & scalar target.")
77
-
78
- class_idx = int(target.item())
90
+ class_idx = self._get_class_idx(idx)
79
91
  self._class_indices[class_idx].append(idx)
80
92
 
81
93
  if not self._replacement:
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Optional
4
4
 
5
+ import torch
5
6
  from torch.utils import data
6
7
  from typing_extensions import override
7
8
 
@@ -10,30 +11,36 @@ from eva.core.data.samplers.sampler import SamplerWithDataSource
10
11
 
11
12
 
12
13
  class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
13
- """Samples elements randomly."""
14
+ """Samples elements randomly from a MapDataset."""
14
15
 
15
16
  data_source: datasets.MapDataset # type: ignore
16
17
 
17
18
  def __init__(
18
- self, replacement: bool = False, num_samples: Optional[int] = None, generator=None
19
+ self,
20
+ replacement: bool = False,
21
+ num_samples: Optional[int] = None,
22
+ seed: Optional[int] = None,
19
23
  ) -> None:
20
- """Initializes the random sampler.
24
+ """Initialize the random sampler.
21
25
 
22
26
  Args:
23
- data_source: dataset to sample from
24
- replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
25
- num_samples: number of samples to draw, default=`len(dataset)`.
26
- generator: Generator used in sampling.
27
+ replacement: Samples are drawn on-demand with replacement if ``True``, default=``False``
28
+ num_samples: Number of samples to draw, default=``len(dataset)``.
29
+ seed: Optional seed for the random number generator.
27
30
  """
28
31
  self.replacement = replacement
29
32
  self._num_samples = num_samples
30
- self.generator = generator
33
+ self._generator = None
34
+
35
+ if seed is not None:
36
+ self._generator = torch.Generator()
37
+ self._generator.manual_seed(seed)
31
38
 
32
39
  @override
33
40
  def set_dataset(self, data_source: datasets.MapDataset) -> None:
34
41
  super().__init__(
35
42
  data_source,
36
43
  replacement=self.replacement,
37
- num_samples=self.num_samples,
38
- generator=self.generator,
44
+ num_samples=self._num_samples,
45
+ generator=self._generator,
39
46
  )
@@ -132,3 +132,24 @@ class Interface:
132
132
  n_runs=trainer.n_runs,
133
133
  verbose=trainer.n_runs > 1,
134
134
  )
135
+
136
+ def validate_test(
137
+ self,
138
+ trainer: eva_trainer.Trainer,
139
+ model: modules.ModelModule,
140
+ data: datamodules.DataModule,
141
+ ) -> None:
142
+ """Runs validation & test stages."""
143
+ if getattr(data.datasets, "val", None) is None:
144
+ raise ValueError("The provided data module does not contain a validation dataset.")
145
+ if getattr(data.datasets, "test", None) is None:
146
+ raise ValueError("The provided data module does not contain a test dataset.")
147
+
148
+ eva_trainer.run_evaluation_session(
149
+ base_trainer=trainer,
150
+ base_model=model,
151
+ datamodule=data,
152
+ stages=["validate", "test"],
153
+ n_runs=trainer.n_runs,
154
+ verbose=trainer.n_runs > 1,
155
+ )
@@ -5,6 +5,8 @@ from typing import Any, Dict
5
5
 
6
6
  from loguru import logger
7
7
 
8
+ from eva.core.utils import requirements
9
+
8
10
 
9
11
  def rename_active_run(name: str) -> None:
10
12
  """Renames the current run."""
@@ -12,7 +14,8 @@ def rename_active_run(name: str) -> None:
12
14
 
13
15
  if wandb.run:
14
16
  wandb.run.name = name
15
- wandb.run.save()
17
+ if requirements.below("wandb", "0.21.0"):
18
+ wandb.run.save()
16
19
  else:
17
20
  logger.warning("No active wandb run found that could be renamed.")
18
21
 
@@ -33,8 +33,8 @@ class ModelModule(pl.LightningModule):
33
33
  super().__init__()
34
34
 
35
35
  self._metrics = metrics or self.default_metrics
36
- self._postprocess = postprocess or self.default_postprocess
37
36
 
37
+ self.postprocess = postprocess or self.default_postprocess
38
38
  self.metrics = metrics_lib.MetricModule.from_schema(self._metrics)
39
39
 
40
40
  @property
@@ -133,7 +133,7 @@ class ModelModule(pl.LightningModule):
133
133
  Returns:
134
134
  The updated outputs.
135
135
  """
136
- self._postprocess(outputs)
136
+ self.postprocess(outputs)
137
137
  return memory.recursive_detach(outputs, to_cpu=self.metrics_device.type == "cpu")
138
138
 
139
139
  def _forward_and_log_metrics(
@@ -25,7 +25,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
25
25
 
26
26
  self._output_transforms = transforms
27
27
 
28
- self._model: Callable[..., OutputType] | nn.Module
28
+ self.model: Callable[..., OutputType] | nn.Module
29
29
 
30
30
  @override
31
31
  def forward(self, tensor: InputType) -> OutputType:
@@ -43,7 +43,7 @@ class BaseModel(nn.Module, Generic[InputType, OutputType]):
43
43
  Args:
44
44
  tensor: The input tensor to the model.
45
45
  """
46
- return self._model(tensor)
46
+ return self.model(tensor)
47
47
 
48
48
  def _apply_transforms(self, tensor: OutputType) -> OutputType:
49
49
  if self._output_transforms is not None:
@@ -41,12 +41,12 @@ class ModelFromFunction(base.BaseModel[torch.Tensor, torch.Tensor]):
41
41
  self._arguments = arguments
42
42
  self._checkpoint_path = checkpoint_path
43
43
 
44
- self.load_model()
44
+ self.model = self.load_model()
45
45
 
46
46
  @override
47
- def load_model(self) -> None:
47
+ def load_model(self) -> nn.Module:
48
48
  class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
49
49
  model = class_path(**self._arguments or {})
50
50
  if self._checkpoint_path is not None:
51
51
  _utils.load_model_weights(model, self._checkpoint_path)
52
- self._model = model
52
+ return model
@@ -52,12 +52,12 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
52
52
  self._trust_repo = trust_repo
53
53
  self._model_kwargs = model_kwargs or {}
54
54
 
55
- self.load_model()
55
+ self.model = self.load_model()
56
56
 
57
57
  @override
58
- def load_model(self) -> None:
58
+ def load_model(self) -> nn.Module:
59
59
  """Builds and loads the torch.hub model."""
60
- self._model: nn.Module = torch.hub.load(
60
+ model: nn.Module = torch.hub.load(
61
61
  repo_or_dir=self._repo_or_dir,
62
62
  model=self._model_name,
63
63
  trust_repo=self._trust_repo,
@@ -66,21 +66,23 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
66
66
  ) # type: ignore
67
67
 
68
68
  if self._checkpoint_path:
69
- _utils.load_model_weights(self._model, self._checkpoint_path)
69
+ _utils.load_model_weights(model, self._checkpoint_path)
70
70
 
71
71
  TorchHubModel.__name__ = self._model_name
72
72
 
73
+ return model
74
+
73
75
  @override
74
76
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
75
77
  if self._out_indices is not None:
76
- if not hasattr(self._model, "get_intermediate_layers"):
78
+ if not hasattr(self.model, "get_intermediate_layers"):
77
79
  raise ValueError(
78
80
  "Only models with `get_intermediate_layers` are supported "
79
81
  "when using `out_indices`."
80
82
  )
81
83
 
82
84
  return list(
83
- self._model.get_intermediate_layers(
85
+ self.model.get_intermediate_layers( # type: ignore
84
86
  tensor,
85
87
  self._out_indices,
86
88
  reshape=True,
@@ -89,4 +91,4 @@ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
89
91
  )
90
92
  )
91
93
 
92
- return self._model(tensor)
94
+ return self.model(tensor)
@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict
4
4
 
5
5
  import torch
6
6
  import transformers
7
+ from torch import nn
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.core.models.wrappers import base
@@ -33,12 +34,10 @@ class HuggingFaceModel(base.BaseModel[torch.Tensor, torch.Tensor]):
33
34
  self._model_name_or_path = model_name_or_path
34
35
  self._model_kwargs = model_kwargs or {}
35
36
 
36
- self.load_model()
37
+ self.model = self.load_model()
37
38
 
38
39
  @override
39
- def load_model(self) -> None:
40
+ def load_model(self) -> nn.Module:
40
41
  # Use safetensors to avoid torch.load security vulnerability
41
42
  model_kwargs = {"use_safetensors": True, **self._model_kwargs}
42
- self._model = transformers.AutoModel.from_pretrained(
43
- self._model_name_or_path, **model_kwargs
44
- )
43
+ return transformers.AutoModel.from_pretrained(self._model_name_or_path, **model_kwargs)
@@ -30,21 +30,21 @@ class ONNXModel(base.BaseModel[torch.Tensor, torch.Tensor]):
30
30
  self._path = path
31
31
  self._device = device
32
32
 
33
- self.load_model()
33
+ self.model = self.load_model()
34
34
 
35
35
  @override
36
36
  def load_model(self) -> Any:
37
37
  if self._device == "cuda" and not torch.cuda.is_available():
38
38
  raise ValueError("Device is set to 'cuda', but CUDA is not available.")
39
39
  provider = "CUDAExecutionProvider" if self._device == "cuda" else "CPUExecutionProvider"
40
- self._model = ort.InferenceSession(self._path, providers=[provider]) # type: ignore
40
+ return ort.InferenceSession(self._path, providers=[provider]) # type: ignore
41
41
 
42
42
  @override
43
43
  def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
44
44
  # TODO: Use IO binding to avoid copying the tensor to CPU.
45
45
  # https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
46
- if not isinstance(self._model, ort.InferenceSession):
46
+ if not isinstance(self.model, ort.InferenceSession):
47
47
  raise ValueError("Model is not loaded.")
48
- inputs = {self._model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
49
- outputs = self._model.run(None, inputs)[0]
48
+ inputs = {self.model.get_inputs()[0].name: tensor.detach().cpu().numpy()}
49
+ outputs = self.model.run(None, inputs)[0]
50
50
  return torch.from_numpy(outputs).float().to(tensor.device)
@@ -8,6 +8,7 @@ from lightning.pytorch import loggers as pl_loggers
8
8
  from lightning.pytorch import trainer as pl_trainer
9
9
  from lightning.pytorch.utilities import argparse
10
10
  from lightning_fabric.utilities import cloud_io
11
+ from lightning_utilities.core.rank_zero import rank_zero_only
11
12
  from typing_extensions import override
12
13
 
13
14
  from eva.core import loggers as eva_loggers
@@ -30,6 +31,8 @@ class Trainer(pl_trainer.Trainer):
30
31
  default_root_dir: str = "logs",
31
32
  n_runs: int = 1,
32
33
  checkpoint_type: Literal["best", "last"] = "best",
34
+ accelerator: str = "auto",
35
+ devices: int = 1,
33
36
  **kwargs: Any,
34
37
  ) -> None:
35
38
  """Initializes the trainer.
@@ -44,9 +47,17 @@ class Trainer(pl_trainer.Trainer):
44
47
  n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
45
48
  checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint
46
49
  callback for evaluations on validation & test sets.
50
+ accelerator: The accelerator to use for training (e.g. "cpu", "gpu").
51
+ devices: The number of devices (GPUs) to use for training.
47
52
  kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
48
53
  """
49
- super().__init__(*args, default_root_dir=default_root_dir, **kwargs)
54
+ super().__init__(
55
+ *args,
56
+ default_root_dir=default_root_dir,
57
+ accelerator=accelerator,
58
+ devices=devices,
59
+ **kwargs,
60
+ )
50
61
 
51
62
  self.checkpoint_type = checkpoint_type
52
63
  self.n_runs = n_runs
@@ -66,6 +77,7 @@ class Trainer(pl_trainer.Trainer):
66
77
  def log_dir(self) -> str | None:
67
78
  return self.strategy.broadcast(self._log_dir)
68
79
 
80
+ @rank_zero_only
69
81
  def init_logger_run(self, run_id: int | None) -> None:
70
82
  """Setup the loggers & log directories when starting a new run.
71
83
 
@@ -3,5 +3,6 @@
3
3
  from eva.core.utils.clone import clone
4
4
  from eva.core.utils.memory import to_cpu
5
5
  from eva.core.utils.operations import numeric_sort
6
+ from eva.core.utils.paths import home_dir
6
7
 
7
- __all__ = ["clone", "to_cpu", "numeric_sort"]
8
+ __all__ = ["clone", "to_cpu", "numeric_sort", "home_dir"]
@@ -0,0 +1,12 @@
1
+ """Utility functions for distributed training."""
2
+
3
+ import torch.distributed as dist
4
+
5
+
6
+ def is_distributed() -> bool:
7
+ """Check if current environment is distributed.
8
+
9
+ Returns:
10
+ bool: True if distributed environment (e.g. multiple gpu processes).
11
+ """
12
+ return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
@@ -0,0 +1,14 @@
1
+ """Utility functions for handling paths."""
2
+
3
+ import os
4
+
5
+
6
+ def home_dir():
7
+ """Get eva's home directory for caching."""
8
+ torch_home = os.path.expanduser(
9
+ os.getenv(
10
+ "EVA_HOME",
11
+ os.path.join("~/.cache", "eva"),
12
+ )
13
+ )
14
+ return torch_home
@@ -3,10 +3,58 @@
3
3
  import importlib
4
4
  from typing import Dict
5
5
 
6
- from packaging import version
6
+ import packaging.version
7
7
 
8
8
 
9
- def check_dependencies(requirements: Dict[str, str]) -> None:
9
+ def fetch_version(name: str) -> str | None:
10
+ """Fetch the installed version of a package.
11
+
12
+ Args:
13
+ name: The name of the package.
14
+
15
+ Returns:
16
+ A string representing the installed version of the package, or None if not found.
17
+ """
18
+ try:
19
+ module = importlib.import_module(name)
20
+ return getattr(module, "__version__", None)
21
+ except ImportError:
22
+ return None
23
+
24
+
25
+ def below(name: str, version: str) -> bool:
26
+ """Check if the installed version of a package is below a certain version.
27
+
28
+ Args:
29
+ name: The name of the package.
30
+ version: The version to compare against.
31
+
32
+ Returns:
33
+ True if the installed version is below the specified version, False otherwise.
34
+ """
35
+ actual = fetch_version(name)
36
+ if actual:
37
+ return packaging.version.parse(actual) < packaging.version.parse(version)
38
+ return False
39
+
40
+
41
+ def above_or_equal(name: str, version: str) -> bool:
42
+ """Check if the installed version of a package is above a certain version.
43
+
44
+ Args:
45
+ name: The name of the package.
46
+ version: The version to compare against.
47
+
48
+ Returns:
49
+ True if the installed version is above the specified version, False otherwise.
50
+ """
51
+ actual = fetch_version(name)
52
+ if actual:
53
+ return packaging.version.parse(actual) >= packaging.version.parse(version)
54
+ return False
55
+
56
+
57
+ def check_min_versions(requirements: Dict[str, str]) -> None:
10
58
  """Check installed package versions against requirements dict.
11
59
 
12
60
  Args:
@@ -17,10 +65,8 @@ def check_dependencies(requirements: Dict[str, str]) -> None:
17
65
  ImportError: If any package does not meet the minimum required version.
18
66
  """
19
67
  for package, min_version in requirements.items():
20
- module = importlib.import_module(package)
21
- actual = getattr(module, "__version__", None)
22
- if actual and not (version.parse(actual) >= version.parse(min_version)):
68
+ if below(package, min_version):
23
69
  raise ImportError(
24
- f"Package '{package}' version {actual} does not meet "
70
+ f"Package '{package}' version {fetch_version(package)} does not meet "
25
71
  f"the minimum required version {min_version}."
26
72
  )
eva/language/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """eva language API."""
2
2
 
3
3
  try:
4
+ from eva.language import models
4
5
  from eva.language.data import datasets
5
6
  except ImportError as e:
6
7
  msg = (
@@ -10,4 +11,4 @@ except ImportError as e:
10
11
  )
11
12
  raise ImportError(str(e) + "\n\n" + msg) from e
12
13
 
13
- __all__ = ["datasets"]
14
+ __all__ = ["models", "datasets"]