BETTER-NMA 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- BETTER_NMA/__init__.py +15 -0
- BETTER_NMA/adversarial_score.py +19 -0
- BETTER_NMA/change_cluster_name.py +0 -0
- BETTER_NMA/detect_attack.py +108 -0
- BETTER_NMA/find_lca.py +21 -0
- BETTER_NMA/main.py +285 -0
- BETTER_NMA/nma_creator.py +108 -0
- BETTER_NMA/plot.py +131 -0
- BETTER_NMA/query_image.py +22 -0
- BETTER_NMA/train_adversarial_detector.py +21 -0
- BETTER_NMA/utilss/__init__.py +0 -0
- BETTER_NMA/utilss/classes/__init__.py +0 -0
- BETTER_NMA/utilss/classes/adversarial_dataset.py +61 -0
- BETTER_NMA/utilss/classes/adversarial_detector.py +63 -0
- BETTER_NMA/utilss/classes/dendrogram.py +131 -0
- BETTER_NMA/utilss/classes/edges_dataframe.py +53 -0
- BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
- BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +28 -0
- BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +46 -0
- BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +30 -0
- BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +102 -0
- BETTER_NMA/utilss/classes/preprocessing/tree_node.py +71 -0
- BETTER_NMA/utilss/classes/preprocessing/z_builder.py +93 -0
- BETTER_NMA/utilss/classes/score_calculator.py +165 -0
- BETTER_NMA/utilss/classes/whitebox_testing.py +35 -0
- BETTER_NMA/utilss/enums/__init__.py +0 -0
- BETTER_NMA/utilss/enums/explanation_method.py +6 -0
- BETTER_NMA/utilss/enums/heap_types.py +5 -0
- BETTER_NMA/utilss/models_utils.py +18 -0
- BETTER_NMA/utilss/photos_uitls.py +72 -0
- BETTER_NMA/utilss/photos_utils.py +104 -0
- BETTER_NMA/utilss/verbal_explanation.py +15 -0
- BETTER_NMA/utilss/wordnet_utils.py +177 -0
- BETTER_NMA/white_box_testing.py +101 -0
- BETTER_NMA-1.0.0.dist-info/METADATA +11 -0
- BETTER_NMA-1.0.0.dist-info/RECORD +38 -0
- BETTER_NMA-1.0.0.dist-info/WHEEL +5 -0
- BETTER_NMA-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,93 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
class ZBuilder:
|
4
|
+
def create_z_matrix_from_tree(self, clustering_builder, labels):
|
5
|
+
unique_labels = []
|
6
|
+
seen_labels = set()
|
7
|
+
|
8
|
+
for label in labels:
|
9
|
+
if label not in seen_labels:
|
10
|
+
unique_labels.append(label)
|
11
|
+
seen_labels.add(label)
|
12
|
+
|
13
|
+
print(f"Total labels: {len(labels)}, Unique labels: {len(unique_labels)}")
|
14
|
+
label_to_idx = {label: i for i, label in enumerate(unique_labels)}
|
15
|
+
|
16
|
+
n = len(unique_labels)
|
17
|
+
z_matrix = np.zeros((n - 1, 4), dtype=np.float64)
|
18
|
+
if not clustering_builder.forest:
|
19
|
+
print("ERROR: No tree found in the forest!")
|
20
|
+
return z_matrix, unique_labels
|
21
|
+
|
22
|
+
root = clustering_builder.forest[0]
|
23
|
+
processed = {}
|
24
|
+
next_z_idx = n
|
25
|
+
row_idx = 0
|
26
|
+
|
27
|
+
def process_node(node):
|
28
|
+
nonlocal row_idx, next_z_idx
|
29
|
+
if node.node_id in processed:
|
30
|
+
return processed[node.node_id]
|
31
|
+
|
32
|
+
if len(node) == 0:
|
33
|
+
if hasattr(node, "node_name") and node.node_name in label_to_idx:
|
34
|
+
idx = label_to_idx[node.node_name]
|
35
|
+
processed[node.node_id] = idx
|
36
|
+
return idx
|
37
|
+
else:
|
38
|
+
return None
|
39
|
+
left_idx = None
|
40
|
+
right_idx = None
|
41
|
+
|
42
|
+
if len(node) >= 1:
|
43
|
+
left_idx = process_node(node[0])
|
44
|
+
if len(node) >= 2:
|
45
|
+
right_idx = process_node(node[1])
|
46
|
+
if left_idx is None or right_idx is None:
|
47
|
+
return None
|
48
|
+
if left_idx > right_idx:
|
49
|
+
left_idx, right_idx = right_idx, left_idx
|
50
|
+
left_count = 1 if left_idx < n else z_matrix[left_idx - n][3]
|
51
|
+
right_count = 1 if right_idx < n else z_matrix[right_idx - n][3]
|
52
|
+
|
53
|
+
if row_idx < n - 1:
|
54
|
+
z_matrix[row_idx] = [
|
55
|
+
left_idx,
|
56
|
+
right_idx,
|
57
|
+
node.weight,
|
58
|
+
left_count + right_count,
|
59
|
+
]
|
60
|
+
this_idx = next_z_idx
|
61
|
+
processed[node.node_id] = this_idx
|
62
|
+
next_z_idx += 1
|
63
|
+
row_idx += 1
|
64
|
+
return this_idx
|
65
|
+
|
66
|
+
return None
|
67
|
+
|
68
|
+
process_node(root)
|
69
|
+
if row_idx < n - 1:
|
70
|
+
print(f"WARNING: Only filled {row_idx} of {n-1} rows in Z matrix!")
|
71
|
+
# Trim the matrix to the actual number of rows we filled
|
72
|
+
z_matrix = z_matrix[:row_idx]
|
73
|
+
|
74
|
+
for i in range(z_matrix.shape[0]):
|
75
|
+
left_idx = int(z_matrix[i, 0])
|
76
|
+
right_idx = int(z_matrix[i, 1])
|
77
|
+
left_size = 1 if left_idx < n else z_matrix[left_idx - n, 3]
|
78
|
+
right_size = 1 if right_idx < n else z_matrix[right_idx - n, 3]
|
79
|
+
expected_size = left_size + right_size
|
80
|
+
if z_matrix[i, 3] != expected_size:
|
81
|
+
print(
|
82
|
+
f"Fixing cluster size at row {i}: {z_matrix[i, 3]} → {expected_size}"
|
83
|
+
)
|
84
|
+
z_matrix[i, 3] = expected_size
|
85
|
+
max_size = z_matrix.shape[0] + 1
|
86
|
+
|
87
|
+
if z_matrix[i, 3] > max_size:
|
88
|
+
print(
|
89
|
+
f"Capping excessive cluster size at row {i}: {z_matrix[i, 3]} → {max_size}"
|
90
|
+
)
|
91
|
+
z_matrix[i, 3] = max_size
|
92
|
+
|
93
|
+
return np.asarray(z_matrix, dtype=np.float64)
|
@@ -0,0 +1,165 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pickle
|
3
|
+
|
4
|
+
class ScoreCalculator:
|
5
|
+
def __init__(self, Z_full, class_names):
|
6
|
+
# Check if Z_full is a file path (string) or numpy array
|
7
|
+
if isinstance(Z_full, str):
|
8
|
+
try:
|
9
|
+
with open(Z_full, 'rb') as file:
|
10
|
+
self.Z_full = pickle.load(file)
|
11
|
+
except (FileNotFoundError, pickle.UnpicklingError) as e:
|
12
|
+
print(f"Error loading Z file: {e}")
|
13
|
+
self.Z_full = None
|
14
|
+
elif isinstance(Z_full, np.ndarray):
|
15
|
+
# Z_full is already a numpy array
|
16
|
+
self.Z_full = Z_full
|
17
|
+
else:
|
18
|
+
print(f"Warning: Z_full type {type(Z_full)} not supported. Expected string (file path) or numpy array.")
|
19
|
+
self.Z_full = Z_full
|
20
|
+
|
21
|
+
self.class_names = class_names
|
22
|
+
|
23
|
+
|
24
|
+
def count_ancestors_to_lca(self, label1, label2):
|
25
|
+
"""
|
26
|
+
Count the number of ancestors for each node until they reach their lowest common ancestor.
|
27
|
+
|
28
|
+
Parameters:
|
29
|
+
- Z_full: The full Z matrix from hierarchical clustering
|
30
|
+
- class_names: List of all class names
|
31
|
+
- label1, label2: Class labels to compare
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
- total_count: Total number of ancestors traversed to reach LCA
|
35
|
+
"""
|
36
|
+
# Convert labels to indices if needed
|
37
|
+
if isinstance(label1, str):
|
38
|
+
idx1 = self.class_names.index(label1)
|
39
|
+
else:
|
40
|
+
idx1 = label1
|
41
|
+
|
42
|
+
if isinstance(label2, str):
|
43
|
+
idx2 = self.class_names.index(label2)
|
44
|
+
else:
|
45
|
+
idx2 = label2
|
46
|
+
|
47
|
+
# Build the hierarchical structure
|
48
|
+
n_samples = len(self.class_names)
|
49
|
+
n_nodes = 2 * n_samples - 1
|
50
|
+
|
51
|
+
# Initialize parent mapping
|
52
|
+
parent = np.zeros(n_nodes, dtype=np.int64) - 1 # -1 means no parent
|
53
|
+
|
54
|
+
# Fill in the structure from Z
|
55
|
+
for i, (left, right, height, _) in enumerate(self.Z_full):
|
56
|
+
left = int(left)
|
57
|
+
right = int(right)
|
58
|
+
node_id = n_samples + i
|
59
|
+
|
60
|
+
parent[left] = node_id
|
61
|
+
parent[right] = node_id
|
62
|
+
|
63
|
+
# Trace path from node1 to root
|
64
|
+
path1 = []
|
65
|
+
current = idx1
|
66
|
+
while parent[current] != -1:
|
67
|
+
path1.append(parent[current])
|
68
|
+
current = parent[current]
|
69
|
+
|
70
|
+
# Trace path from node2 to LCA
|
71
|
+
path2 = []
|
72
|
+
current = idx2
|
73
|
+
lca = None
|
74
|
+
|
75
|
+
while parent[current] != -1:
|
76
|
+
current_parent = parent[current]
|
77
|
+
path2.append(current_parent)
|
78
|
+
|
79
|
+
if current_parent in path1:
|
80
|
+
# Found the LCA
|
81
|
+
lca = current_parent
|
82
|
+
break
|
83
|
+
|
84
|
+
current = current_parent
|
85
|
+
|
86
|
+
# If no LCA found (shouldn't happen in a proper hierarchy), return max value
|
87
|
+
if lca is None:
|
88
|
+
return n_nodes
|
89
|
+
|
90
|
+
# Count steps from node1 to LCA
|
91
|
+
steps1 = path1.index(lca) + 1
|
92
|
+
|
93
|
+
# Count steps from node2 to LCA
|
94
|
+
steps2 = path2.index(lca) + 1
|
95
|
+
|
96
|
+
# Total number of ancestors traversed
|
97
|
+
total_count = steps1 + steps2
|
98
|
+
|
99
|
+
return total_count, lca
|
100
|
+
|
101
|
+
def calculate_adversarial_score(self, predictions, top_k=5):
|
102
|
+
"""
|
103
|
+
Calculate adversarial attack score based on statistical anomaly detection.
|
104
|
+
Works regardless of whether attacks increase or decrease semantic distance.
|
105
|
+
|
106
|
+
Parameters:
|
107
|
+
- Z_full: The full Z matrix from hierarchical clustering
|
108
|
+
- class_names: List of all class names
|
109
|
+
- predictions: Model output predictions (logits or probabilities)
|
110
|
+
- top_k: Number of top predictions to consider
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
- Dictionary with score and detailed information
|
114
|
+
"""
|
115
|
+
# Get top-k predictions
|
116
|
+
if len(predictions.shape) > 1:
|
117
|
+
predictions = predictions[0] # For batch predictions, take the first item
|
118
|
+
|
119
|
+
|
120
|
+
top_indices = np.argsort(predictions)[-top_k:][::-1]
|
121
|
+
top_probs = [predictions[i] for i in top_indices]
|
122
|
+
top_labels = [self.class_names[i] for i in top_indices]
|
123
|
+
|
124
|
+
# Calculate ancestors to LCA for all pairs
|
125
|
+
pairwise_distances = []
|
126
|
+
distance_prob_products = []
|
127
|
+
all_pairs = []
|
128
|
+
|
129
|
+
for i in range(len(top_indices)):
|
130
|
+
for j in range(i+1, len(top_indices)):
|
131
|
+
idx1, idx2 = top_indices[i], top_indices[j]
|
132
|
+
label1, label2 = top_labels[i], top_labels[j]
|
133
|
+
prob1, prob2 = top_probs[i], top_probs[j]
|
134
|
+
|
135
|
+
# Calculate semantic distance
|
136
|
+
rank_count, _ = self.count_ancestors_to_lca(idx1, idx2)
|
137
|
+
print(f"Rank Count for {label1} and {label2}: {rank_count}")
|
138
|
+
|
139
|
+
# Calculate product of probabilities and distance
|
140
|
+
prob_product = prob1 * prob2
|
141
|
+
rank_prob = rank_count * prob_product
|
142
|
+
|
143
|
+
pair_info = {
|
144
|
+
'label1': label1,
|
145
|
+
'label2': label2,
|
146
|
+
'probability1': float(prob1),
|
147
|
+
'probability2': float(prob2),
|
148
|
+
'ancestor_distance': rank_count,
|
149
|
+
'prob_product': prob_product,
|
150
|
+
'weighted_distance': rank_prob
|
151
|
+
}
|
152
|
+
|
153
|
+
pairwise_distances.append(rank_count)
|
154
|
+
distance_prob_products.append(rank_prob)
|
155
|
+
all_pairs.append(pair_info)
|
156
|
+
|
157
|
+
# Calculate statistics
|
158
|
+
if pairwise_distances:
|
159
|
+
sum_distance = sum(pairwise_distances)
|
160
|
+
score = sum_distance
|
161
|
+
else:
|
162
|
+
score = 0
|
163
|
+
|
164
|
+
return score
|
165
|
+
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
|
3
|
+
class WhiteBoxTesting:
|
4
|
+
def __init__(self, model_name):
|
5
|
+
self.model_name = model_name
|
6
|
+
self.problematic_img_ids = None
|
7
|
+
self.problematic_img_preds = None
|
8
|
+
|
9
|
+
|
10
|
+
def find_problematic_images(self, source_labels, target_labels, edges_df, explanation_method=None):
|
11
|
+
filtered_edges_df = edges_df[
|
12
|
+
(edges_df['source'].isin(source_labels)) &
|
13
|
+
(edges_df['target'].isin(target_labels))
|
14
|
+
]
|
15
|
+
|
16
|
+
filtered_edges_df_switched = edges_df[
|
17
|
+
(edges_df['source'].isin(target_labels)) &
|
18
|
+
(edges_df['target'].isin(source_labels))
|
19
|
+
]
|
20
|
+
print(filtered_edges_df_switched.head())
|
21
|
+
|
22
|
+
combined_filtered_edges_df = pd.concat([filtered_edges_df, filtered_edges_df_switched])
|
23
|
+
print("Combined filtered edges dataset:")
|
24
|
+
print(combined_filtered_edges_df)
|
25
|
+
|
26
|
+
unique_ids_list = combined_filtered_edges_df['image_id'].unique().tolist()
|
27
|
+
|
28
|
+
matched_dict = {
|
29
|
+
image_id: list(zip(group['source'], group['target'], group['target_probability']))
|
30
|
+
for image_id, group in edges_df[edges_df['image_id'].isin(unique_ids_list)].groupby('image_id')
|
31
|
+
}
|
32
|
+
print("Matched dictionary:")
|
33
|
+
print(matched_dict)
|
34
|
+
|
35
|
+
return matched_dict
|
File without changes
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
def get_top_k_predictions(model, image, class_names, top_k=5):
|
4
|
+
# Get predictions from the model
|
5
|
+
predictions = model.predict(image)
|
6
|
+
|
7
|
+
# Flatten the predictions array if it's 2D (e.g., shape (1, num_classes))
|
8
|
+
if len(predictions.shape) == 2:
|
9
|
+
predictions = predictions[0] # Extract the first (and only) batch
|
10
|
+
|
11
|
+
# Get the indices of the top-k predictions
|
12
|
+
top_indices = np.argsort(predictions)[-top_k:][::-1]
|
13
|
+
|
14
|
+
# Get the top-k probabilities and corresponding labels
|
15
|
+
top_probs = predictions[top_indices]
|
16
|
+
top_labels = [class_names[i] for i in top_indices]
|
17
|
+
|
18
|
+
return list(zip(top_labels, top_probs))
|
@@ -0,0 +1,72 @@
|
|
1
|
+
|
2
|
+
import numpy as np
|
3
|
+
import tensorflow as tf
|
4
|
+
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
|
5
|
+
from tensorflow.keras.applications.vgg16 import preprocess_input as vgg16_preprocess
|
6
|
+
from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess
|
7
|
+
from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess
|
8
|
+
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
|
9
|
+
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
|
10
|
+
from PIL import Image
|
11
|
+
import io
|
12
|
+
|
13
|
+
def get_preprocess_function(model):
|
14
|
+
print("Determining preprocessing function based on model configuration...")
|
15
|
+
preprocess_map = {
|
16
|
+
"resnet50": resnet50_preprocess,
|
17
|
+
"vgg16": vgg16_preprocess,
|
18
|
+
"inception_v3": inception_v3_preprocess,
|
19
|
+
"mobilenet": mobilenet_preprocess,
|
20
|
+
"efficientnet": efficientnet_preprocess,
|
21
|
+
"xception": xception_preprocess,
|
22
|
+
}
|
23
|
+
|
24
|
+
model_config = model.get_config()
|
25
|
+
if "name" in model_config:
|
26
|
+
model_name = model_config["name"].lower()
|
27
|
+
print(f"Model name: {model_name}")
|
28
|
+
for key in preprocess_map.keys():
|
29
|
+
if key in model_name:
|
30
|
+
print(f"Detected model type: {key}")
|
31
|
+
return preprocess_map[key]
|
32
|
+
|
33
|
+
for layer in model.layers:
|
34
|
+
layer_name = layer.name.lower()
|
35
|
+
print(f"Checking layer: {layer_name}")
|
36
|
+
for model_name in preprocess_map.keys():
|
37
|
+
if model_name in layer_name:
|
38
|
+
print(f"Detected model type: {model_name}")
|
39
|
+
return preprocess_map[model_name]
|
40
|
+
|
41
|
+
print("No supported model type found in the configuration. Falling back to generic normalization.")
|
42
|
+
return lambda x: x / 255.0 # Generic normalization to [0, 1]
|
43
|
+
|
44
|
+
|
45
|
+
_cached_preprocess_function = {}
|
46
|
+
|
47
|
+
def get_cached_preprocess_function(model):
|
48
|
+
"""
|
49
|
+
Get the cached preprocessing function for the given model.
|
50
|
+
If not cached, fetch it and store it in the cache.
|
51
|
+
"""
|
52
|
+
global _cached_preprocess_function
|
53
|
+
model_id = id(model) # Use the model's unique ID as the cache key
|
54
|
+
if model_id not in _cached_preprocess_function:
|
55
|
+
_cached_preprocess_function[model_id] = get_preprocess_function(model)
|
56
|
+
return _cached_preprocess_function[model_id]
|
57
|
+
|
58
|
+
def preprocess_loaded_image(model, image):
|
59
|
+
expected_shape = model.input_shape
|
60
|
+
input_height, input_width = expected_shape[1], expected_shape[2]
|
61
|
+
pil_image = Image.open(io.BytesIO(image)).convert("RGB")
|
62
|
+
pil_image = pil_image.resize((input_width, input_height))
|
63
|
+
preprocess_input = get_cached_preprocess_function(model)
|
64
|
+
image_array = preprocess_input(np.array(pil_image))
|
65
|
+
image_preprocessed = np.expand_dims(image_array, axis=0)
|
66
|
+
return image_preprocessed
|
67
|
+
|
68
|
+
def preprocess_image(model, image):
|
69
|
+
preprocess_input = get_cached_preprocess_function(model)
|
70
|
+
image_array = preprocess_input(np.array(image))
|
71
|
+
image_preprocessed = np.expand_dims(image_array, axis=0)
|
72
|
+
return image_preprocessed
|
@@ -0,0 +1,104 @@
|
|
1
|
+
|
2
|
+
import numpy as np
|
3
|
+
import tensorflow as tf
|
4
|
+
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
|
5
|
+
from tensorflow.keras.applications.vgg16 import preprocess_input as vgg16_preprocess
|
6
|
+
from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess
|
7
|
+
from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess
|
8
|
+
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
|
9
|
+
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
|
10
|
+
from PIL import Image
|
11
|
+
import io
|
12
|
+
import base64
|
13
|
+
|
14
|
+
def get_preprocess_function(model):
|
15
|
+
print("Determining preprocessing function based on model configuration...")
|
16
|
+
preprocess_map = {
|
17
|
+
"resnet50": resnet50_preprocess,
|
18
|
+
"vgg16": vgg16_preprocess,
|
19
|
+
"inception_v3": inception_v3_preprocess,
|
20
|
+
"mobilenet": mobilenet_preprocess,
|
21
|
+
"efficientnet": efficientnet_preprocess,
|
22
|
+
"xception": xception_preprocess,
|
23
|
+
}
|
24
|
+
|
25
|
+
model_config = model.get_config()
|
26
|
+
if "name" in model_config:
|
27
|
+
model_name = model_config["name"].lower()
|
28
|
+
print(f"Model name: {model_name}")
|
29
|
+
for key in preprocess_map.keys():
|
30
|
+
if key in model_name:
|
31
|
+
print(f"Detected model type: {key}")
|
32
|
+
return preprocess_map[key]
|
33
|
+
|
34
|
+
for layer in model.layers:
|
35
|
+
layer_name = layer.name.lower()
|
36
|
+
print(f"Checking layer: {layer_name}")
|
37
|
+
for model_name in preprocess_map.keys():
|
38
|
+
if model_name in layer_name:
|
39
|
+
print(f"Detected model type: {model_name}")
|
40
|
+
return preprocess_map[model_name]
|
41
|
+
|
42
|
+
print("No supported model type found in the configuration. Falling back to generic normalization.")
|
43
|
+
return lambda x: x / 255.0 # Generic normalization to [0, 1]
|
44
|
+
|
45
|
+
|
46
|
+
_cached_preprocess_function = {}
|
47
|
+
|
48
|
+
def get_cached_preprocess_function(model):
|
49
|
+
global _cached_preprocess_function
|
50
|
+
model_id = id(model)
|
51
|
+
if model_id not in _cached_preprocess_function:
|
52
|
+
_cached_preprocess_function[model_id] = get_preprocess_function(model)
|
53
|
+
return _cached_preprocess_function[model_id]
|
54
|
+
|
55
|
+
def preprocess_loaded_image(model, image):
|
56
|
+
expected_shape = model.input_shape
|
57
|
+
input_height, input_width = expected_shape[1], expected_shape[2]
|
58
|
+
|
59
|
+
# Handle different input types
|
60
|
+
if isinstance(image, bytes):
|
61
|
+
# If image is bytes, convert to PIL Image
|
62
|
+
pil_image = Image.open(io.BytesIO(image)).convert("RGB")
|
63
|
+
elif isinstance(image, Image.Image):
|
64
|
+
# If image is already a PIL Image, just convert to RGB
|
65
|
+
pil_image = image.convert("RGB")
|
66
|
+
elif isinstance(image, np.ndarray):
|
67
|
+
# If image is numpy array, convert to PIL Image
|
68
|
+
if image.dtype == np.uint8:
|
69
|
+
pil_image = Image.fromarray(image)
|
70
|
+
else:
|
71
|
+
# Normalize to 0-255 if not uint8
|
72
|
+
image_normalized = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
|
73
|
+
pil_image = Image.fromarray(image_normalized)
|
74
|
+
else:
|
75
|
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
76
|
+
|
77
|
+
# pil_image = Image.open(io.BytesIO(image)).convert("RGB")
|
78
|
+
pil_image = pil_image.resize((input_width, input_height))
|
79
|
+
preprocess_input = get_cached_preprocess_function(model)
|
80
|
+
image_array = preprocess_input(np.array(pil_image))
|
81
|
+
image_preprocessed = np.expand_dims(image_array, axis=0)
|
82
|
+
return image_preprocessed, pil_image
|
83
|
+
|
84
|
+
def preprocess_image(model, image):
|
85
|
+
preprocess_input = get_cached_preprocess_function(model)
|
86
|
+
image_array = preprocess_input(np.array(image))
|
87
|
+
image_preprocessed = np.expand_dims(image_array, axis=0)
|
88
|
+
return image_preprocessed
|
89
|
+
|
90
|
+
def encode_image_to_base64(image):
|
91
|
+
if isinstance(image, np.ndarray):
|
92
|
+
if image.dtype != np.uint8:
|
93
|
+
if image.max() <= 1.0:
|
94
|
+
image = (image * 255).astype(np.uint8)
|
95
|
+
else:
|
96
|
+
image = image.astype(np.uint8)
|
97
|
+
|
98
|
+
pil_image = Image.fromarray(image)
|
99
|
+
buffered = io.BytesIO()
|
100
|
+
pil_image.save(buffered, format="PNG")
|
101
|
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
102
|
+
return img_str
|
103
|
+
else:
|
104
|
+
raise ValueError("Input must be a numpy array")
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
def get_verbal_explanation(explanation):
|
4
|
+
|
5
|
+
result = []
|
6
|
+
for i, word in enumerate(explanation):
|
7
|
+
sanitized_word = re.sub(r'[0-9_]', '', word)
|
8
|
+
if i == 0:
|
9
|
+
result.append(f"**{sanitized_word}**") # You can use bold markdown or just the word
|
10
|
+
else:
|
11
|
+
prev_sanitized = re.sub(r'[0-9_]', '', explanation[i - 1])
|
12
|
+
if sanitized_word == prev_sanitized:
|
13
|
+
continue
|
14
|
+
result.append(f" is a part of **{sanitized_word}**")
|
15
|
+
return ''.join(result)
|