kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/dataloaders/__init__.py +2 -1
- eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/core/data/dataloaders/collate_fn/collate.py +24 -0
- eva/core/data/dataloaders/dataloader.py +4 -0
- eva/core/interface/interface.py +34 -1
- eva/core/metrics/defaults/classification/multiclass.py +45 -35
- eva/core/models/modules/__init__.py +2 -1
- eva/core/models/modules/scheduler.py +51 -0
- eva/core/models/transforms/extract_cls_features.py +1 -1
- eva/core/models/transforms/extract_patch_features.py +1 -1
- eva/core/models/wrappers/base.py +17 -14
- eva/core/models/wrappers/from_function.py +5 -4
- eva/core/models/wrappers/from_torchhub.py +5 -6
- eva/core/models/wrappers/huggingface.py +8 -5
- eva/core/models/wrappers/onnx.py +4 -4
- eva/core/trainers/functional.py +40 -43
- eva/core/utils/factory.py +66 -0
- eva/core/utils/registry.py +42 -0
- eva/core/utils/requirements.py +26 -0
- eva/language/__init__.py +13 -0
- eva/language/data/__init__.py +5 -0
- eva/language/data/datasets/__init__.py +9 -0
- eva/language/data/datasets/classification/__init__.py +7 -0
- eva/language/data/datasets/classification/base.py +63 -0
- eva/language/data/datasets/classification/pubmedqa.py +149 -0
- eva/language/data/datasets/language.py +13 -0
- eva/language/models/__init__.py +25 -0
- eva/language/models/modules/__init__.py +5 -0
- eva/language/models/modules/text.py +85 -0
- eva/language/models/modules/typings.py +16 -0
- eva/language/models/wrappers/__init__.py +11 -0
- eva/language/models/wrappers/huggingface.py +69 -0
- eva/language/models/wrappers/litellm.py +77 -0
- eva/language/models/wrappers/vllm.py +149 -0
- eva/language/utils/__init__.py +5 -0
- eva/language/utils/str_to_int_tensor.py +95 -0
- eva/vision/data/dataloaders/__init__.py +2 -1
- eva/vision/data/dataloaders/worker_init.py +35 -0
- eva/vision/data/datasets/__init__.py +5 -5
- eva/vision/data/datasets/segmentation/__init__.py +4 -4
- eva/vision/data/datasets/segmentation/btcv.py +3 -0
- eva/vision/data/datasets/segmentation/consep.py +5 -4
- eva/vision/data/datasets/segmentation/lits17.py +231 -0
- eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
- eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
- eva/vision/data/transforms/__init__.py +11 -2
- eva/vision/data/transforms/base/__init__.py +5 -0
- eva/vision/data/transforms/base/monai.py +27 -0
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/squeeze.py +24 -0
- eva/vision/data/transforms/croppad/__init__.py +4 -0
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
- eva/vision/models/modules/semantic_segmentation.py +18 -7
- eva/vision/models/networks/backbones/__init__.py +2 -3
- eva/vision/models/networks/backbones/_utils.py +1 -1
- eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
- eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
- eva/vision/models/networks/backbones/pathology/histai.py +3 -3
- eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
- eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
- eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
- eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
- eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
- eva/vision/models/networks/backbones/pathology/paige.py +3 -3
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
- eva/vision/models/networks/backbones/radiology/voco.py +5 -5
- eva/vision/models/networks/backbones/registry.py +2 -44
- eva/vision/models/networks/backbones/timm/backbones.py +2 -2
- eva/vision/models/networks/backbones/universal/__init__.py +8 -1
- eva/vision/models/networks/backbones/universal/vit.py +53 -3
- eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
- eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
- eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
- eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
- eva/vision/models/wrappers/from_registry.py +14 -9
- eva/vision/models/wrappers/from_timm.py +6 -5
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/segmentation/lits.py +0 -199
- eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
- /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Collate functions for text data."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def text_collate_fn(
|
|
9
|
+
batch: List[Tuple[str, torch.Tensor, Dict]],
|
|
10
|
+
) -> Tuple[List[str], torch.Tensor, List[Dict]]:
|
|
11
|
+
"""Collate function for text data that keeps texts as separate strings.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
batch: List of tuples containing (text, target, metadata) from the dataset
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Tuple containing:
|
|
18
|
+
- List of text strings
|
|
19
|
+
- Batched tensor of targets
|
|
20
|
+
- List of metadata dictionaries
|
|
21
|
+
"""
|
|
22
|
+
texts, targets, metadata = zip(*batch, strict=False)
|
|
23
|
+
targets = torch.stack(targets)
|
|
24
|
+
return list(texts), targets, list(metadata)
|
|
@@ -56,6 +56,9 @@ class DataLoader:
|
|
|
56
56
|
persistent_workers: bool = True
|
|
57
57
|
"""Will keep the worker processes after a dataset has been consumed once."""
|
|
58
58
|
|
|
59
|
+
worker_init_fn: Callable | None = None
|
|
60
|
+
"""Function to call on each worker process before data loading."""
|
|
61
|
+
|
|
59
62
|
prefetch_factor: int | None = 2
|
|
60
63
|
"""Number of batches loaded in advance by each worker."""
|
|
61
64
|
|
|
@@ -80,4 +83,5 @@ class DataLoader:
|
|
|
80
83
|
drop_last=self.drop_last,
|
|
81
84
|
persistent_workers=self.persistent_workers,
|
|
82
85
|
prefetch_factor=self.prefetch_factor,
|
|
86
|
+
worker_init_fn=self.worker_init_fn,
|
|
83
87
|
)
|
eva/core/interface/interface.py
CHANGED
|
@@ -34,7 +34,14 @@ class Interface:
|
|
|
34
34
|
model: The model module to use but not modify.
|
|
35
35
|
data: The data module.
|
|
36
36
|
"""
|
|
37
|
-
|
|
37
|
+
eva_trainer.run_evaluation_session(
|
|
38
|
+
base_trainer=trainer,
|
|
39
|
+
base_model=model,
|
|
40
|
+
datamodule=data,
|
|
41
|
+
stages=["fit", "validate", "test"],
|
|
42
|
+
n_runs=trainer.n_runs,
|
|
43
|
+
verbose=trainer.n_runs > 1,
|
|
44
|
+
)
|
|
38
45
|
|
|
39
46
|
def predict(
|
|
40
47
|
self,
|
|
@@ -77,3 +84,29 @@ class Interface:
|
|
|
77
84
|
"""
|
|
78
85
|
self.predict(trainer=trainer, model=model, data=data)
|
|
79
86
|
self.fit(trainer=trainer, model=model, data=data)
|
|
87
|
+
|
|
88
|
+
def validate(
|
|
89
|
+
self,
|
|
90
|
+
trainer: eva_trainer.Trainer,
|
|
91
|
+
model: modules.ModelModule,
|
|
92
|
+
data: datamodules.DataModule,
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Perform model validation out-of-place without running fit.
|
|
95
|
+
|
|
96
|
+
This method is useful when the model is already trained or does not
|
|
97
|
+
require further training (e.g., large language models) and you only
|
|
98
|
+
want to measure performance.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
trainer: The base trainer to use but not modify.
|
|
102
|
+
model: The model module to use but not modify.
|
|
103
|
+
data: The data module containing validation data.
|
|
104
|
+
"""
|
|
105
|
+
eva_trainer.run_evaluation_session(
|
|
106
|
+
base_trainer=trainer,
|
|
107
|
+
base_model=model,
|
|
108
|
+
datamodule=data,
|
|
109
|
+
stages=["validate"],
|
|
110
|
+
n_runs=trainer.n_runs,
|
|
111
|
+
verbose=trainer.n_runs > 1,
|
|
112
|
+
)
|
|
@@ -17,6 +17,7 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
|
|
|
17
17
|
ignore_index: int | None = None,
|
|
18
18
|
prefix: str | None = None,
|
|
19
19
|
postfix: str | None = None,
|
|
20
|
+
input_type: Literal["logits", "discrete"] = "logits",
|
|
20
21
|
) -> None:
|
|
21
22
|
"""Initializes the multi-class classification metrics.
|
|
22
23
|
|
|
@@ -27,46 +28,55 @@ class MulticlassClassificationMetrics(structs.MetricCollection):
|
|
|
27
28
|
contribute to the metric calculation.
|
|
28
29
|
prefix: A string to append in front of the keys of the output dict.
|
|
29
30
|
postfix: A string to append after the keys of the output dict.
|
|
31
|
+
input_type: Type of input predictions - "logits" for probabilities/logits
|
|
32
|
+
or "discrete" for discrete class predictions. Determines which metrics
|
|
33
|
+
are applicable.
|
|
30
34
|
"""
|
|
31
|
-
|
|
32
|
-
|
|
35
|
+
metrics = [
|
|
36
|
+
classification.MulticlassAccuracy(
|
|
37
|
+
num_classes=num_classes,
|
|
38
|
+
average=average,
|
|
39
|
+
ignore_index=ignore_index,
|
|
40
|
+
),
|
|
41
|
+
classification.MulticlassF1Score(
|
|
42
|
+
num_classes=num_classes,
|
|
43
|
+
average=average,
|
|
44
|
+
ignore_index=ignore_index,
|
|
45
|
+
),
|
|
46
|
+
classification.MulticlassPrecision(
|
|
47
|
+
num_classes=num_classes,
|
|
48
|
+
average=average,
|
|
49
|
+
ignore_index=ignore_index,
|
|
50
|
+
),
|
|
51
|
+
classification.MulticlassRecall(
|
|
52
|
+
num_classes=num_classes,
|
|
53
|
+
average=average,
|
|
54
|
+
ignore_index=ignore_index,
|
|
55
|
+
),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
compute_groups = [
|
|
59
|
+
[
|
|
60
|
+
"MulticlassAccuracy",
|
|
61
|
+
"MulticlassF1Score",
|
|
62
|
+
"MulticlassPrecision",
|
|
63
|
+
"MulticlassRecall",
|
|
64
|
+
]
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
if input_type == "logits":
|
|
68
|
+
metrics.append(
|
|
33
69
|
classification.MulticlassAUROC(
|
|
34
70
|
num_classes=num_classes,
|
|
35
71
|
average=average,
|
|
36
72
|
ignore_index=ignore_index,
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
classification.MulticlassF1Score(
|
|
44
|
-
num_classes=num_classes,
|
|
45
|
-
average=average,
|
|
46
|
-
ignore_index=ignore_index,
|
|
47
|
-
),
|
|
48
|
-
classification.MulticlassPrecision(
|
|
49
|
-
num_classes=num_classes,
|
|
50
|
-
average=average,
|
|
51
|
-
ignore_index=ignore_index,
|
|
52
|
-
),
|
|
53
|
-
classification.MulticlassRecall(
|
|
54
|
-
num_classes=num_classes,
|
|
55
|
-
average=average,
|
|
56
|
-
ignore_index=ignore_index,
|
|
57
|
-
),
|
|
58
|
-
],
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
compute_groups.append(["MulticlassAUROC"])
|
|
76
|
+
|
|
77
|
+
super().__init__(
|
|
78
|
+
metrics=metrics,
|
|
59
79
|
prefix=prefix,
|
|
60
80
|
postfix=postfix,
|
|
61
|
-
compute_groups=
|
|
62
|
-
[
|
|
63
|
-
"MulticlassAccuracy",
|
|
64
|
-
"MulticlassF1Score",
|
|
65
|
-
"MulticlassPrecision",
|
|
66
|
-
"MulticlassRecall",
|
|
67
|
-
],
|
|
68
|
-
[
|
|
69
|
-
"MulticlassAUROC",
|
|
70
|
-
],
|
|
71
|
-
],
|
|
81
|
+
compute_groups=compute_groups,
|
|
72
82
|
)
|
|
@@ -3,5 +3,6 @@
|
|
|
3
3
|
from eva.core.models.modules.head import HeadModule
|
|
4
4
|
from eva.core.models.modules.inference import InferenceModule
|
|
5
5
|
from eva.core.models.modules.module import ModelModule
|
|
6
|
+
from eva.core.models.modules.scheduler import SchedulerConfiguration
|
|
6
7
|
|
|
7
|
-
__all__ = ["HeadModule", "ModelModule", "InferenceModule"]
|
|
8
|
+
__all__ = ["HeadModule", "ModelModule", "InferenceModule", "SchedulerConfiguration"]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Learning Rate scheduler configuration."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch.cli import LRSchedulerCallable
|
|
7
|
+
from torch import optim
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclasses.dataclass
|
|
11
|
+
class SchedulerConfiguration:
|
|
12
|
+
"""Initializes and builds the learning rate scheduler configuration."""
|
|
13
|
+
|
|
14
|
+
scheduler: LRSchedulerCallable
|
|
15
|
+
"""The learning rate scheduler instance."""
|
|
16
|
+
|
|
17
|
+
interval: Literal["step", "epoch"] = "epoch"
|
|
18
|
+
"""The unit of the scheduler's step size.
|
|
19
|
+
|
|
20
|
+
It can be 'step' or 'epoch', to update the scheduler on step or epoch end respectively.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
frequency: int = 1
|
|
24
|
+
"""How many epochs/steps should pass between calls to `scheduler.step()`.
|
|
25
|
+
|
|
26
|
+
Value `1` corresponds to updating the learning rate after every epoch/step.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
monitor: str = "val_loss"
|
|
30
|
+
"""Metric to to monitor for schedulers like `ReduceLROnPlateau`."""
|
|
31
|
+
|
|
32
|
+
strict: bool = True
|
|
33
|
+
"""Whether to enforce that the value specified 'monitor' must be available.
|
|
34
|
+
|
|
35
|
+
If the values is not available when the scheduler is updated it will stop the
|
|
36
|
+
training. With `False`, it will only produce a warning.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
name: str | None = None
|
|
40
|
+
"""Specifies a custom logged name for the `LearningRateMonitor` callback."""
|
|
41
|
+
|
|
42
|
+
def __call__(self, optimizer: optim.Optimizer) -> dict[str, Any]:
|
|
43
|
+
"""Returns Lightning's lr_scheduler_config configuration."""
|
|
44
|
+
return {
|
|
45
|
+
"scheduler": self.scheduler(optimizer),
|
|
46
|
+
"interval": self.interval,
|
|
47
|
+
"frequency": self.frequency,
|
|
48
|
+
"monitor": self.monitor,
|
|
49
|
+
"strict": self.strict,
|
|
50
|
+
"name": self.name,
|
|
51
|
+
}
|
|
@@ -31,7 +31,7 @@ class ExtractCLSFeatures:
|
|
|
31
31
|
tensor: The tensor representing the model output.
|
|
32
32
|
"""
|
|
33
33
|
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
34
|
-
tensor = tensor.last_hidden_state
|
|
34
|
+
tensor = tensor.last_hidden_state # type: ignore
|
|
35
35
|
|
|
36
36
|
cls_token = tensor[:, self._cls_index, :]
|
|
37
37
|
if self._include_patch_tokens:
|
|
@@ -43,7 +43,7 @@ class ExtractPatchFeatures:
|
|
|
43
43
|
"""
|
|
44
44
|
num_skip = int(self._has_cls_token) + self._num_register_tokens
|
|
45
45
|
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
|
|
46
|
-
features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1)
|
|
46
|
+
features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1) # type: ignore
|
|
47
47
|
else:
|
|
48
48
|
features = tensor[:, num_skip:, :].permute(0, 2, 1)
|
|
49
49
|
|
eva/core/models/wrappers/base.py
CHANGED
|
@@ -1,40 +1,43 @@
|
|
|
1
1
|
"""Base class for model wrappers."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import Callable
|
|
4
|
+
from typing import Callable, Generic, TypeVar
|
|
5
5
|
|
|
6
|
-
import torch
|
|
7
6
|
import torch.nn as nn
|
|
8
7
|
from typing_extensions import override
|
|
9
8
|
|
|
9
|
+
InputType = TypeVar("InputType")
|
|
10
|
+
"""The input data type."""
|
|
11
|
+
OutputType = TypeVar("OutputType")
|
|
12
|
+
"""The output data type."""
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
|
|
15
|
+
class BaseModel(nn.Module, Generic[InputType, OutputType]):
|
|
12
16
|
"""Base class for model wrappers."""
|
|
13
17
|
|
|
14
|
-
def __init__(self,
|
|
18
|
+
def __init__(self, transforms: Callable | None = None) -> None:
|
|
15
19
|
"""Initializes the model.
|
|
16
20
|
|
|
17
21
|
Args:
|
|
18
|
-
|
|
19
|
-
tensor produced by the model.
|
|
22
|
+
transforms: The transforms to apply to the output produced by the model.
|
|
20
23
|
"""
|
|
21
24
|
super().__init__()
|
|
22
25
|
|
|
23
|
-
self._output_transforms =
|
|
26
|
+
self._output_transforms = transforms
|
|
24
27
|
|
|
25
|
-
self._model: Callable[...,
|
|
28
|
+
self._model: Callable[..., OutputType] | nn.Module
|
|
26
29
|
|
|
27
30
|
@override
|
|
28
|
-
def forward(self, tensor:
|
|
29
|
-
|
|
30
|
-
return self._apply_transforms(
|
|
31
|
+
def forward(self, tensor: InputType) -> OutputType:
|
|
32
|
+
out = self.model_forward(tensor)
|
|
33
|
+
return self._apply_transforms(out)
|
|
31
34
|
|
|
32
35
|
@abc.abstractmethod
|
|
33
|
-
def load_model(self) -> Callable[...,
|
|
36
|
+
def load_model(self) -> Callable[..., OutputType]:
|
|
34
37
|
"""Loads the model."""
|
|
35
38
|
raise NotImplementedError
|
|
36
39
|
|
|
37
|
-
def model_forward(self, tensor:
|
|
40
|
+
def model_forward(self, tensor: InputType) -> OutputType:
|
|
38
41
|
"""Implements the forward pass of the model.
|
|
39
42
|
|
|
40
43
|
Args:
|
|
@@ -42,7 +45,7 @@ class BaseModel(nn.Module):
|
|
|
42
45
|
"""
|
|
43
46
|
return self._model(tensor)
|
|
44
47
|
|
|
45
|
-
def _apply_transforms(self, tensor:
|
|
48
|
+
def _apply_transforms(self, tensor: OutputType) -> OutputType:
|
|
46
49
|
if self._output_transforms is not None:
|
|
47
50
|
tensor = self._output_transforms(tensor)
|
|
48
51
|
return tensor
|
|
@@ -3,13 +3,14 @@
|
|
|
3
3
|
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
5
|
import jsonargparse
|
|
6
|
+
import torch
|
|
6
7
|
from torch import nn
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.core.models.wrappers import _utils, base
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
class ModelFromFunction(base.BaseModel):
|
|
13
|
+
class ModelFromFunction(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
13
14
|
"""Wrapper class for models which are initialized from functions.
|
|
14
15
|
|
|
15
16
|
This is helpful for initializing models in a `.yaml` configuration file.
|
|
@@ -20,7 +21,7 @@ class ModelFromFunction(base.BaseModel):
|
|
|
20
21
|
path: Callable[..., nn.Module],
|
|
21
22
|
arguments: Dict[str, Any] | None = None,
|
|
22
23
|
checkpoint_path: str | None = None,
|
|
23
|
-
|
|
24
|
+
transforms: Callable | None = None,
|
|
24
25
|
) -> None:
|
|
25
26
|
"""Initializes and constructs the model.
|
|
26
27
|
|
|
@@ -31,10 +32,10 @@ class ModelFromFunction(base.BaseModel):
|
|
|
31
32
|
weights from. This is currently only supported for torch
|
|
32
33
|
model checkpoints. For other formats, the checkpoint loading
|
|
33
34
|
should be handled within the provided callable object in <path>.
|
|
34
|
-
|
|
35
|
+
transforms: The transforms to apply to the output tensor
|
|
35
36
|
produced by the model.
|
|
36
37
|
"""
|
|
37
|
-
super().__init__(
|
|
38
|
+
super().__init__(transforms=transforms)
|
|
38
39
|
|
|
39
40
|
self._path = path
|
|
40
41
|
self._arguments = arguments
|
|
@@ -6,11 +6,10 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
|
-
from eva.core.models import
|
|
10
|
-
from eva.core.models.wrappers import _utils
|
|
9
|
+
from eva.core.models.wrappers import _utils, base
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
class TorchHubModel(
|
|
12
|
+
class TorchHubModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
14
13
|
"""Model wrapper for `torch.hub` models."""
|
|
15
14
|
|
|
16
15
|
def __init__(
|
|
@@ -23,7 +22,7 @@ class TorchHubModel(wrappers.BaseModel):
|
|
|
23
22
|
norm: bool = False,
|
|
24
23
|
trust_repo: bool = True,
|
|
25
24
|
model_kwargs: Dict[str, Any] | None = None,
|
|
26
|
-
|
|
25
|
+
transforms: Callable | None = None,
|
|
27
26
|
) -> None:
|
|
28
27
|
"""Initializes the encoder.
|
|
29
28
|
|
|
@@ -39,10 +38,10 @@ class TorchHubModel(wrappers.BaseModel):
|
|
|
39
38
|
trust_repo: If set to `False`, a prompt will ask the user whether the
|
|
40
39
|
repo should be trusted.
|
|
41
40
|
model_kwargs: Extra model arguments.
|
|
42
|
-
|
|
41
|
+
transforms: The transforms to apply to the output tensor
|
|
43
42
|
produced by the model.
|
|
44
43
|
"""
|
|
45
|
-
super().__init__(
|
|
44
|
+
super().__init__(transforms=transforms)
|
|
46
45
|
|
|
47
46
|
self._model_name = model_name
|
|
48
47
|
self._repo_or_dir = repo_or_dir
|
|
@@ -2,19 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Callable, Dict
|
|
4
4
|
|
|
5
|
+
import torch
|
|
5
6
|
import transformers
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from eva.core.models.wrappers import base
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
class HuggingFaceModel(base.BaseModel):
|
|
12
|
+
class HuggingFaceModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
12
13
|
"""Wrapper class for loading HuggingFace `transformers` models."""
|
|
13
14
|
|
|
14
15
|
def __init__(
|
|
15
16
|
self,
|
|
16
17
|
model_name_or_path: str,
|
|
17
|
-
|
|
18
|
+
transforms: Callable | None = None,
|
|
18
19
|
model_kwargs: Dict[str, Any] | None = None,
|
|
19
20
|
) -> None:
|
|
20
21
|
"""Initializes the model.
|
|
@@ -23,11 +24,11 @@ class HuggingFaceModel(base.BaseModel):
|
|
|
23
24
|
model_name_or_path: The model name or path to load the model from.
|
|
24
25
|
This can be a local path or a model name from the `HuggingFace`
|
|
25
26
|
model hub.
|
|
26
|
-
|
|
27
|
+
transforms: The transforms to apply to the output tensor
|
|
27
28
|
produced by the model.
|
|
28
29
|
model_kwargs: The arguments used for instantiating the model.
|
|
29
30
|
"""
|
|
30
|
-
super().__init__(
|
|
31
|
+
super().__init__(transforms=transforms)
|
|
31
32
|
|
|
32
33
|
self._model_name_or_path = model_name_or_path
|
|
33
34
|
self._model_kwargs = model_kwargs or {}
|
|
@@ -36,6 +37,8 @@ class HuggingFaceModel(base.BaseModel):
|
|
|
36
37
|
|
|
37
38
|
@override
|
|
38
39
|
def load_model(self) -> None:
|
|
40
|
+
# Use safetensors to avoid torch.load security vulnerability
|
|
41
|
+
model_kwargs = {"use_safetensors": True, **self._model_kwargs}
|
|
39
42
|
self._model = transformers.AutoModel.from_pretrained(
|
|
40
|
-
self._model_name_or_path, **
|
|
43
|
+
self._model_name_or_path, **model_kwargs
|
|
41
44
|
)
|
eva/core/models/wrappers/onnx.py
CHANGED
|
@@ -9,23 +9,23 @@ from typing_extensions import override
|
|
|
9
9
|
from eva.core.models.wrappers import base
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class ONNXModel(base.BaseModel):
|
|
12
|
+
class ONNXModel(base.BaseModel[torch.Tensor, torch.Tensor]):
|
|
13
13
|
"""Wrapper class for loading ONNX models."""
|
|
14
14
|
|
|
15
15
|
def __init__(
|
|
16
16
|
self,
|
|
17
17
|
path: str,
|
|
18
18
|
device: Literal["cpu", "cuda"] | None = "cpu",
|
|
19
|
-
|
|
19
|
+
transforms: Callable | None = None,
|
|
20
20
|
):
|
|
21
21
|
"""Initializes the model.
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
24
|
path: The path to the .onnx model file.
|
|
25
25
|
device: The device to run the model on. This can be either "cpu" or "cuda".
|
|
26
|
-
|
|
26
|
+
transforms: The transforms to apply to the output tensor produced by the model.
|
|
27
27
|
"""
|
|
28
|
-
super().__init__(
|
|
28
|
+
super().__init__(transforms=transforms)
|
|
29
29
|
|
|
30
30
|
self._path = path
|
|
31
31
|
self._device = device
|
eva/core/trainers/functional.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Fit session related functions."""
|
|
2
2
|
|
|
3
|
-
from typing import Tuple
|
|
3
|
+
from typing import List, Literal, Tuple
|
|
4
4
|
|
|
5
5
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
|
|
6
6
|
|
|
@@ -16,11 +16,12 @@ def run_evaluation_session(
|
|
|
16
16
|
datamodule: datamodules.DataModule,
|
|
17
17
|
*,
|
|
18
18
|
n_runs: int = 1,
|
|
19
|
+
stages: List[Literal["fit", "validate", "test"]] | None = None,
|
|
19
20
|
verbose: bool = True,
|
|
20
21
|
) -> None:
|
|
21
22
|
"""Runs a downstream evaluation session out-of-place.
|
|
22
23
|
|
|
23
|
-
It performs an evaluation run (
|
|
24
|
+
It performs an evaluation run (with configurable stages) on the model
|
|
24
25
|
multiple times. Note that as the input `base_trainer` and
|
|
25
26
|
`base_model` would be cloned, the input object would not
|
|
26
27
|
be modified.
|
|
@@ -29,10 +30,13 @@ def run_evaluation_session(
|
|
|
29
30
|
base_trainer: The base trainer module to use.
|
|
30
31
|
base_model: The base model module to use.
|
|
31
32
|
datamodule: The data module.
|
|
32
|
-
n_runs: The
|
|
33
|
+
n_runs: The number of runs to perform.
|
|
34
|
+
stages: List of stages to execute. Options: "fit", "validate", "test".
|
|
33
35
|
verbose: Whether to verbose the session metrics instead of
|
|
34
|
-
|
|
36
|
+
those of each individual run and vice-versa.
|
|
35
37
|
"""
|
|
38
|
+
if not stages:
|
|
39
|
+
stages = ["fit", "validate", "test"]
|
|
36
40
|
recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose)
|
|
37
41
|
for run_index in range(n_runs):
|
|
38
42
|
validation_scores, test_scores = run_evaluation(
|
|
@@ -40,9 +44,11 @@ def run_evaluation_session(
|
|
|
40
44
|
base_model,
|
|
41
45
|
datamodule,
|
|
42
46
|
run_id=run_index,
|
|
47
|
+
stages=stages,
|
|
43
48
|
verbose=not verbose,
|
|
44
49
|
)
|
|
45
|
-
|
|
50
|
+
if validation_scores:
|
|
51
|
+
recorder.update(validation_scores, test_scores)
|
|
46
52
|
recorder.save()
|
|
47
53
|
|
|
48
54
|
|
|
@@ -52,61 +58,52 @@ def run_evaluation(
|
|
|
52
58
|
datamodule: datamodules.DataModule,
|
|
53
59
|
*,
|
|
54
60
|
run_id: int | None = None,
|
|
61
|
+
stages: List[Literal["fit", "validate", "test"]] | None = None,
|
|
55
62
|
verbose: bool = True,
|
|
56
|
-
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
57
|
-
"""
|
|
63
|
+
) -> Tuple[_EVALUATE_OUTPUT | None, _EVALUATE_OUTPUT | None]:
|
|
64
|
+
"""Runs the specified evaluation stages out-of-place.
|
|
58
65
|
|
|
59
66
|
Args:
|
|
60
67
|
base_trainer: The base trainer to use but not modify.
|
|
61
68
|
base_model: The model module to use but not modify.
|
|
62
69
|
datamodule: The data module.
|
|
63
70
|
run_id: The run id to be appended to the output log directory.
|
|
71
|
+
If `None`, it will use the log directory of the trainer as is.
|
|
72
|
+
stages: List of stages to execute. Options: "fit", "validate", "test".
|
|
64
73
|
verbose: Whether to print the validation and test metrics
|
|
65
74
|
in the end of the training.
|
|
66
75
|
|
|
67
76
|
Returns:
|
|
68
|
-
A tuple
|
|
77
|
+
A tuple with the validation and the test metrics (if executed).
|
|
78
|
+
If a stage is not executed, its value will be None.
|
|
69
79
|
"""
|
|
80
|
+
if not stages:
|
|
81
|
+
stages = ["fit", "validate", "test"]
|
|
70
82
|
trainer, model = _utils.clone(base_trainer, base_model)
|
|
71
83
|
model.configure_model()
|
|
72
84
|
|
|
73
85
|
trainer.init_logger_run(run_id)
|
|
74
|
-
results = fit_and_validate(trainer, model, datamodule, verbose=verbose)
|
|
75
|
-
trainer.finish_logger_run(run_id)
|
|
76
|
-
|
|
77
|
-
return results
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def fit_and_validate(
|
|
81
|
-
trainer: eva_trainer.Trainer,
|
|
82
|
-
model: modules.ModelModule,
|
|
83
|
-
datamodule: datamodules.DataModule,
|
|
84
|
-
verbose: bool = True,
|
|
85
|
-
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
|
|
86
|
-
"""Fits and evaluates a model in-place.
|
|
87
|
-
|
|
88
|
-
If the test set is set in the datamodule, it will evaluate the model
|
|
89
|
-
on the test set as well.
|
|
90
86
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
87
|
+
validation_scores = None
|
|
88
|
+
test_scores = None
|
|
89
|
+
|
|
90
|
+
if "fit" in stages:
|
|
91
|
+
trainer.fit(model, datamodule=datamodule)
|
|
92
|
+
if "validate" in stages:
|
|
93
|
+
validation_scores = trainer.validate(
|
|
94
|
+
model=model,
|
|
95
|
+
datamodule=datamodule,
|
|
96
|
+
verbose=verbose,
|
|
97
|
+
ckpt_path=trainer.checkpoint_type,
|
|
98
|
+
)
|
|
99
|
+
if "test" in stages and getattr(datamodule.datasets, "test", None) is not None:
|
|
100
|
+
test_scores = trainer.test(
|
|
101
|
+
model=model,
|
|
102
|
+
datamodule=datamodule,
|
|
103
|
+
verbose=verbose,
|
|
104
|
+
ckpt_path=trainer.checkpoint_type,
|
|
105
|
+
)
|
|
106
|
+
trainer.finish_logger_run(run_id)
|
|
110
107
|
return validation_scores, test_scores
|
|
111
108
|
|
|
112
109
|
|