kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Default metric collection for binary classification tasks."""
|
|
2
|
+
|
|
3
|
+
from torchmetrics import classification
|
|
4
|
+
|
|
5
|
+
from eva.core.metrics import binary_balanced_accuracy, structs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BinaryClassificationMetrics(structs.MetricCollection):
|
|
9
|
+
"""Default metrics for binary classification tasks."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
threshold: float = 0.5,
|
|
14
|
+
ignore_index: int | None = None,
|
|
15
|
+
prefix: str | None = None,
|
|
16
|
+
postfix: str | None = None,
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Initializes the binary classification metrics.
|
|
19
|
+
|
|
20
|
+
The metrics instantiated here are:
|
|
21
|
+
|
|
22
|
+
- BinaryAUROC
|
|
23
|
+
- BinaryAccuracy
|
|
24
|
+
- BinaryBalancedAccuracy
|
|
25
|
+
- BinaryF1Score
|
|
26
|
+
- BinaryPrecision
|
|
27
|
+
- BinaryRecall
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
threshold: Threshold for transforming probability to binary (0,1) predictions
|
|
31
|
+
ignore_index: Specifies a target value that is ignored and does not
|
|
32
|
+
contribute to the metric calculation.
|
|
33
|
+
prefix: A string to append in front of the keys of the output dict.
|
|
34
|
+
postfix: A string to append after the keys of the output dict.
|
|
35
|
+
"""
|
|
36
|
+
super().__init__(
|
|
37
|
+
metrics=[
|
|
38
|
+
classification.BinaryAUROC(
|
|
39
|
+
ignore_index=ignore_index,
|
|
40
|
+
),
|
|
41
|
+
classification.BinaryAccuracy(
|
|
42
|
+
threshold=threshold,
|
|
43
|
+
ignore_index=ignore_index,
|
|
44
|
+
),
|
|
45
|
+
binary_balanced_accuracy.BinaryBalancedAccuracy(
|
|
46
|
+
threshold=threshold,
|
|
47
|
+
ignore_index=ignore_index,
|
|
48
|
+
),
|
|
49
|
+
classification.BinaryF1Score(
|
|
50
|
+
threshold=threshold,
|
|
51
|
+
ignore_index=ignore_index,
|
|
52
|
+
),
|
|
53
|
+
classification.BinaryPrecision(
|
|
54
|
+
threshold=threshold,
|
|
55
|
+
ignore_index=ignore_index,
|
|
56
|
+
),
|
|
57
|
+
classification.BinaryRecall(
|
|
58
|
+
threshold=threshold,
|
|
59
|
+
ignore_index=ignore_index,
|
|
60
|
+
),
|
|
61
|
+
],
|
|
62
|
+
prefix=prefix,
|
|
63
|
+
postfix=postfix,
|
|
64
|
+
compute_groups=[
|
|
65
|
+
[
|
|
66
|
+
"BinaryAccuracy",
|
|
67
|
+
"BinaryBalancedAccuracy",
|
|
68
|
+
"BinaryF1Score",
|
|
69
|
+
"BinaryPrecision",
|
|
70
|
+
"BinaryRecall",
|
|
71
|
+
],
|
|
72
|
+
[
|
|
73
|
+
"BinaryAUROC",
|
|
74
|
+
],
|
|
75
|
+
],
|
|
76
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Default metric collection for multiclass classification tasks."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from torchmetrics import classification
|
|
6
|
+
|
|
7
|
+
from eva.core.metrics import structs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MulticlassClassificationMetrics(structs.MetricCollection):
|
|
11
|
+
"""Default metrics for multi-class classification tasks."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
num_classes: int,
|
|
16
|
+
average: Literal["macro", "weighted", "none"] = "macro",
|
|
17
|
+
ignore_index: int | None = None,
|
|
18
|
+
prefix: str | None = None,
|
|
19
|
+
postfix: str | None = None,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Initializes the multi-class classification metrics.
|
|
22
|
+
|
|
23
|
+
The metrics instantiated here are:
|
|
24
|
+
|
|
25
|
+
- MulticlassAccuracy
|
|
26
|
+
- MulticlassPrecision
|
|
27
|
+
- MulticlassRecall
|
|
28
|
+
- MulticlassF1Score
|
|
29
|
+
- MulticlassAUROC
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
num_classes: Integer specifying the number of classes.
|
|
33
|
+
average: Defines the reduction that is applied over labels.
|
|
34
|
+
ignore_index: Specifies a target value that is ignored and does not
|
|
35
|
+
contribute to the metric calculation.
|
|
36
|
+
prefix: A string to append in front of the keys of the output dict.
|
|
37
|
+
postfix: A string to append after the keys of the output dict.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(
|
|
40
|
+
metrics=[
|
|
41
|
+
classification.MulticlassAUROC(
|
|
42
|
+
num_classes=num_classes,
|
|
43
|
+
average=average,
|
|
44
|
+
ignore_index=ignore_index,
|
|
45
|
+
),
|
|
46
|
+
classification.MulticlassAccuracy(
|
|
47
|
+
num_classes=num_classes,
|
|
48
|
+
average=average,
|
|
49
|
+
ignore_index=ignore_index,
|
|
50
|
+
),
|
|
51
|
+
classification.MulticlassF1Score(
|
|
52
|
+
num_classes=num_classes,
|
|
53
|
+
average=average,
|
|
54
|
+
ignore_index=ignore_index,
|
|
55
|
+
),
|
|
56
|
+
classification.MulticlassPrecision(
|
|
57
|
+
num_classes=num_classes,
|
|
58
|
+
average=average,
|
|
59
|
+
ignore_index=ignore_index,
|
|
60
|
+
),
|
|
61
|
+
classification.MulticlassRecall(
|
|
62
|
+
num_classes=num_classes,
|
|
63
|
+
average=average,
|
|
64
|
+
ignore_index=ignore_index,
|
|
65
|
+
),
|
|
66
|
+
],
|
|
67
|
+
prefix=prefix,
|
|
68
|
+
postfix=postfix,
|
|
69
|
+
compute_groups=[
|
|
70
|
+
[
|
|
71
|
+
"MulticlassAccuracy",
|
|
72
|
+
"MulticlassF1Score",
|
|
73
|
+
"MulticlassPrecision",
|
|
74
|
+
"MulticlassRecall",
|
|
75
|
+
],
|
|
76
|
+
[
|
|
77
|
+
"MulticlassAUROC",
|
|
78
|
+
],
|
|
79
|
+
],
|
|
80
|
+
)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Core metrics modules API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.metrics.structs.collection import MetricCollection
|
|
4
|
+
from eva.core.metrics.structs.metric import Metric
|
|
5
|
+
from eva.core.metrics.structs.module import MetricModule
|
|
6
|
+
from eva.core.metrics.structs.schemas import MetricsSchema
|
|
7
|
+
from eva.core.metrics.structs.typings import MetricModuleType
|
|
8
|
+
|
|
9
|
+
__all__ = ["MetricCollection", "Metric", "MetricModule", "MetricsSchema", "MetricModuleType"]
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Metrics module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from eva.core.metrics.structs import collection, schemas
|
|
8
|
+
from eva.core.metrics.structs.typings import MetricModuleType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MetricModule(nn.Module):
|
|
12
|
+
"""The metrics module.
|
|
13
|
+
|
|
14
|
+
Allows to store and keep track of `train`, `val` and `test` metrics.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
train: collection.MetricCollection | None,
|
|
20
|
+
val: collection.MetricCollection | None,
|
|
21
|
+
test: collection.MetricCollection | None,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initializes the metrics for the Trainer.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
train: The training metric collection.
|
|
27
|
+
val: The validation metric collection.
|
|
28
|
+
test: The test metric collection.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self._train = train or self.default_metric_collection
|
|
33
|
+
self._val = val or self.default_metric_collection
|
|
34
|
+
self._test = test or self.default_metric_collection
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def default_metric_collection(self) -> collection.MetricCollection:
|
|
38
|
+
"""Returns the default metric collection."""
|
|
39
|
+
return collection.MetricCollection([])
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_metrics(
|
|
43
|
+
cls,
|
|
44
|
+
train: MetricModuleType | None,
|
|
45
|
+
val: MetricModuleType | None,
|
|
46
|
+
test: MetricModuleType | None,
|
|
47
|
+
*,
|
|
48
|
+
separator: str = "/",
|
|
49
|
+
) -> MetricModule:
|
|
50
|
+
"""Initializes a metric module from a list of metrics.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
train: Metrics for the training stage.
|
|
54
|
+
val: Metrics for the validation stage.
|
|
55
|
+
test: Metrics for the test stage.
|
|
56
|
+
separator: The separator between the group name of the metric
|
|
57
|
+
and the metric itself.
|
|
58
|
+
"""
|
|
59
|
+
return cls(
|
|
60
|
+
train=_create_collection_from_metrics(train, prefix="train" + separator),
|
|
61
|
+
val=_create_collection_from_metrics(val, prefix="val" + separator),
|
|
62
|
+
test=_create_collection_from_metrics(test, prefix="test" + separator),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_schema(
|
|
67
|
+
cls,
|
|
68
|
+
schema: schemas.MetricsSchema,
|
|
69
|
+
*,
|
|
70
|
+
separator: str = "/",
|
|
71
|
+
) -> MetricModule:
|
|
72
|
+
"""Initializes a metric module from the metrics schema.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
schema: The dataclass metric schema.
|
|
76
|
+
separator: The separator between the group name of the metric
|
|
77
|
+
and the metric itself.
|
|
78
|
+
"""
|
|
79
|
+
return cls.from_metrics(
|
|
80
|
+
train=schema.training_metrics,
|
|
81
|
+
val=schema.evaluation_metrics,
|
|
82
|
+
test=schema.evaluation_metrics,
|
|
83
|
+
separator=separator,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def training_metrics(self) -> collection.MetricCollection:
|
|
88
|
+
"""Returns the metrics of the train dataset."""
|
|
89
|
+
return self._train
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def validation_metrics(self) -> collection.MetricCollection:
|
|
93
|
+
"""Returns the metrics of the validation dataset."""
|
|
94
|
+
return self._val
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def test_metrics(self) -> collection.MetricCollection:
|
|
98
|
+
"""Returns the metrics of the test dataset."""
|
|
99
|
+
return self._test
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _create_collection_from_metrics(
|
|
103
|
+
metrics: MetricModuleType | None, *, prefix: str | None = None
|
|
104
|
+
) -> collection.MetricCollection | None:
|
|
105
|
+
"""Create a unique collection from metrics.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
metrics: The desired metrics.
|
|
109
|
+
prefix: A prefix to added to the collection.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The resulted metrics collection.
|
|
113
|
+
"""
|
|
114
|
+
metrics_collection = collection.MetricCollection(metrics or [], prefix=prefix) # type: ignore
|
|
115
|
+
return metrics_collection.clone()
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Metrics related helper schemas."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
|
|
5
|
+
from eva.core.metrics.structs.typings import MetricModuleType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass(frozen=True)
|
|
9
|
+
class MetricsSchema:
|
|
10
|
+
"""Metrics schema."""
|
|
11
|
+
|
|
12
|
+
common: MetricModuleType | None = None
|
|
13
|
+
"""Holds the common train and evaluation metrics."""
|
|
14
|
+
|
|
15
|
+
train: MetricModuleType | None = None
|
|
16
|
+
"""The exclusive training metrics."""
|
|
17
|
+
|
|
18
|
+
evaluation: MetricModuleType | None = None
|
|
19
|
+
"""The exclusive evaluation metrics."""
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def training_metrics(self) -> MetricModuleType | None:
|
|
23
|
+
"""Returns the training metics."""
|
|
24
|
+
return self._join_with_common(self.train)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def evaluation_metrics(self) -> MetricModuleType | None:
|
|
28
|
+
"""Returns the evaluation metics."""
|
|
29
|
+
return self._join_with_common(self.evaluation)
|
|
30
|
+
|
|
31
|
+
def _join_with_common(self, metrics: MetricModuleType | None) -> MetricModuleType | None:
|
|
32
|
+
"""Joins the provided metrics with the common.
|
|
33
|
+
|
|
34
|
+
Note that if there is duplication of metrics between the provided and the common
|
|
35
|
+
(meaning there is the same metric in `metrics` and in `self.common`) both will
|
|
36
|
+
be preserved.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
metrics: The metrics to join.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The resulted metrics after joining with the common ones.
|
|
43
|
+
"""
|
|
44
|
+
if metrics is None or self.common is None:
|
|
45
|
+
return self.common or metrics
|
|
46
|
+
|
|
47
|
+
return [self.common, metrics] # type: ignore
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Metric typings."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Sequence, Union
|
|
4
|
+
|
|
5
|
+
from eva.core.metrics.structs import collection, metric
|
|
6
|
+
|
|
7
|
+
BaseMetricModuleType = Union[metric.Metric, collection.MetricCollection]
|
|
8
|
+
"""The base module metric type."""
|
|
9
|
+
|
|
10
|
+
MetricModuleType = Union[
|
|
11
|
+
BaseMetricModuleType,
|
|
12
|
+
Sequence[BaseMetricModuleType],
|
|
13
|
+
Dict[str, BaseMetricModuleType],
|
|
14
|
+
]
|
|
15
|
+
"""The module metric type."""
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Models API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.models.modules import HeadModule, InferenceModule
|
|
4
|
+
from eva.core.models.networks import MLP, HuggingFaceModel, ModelFromFunction, ONNXModel
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"HeadModule",
|
|
8
|
+
"InferenceModule",
|
|
9
|
+
"MLP",
|
|
10
|
+
"HuggingFaceModel",
|
|
11
|
+
"ModelFromFunction",
|
|
12
|
+
"ONNXModel",
|
|
13
|
+
]
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
""""Neural Network Head Module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
|
|
7
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
+
from torch import optim
|
|
9
|
+
from torch.optim import lr_scheduler
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.metrics import structs as metrics_lib
|
|
13
|
+
from eva.core.models.modules import module
|
|
14
|
+
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
|
|
15
|
+
from eva.core.models.modules.utils import batch_postprocess, grad
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class HeadModule(module.ModelModule):
|
|
19
|
+
"""Neural Net Head Module for training on features.
|
|
20
|
+
|
|
21
|
+
It can be used for supervised (mini-batch) stochastic gradient descent
|
|
22
|
+
downstream tasks such as classification, regression and segmentation.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
head: MODEL_TYPE,
|
|
28
|
+
criterion: Callable[..., torch.Tensor],
|
|
29
|
+
backbone: MODEL_TYPE | None = None,
|
|
30
|
+
optimizer: OptimizerCallable = optim.Adam,
|
|
31
|
+
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
|
|
32
|
+
metrics: metrics_lib.MetricsSchema | None = None,
|
|
33
|
+
postprocess: batch_postprocess.BatchPostProcess | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initializes the neural net head module.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
head: The neural network that would be trained on the features.
|
|
39
|
+
criterion: The loss function to use.
|
|
40
|
+
backbone: The feature extractor. If `None`, it will be expected
|
|
41
|
+
that the input batch returns the features directly.
|
|
42
|
+
optimizer: The optimizer to use.
|
|
43
|
+
lr_scheduler: The learning rate scheduler to use.
|
|
44
|
+
metrics: The metric groups to track.
|
|
45
|
+
postprocess: A list of helper functions to apply after the
|
|
46
|
+
loss and before the metrics calculation to the model
|
|
47
|
+
predictions and targets.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(metrics=metrics, postprocess=postprocess)
|
|
50
|
+
|
|
51
|
+
self.head = head
|
|
52
|
+
self.criterion = criterion
|
|
53
|
+
self.backbone = backbone
|
|
54
|
+
self.optimizer = optimizer
|
|
55
|
+
self.lr_scheduler = lr_scheduler
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def configure_optimizers(self) -> Any:
|
|
59
|
+
parameters = list(self.head.parameters())
|
|
60
|
+
optimizer = self.optimizer(parameters)
|
|
61
|
+
lr_scheduler = self.lr_scheduler(optimizer)
|
|
62
|
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
66
|
+
features = tensor if self.backbone is None else self.backbone(tensor)
|
|
67
|
+
return self.head(features).squeeze(-1)
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def on_fit_start(self) -> None:
|
|
71
|
+
if self.backbone is not None:
|
|
72
|
+
grad.deactivate_requires_grad(self.backbone)
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
76
|
+
return self._batch_step(batch)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def validation_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
80
|
+
return self._batch_step(batch)
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
84
|
+
return self._batch_step(batch)
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
88
|
+
tensor = INPUT_BATCH(*batch).data
|
|
89
|
+
return tensor if self.backbone is None else self.backbone(tensor)
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def on_fit_end(self) -> None:
|
|
93
|
+
if self.backbone is not None:
|
|
94
|
+
grad.activate_requires_grad(self.backbone)
|
|
95
|
+
|
|
96
|
+
def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
|
|
97
|
+
"""Performs a model forward step and calculates the loss.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
batch: The desired batch to process.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The batch step output.
|
|
104
|
+
"""
|
|
105
|
+
data, targets, metadata = INPUT_BATCH(*batch)
|
|
106
|
+
predictions = self(data)
|
|
107
|
+
loss = self.criterion(predictions, targets)
|
|
108
|
+
return {
|
|
109
|
+
"loss": loss,
|
|
110
|
+
"targets": targets,
|
|
111
|
+
"predictions": predictions,
|
|
112
|
+
"metadata": metadata,
|
|
113
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Model inference module."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from eva.core.models.modules import module
|
|
8
|
+
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InferenceModule(module.ModelModule):
|
|
12
|
+
"""An lightweight model module to perform inference."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, backbone: MODEL_TYPE) -> None:
|
|
15
|
+
"""Initializes the module.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
backbone: The network to be used for inference.
|
|
19
|
+
"""
|
|
20
|
+
super().__init__(metrics=None)
|
|
21
|
+
|
|
22
|
+
self.backbone = backbone
|
|
23
|
+
|
|
24
|
+
@override
|
|
25
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
return self.backbone(tensor)
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
def predict_step(
|
|
30
|
+
self,
|
|
31
|
+
batch: INPUT_BATCH,
|
|
32
|
+
batch_idx: int,
|
|
33
|
+
dataloader_idx: int = 0,
|
|
34
|
+
) -> STEP_OUTPUT:
|
|
35
|
+
data, *_ = INPUT_BATCH(*batch)
|
|
36
|
+
predictions = self(data)
|
|
37
|
+
return predictions
|