hafnia 0.1.27__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__main__.py +2 -2
- cli/config.py +17 -4
- cli/dataset_cmds.py +60 -0
- cli/runc_cmds.py +1 -1
- hafnia/data/__init__.py +2 -2
- hafnia/data/factory.py +12 -56
- hafnia/dataset/dataset_helpers.py +91 -0
- hafnia/dataset/dataset_names.py +72 -0
- hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
- hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
- hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
- hafnia/dataset/dataset_upload_helper.py +468 -0
- hafnia/dataset/hafnia_dataset.py +624 -0
- hafnia/dataset/operations/dataset_stats.py +15 -0
- hafnia/dataset/operations/dataset_transformations.py +82 -0
- hafnia/dataset/operations/table_transformations.py +183 -0
- hafnia/dataset/primitives/__init__.py +16 -0
- hafnia/dataset/primitives/bbox.py +137 -0
- hafnia/dataset/primitives/bitmask.py +182 -0
- hafnia/dataset/primitives/classification.py +56 -0
- hafnia/dataset/primitives/point.py +25 -0
- hafnia/dataset/primitives/polygon.py +100 -0
- hafnia/dataset/primitives/primitive.py +44 -0
- hafnia/dataset/primitives/segmentation.py +51 -0
- hafnia/dataset/primitives/utils.py +51 -0
- hafnia/experiment/hafnia_logger.py +7 -7
- hafnia/helper_testing.py +108 -0
- hafnia/http.py +5 -3
- hafnia/platform/__init__.py +2 -2
- hafnia/platform/datasets.py +197 -0
- hafnia/platform/download.py +85 -23
- hafnia/torch_helpers.py +180 -95
- hafnia/utils.py +21 -2
- hafnia/visualizations/colors.py +267 -0
- hafnia/visualizations/image_visualizations.py +202 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/METADATA +209 -99
- hafnia-0.2.1.dist-info/RECORD +50 -0
- cli/data_cmds.py +0 -53
- hafnia-0.1.27.dist-info/RECORD +0 -27
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
- {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
hafnia/platform/download.py
CHANGED
|
@@ -1,15 +1,87 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Dict
|
|
3
3
|
|
|
4
4
|
import boto3
|
|
5
5
|
from botocore.exceptions import ClientError
|
|
6
|
+
from pydantic import BaseModel, field_validator
|
|
6
7
|
from tqdm import tqdm
|
|
7
8
|
|
|
8
9
|
from hafnia.http import fetch
|
|
9
10
|
from hafnia.log import sys_logger, user_logger
|
|
10
11
|
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
ARN_PREFIX = "arn:aws:s3:::"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ResourceCredentials(BaseModel):
|
|
16
|
+
access_key: str
|
|
17
|
+
secret_key: str
|
|
18
|
+
session_token: str
|
|
19
|
+
s3_arn: str
|
|
20
|
+
region: str
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
|
|
24
|
+
"""
|
|
25
|
+
The endpoint returns a payload with a key called 's3_path', but it
|
|
26
|
+
is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
|
|
27
|
+
"""
|
|
28
|
+
if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
|
|
29
|
+
payload["s3_arn"] = payload.pop("s3_path")
|
|
30
|
+
|
|
31
|
+
if "region" not in payload:
|
|
32
|
+
payload["region"] = "eu-west-1"
|
|
33
|
+
return ResourceCredentials(**payload)
|
|
34
|
+
|
|
35
|
+
@field_validator("s3_arn")
|
|
36
|
+
@classmethod
|
|
37
|
+
def validate_s3_arn(cls, value: str) -> str:
|
|
38
|
+
"""Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
|
|
39
|
+
if not value.startswith("arn:aws:s3:::"):
|
|
40
|
+
raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
|
|
41
|
+
return value
|
|
42
|
+
|
|
43
|
+
def s3_path(self) -> str:
|
|
44
|
+
"""
|
|
45
|
+
Extracts the S3 path from the ARN.
|
|
46
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
|
|
47
|
+
"""
|
|
48
|
+
return self.s3_arn[len(ARN_PREFIX) :]
|
|
49
|
+
|
|
50
|
+
def s3_uri(self) -> str:
|
|
51
|
+
"""
|
|
52
|
+
Converts the S3 ARN to a URI format.
|
|
53
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
|
|
54
|
+
"""
|
|
55
|
+
return f"s3://{self.s3_path()}"
|
|
56
|
+
|
|
57
|
+
def bucket_name(self) -> str:
|
|
58
|
+
"""
|
|
59
|
+
Extracts the bucket name from the S3 ARN.
|
|
60
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
|
|
61
|
+
"""
|
|
62
|
+
return self.s3_path().split("/")[0]
|
|
63
|
+
|
|
64
|
+
def object_key(self) -> str:
|
|
65
|
+
"""
|
|
66
|
+
Extracts the object key from the S3 ARN.
|
|
67
|
+
Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
|
|
68
|
+
"""
|
|
69
|
+
return "/".join(self.s3_path().split("/")[1:])
|
|
70
|
+
|
|
71
|
+
def aws_credentials(self) -> Dict[str, str]:
|
|
72
|
+
"""
|
|
73
|
+
Returns the AWS credentials as a dictionary.
|
|
74
|
+
"""
|
|
75
|
+
environment_vars = {
|
|
76
|
+
"AWS_ACCESS_KEY_ID": self.access_key,
|
|
77
|
+
"AWS_SECRET_ACCESS_KEY": self.secret_key,
|
|
78
|
+
"AWS_SESSION_TOKEN": self.session_token,
|
|
79
|
+
"AWS_REGION": self.region,
|
|
80
|
+
}
|
|
81
|
+
return environment_vars
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials:
|
|
13
85
|
"""
|
|
14
86
|
Retrieve credentials for accessing the recipe stored in S3 (or another resource)
|
|
15
87
|
by calling a DIP endpoint with the API key.
|
|
@@ -18,21 +90,16 @@ def get_resource_creds(endpoint: str, api_key: str) -> Dict[str, Any]:
|
|
|
18
90
|
endpoint (str): The endpoint URL to fetch credentials from.
|
|
19
91
|
|
|
20
92
|
Returns:
|
|
21
|
-
|
|
22
|
-
{
|
|
23
|
-
"access_key": str,
|
|
24
|
-
"secret_key": str,
|
|
25
|
-
"session_token": str,
|
|
26
|
-
"s3_path": str
|
|
27
|
-
}
|
|
93
|
+
ResourceCredentials
|
|
28
94
|
|
|
29
95
|
Raises:
|
|
30
96
|
RuntimeError: If the call to fetch the credentials fails for any reason.
|
|
31
97
|
"""
|
|
32
98
|
try:
|
|
33
|
-
|
|
99
|
+
credentials_dict = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
|
|
100
|
+
credentials = ResourceCredentials.fix_naming(credentials_dict)
|
|
34
101
|
sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
|
|
35
|
-
return
|
|
102
|
+
return credentials
|
|
36
103
|
except Exception as e:
|
|
37
104
|
sys_logger.error(f"Failed to fetch credentials from endpoint: {e}")
|
|
38
105
|
raise RuntimeError(f"Failed to retrieve credentials: {e}") from e
|
|
@@ -76,23 +143,18 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
|
|
|
76
143
|
ValueError: If the S3 ARN is invalid or no objects found under prefix.
|
|
77
144
|
RuntimeError: If S3 calls fail with an unexpected error.
|
|
78
145
|
"""
|
|
79
|
-
|
|
80
|
-
s3_arn = res_creds["s3_path"]
|
|
81
|
-
arn_prefix = "arn:aws:s3:::"
|
|
82
|
-
if not s3_arn.startswith(arn_prefix):
|
|
83
|
-
raise ValueError(f"Invalid S3 ARN: {s3_arn}")
|
|
146
|
+
res_credentials = get_resource_credentials(resource_url, api_key)
|
|
84
147
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
key = "/".join(key_parts)
|
|
148
|
+
bucket_name = res_credentials.bucket_name()
|
|
149
|
+
key = res_credentials.object_key()
|
|
88
150
|
|
|
89
151
|
output_path = Path(destination)
|
|
90
152
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
91
153
|
s3_client = boto3.client(
|
|
92
154
|
"s3",
|
|
93
|
-
aws_access_key_id=
|
|
94
|
-
aws_secret_access_key=
|
|
95
|
-
aws_session_token=
|
|
155
|
+
aws_access_key_id=res_credentials.access_key,
|
|
156
|
+
aws_secret_access_key=res_credentials.secret_key,
|
|
157
|
+
aws_session_token=res_credentials.session_token,
|
|
96
158
|
)
|
|
97
159
|
downloaded_files = []
|
|
98
160
|
try:
|
hafnia/torch_helpers.py
CHANGED
|
@@ -1,80 +1,126 @@
|
|
|
1
|
-
from typing import List, Optional
|
|
1
|
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
import torchvision
|
|
6
|
-
from flatten_dict import flatten
|
|
6
|
+
from flatten_dict import flatten, unflatten
|
|
7
7
|
from PIL import Image, ImageDraw, ImageFont
|
|
8
8
|
from torchvision import tv_tensors
|
|
9
9
|
from torchvision import utils as tv_utils
|
|
10
10
|
from torchvision.transforms import v2
|
|
11
11
|
|
|
12
|
+
from hafnia.dataset.dataset_names import FieldName
|
|
13
|
+
from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
|
|
14
|
+
from hafnia.dataset.primitives import (
|
|
15
|
+
PRIMITIVE_COLUMN_NAMES,
|
|
16
|
+
class_color_by_name,
|
|
17
|
+
)
|
|
18
|
+
from hafnia.dataset.primitives.bbox import Bbox
|
|
19
|
+
from hafnia.dataset.primitives.bitmask import Bitmask
|
|
20
|
+
from hafnia.dataset.primitives.classification import Classification
|
|
21
|
+
from hafnia.dataset.primitives.primitive import Primitive
|
|
22
|
+
from hafnia.dataset.primitives.segmentation import Segmentation
|
|
23
|
+
from hafnia.log import user_logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_primitives_per_task_name_for_primitive(
|
|
27
|
+
sample: Sample, PrimitiveType: Type[Primitive], split_by_task_name: bool = True
|
|
28
|
+
) -> Dict[str, List[Primitive]]:
|
|
29
|
+
if not hasattr(sample, PrimitiveType.column_name()):
|
|
30
|
+
return {}
|
|
31
|
+
|
|
32
|
+
primitives = getattr(sample, PrimitiveType.column_name())
|
|
33
|
+
if primitives is None:
|
|
34
|
+
return {}
|
|
35
|
+
|
|
36
|
+
primitives_by_task_name: Dict[str, List[Primitive]] = {}
|
|
37
|
+
for primitive in primitives:
|
|
38
|
+
if primitive.task_name not in primitives_by_task_name:
|
|
39
|
+
primitives_by_task_name[primitive.task_name] = []
|
|
40
|
+
primitives_by_task_name[primitive.task_name].append(primitive)
|
|
41
|
+
return primitives_by_task_name
|
|
42
|
+
|
|
12
43
|
|
|
13
44
|
class TorchvisionDataset(torch.utils.data.Dataset):
|
|
14
45
|
def __init__(
|
|
15
46
|
self,
|
|
16
|
-
|
|
47
|
+
dataset: HafniaDataset,
|
|
17
48
|
transforms=None,
|
|
18
|
-
classification_tasks: Optional[List] = None,
|
|
19
|
-
object_tasks: Optional[List] = None,
|
|
20
|
-
segmentation_tasks: Optional[List] = None,
|
|
21
49
|
keep_metadata: bool = False,
|
|
22
50
|
):
|
|
23
|
-
self.dataset =
|
|
51
|
+
self.dataset = dataset
|
|
52
|
+
|
|
24
53
|
self.transforms = transforms
|
|
25
|
-
self.object_tasks = object_tasks or ["objects"]
|
|
26
|
-
self.segmentation_tasks = segmentation_tasks or ["segmentation"]
|
|
27
|
-
self.classification_tasks = classification_tasks or ["classification"]
|
|
28
54
|
self.keep_metadata = keep_metadata
|
|
29
55
|
|
|
30
|
-
def __getitem__(self, idx):
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
56
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict]:
|
|
57
|
+
sample_dict = self.dataset[idx]
|
|
58
|
+
sample = Sample(**sample_dict)
|
|
59
|
+
image = tv_tensors.Image(sample.read_image_pillow())
|
|
60
|
+
h, w = image.shape[-2:]
|
|
61
|
+
target_flat = {}
|
|
62
|
+
mask_tasks: Dict[str, List[Segmentation]] = get_primitives_per_task_name_for_primitive(sample, Segmentation)
|
|
63
|
+
for task_name, masks in mask_tasks.items():
|
|
64
|
+
raise NotImplementedError("Segmentation tasks are not yet implemented")
|
|
65
|
+
# target[f"{mask.task_name}.mask"] = tv_tensors.Mask(mask.mask)
|
|
66
|
+
|
|
67
|
+
class_tasks: Dict[str, List] = get_primitives_per_task_name_for_primitive(sample, Classification)
|
|
68
|
+
for task_name, classifications in class_tasks.items():
|
|
69
|
+
assert len(classifications) == 1, "Expected exactly one classification task per sample"
|
|
70
|
+
target_flat[f"{Classification.column_name()}.{task_name}"] = {
|
|
71
|
+
FieldName.CLASS_IDX: classifications[0].class_idx,
|
|
72
|
+
FieldName.CLASS_NAME: classifications[0].class_name,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
bbox_tasks: Dict[str, List[Bbox]] = get_primitives_per_task_name_for_primitive(sample, Bbox)
|
|
76
|
+
for task_name, bboxes in bbox_tasks.items():
|
|
77
|
+
bboxes_list = [bbox.to_coco(image_height=h, image_width=w) for bbox in bboxes]
|
|
78
|
+
bboxes_tensor = torch.as_tensor(bboxes_list).reshape(-1, 4)
|
|
79
|
+
target_flat[f"{Bbox.column_name()}.{task_name}"] = {
|
|
80
|
+
FieldName.CLASS_IDX: [bbox.class_idx for bbox in bboxes],
|
|
81
|
+
FieldName.CLASS_NAME: [bbox.class_name for bbox in bboxes],
|
|
82
|
+
"bbox": tv_tensors.BoundingBoxes(bboxes_tensor, format="XYWH", canvas_size=(h, w)),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
bitmask_tasks: Dict[str, List[Bitmask]] = get_primitives_per_task_name_for_primitive(sample, Bitmask)
|
|
86
|
+
for task_name, bitmasks in bitmask_tasks.items():
|
|
87
|
+
bitmasks_np = np.array([bitmask.to_mask(img_height=h, img_width=w) for bitmask in bitmasks])
|
|
88
|
+
target_flat[f"{Bitmask.column_name()}.{task_name}"] = {
|
|
89
|
+
FieldName.CLASS_IDX: [bitmask.class_idx for bitmask in bitmasks],
|
|
90
|
+
FieldName.CLASS_NAME: [bitmask.class_name for bitmask in bitmasks],
|
|
91
|
+
"mask": tv_tensors.Mask(bitmasks_np),
|
|
92
|
+
}
|
|
38
93
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
target[f"{segmentation_task}.mask"] = tv_tensors.Mask(sample[segmentation_task].pop("mask"))
|
|
42
|
-
|
|
43
|
-
for classification_task in self.classification_tasks:
|
|
44
|
-
if classification_task in sample:
|
|
45
|
-
target[f"{classification_task}.class_idx"] = sample[classification_task].pop("class_idx")
|
|
46
|
-
|
|
47
|
-
for object_task in self.object_tasks:
|
|
48
|
-
if object_task in sample:
|
|
49
|
-
bboxes_list = sample[object_task].pop("bbox")
|
|
50
|
-
bboxes = tv_tensors.BoundingBoxes(bboxes_list, format="XYWH", canvas_size=img_shape)
|
|
51
|
-
if bboxes.numel() == 0:
|
|
52
|
-
bboxes = bboxes.reshape(-1, 4)
|
|
53
|
-
target[f"{object_task}.bbox"] = bboxes
|
|
54
|
-
target[f"{object_task}.class_idx"] = torch.tensor(sample[object_task].pop("class_idx"))
|
|
94
|
+
if self.transforms:
|
|
95
|
+
image, target_flat = self.transforms(image, target_flat)
|
|
55
96
|
|
|
56
97
|
if self.keep_metadata:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
98
|
+
sample_dict = sample_dict.copy()
|
|
99
|
+
drop_columns = PRIMITIVE_COLUMN_NAMES
|
|
100
|
+
for column in drop_columns:
|
|
101
|
+
if column in sample_dict:
|
|
102
|
+
sample_dict.pop(column)
|
|
61
103
|
|
|
104
|
+
target = flatten(target_flat, reducer="dot")
|
|
62
105
|
return image, target
|
|
63
106
|
|
|
64
107
|
def __len__(self):
|
|
65
108
|
return len(self.dataset)
|
|
66
109
|
|
|
67
110
|
|
|
68
|
-
def draw_image_classification(visualize_image: torch.Tensor,
|
|
111
|
+
def draw_image_classification(visualize_image: torch.Tensor, text_labels: Union[str, List[str]]) -> torch.Tensor:
|
|
112
|
+
if isinstance(text_labels, str):
|
|
113
|
+
text_labels = [text_labels]
|
|
114
|
+
text = "\n".join(text_labels)
|
|
69
115
|
max_dim = max(visualize_image.shape[-2:])
|
|
70
|
-
font_size = max(int(max_dim * 0.
|
|
116
|
+
font_size = max(int(max_dim * 0.06), 10) # Minimum font size of 10
|
|
71
117
|
txt_font = ImageFont.load_default(font_size)
|
|
72
118
|
dummie_draw = ImageDraw.Draw(Image.new("RGB", (10, 10)))
|
|
73
|
-
_, _, w, h = dummie_draw.textbbox((0, 0), text=
|
|
119
|
+
_, _, w, h = dummie_draw.textbbox((0, 0), text=text, font=txt_font) # type: ignore[arg-type]
|
|
74
120
|
|
|
75
121
|
text_image = Image.new("RGB", (int(w), int(h)))
|
|
76
122
|
draw = ImageDraw.Draw(text_image)
|
|
77
|
-
draw.text((0, 0), text=
|
|
123
|
+
draw.text((0, 0), text=text, font=txt_font) # type: ignore[arg-type]
|
|
78
124
|
text_tensor = v2.functional.to_image(text_image)
|
|
79
125
|
|
|
80
126
|
height = text_tensor.shape[-2] + visualize_image.shape[-2]
|
|
@@ -94,77 +140,116 @@ def draw_image_classification(visualize_image: torch.Tensor, text_label: str) ->
|
|
|
94
140
|
def draw_image_and_targets(
|
|
95
141
|
image: torch.Tensor,
|
|
96
142
|
targets,
|
|
97
|
-
detection_tasks: Optional[List[str]] = None,
|
|
98
|
-
segmentation_tasks: Optional[List[str]] = None,
|
|
99
|
-
classification_tasks: Optional[List[str]] = None,
|
|
100
143
|
) -> torch.Tensor:
|
|
101
|
-
detection_tasks = detection_tasks or ["objects"]
|
|
102
|
-
segmentation_tasks = segmentation_tasks or ["segmentation"]
|
|
103
|
-
classification_tasks = classification_tasks or ["classification"]
|
|
104
|
-
|
|
105
144
|
visualize_image = image.clone()
|
|
106
145
|
if visualize_image.is_floating_point():
|
|
107
146
|
visualize_image = image - torch.min(image)
|
|
108
147
|
visualize_image = visualize_image / visualize_image.max()
|
|
109
148
|
|
|
110
149
|
visualize_image = v2.functional.to_dtype(visualize_image, torch.uint8, scale=True)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
150
|
+
targets = unflatten(targets, splitter="dot") # Nested dictionary format
|
|
151
|
+
# NOTE: Order of drawing is important so visualizations are not overlapping in an undesired way
|
|
152
|
+
if Segmentation.column_name() in targets:
|
|
153
|
+
primitive_annotations = targets[Segmentation.column_name()]
|
|
154
|
+
for task_name, task_annotations in primitive_annotations.items():
|
|
155
|
+
raise NotImplementedError("Segmentation tasks are not yet implemented")
|
|
156
|
+
# mask = targets[mask_field].squeeze(0)
|
|
157
|
+
# masks_list = [mask == value for value in mask.unique()]
|
|
158
|
+
# masks = torch.stack(masks_list, dim=0).to(torch.bool)
|
|
159
|
+
# visualize_image = tv_utils.draw_segmentation_masks(visualize_image, masks=masks, alpha=0.5)
|
|
160
|
+
|
|
161
|
+
if Bitmask.column_name() in targets:
|
|
162
|
+
primitive_annotations = targets[Bitmask.column_name()]
|
|
163
|
+
for task_name, task_annotations in primitive_annotations.items():
|
|
164
|
+
colors = [class_color_by_name(class_name) for class_name in task_annotations[FieldName.CLASS_NAME]]
|
|
165
|
+
visualize_image = tv_utils.draw_segmentation_masks(
|
|
166
|
+
image=visualize_image,
|
|
167
|
+
masks=task_annotations["mask"],
|
|
168
|
+
colors=colors,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if Bbox.column_name() in targets:
|
|
172
|
+
primitive_annotations = targets[Bbox.column_name()]
|
|
173
|
+
for task_name, task_annotations in primitive_annotations.items():
|
|
174
|
+
bboxes = torchvision.ops.box_convert(task_annotations["bbox"], in_fmt="xywh", out_fmt="xyxy")
|
|
175
|
+
colors = [class_color_by_name(class_name) for class_name in task_annotations[FieldName.CLASS_NAME]]
|
|
176
|
+
visualize_image = tv_utils.draw_bounding_boxes(
|
|
177
|
+
image=visualize_image,
|
|
178
|
+
boxes=bboxes,
|
|
179
|
+
labels=task_annotations[FieldName.CLASS_NAME],
|
|
180
|
+
width=2,
|
|
181
|
+
colors=colors,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Important that classification is drawn last as it will change image dimensions
|
|
185
|
+
if Classification.column_name() in targets:
|
|
186
|
+
primitive_annotations = targets[Classification.column_name()]
|
|
187
|
+
text_labels = []
|
|
188
|
+
for task_name, task_annotations in primitive_annotations.items():
|
|
189
|
+
if task_name == Classification.default_task_name():
|
|
190
|
+
text_label = task_annotations[FieldName.CLASS_NAME]
|
|
191
|
+
else:
|
|
192
|
+
text_label = f"{task_name}: {task_annotations[FieldName.CLASS_NAME]}"
|
|
193
|
+
text_labels.append(text_label)
|
|
194
|
+
visualize_image = draw_image_classification(visualize_image, text_labels)
|
|
137
195
|
return visualize_image
|
|
138
196
|
|
|
139
197
|
|
|
140
198
|
class TorchVisionCollateFn:
|
|
141
199
|
def __init__(self, skip_stacking: Optional[List] = None):
|
|
142
200
|
if skip_stacking is None:
|
|
143
|
-
skip_stacking = []
|
|
144
|
-
|
|
201
|
+
skip_stacking = [f"{Bbox.column_name()}.*", f"{Bitmask.column_name()}.*"]
|
|
202
|
+
|
|
203
|
+
self.wild_card_skip_stacking = []
|
|
204
|
+
self.skip_stacking_list = []
|
|
205
|
+
for skip_name in skip_stacking:
|
|
206
|
+
if skip_name.endswith("*"):
|
|
207
|
+
self.wild_card_skip_stacking.append(skip_name[:-1]) # Remove the trailing '*'
|
|
208
|
+
else:
|
|
209
|
+
self.skip_stacking_list.append(skip_name)
|
|
210
|
+
|
|
211
|
+
def skip_key_name(self, key_name: str) -> bool:
|
|
212
|
+
if key_name in self.skip_stacking_list:
|
|
213
|
+
return True
|
|
214
|
+
if any(key_name.startswith(wild_card) for wild_card in self.wild_card_skip_stacking):
|
|
215
|
+
return True
|
|
216
|
+
return False
|
|
145
217
|
|
|
146
218
|
def __call__(self, batch):
|
|
147
219
|
images, targets = tuple(zip(*batch, strict=False))
|
|
148
220
|
if "image" not in self.skip_stacking_list:
|
|
149
221
|
images = torch.stack(images)
|
|
150
222
|
|
|
151
|
-
|
|
223
|
+
keys_min = set(targets[0])
|
|
224
|
+
keys_max = set(targets[0])
|
|
225
|
+
for target in targets:
|
|
226
|
+
keys_min = keys_min.intersection(target)
|
|
227
|
+
keys_max = keys_max.union(target)
|
|
228
|
+
|
|
229
|
+
if keys_min != keys_max:
|
|
230
|
+
user_logger.warning(
|
|
231
|
+
"Not all images in the batch contain the same targets. To solve for missing targets "
|
|
232
|
+
f"the following keys {keys_max - keys_min} are dropped from the batch "
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
targets_modified = {k: [d[k] for d in targets] for k in keys_min}
|
|
152
236
|
for key_name, item_values in targets_modified.items():
|
|
153
|
-
if
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
237
|
+
if self.skip_key_name(key_name):
|
|
238
|
+
continue
|
|
239
|
+
first_element = item_values[0]
|
|
240
|
+
if isinstance(first_element, torch.Tensor):
|
|
241
|
+
item_values = torch.stack(item_values)
|
|
242
|
+
elif isinstance(first_element, (int, float)):
|
|
243
|
+
item_values = torch.tensor(item_values)
|
|
244
|
+
elif isinstance(first_element, (str, list)):
|
|
245
|
+
# Skip stacking for certain types such as strings and lists
|
|
246
|
+
pass
|
|
247
|
+
if isinstance(first_element, tv_tensors.Mask):
|
|
248
|
+
item_values = tv_tensors.Mask(item_values)
|
|
249
|
+
elif isinstance(first_element, tv_tensors.Image):
|
|
250
|
+
item_values = tv_tensors.Image(item_values)
|
|
251
|
+
elif isinstance(first_element, tv_tensors.BoundingBoxes):
|
|
252
|
+
item_values = tv_tensors.BoundingBoxes(item_values)
|
|
253
|
+
targets_modified[key_name] = item_values
|
|
169
254
|
|
|
170
255
|
return images, targets_modified
|
hafnia/utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import os
|
|
2
3
|
import time
|
|
3
4
|
import zipfile
|
|
@@ -14,7 +15,7 @@ from rich import print as rprint
|
|
|
14
15
|
from hafnia.log import sys_logger, user_logger
|
|
15
16
|
|
|
16
17
|
PATH_DATA = Path("./.data")
|
|
17
|
-
|
|
18
|
+
PATH_DATASETS = PATH_DATA / "datasets"
|
|
18
19
|
PATH_RECIPES = PATH_DATA / "recipes"
|
|
19
20
|
FILENAME_HAFNIAIGNORE = ".hafniaignore"
|
|
20
21
|
DEFAULT_IGNORE_SPECIFICATION = [
|
|
@@ -132,6 +133,24 @@ def show_recipe_content(recipe_path: Path, style: str = "emoji", depth_limit: in
|
|
|
132
133
|
user_logger.info(f"Recipe size: {size_human_readable(os.path.getsize(recipe_path))}. Max size 800 MiB")
|
|
133
134
|
|
|
134
135
|
|
|
135
|
-
def
|
|
136
|
+
def is_hafnia_cloud_job() -> bool:
|
|
136
137
|
"""Check if the current job is running in HAFNIA cloud environment."""
|
|
137
138
|
return os.getenv("HAFNIA_CLOUD", "false").lower() == "true"
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def pascal_to_snake_case(name: str) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Convert PascalCase to snake_case.
|
|
144
|
+
"""
|
|
145
|
+
return "".join(["_" + char.lower() if char.isupper() else char for char in name]).lstrip("_")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def snake_to_pascal_case(name: str) -> str:
|
|
149
|
+
"""
|
|
150
|
+
Convert snake_case to PascalCase.
|
|
151
|
+
"""
|
|
152
|
+
return "".join(word.capitalize() for word in name.split("_"))
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def hash_from_string(s: str) -> str:
|
|
156
|
+
return hashlib.md5(s.encode("utf-8")).hexdigest()
|