dragon-ml-toolbox 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,218 @@
1
+ import os
2
+ import imghdr
3
+ from PIL import Image, ImageOps
4
+ from typing import Literal
5
+ from torchvision import transforms
6
+ import torch
7
+
8
+
9
+ # --- Helper Functions ---
10
+ def inspect_images(path: str):
11
+ """
12
+ Prints out the types, sizes and channels of image files found in the directory and its subdirectories.
13
+
14
+ Possible band names (channels):
15
+ * “R”: Red channel
16
+ * “G”: Green channel
17
+ * “B”: Blue channel
18
+ * “A”: Alpha (transparency) channel
19
+ * “L”: Luminance (grayscale) channel
20
+ * “P”: Palette channel
21
+ * “I”: Integer channel
22
+ * “F”: Floating point channel
23
+
24
+ Args:
25
+ path (string): path to target directory.
26
+ """
27
+ # Non-image files present?
28
+ red_flag = False
29
+ non_image = set()
30
+ # Image types found
31
+ img_types = set()
32
+ # Image sizes found
33
+ img_sizes = set()
34
+ # Color channels found
35
+ img_channels = set()
36
+ # Number of images
37
+ img_counter = 0
38
+ # Loop through files in the directory and subdirectories
39
+ for root, directories, files in os.walk(path):
40
+ for filename in files:
41
+ filepath = os.path.join(root, filename)
42
+ img_type = imghdr.what(filepath)
43
+ # Not an image file
44
+ if img_type is None:
45
+ red_flag = True
46
+ non_image.add(filename)
47
+ continue
48
+ # Image type
49
+ img_types.add(img_type)
50
+ # Image counter
51
+ img_counter += 1
52
+ # Image size
53
+ img = Image.open(filepath)
54
+ img_sizes.add(img.size)
55
+ # Image color channels
56
+ channels = img.getbands()
57
+ for code in channels:
58
+ img_channels.add(code)
59
+
60
+ if red_flag:
61
+ print(f"⚠️ Non-image files found: {non_image}")
62
+ # Print results
63
+ print(f"Image types found: {img_types}\nImage sizes found: {img_sizes}\nImage channels found: {img_channels}\nImages found: {img_counter}")
64
+
65
+
66
+ def image_augmentation(path: str, samples: int=100, size: int=256, mode: Literal["RGB", "L"]="RGB", jitter_ratio: float=0.0,
67
+ rotation_deg=270, output: Literal["jpeg", "png", "tiff", "bmp"]="jpeg"):
68
+ """
69
+ Perform image augmentation on a directory containing image files.
70
+ A new directory "temp_augmented_images" will be created; an error will be raised if it already exists.
71
+
72
+ Args:
73
+ path (str): Path to target directory.
74
+ samples (int, optional): Number of images to create per image in the directory. Defaults to 100.
75
+ size (int, optional): Image size to resize to. Defaults to 256.
76
+ mode (str, optional): 'RGB' for 3 channels, 'L' for 1 grayscale channel.
77
+ jitter_ratio (float, optional): Brightness and Contrast factor to use in the ColorJitter transform. Defaults to 0.
78
+ rotation_deg (int, optional): Range for the rotation transformation. Defaults to 270.
79
+ output (str, optional): output image format. Defaults to 'jpeg'.
80
+ """
81
+ # Define the transformations
82
+ transform = transforms.Compose([
83
+ transforms.Resize(size=(int(size*1.2),int(size*1.2))),
84
+ transforms.CenterCrop(size=size),
85
+ transforms.ColorJitter(brightness=jitter_ratio, contrast=jitter_ratio),
86
+ transforms.RandomHorizontalFlip(p=0.5),
87
+ transforms.RandomRotation(degrees=rotation_deg),
88
+ ])
89
+
90
+ # Create container folder
91
+ dir_name = "temp_augmented_images"
92
+ os.makedirs(dir_name, exist_ok=False)
93
+
94
+ # Keep track of non-image files
95
+ non_image = set()
96
+
97
+ # Apply transformation to each image in path
98
+ for filename in os.listdir(path):
99
+ filepath = os.path.join(path, filename)
100
+
101
+ # Is image file?
102
+ if not is_image(filename):
103
+ non_image.add(filename)
104
+ continue
105
+ # if imghdr.what(filepath) is None:
106
+ # non_image.add(filename)
107
+ # continue
108
+
109
+ # current image
110
+ img = Image.open(filepath)
111
+
112
+ # Convert to RGB or grayscale
113
+ if mode == "RGB":
114
+ img = img.convert("RGB")
115
+ else:
116
+ img = img.convert("L")
117
+
118
+ # Create and save images
119
+ for i in range(1, samples+1):
120
+ new_img = transform(img)
121
+ filename_no_ext = os.path.splitext(filename)[0]
122
+ new_img.save(f"{dir_name}/{filename_no_ext}_{i}.{output}")
123
+
124
+ # Print non-image files
125
+ if len(non_image) != 0:
126
+ print(f"Files not processed: {non_image}")
127
+
128
+
129
+ class ResizeAspectFill:
130
+ """
131
+ Custom transformation to make a square image (width/height = 1).
132
+
133
+ Implemented by padding with a `pad_color` border an image of size (w, h) when w > h or w < h to match the longest side.
134
+ """
135
+ def __init__(self, pad_color: Literal["black", "white"]="black") -> None:
136
+ self.pad_color = pad_color
137
+
138
+ def __call__(self, image: Image.Image):
139
+ # Check correct PIL.Image file
140
+ if not isinstance(image, Image.Image):
141
+ raise TypeError(f"Expected PIL.Image.Image, got {type(image).__name__}")
142
+
143
+ w = image.width
144
+ h = image.height
145
+ delta = abs(w - h)
146
+
147
+ if w > h:
148
+ # padding: left, top, right, bottom
149
+ padding = (0, 0, 0, delta)
150
+ elif h > w:
151
+ padding = (0, 0, delta, 0)
152
+ else:
153
+ padding = (0, 0)
154
+
155
+ return ImageOps.expand(image=image, border=padding, fill=self.pad_color)
156
+
157
+
158
+ def is_image(file: str):
159
+ """
160
+ Returns `True` if the file is an image, `False` otherwise.
161
+
162
+ Args:
163
+ `file`, filename with extension.
164
+ """
165
+ try:
166
+ Image.open(file)
167
+ except IOError:
168
+ return False
169
+ else:
170
+ return True
171
+
172
+
173
+ def model_predict(model: torch.nn.Module, kind: Literal["regression", "classification"], samples_list: list[torch.Tensor],
174
+ device: Literal["cpu", "cuda", "mps"]='cpu', view_as: tuple[int,int]=(1,-1), add_batch_dimension: bool=True):
175
+ """
176
+ Returns a list containing lists of predicted values, one for each input sample.
177
+
178
+ Each sample must be a tensor and have the same shape and normalization expected by the model.
179
+
180
+ Args:
181
+ `model`: A trained PyTorch model.
182
+
183
+ `kind`: Regression or Classification task.
184
+
185
+ `samples_list`: A list of input tensors.
186
+
187
+ `device`: Device to use, default is CPU.
188
+
189
+ `view_as`: Reshape each model output, default is (1,-1).
190
+
191
+ `add_batch_dimension`: Automatically adds the batch dimension to each sample shape.
192
+ """
193
+ # Validate device
194
+ if device == "cuda":
195
+ if not torch.cuda.is_available():
196
+ print("CUDA not available, switching to CPU.")
197
+ device = "cpu"
198
+ elif device == "mps":
199
+ if not torch.backends.mps.is_available():
200
+ print("MPS not available, switching to CPU.")
201
+ device = "cpu"
202
+
203
+ model.eval()
204
+ results = list()
205
+ with torch.no_grad():
206
+ for data_point in samples_list:
207
+ if add_batch_dimension:
208
+ data_point = data_point.unsqueeze(0).to(device)
209
+ else:
210
+ data_point = data_point.to(device)
211
+
212
+ output = model(data_point)
213
+ if kind == "classification":
214
+ results.append(output.argmax(dim=1).view(view_as).cpu().tolist())
215
+ else: #regression
216
+ results.append(output.view(view_as).cpu().tolist())
217
+
218
+ return results