kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (90) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/functional.py +40 -43
  17. eva/core/utils/factory.py +66 -0
  18. eva/core/utils/registry.py +42 -0
  19. eva/core/utils/requirements.py +26 -0
  20. eva/language/__init__.py +13 -0
  21. eva/language/data/__init__.py +5 -0
  22. eva/language/data/datasets/__init__.py +9 -0
  23. eva/language/data/datasets/classification/__init__.py +7 -0
  24. eva/language/data/datasets/classification/base.py +63 -0
  25. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  26. eva/language/data/datasets/language.py +13 -0
  27. eva/language/models/__init__.py +25 -0
  28. eva/language/models/modules/__init__.py +5 -0
  29. eva/language/models/modules/text.py +85 -0
  30. eva/language/models/modules/typings.py +16 -0
  31. eva/language/models/wrappers/__init__.py +11 -0
  32. eva/language/models/wrappers/huggingface.py +69 -0
  33. eva/language/models/wrappers/litellm.py +77 -0
  34. eva/language/models/wrappers/vllm.py +149 -0
  35. eva/language/utils/__init__.py +5 -0
  36. eva/language/utils/str_to_int_tensor.py +95 -0
  37. eva/vision/data/dataloaders/__init__.py +2 -1
  38. eva/vision/data/dataloaders/worker_init.py +35 -0
  39. eva/vision/data/datasets/__init__.py +5 -5
  40. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  41. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  42. eva/vision/data/datasets/segmentation/consep.py +5 -4
  43. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  44. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  45. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  46. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  47. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  48. eva/vision/data/transforms/__init__.py +11 -2
  49. eva/vision/data/transforms/base/__init__.py +5 -0
  50. eva/vision/data/transforms/base/monai.py +27 -0
  51. eva/vision/data/transforms/common/__init__.py +2 -1
  52. eva/vision/data/transforms/common/squeeze.py +24 -0
  53. eva/vision/data/transforms/croppad/__init__.py +4 -0
  54. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  56. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  57. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  58. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  59. eva/vision/models/modules/semantic_segmentation.py +18 -7
  60. eva/vision/models/networks/backbones/__init__.py +2 -3
  61. eva/vision/models/networks/backbones/_utils.py +1 -1
  62. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  63. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  64. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  65. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  66. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  67. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  68. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  72. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  73. eva/vision/models/networks/backbones/registry.py +2 -44
  74. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  75. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  76. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  77. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  78. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  80. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  81. eva/vision/models/wrappers/from_registry.py +14 -9
  82. eva/vision/models/wrappers/from_timm.py +6 -5
  83. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
  84. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
  85. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
  86. eva/vision/data/datasets/segmentation/lits.py +0 -199
  87. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  88. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  89. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
  90. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  """Dataloaders API."""
2
2
 
3
+ from eva.core.data.dataloaders.collate_fn import text_collate_fn
3
4
  from eva.core.data.dataloaders.dataloader import DataLoader
4
5
 
5
- __all__ = ["DataLoader"]
6
+ __all__ = ["text_collate_fn", "DataLoader"]
@@ -0,0 +1,5 @@
1
+ """Collate functions API."""
2
+
3
+ from eva.core.data.dataloaders.collate_fn.collate import text_collate_fn
4
+
5
+ __all__ = ["text_collate_fn"]
@@ -0,0 +1,24 @@
1
+ """Collate functions for text data."""
2
+
3
+ from typing import Dict, List, Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def text_collate_fn(
9
+ batch: List[Tuple[str, torch.Tensor, Dict]],
10
+ ) -> Tuple[List[str], torch.Tensor, List[Dict]]:
11
+ """Collate function for text data that keeps texts as separate strings.
12
+
13
+ Args:
14
+ batch: List of tuples containing (text, target, metadata) from the dataset
15
+
16
+ Returns:
17
+ Tuple containing:
18
+ - List of text strings
19
+ - Batched tensor of targets
20
+ - List of metadata dictionaries
21
+ """
22
+ texts, targets, metadata = zip(*batch, strict=False)
23
+ targets = torch.stack(targets)
24
+ return list(texts), targets, list(metadata)
@@ -56,6 +56,9 @@ class DataLoader:
56
56
  persistent_workers: bool = True
57
57
  """Will keep the worker processes after a dataset has been consumed once."""
58
58
 
59
+ worker_init_fn: Callable | None = None
60
+ """Function to call on each worker process before data loading."""
61
+
59
62
  prefetch_factor: int | None = 2
60
63
  """Number of batches loaded in advance by each worker."""
61
64
 
@@ -80,4 +83,5 @@ class DataLoader:
80
83
  drop_last=self.drop_last,
81
84
  persistent_workers=self.persistent_workers,
82
85
  prefetch_factor=self.prefetch_factor,
86
+ worker_init_fn=self.worker_init_fn,
83
87
  )
@@ -34,7 +34,14 @@ class Interface:
34
34
  model: The model module to use but not modify.
35
35
  data: The data module.
36
36
  """
37
- trainer.run_evaluation_session(model=model, datamodule=data)
37
+ eva_trainer.run_evaluation_session(
38
+ base_trainer=trainer,
39
+ base_model=model,
40
+ datamodule=data,
41
+ stages=["fit", "validate", "test"],
42
+ n_runs=trainer.n_runs,
43
+ verbose=trainer.n_runs > 1,
44
+ )
38
45
 
39
46
  def predict(
40
47
  self,
@@ -77,3 +84,29 @@ class Interface:
77
84
  """
78
85
  self.predict(trainer=trainer, model=model, data=data)
79
86
  self.fit(trainer=trainer, model=model, data=data)
87
+
88
+ def validate(
89
+ self,
90
+ trainer: eva_trainer.Trainer,
91
+ model: modules.ModelModule,
92
+ data: datamodules.DataModule,
93
+ ) -> None:
94
+ """Perform model validation out-of-place without running fit.
95
+
96
+ This method is useful when the model is already trained or does not
97
+ require further training (e.g., large language models) and you only
98
+ want to measure performance.
99
+
100
+ Args:
101
+ trainer: The base trainer to use but not modify.
102
+ model: The model module to use but not modify.
103
+ data: The data module containing validation data.
104
+ """
105
+ eva_trainer.run_evaluation_session(
106
+ base_trainer=trainer,
107
+ base_model=model,
108
+ datamodule=data,
109
+ stages=["validate"],
110
+ n_runs=trainer.n_runs,
111
+ verbose=trainer.n_runs > 1,
112
+ )
@@ -17,6 +17,7 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
17
17
  ignore_index: int | None = None,
18
18
  prefix: str | None = None,
19
19
  postfix: str | None = None,
20
+ input_type: Literal["logits", "discrete"] = "logits",
20
21
  ) -> None:
21
22
  """Initializes the multi-class classification metrics.
22
23
 
@@ -27,46 +28,55 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
27
28
  contribute to the metric calculation.
28
29
  prefix: A string to append in front of the keys of the output dict.
29
30
  postfix: A string to append after the keys of the output dict.
31
+ input_type: Type of input predictions - "logits" for probabilities/logits
32
+ or "discrete" for discrete class predictions. Determines which metrics
33
+ are applicable.
30
34
  """
31
- super().__init__(
32
- metrics=[
35
+ metrics = [
36
+ classification.MulticlassAccuracy(
37
+ num_classes=num_classes,
38
+ average=average,
39
+ ignore_index=ignore_index,
40
+ ),
41
+ classification.MulticlassF1Score(
42
+ num_classes=num_classes,
43
+ average=average,
44
+ ignore_index=ignore_index,
45
+ ),
46
+ classification.MulticlassPrecision(
47
+ num_classes=num_classes,
48
+ average=average,
49
+ ignore_index=ignore_index,
50
+ ),
51
+ classification.MulticlassRecall(
52
+ num_classes=num_classes,
53
+ average=average,
54
+ ignore_index=ignore_index,
55
+ ),
56
+ ]
57
+
58
+ compute_groups = [
59
+ [
60
+ "MulticlassAccuracy",
61
+ "MulticlassF1Score",
62
+ "MulticlassPrecision",
63
+ "MulticlassRecall",
64
+ ]
65
+ ]
66
+
67
+ if input_type == "logits":
68
+ metrics.append(
33
69
  classification.MulticlassAUROC(
34
70
  num_classes=num_classes,
35
71
  average=average,
36
72
  ignore_index=ignore_index,
37
- ),
38
- classification.MulticlassAccuracy(
39
- num_classes=num_classes,
40
- average=average,
41
- ignore_index=ignore_index,
42
- ),
43
- classification.MulticlassF1Score(
44
- num_classes=num_classes,
45
- average=average,
46
- ignore_index=ignore_index,
47
- ),
48
- classification.MulticlassPrecision(
49
- num_classes=num_classes,
50
- average=average,
51
- ignore_index=ignore_index,
52
- ),
53
- classification.MulticlassRecall(
54
- num_classes=num_classes,
55
- average=average,
56
- ignore_index=ignore_index,
57
- ),
58
- ],
73
+ )
74
+ )
75
+ compute_groups.append(["MulticlassAUROC"])
76
+
77
+ super().__init__(
78
+ metrics=metrics,
59
79
  prefix=prefix,
60
80
  postfix=postfix,
61
- compute_groups=[
62
- [
63
- "MulticlassAccuracy",
64
- "MulticlassF1Score",
65
- "MulticlassPrecision",
66
- "MulticlassRecall",
67
- ],
68
- [
69
- "MulticlassAUROC",
70
- ],
71
- ],
81
+ compute_groups=compute_groups,
72
82
  )
@@ -3,5 +3,6 @@
3
3
  from eva.core.models.modules.head import HeadModule
4
4
  from eva.core.models.modules.inference import InferenceModule
5
5
  from eva.core.models.modules.module import ModelModule
6
+ from eva.core.models.modules.scheduler import SchedulerConfiguration
6
7
 
7
- __all__ = ["HeadModule", "ModelModule", "InferenceModule"]
8
+ __all__ = ["HeadModule", "ModelModule", "InferenceModule", "SchedulerConfiguration"]
@@ -0,0 +1,51 @@
1
+ """Learning Rate scheduler configuration."""
2
+
3
+ import dataclasses
4
+ from typing import Any, Literal
5
+
6
+ from lightning.pytorch.cli import LRSchedulerCallable
7
+ from torch import optim
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class SchedulerConfiguration:
12
+ """Initializes and builds the learning rate scheduler configuration."""
13
+
14
+ scheduler: LRSchedulerCallable
15
+ """The learning rate scheduler instance."""
16
+
17
+ interval: Literal["step", "epoch"] = "epoch"
18
+ """The unit of the scheduler's step size.
19
+
20
+ It can be 'step' or 'epoch', to update the scheduler on step or epoch end respectively.
21
+ """
22
+
23
+ frequency: int = 1
24
+ """How many epochs/steps should pass between calls to `scheduler.step()`.
25
+
26
+ Value `1` corresponds to updating the learning rate after every epoch/step.
27
+ """
28
+
29
+ monitor: str = "val_loss"
30
+ """Metric to to monitor for schedulers like `ReduceLROnPlateau`."""
31
+
32
+ strict: bool = True
33
+ """Whether to enforce that the value specified 'monitor' must be available.
34
+
35
+ If the values is not available when the scheduler is updated it will stop the
36
+ training. With `False`, it will only produce a warning.
37
+ """
38
+
39
+ name: str | None = None
40
+ """Specifies a custom logged name for the `LearningRateMonitor` callback."""
41
+
42
+ def __call__(self, optimizer: optim.Optimizer) -> dict[str, Any]:
43
+ """Returns Lightning's lr_scheduler_config configuration."""
44
+ return {
45
+ "scheduler": self.scheduler(optimizer),
46
+ "interval": self.interval,
47
+ "frequency": self.frequency,
48
+ "monitor": self.monitor,
49
+ "strict": self.strict,
50
+ "name": self.name,
51
+ }
@@ -31,7 +31,7 @@ class ExtractCLSFeatures:
31
31
  tensor: The tensor representing the model output.
32
32
  """
33
33
  if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
34
- tensor = tensor.last_hidden_state
34
+ tensor = tensor.last_hidden_state # type: ignore
35
35
 
36
36
  cls_token = tensor[:, self._cls_index, :]
37
37
  if self._include_patch_tokens:
@@ -43,7 +43,7 @@ class ExtractPatchFeatures:
43
43
  """
44
44
  num_skip = int(self._has_cls_token) + self._num_register_tokens
45
45
  if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
46
- features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1)
46
+ features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1) # type: ignore
47
47
  else:
48
48
  features = tensor[:, num_skip:, :].permute(0, 2, 1)
49
49
 
@@ -1,40 +1,43 @@
1
1
  """Base class for model wrappers."""
2
2
 
3
3
  import abc
4
- from typing import Callable
4
+ from typing import Callable, Generic, TypeVar
5
5
 
6
- import torch
7
6
  import torch.nn as nn
8
7
  from typing_extensions import override
9
8
 
9
+ InputType = TypeVar("InputType")
10
+ """The input data type."""
11
+ OutputType = TypeVar("OutputType")
12
+ """The output data type."""
10
13
 
11
- class BaseModel(nn.Module):
14
+
15
+ class BaseModel(nn.Module, Generic[InputType, OutputType]):
12
16
  """Base class for model wrappers."""
13
17
 
14
- def __init__(self, tensor_transforms: Callable | None = None) -> None:
18
+ def __init__(self, transforms: Callable | None = None) -> None:
15
19
  """Initializes the model.
16
20
 
17
21
  Args:
18
- tensor_transforms: The transforms to apply to the output
19
- tensor produced by the model.
22
+ transforms: The transforms to apply to the output produced by the model.
20
23
  """
21
24
  super().__init__()
22
25
 
23
- self._output_transforms = tensor_transforms
26
+ self._output_transforms = transforms
24
27
 
25
- self._model: Callable[..., torch.Tensor] | nn.Module
28
+ self._model: Callable[..., OutputType] | nn.Module
26
29
 
27
30
  @override
28
- def forward(self, tensor: torch.Tensor) -> torch.Tensor:
29
- tensor = self.model_forward(tensor)
30
- return self._apply_transforms(tensor)
31
+ def forward(self, tensor: InputType) -> OutputType:
32
+ out = self.model_forward(tensor)
33
+ return self._apply_transforms(out)
31
34
 
32
35
  @abc.abstractmethod
33
- def load_model(self) -> Callable[..., torch.Tensor]:
36
+ def load_model(self) -> Callable[..., OutputType]:
34
37
  """Loads the model."""
35
38
  raise NotImplementedError
36
39
 
37
- def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
40
+ def model_forward(self, tensor: InputType) -> OutputType:
38
41
  """Implements the forward pass of the model.
39
42
 
40
43
  Args:
@@ -42,7 +45,7 @@ class BaseModel(nn.Module):
42
45
  """
43
46
  return self._model(tensor)
44
47
 
45
- def _apply_transforms(self, tensor: torch.Tensor) -> torch.Tensor:
48
+ def _apply_transforms(self, tensor: OutputType) -> OutputType:
46
49
  if self._output_transforms is not None:
47
50
  tensor = self._output_transforms(tensor)
48
51
  return tensor
@@ -3,13 +3,14 @@
3
3
  from typing import Any, Callable, Dict
4
4
 
5
5
  import jsonargparse
6
+ import torch
6
7
  from torch import nn
7
8
  from typing_extensions import override
8
9
 
9
10
  from eva.core.models.wrappers import _utils, base
10
11
 
11
12
 
12
- class ModelFromFunction(base.BaseModel):
13
+ class ModelFromFunction(base.BaseModel[torch.Tensor, torch.Tensor]):
13
14
  """Wrapper class for models which are initialized from functions.
14
15
 
15
16
  This is helpful for initializing models in a `.yaml` configuration file.
@@ -20,7 +21,7 @@ class ModelFromFunction(base.BaseModel):
20
21
  path: Callable[..., nn.Module],
21
22
  arguments: Dict[str, Any] | None = None,
22
23
  checkpoint_path: str | None = None,
23
- tensor_transforms: Callable | None = None,
24
+ transforms: Callable | None = None,
24
25
  ) -> None:
25
26
  """Initializes and constructs the model.
26
27
 
@@ -31,10 +32,10 @@ class ModelFromFunction(base.BaseModel):
31
32
  weights from. This is currently only supported for torch
32
33
  model checkpoints. For other formats, the checkpoint loading
33
34
  should be handled within the provided callable object in <path>.
34
- tensor_transforms: The transforms to apply to the output tensor
35
+ transforms: The transforms to apply to the output tensor
35
36
  produced by the model.
36
37
  """
37
- super().__init__(tensor_transforms=tensor_transforms)
38
+ super().__init__(transforms=transforms)
38
39
 
39
40
  self._path = path
40
41
  self._arguments = arguments
@@ -6,11 +6,10 @@ import torch
6
6
  import torch.nn as nn
7
7
  from typing_extensions import override
8
8
 
9
- from eva.core.models import wrappers
10
- from eva.core.models.wrappers import _utils
9
+ from eva.core.models.wrappers import _utils, base
11
10
 
12
11
 
13
- class TorchHubModel(wrappers.BaseModel):
12
+ class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
14
13
  """Model wrapper for `torch.hub` models."""
15
14
 
16
15
  def __init__(
@@ -23,7 +22,7 @@ class TorchHubModel(wrappers.BaseModel):
23
22
  norm: bool = False,
24
23
  trust_repo: bool = True,
25
24
  model_kwargs: Dict[str, Any] | None = None,
26
- tensor_transforms: Callable | None = None,
25
+ transforms: Callable | None = None,
27
26
  ) -> None:
28
27
  """Initializes the encoder.
29
28
 
@@ -39,10 +38,10 @@ class TorchHubModel(wrappers.BaseModel):
39
38
  trust_repo: If set to `False`, a prompt will ask the user whether the
40
39
  repo should be trusted.
41
40
  model_kwargs: Extra model arguments.
42
- tensor_transforms: The transforms to apply to the output tensor
41
+ transforms: The transforms to apply to the output tensor
43
42
  produced by the model.
44
43
  """
45
- super().__init__(tensor_transforms=tensor_transforms)
44
+ super().__init__(transforms=transforms)
46
45
 
47
46
  self._model_name = model_name
48
47
  self._repo_or_dir = repo_or_dir
@@ -2,19 +2,20 @@
2
2
 
3
3
  from typing import Any, Callable, Dict
4
4
 
5
+ import torch
5
6
  import transformers
6
7
  from typing_extensions import override
7
8
 
8
9
  from eva.core.models.wrappers import base
9
10
 
10
11
 
11
- class HuggingFaceModel(base.BaseModel):
12
+ class HuggingFaceModel(base.BaseModel[torch.Tensor, torch.Tensor]):
12
13
  """Wrapper class for loading HuggingFace `transformers` models."""
13
14
 
14
15
  def __init__(
15
16
  self,
16
17
  model_name_or_path: str,
17
- tensor_transforms: Callable | None = None,
18
+ transforms: Callable | None = None,
18
19
  model_kwargs: Dict[str, Any] | None = None,
19
20
  ) -> None:
20
21
  """Initializes the model.
@@ -23,11 +24,11 @@ class HuggingFaceModel(base.BaseModel):
23
24
  model_name_or_path: The model name or path to load the model from.
24
25
  This can be a local path or a model name from the `HuggingFace`
25
26
  model hub.
26
- tensor_transforms: The transforms to apply to the output tensor
27
+ transforms: The transforms to apply to the output tensor
27
28
  produced by the model.
28
29
  model_kwargs: The arguments used for instantiating the model.
29
30
  """
30
- super().__init__(tensor_transforms=tensor_transforms)
31
+ super().__init__(transforms=transforms)
31
32
 
32
33
  self._model_name_or_path = model_name_or_path
33
34
  self._model_kwargs = model_kwargs or {}
@@ -36,6 +37,8 @@ class HuggingFaceModel(base.BaseModel):
36
37
 
37
38
  @override
38
39
  def load_model(self) -> None:
40
+ # Use safetensors to avoid torch.load security vulnerability
41
+ model_kwargs = {"use_safetensors": True, **self._model_kwargs}
39
42
  self._model = transformers.AutoModel.from_pretrained(
40
- self._model_name_or_path, **self._model_kwargs
43
+ self._model_name_or_path, **model_kwargs
41
44
  )
@@ -9,23 +9,23 @@ from typing_extensions import override
9
9
  from eva.core.models.wrappers import base
10
10
 
11
11
 
12
- class ONNXModel(base.BaseModel):
12
+ class ONNXModel(base.BaseModel[torch.Tensor, torch.Tensor]):
13
13
  """Wrapper class for loading ONNX models."""
14
14
 
15
15
  def __init__(
16
16
  self,
17
17
  path: str,
18
18
  device: Literal["cpu", "cuda"] | None = "cpu",
19
- tensor_transforms: Callable | None = None,
19
+ transforms: Callable | None = None,
20
20
  ):
21
21
  """Initializes the model.
22
22
 
23
23
  Args:
24
24
  path: The path to the .onnx model file.
25
25
  device: The device to run the model on. This can be either "cpu" or "cuda".
26
- tensor_transforms: The transforms to apply to the output tensor produced by the model.
26
+ transforms: The transforms to apply to the output tensor produced by the model.
27
27
  """
28
- super().__init__(tensor_transforms=tensor_transforms)
28
+ super().__init__(transforms=transforms)
29
29
 
30
30
  self._path = path
31
31
  self._device = device
@@ -1,6 +1,6 @@
1
1
  """Fit session related functions."""
2
2
 
3
- from typing import Tuple
3
+ from typing import List, Literal, Tuple
4
4
 
5
5
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
6
6
 
@@ -16,11 +16,12 @@ def run_evaluation_session(
16
16
  datamodule: datamodules.DataModule,
17
17
  *,
18
18
  n_runs: int = 1,
19
+ stages: List[Literal["fit", "validate", "test"]] | None = None,
19
20
  verbose: bool = True,
20
21
  ) -> None:
21
22
  """Runs a downstream evaluation session out-of-place.
22
23
 
23
- It performs an evaluation run (fit and evaluate) on the model
24
+ It performs an evaluation run (with configurable stages) on the model
24
25
  multiple times. Note that as the input `base_trainer` and
25
26
  `base_model` would be cloned, the input object would not
26
27
  be modified.
@@ -29,10 +30,13 @@ def run_evaluation_session(
29
30
  base_trainer: The base trainer module to use.
30
31
  base_model: The base model module to use.
31
32
  datamodule: The data module.
32
- n_runs: The amount of runs (fit and evaluate) to perform.
33
+ n_runs: The number of runs to perform.
34
+ stages: List of stages to execute. Options: "fit", "validate", "test".
33
35
  verbose: Whether to verbose the session metrics instead of
34
- these of each individual runs and vice-versa.
36
+ those of each individual run and vice-versa.
35
37
  """
38
+ if not stages:
39
+ stages = ["fit", "validate", "test"]
36
40
  recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose)
37
41
  for run_index in range(n_runs):
38
42
  validation_scores, test_scores = run_evaluation(
@@ -40,9 +44,11 @@ def run_evaluation_session(
40
44
  base_model,
41
45
  datamodule,
42
46
  run_id=run_index,
47
+ stages=stages,
43
48
  verbose=not verbose,
44
49
  )
45
- recorder.update(validation_scores, test_scores)
50
+ if validation_scores:
51
+ recorder.update(validation_scores, test_scores)
46
52
  recorder.save()
47
53
 
48
54
 
@@ -52,61 +58,52 @@ def run_evaluation(
52
58
  datamodule: datamodules.DataModule,
53
59
  *,
54
60
  run_id: int | None = None,
61
+ stages: List[Literal["fit", "validate", "test"]] | None = None,
55
62
  verbose: bool = True,
56
- ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
57
- """Fits and evaluates a model out-of-place.
63
+ ) -> Tuple[_EVALUATE_OUTPUT | None, _EVALUATE_OUTPUT | None]:
64
+ """Runs the specified evaluation stages out-of-place.
58
65
 
59
66
  Args:
60
67
  base_trainer: The base trainer to use but not modify.
61
68
  base_model: The model module to use but not modify.
62
69
  datamodule: The data module.
63
70
  run_id: The run id to be appended to the output log directory.
71
+ If `None`, it will use the log directory of the trainer as is.
72
+ stages: List of stages to execute. Options: "fit", "validate", "test".
64
73
  verbose: Whether to print the validation and test metrics
65
74
  in the end of the training.
66
75
 
67
76
  Returns:
68
- A tuple of with the validation and the test metrics (if exists).
77
+ A tuple with the validation and the test metrics (if executed).
78
+ If a stage is not executed, its value will be None.
69
79
  """
80
+ if not stages:
81
+ stages = ["fit", "validate", "test"]
70
82
  trainer, model = _utils.clone(base_trainer, base_model)
71
83
  model.configure_model()
72
84
 
73
85
  trainer.init_logger_run(run_id)
74
- results = fit_and_validate(trainer, model, datamodule, verbose=verbose)
75
- trainer.finish_logger_run(run_id)
76
-
77
- return results
78
-
79
-
80
- def fit_and_validate(
81
- trainer: eva_trainer.Trainer,
82
- model: modules.ModelModule,
83
- datamodule: datamodules.DataModule,
84
- verbose: bool = True,
85
- ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
86
- """Fits and evaluates a model in-place.
87
-
88
- If the test set is set in the datamodule, it will evaluate the model
89
- on the test set as well.
90
86
 
91
- Args:
92
- trainer: The trainer module to use and update in-place.
93
- model: The model module to use and update in-place.
94
- datamodule: The data module.
95
- verbose: Whether to print the validation and test metrics
96
- in the end of the training.
97
-
98
- Returns:
99
- A tuple of with the validation and the test metrics (if exists).
100
- """
101
- trainer.fit(model, datamodule=datamodule)
102
- validation_scores = trainer.validate(
103
- datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type
104
- )
105
- test_scores = (
106
- None
107
- if datamodule.datasets.test is None
108
- else trainer.test(datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type)
109
- )
87
+ validation_scores = None
88
+ test_scores = None
89
+
90
+ if "fit" in stages:
91
+ trainer.fit(model, datamodule=datamodule)
92
+ if "validate" in stages:
93
+ validation_scores = trainer.validate(
94
+ model=model,
95
+ datamodule=datamodule,
96
+ verbose=verbose,
97
+ ckpt_path=trainer.checkpoint_type,
98
+ )
99
+ if "test" in stages and getattr(datamodule.datasets, "test", None) is not None:
100
+ test_scores = trainer.test(
101
+ model=model,
102
+ datamodule=datamodule,
103
+ verbose=verbose,
104
+ ckpt_path=trainer.checkpoint_type,
105
+ )
106
+ trainer.finish_logger_run(run_id)
110
107
  return validation_scores, test_scores
111
108
 
112
109