wsi-toolbox 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wsi_toolbox/__init__.py +122 -0
- wsi_toolbox/app.py +874 -0
- wsi_toolbox/cli.py +599 -0
- wsi_toolbox/commands/__init__.py +66 -0
- wsi_toolbox/commands/clustering.py +198 -0
- wsi_toolbox/commands/data_loader.py +219 -0
- wsi_toolbox/commands/dzi.py +160 -0
- wsi_toolbox/commands/patch_embedding.py +196 -0
- wsi_toolbox/commands/pca.py +206 -0
- wsi_toolbox/commands/preview.py +394 -0
- wsi_toolbox/commands/show.py +171 -0
- wsi_toolbox/commands/umap_embedding.py +174 -0
- wsi_toolbox/commands/wsi.py +223 -0
- wsi_toolbox/common.py +148 -0
- wsi_toolbox/models.py +30 -0
- wsi_toolbox/utils/__init__.py +109 -0
- wsi_toolbox/utils/analysis.py +174 -0
- wsi_toolbox/utils/hdf5_paths.py +232 -0
- wsi_toolbox/utils/plot.py +227 -0
- wsi_toolbox/utils/progress.py +207 -0
- wsi_toolbox/utils/seed.py +26 -0
- wsi_toolbox/utils/st.py +55 -0
- wsi_toolbox/utils/white.py +121 -0
- wsi_toolbox/watcher.py +256 -0
- wsi_toolbox/wsi_files.py +619 -0
- wsi_toolbox-0.2.0.dist-info/METADATA +253 -0
- wsi_toolbox-0.2.0.dist-info/RECORD +30 -0
- wsi_toolbox-0.2.0.dist-info/WHEEL +4 -0
- wsi_toolbox-0.2.0.dist-info/entry_points.txt +3 -0
- wsi_toolbox-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def reorder_clusters_by_pca(clusters: np.ndarray, pca_values: np.ndarray) -> np.ndarray:
|
|
8
|
+
"""
|
|
9
|
+
Reorder cluster IDs based on PCA distribution for consistent visualization.
|
|
10
|
+
|
|
11
|
+
The goal is to ensure that when clusters are plotted in a violin plot (left to right),
|
|
12
|
+
the distribution rises gradually from left and steeply on the right.
|
|
13
|
+
|
|
14
|
+
Algorithm:
|
|
15
|
+
1. Sort clusters by their mean PCA1 value
|
|
16
|
+
2. Check if median of sorted means is below or above the midpoint
|
|
17
|
+
3. If median > midpoint, flip the order so lower cluster IDs have lower PCA values
|
|
18
|
+
|
|
19
|
+
This ensures consistent ordering regardless of PCA sign ambiguity.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
clusters: Cluster labels array [N]
|
|
23
|
+
pca_values: PCA1 values array [N] (first principal component)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Reordered cluster labels with same shape
|
|
27
|
+
"""
|
|
28
|
+
unique_clusters = [c for c in np.unique(clusters) if c >= 0]
|
|
29
|
+
if len(unique_clusters) <= 1:
|
|
30
|
+
return clusters
|
|
31
|
+
|
|
32
|
+
# 1. Compute mean PCA1 for each cluster
|
|
33
|
+
cluster_means = {}
|
|
34
|
+
for c in unique_clusters:
|
|
35
|
+
cluster_means[c] = np.mean(pca_values[clusters == c])
|
|
36
|
+
|
|
37
|
+
# 2. Sort clusters by mean PCA1
|
|
38
|
+
sorted_clusters = sorted(unique_clusters, key=lambda c: cluster_means[c])
|
|
39
|
+
sorted_means = [cluster_means[c] for c in sorted_clusters]
|
|
40
|
+
|
|
41
|
+
# 3. Check distribution: flip if median is on the higher side
|
|
42
|
+
midpoint = (sorted_means[0] + sorted_means[-1]) / 2
|
|
43
|
+
median_mean = np.median(sorted_means)
|
|
44
|
+
|
|
45
|
+
if median_mean > midpoint:
|
|
46
|
+
sorted_clusters = sorted_clusters[::-1]
|
|
47
|
+
|
|
48
|
+
# 4. Build remapping
|
|
49
|
+
old_to_new = {old: new for new, old in enumerate(sorted_clusters)}
|
|
50
|
+
|
|
51
|
+
# 5. Apply remapping (preserve -1 for filtered)
|
|
52
|
+
return np.array([old_to_new.get(c, c) for c in clusters])
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def find_optimal_components(features, threshold=0.95):
|
|
56
|
+
# Lazy import: sklearn is slow to load (~600ms), defer until needed
|
|
57
|
+
from sklearn.decomposition import PCA # noqa: PLC0415
|
|
58
|
+
|
|
59
|
+
pca = PCA()
|
|
60
|
+
pca.fit(features)
|
|
61
|
+
explained_variance = pca.explained_variance_ratio_
|
|
62
|
+
# 累積寄与率が95%を超える次元数を選択する例
|
|
63
|
+
cumulative_variance = np.cumsum(explained_variance)
|
|
64
|
+
optimal_n = np.argmax(cumulative_variance >= threshold) + 1
|
|
65
|
+
return min(optimal_n, len(features) - 1)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def process_edges_batch(batch_indices, all_indices, h, use_umap_embs, pca=None):
|
|
69
|
+
"""Process a batch of nodes and their edges"""
|
|
70
|
+
edges = []
|
|
71
|
+
weights = []
|
|
72
|
+
|
|
73
|
+
for i in batch_indices:
|
|
74
|
+
for j in all_indices[i]:
|
|
75
|
+
if i == j: # skip self loop
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
if use_umap_embs:
|
|
79
|
+
distance = np.linalg.norm(h[i] - h[j])
|
|
80
|
+
weight = np.exp(-distance)
|
|
81
|
+
else:
|
|
82
|
+
explained_variance_ratio = pca.explained_variance_ratio_
|
|
83
|
+
weighted_diff = (h[i] - h[j]) * np.sqrt(explained_variance_ratio[: len(h[i])])
|
|
84
|
+
distance = np.linalg.norm(weighted_diff)
|
|
85
|
+
weight = np.exp(-distance / distance.mean())
|
|
86
|
+
|
|
87
|
+
edges.append((i, j))
|
|
88
|
+
weights.append(weight)
|
|
89
|
+
|
|
90
|
+
return edges, weights
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def leiden_cluster(
|
|
94
|
+
features: np.ndarray,
|
|
95
|
+
resolution: float = 1.0,
|
|
96
|
+
n_jobs: int = -1,
|
|
97
|
+
on_progress: Callable[[str], None] | None = None,
|
|
98
|
+
) -> np.ndarray:
|
|
99
|
+
"""
|
|
100
|
+
Perform Leiden clustering on feature embeddings.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
features: Feature matrix (n_samples, n_features)
|
|
104
|
+
resolution: Leiden clustering resolution parameter
|
|
105
|
+
n_jobs: Number of parallel jobs (-1 = all CPUs)
|
|
106
|
+
on_progress: Optional callback for progress updates, receives message string
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
np.ndarray: Cluster labels for each sample
|
|
110
|
+
"""
|
|
111
|
+
# Lazy import: sklearn/igraph/networkx are slow to load, defer until needed
|
|
112
|
+
import igraph as ig # noqa: PLC0415
|
|
113
|
+
import leidenalg as la # noqa: PLC0415
|
|
114
|
+
import networkx as nx # noqa: PLC0415
|
|
115
|
+
from joblib import Parallel, delayed # noqa: PLC0415
|
|
116
|
+
from sklearn.decomposition import PCA # noqa: PLC0415
|
|
117
|
+
from sklearn.neighbors import NearestNeighbors # noqa: PLC0415
|
|
118
|
+
|
|
119
|
+
if n_jobs < 0:
|
|
120
|
+
n_jobs = multiprocessing.cpu_count()
|
|
121
|
+
n_samples = features.shape[0]
|
|
122
|
+
|
|
123
|
+
def _progress(msg: str):
|
|
124
|
+
if on_progress:
|
|
125
|
+
on_progress(msg)
|
|
126
|
+
|
|
127
|
+
# 1. PCA
|
|
128
|
+
_progress("Processing PCA")
|
|
129
|
+
n_components = find_optimal_components(features)
|
|
130
|
+
pca = PCA(n_components)
|
|
131
|
+
target_features = pca.fit_transform(features)
|
|
132
|
+
|
|
133
|
+
# 2. KNN
|
|
134
|
+
_progress("Processing KNN")
|
|
135
|
+
k = int(np.sqrt(len(target_features)))
|
|
136
|
+
nn = NearestNeighbors(n_neighbors=k).fit(target_features)
|
|
137
|
+
distances, indices = nn.kneighbors(target_features)
|
|
138
|
+
|
|
139
|
+
# 3. Build graph
|
|
140
|
+
_progress("Building graph")
|
|
141
|
+
G = nx.Graph()
|
|
142
|
+
G.add_nodes_from(range(n_samples))
|
|
143
|
+
|
|
144
|
+
batch_size = max(1, n_samples // n_jobs)
|
|
145
|
+
batches = [list(range(i, min(i + batch_size, n_samples))) for i in range(0, n_samples, batch_size)]
|
|
146
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
147
|
+
[delayed(process_edges_batch)(batch, indices, target_features, False, pca) for batch in batches]
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
for batch_edges, batch_weights in results:
|
|
151
|
+
for (i, j), weight in zip(batch_edges, batch_weights):
|
|
152
|
+
G.add_edge(i, j, weight=weight)
|
|
153
|
+
|
|
154
|
+
# 4. Leiden clustering
|
|
155
|
+
_progress("Leiden clustering")
|
|
156
|
+
edges = list(G.edges())
|
|
157
|
+
weights = [G[u][v]["weight"] for u, v in edges]
|
|
158
|
+
ig_graph = ig.Graph(n=n_samples, edges=edges, edge_attrs={"weight": weights})
|
|
159
|
+
|
|
160
|
+
partition = la.find_partition(
|
|
161
|
+
ig_graph,
|
|
162
|
+
la.RBConfigurationVertexPartition,
|
|
163
|
+
weights="weight",
|
|
164
|
+
resolution_parameter=resolution,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# 5. Finalize
|
|
168
|
+
_progress("Finalizing")
|
|
169
|
+
clusters = np.full(n_samples, -1)
|
|
170
|
+
for i, community in enumerate(partition):
|
|
171
|
+
for node in community:
|
|
172
|
+
clusters[node] = i
|
|
173
|
+
|
|
174
|
+
return clusters
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HDF5 path utilities for consistent namespace and filter handling
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import h5py
|
|
8
|
+
|
|
9
|
+
# Reserved namespace names that cannot be used
|
|
10
|
+
# These conflict with existing HDF5 structure (e.g., model/features, model/metadata)
|
|
11
|
+
RESERVED_NAMESPACES = frozenset({"features", "metadata", "latent_features"})
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def validate_namespace(namespace: str) -> bool:
|
|
15
|
+
"""
|
|
16
|
+
Validate namespace string
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
namespace: Namespace to validate
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
True if valid, False if invalid
|
|
23
|
+
"""
|
|
24
|
+
if not namespace:
|
|
25
|
+
return False
|
|
26
|
+
|
|
27
|
+
if namespace in RESERVED_NAMESPACES:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def normalize_filename(path: str) -> str:
|
|
34
|
+
"""
|
|
35
|
+
Normalize filename for use in namespace
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
path: File path
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Normalized name (stem only, forbidden chars replaced)
|
|
42
|
+
"""
|
|
43
|
+
name = Path(path).stem
|
|
44
|
+
# Replace forbidden characters
|
|
45
|
+
name = name.replace("+", "_") # + is reserved for separator
|
|
46
|
+
name = name.replace("/", "_") # path separator
|
|
47
|
+
return name
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_namespace(input_paths: list[str]) -> str:
|
|
51
|
+
"""
|
|
52
|
+
Build namespace from input file paths
|
|
53
|
+
|
|
54
|
+
Note: No validation here - auto-generated namespaces always contain +
|
|
55
|
+
Validation happens at build_cluster_path() which is the final path assembly
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
input_paths: List of HDF5 file paths
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Namespace string
|
|
62
|
+
- Single file: "default"
|
|
63
|
+
- Multiple files: "file1+file2+..." (sorted, normalized)
|
|
64
|
+
"""
|
|
65
|
+
if len(input_paths) == 1:
|
|
66
|
+
return "default"
|
|
67
|
+
|
|
68
|
+
# Normalize and sort filenames
|
|
69
|
+
names = sorted([normalize_filename(p) for p in input_paths])
|
|
70
|
+
return "+".join(names)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def build_cluster_path(
|
|
74
|
+
model_name: str,
|
|
75
|
+
namespace: str = "default",
|
|
76
|
+
filters: list[list[int]] | None = None,
|
|
77
|
+
dataset: str = "clusters",
|
|
78
|
+
) -> str:
|
|
79
|
+
"""
|
|
80
|
+
Build HDF5 path for clustering data
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_name: Model name (e.g., "uni", "gigapath")
|
|
84
|
+
namespace: Namespace (e.g., "default", "001+002")
|
|
85
|
+
filters: Nested list of cluster filters, e.g., [[1,2,3], [0,1]]
|
|
86
|
+
dataset: Dataset name ("clusters", "umap", "pca1", "pca2", "pca3")
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Full HDF5 path
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ValueError: If namespace is invalid or reserved
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
>>> build_cluster_path("uni", "default")
|
|
96
|
+
'uni/default/clusters'
|
|
97
|
+
|
|
98
|
+
>>> build_cluster_path("uni", "default", [[1,2,3]])
|
|
99
|
+
'uni/default/filter/1+2+3/clusters'
|
|
100
|
+
|
|
101
|
+
>>> build_cluster_path("uni", "default", [[1,2,3], [0,1]])
|
|
102
|
+
'uni/default/filter/1+2+3/filter/0+1/clusters'
|
|
103
|
+
|
|
104
|
+
>>> build_cluster_path("uni", "001+002", [[5]])
|
|
105
|
+
'uni/001+002/filter/5/clusters'
|
|
106
|
+
"""
|
|
107
|
+
# Validate namespace
|
|
108
|
+
if not validate_namespace(namespace):
|
|
109
|
+
raise ValueError(f"Invalid namespace '{namespace}'. Reserved names: {', '.join(sorted(RESERVED_NAMESPACES))}")
|
|
110
|
+
|
|
111
|
+
path = f"{model_name}/{namespace}"
|
|
112
|
+
|
|
113
|
+
if filters:
|
|
114
|
+
for filter_ids in filters:
|
|
115
|
+
filter_str = "+".join(map(str, sorted(filter_ids)))
|
|
116
|
+
path += f"/filter/{filter_str}"
|
|
117
|
+
|
|
118
|
+
path += f"/{dataset}"
|
|
119
|
+
return path
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def parse_cluster_path(path: str) -> dict:
|
|
123
|
+
"""
|
|
124
|
+
Parse cluster path into components
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
path: HDF5 path (e.g., "uni/default/filter/1+2+3/clusters")
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dict with keys: model_name, namespace, filters, dataset
|
|
131
|
+
|
|
132
|
+
Examples:
|
|
133
|
+
>>> parse_cluster_path("uni/default/clusters")
|
|
134
|
+
{'model_name': 'uni', 'namespace': 'default', 'filters': [], 'dataset': 'clusters'}
|
|
135
|
+
|
|
136
|
+
>>> parse_cluster_path("uni/default/filter/1+2+3/clusters")
|
|
137
|
+
{'model_name': 'uni', 'namespace': 'default', 'filters': [[1,2,3]], 'dataset': 'clusters'}
|
|
138
|
+
"""
|
|
139
|
+
parts = path.split("/")
|
|
140
|
+
|
|
141
|
+
result = {"model_name": parts[0], "namespace": parts[1], "filters": [], "dataset": parts[-1]}
|
|
142
|
+
|
|
143
|
+
# Parse filter hierarchy
|
|
144
|
+
i = 2
|
|
145
|
+
while i < len(parts) - 1:
|
|
146
|
+
if parts[i] == "filter":
|
|
147
|
+
filter_str = parts[i + 1]
|
|
148
|
+
filter_ids = [int(x) for x in filter_str.split("+")]
|
|
149
|
+
result["filters"].append(filter_ids)
|
|
150
|
+
i += 2
|
|
151
|
+
else:
|
|
152
|
+
i += 1
|
|
153
|
+
|
|
154
|
+
return result
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def list_namespaces(h5_file, model_name: str) -> list[str]:
|
|
158
|
+
"""
|
|
159
|
+
List all namespaces in HDF5 file for given model
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
h5_file: h5py.File object (opened)
|
|
163
|
+
model_name: Model name
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
List of namespace strings
|
|
167
|
+
"""
|
|
168
|
+
if model_name not in h5_file:
|
|
169
|
+
return []
|
|
170
|
+
|
|
171
|
+
namespaces = []
|
|
172
|
+
for key in h5_file[model_name].keys():
|
|
173
|
+
if isinstance(h5_file[f"{model_name}/{key}"], h5py.Group):
|
|
174
|
+
# Check if it contains 'clusters' dataset
|
|
175
|
+
if "clusters" in h5_file[f"{model_name}/{key}"]:
|
|
176
|
+
namespaces.append(key)
|
|
177
|
+
|
|
178
|
+
return namespaces
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def list_filters(h5_file, model_name: str, namespace: str) -> list[str]:
|
|
182
|
+
"""
|
|
183
|
+
List all filter paths under a namespace
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
h5_file: h5py.File object (opened)
|
|
187
|
+
model_name: Model name
|
|
188
|
+
namespace: Namespace
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
List of filter strings (e.g., ["1+2+3", "5"])
|
|
192
|
+
"""
|
|
193
|
+
base_path = f"{model_name}/{namespace}/filter"
|
|
194
|
+
if base_path not in h5_file:
|
|
195
|
+
return []
|
|
196
|
+
|
|
197
|
+
filters = []
|
|
198
|
+
|
|
199
|
+
def visit_filters(name, obj):
|
|
200
|
+
if isinstance(obj, h5py.Group) and "clusters" in obj:
|
|
201
|
+
# Extract filter string from full path
|
|
202
|
+
rel_path = name.replace(base_path + "/", "")
|
|
203
|
+
# Remove '/filter/' segments to get just the IDs
|
|
204
|
+
filter_str = rel_path.replace("/filter/", "/")
|
|
205
|
+
filters.append(filter_str)
|
|
206
|
+
|
|
207
|
+
h5_file[base_path].visititems(visit_filters)
|
|
208
|
+
|
|
209
|
+
return filters
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def ensure_groups(h5file: h5py.File, path: str) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Ensure all parent groups exist for a given path.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
h5file: Open h5py.File object
|
|
218
|
+
path: Full path to dataset (e.g., "model/namespace/clusters")
|
|
219
|
+
|
|
220
|
+
Example:
|
|
221
|
+
>>> with h5py.File("data.h5", "a") as f:
|
|
222
|
+
... ensure_groups(f, "uni/default/filter/1+2/clusters")
|
|
223
|
+
... f.create_dataset("uni/default/filter/1+2/clusters", data=clusters)
|
|
224
|
+
"""
|
|
225
|
+
parts = path.split("/")
|
|
226
|
+
group_parts = parts[:-1] # Exclude the dataset name
|
|
227
|
+
|
|
228
|
+
current = ""
|
|
229
|
+
for part in group_parts:
|
|
230
|
+
current = f"{current}/{part}" if current else part
|
|
231
|
+
if current not in h5file:
|
|
232
|
+
h5file.create_group(current)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plotting utilities for 2D scatter plots and 1D violin plots
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from matplotlib import pyplot as plt
|
|
7
|
+
|
|
8
|
+
from ..common import _get_cluster_color
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def plot_scatter_2d(
|
|
12
|
+
coords_list: list[np.ndarray],
|
|
13
|
+
clusters_list: list[np.ndarray],
|
|
14
|
+
filenames: list[str],
|
|
15
|
+
title: str = "2D Projection",
|
|
16
|
+
figsize: tuple = (12, 8),
|
|
17
|
+
xlabel: str = "Dimension 1",
|
|
18
|
+
ylabel: str = "Dimension 2",
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Plot 2D scatter plot from single or multiple files
|
|
22
|
+
|
|
23
|
+
Unified plotting logic that works for both single and multiple files.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
coords_list: List of coordinate arrays (one per file)
|
|
27
|
+
clusters_list: List of cluster arrays (one per file)
|
|
28
|
+
filenames: List of file names for legend
|
|
29
|
+
title: Plot title
|
|
30
|
+
figsize: Figure size
|
|
31
|
+
xlabel: X-axis label
|
|
32
|
+
ylabel: Y-axis label
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
matplotlib Figure
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
markers = ["o", "s", "^", "D", "v", "<", ">", "p", "*", "h"]
|
|
39
|
+
|
|
40
|
+
# Get all unique clusters (same namespace = same clusters)
|
|
41
|
+
all_unique_clusters = sorted(np.unique(np.concatenate(clusters_list)))
|
|
42
|
+
cluster_to_color = {cluster_id: _get_cluster_color(cluster_id) for cluster_id in all_unique_clusters}
|
|
43
|
+
|
|
44
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
45
|
+
|
|
46
|
+
# Single file: simpler legend (no file markers)
|
|
47
|
+
if len(coords_list) == 1:
|
|
48
|
+
for cluster_id in all_unique_clusters:
|
|
49
|
+
mask = clusters_list[0] == cluster_id
|
|
50
|
+
if np.sum(mask) > 0:
|
|
51
|
+
if cluster_id == -1:
|
|
52
|
+
color = "black"
|
|
53
|
+
label = "Noise"
|
|
54
|
+
size = 12
|
|
55
|
+
else:
|
|
56
|
+
color = cluster_to_color[cluster_id]
|
|
57
|
+
label = f"Cluster {cluster_id}"
|
|
58
|
+
size = 7
|
|
59
|
+
ax.scatter(
|
|
60
|
+
coords_list[0][mask, 0],
|
|
61
|
+
coords_list[0][mask, 1],
|
|
62
|
+
s=size,
|
|
63
|
+
c=[color],
|
|
64
|
+
label=label,
|
|
65
|
+
alpha=0.8,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
# Multiple files: show both cluster colors and file markers
|
|
69
|
+
# Create handles for cluster legend (colors)
|
|
70
|
+
cluster_handles = []
|
|
71
|
+
for cluster_id in all_unique_clusters:
|
|
72
|
+
if cluster_id < 0: # Skip noise
|
|
73
|
+
continue
|
|
74
|
+
handle = plt.Line2D(
|
|
75
|
+
[0],
|
|
76
|
+
[0],
|
|
77
|
+
marker="o",
|
|
78
|
+
color="w",
|
|
79
|
+
markerfacecolor=cluster_to_color[cluster_id],
|
|
80
|
+
markersize=8,
|
|
81
|
+
label=f"Cluster {cluster_id}",
|
|
82
|
+
)
|
|
83
|
+
cluster_handles.append(handle)
|
|
84
|
+
|
|
85
|
+
# Create handles for file legend (markers)
|
|
86
|
+
file_handles = []
|
|
87
|
+
for i, filename in enumerate(filenames):
|
|
88
|
+
marker = markers[i % len(markers)]
|
|
89
|
+
handle = plt.Line2D(
|
|
90
|
+
[0], [0], marker=marker, color="w", markerfacecolor="gray", markersize=8, label=filename
|
|
91
|
+
)
|
|
92
|
+
file_handles.append(handle)
|
|
93
|
+
|
|
94
|
+
# Plot all data: cluster-first, then file-specific markers
|
|
95
|
+
for cluster_id in all_unique_clusters:
|
|
96
|
+
for i, (coords, clusters, filename) in enumerate(zip(coords_list, clusters_list, filenames)):
|
|
97
|
+
mask = clusters == cluster_id
|
|
98
|
+
if np.sum(mask) > 0: # Only plot if this file has patches in this cluster
|
|
99
|
+
marker = markers[i % len(markers)]
|
|
100
|
+
ax.scatter(
|
|
101
|
+
coords[mask, 0],
|
|
102
|
+
coords[mask, 1],
|
|
103
|
+
marker=marker,
|
|
104
|
+
c=[cluster_to_color[cluster_id]],
|
|
105
|
+
s=10,
|
|
106
|
+
alpha=0.6,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Add legends for multiple files
|
|
110
|
+
legend1 = ax.legend(handles=cluster_handles, title="Clusters", loc="upper left", bbox_to_anchor=(1.02, 1))
|
|
111
|
+
ax.add_artist(legend1)
|
|
112
|
+
ax.legend(handles=file_handles, title="Sources", loc="upper left", bbox_to_anchor=(1.02, 0.5))
|
|
113
|
+
|
|
114
|
+
# Draw cluster numbers at centroids
|
|
115
|
+
all_coords_combined = np.concatenate(coords_list)
|
|
116
|
+
all_clusters_combined = np.concatenate(clusters_list)
|
|
117
|
+
for cluster_id in all_unique_clusters:
|
|
118
|
+
if cluster_id < 0: # Skip noise cluster
|
|
119
|
+
continue
|
|
120
|
+
cluster_points = all_coords_combined[all_clusters_combined == cluster_id]
|
|
121
|
+
if len(cluster_points) < 1:
|
|
122
|
+
continue
|
|
123
|
+
centroid_x = np.mean(cluster_points[:, 0])
|
|
124
|
+
centroid_y = np.mean(cluster_points[:, 1])
|
|
125
|
+
ax.text(
|
|
126
|
+
centroid_x,
|
|
127
|
+
centroid_y,
|
|
128
|
+
str(cluster_id),
|
|
129
|
+
fontsize=12,
|
|
130
|
+
fontweight="bold",
|
|
131
|
+
ha="center",
|
|
132
|
+
va="center",
|
|
133
|
+
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Single file: show legend normally
|
|
137
|
+
if len(coords_list) == 1:
|
|
138
|
+
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
139
|
+
|
|
140
|
+
ax.set_title(title)
|
|
141
|
+
ax.set_xlabel(xlabel)
|
|
142
|
+
ax.set_ylabel(ylabel)
|
|
143
|
+
plt.tight_layout()
|
|
144
|
+
|
|
145
|
+
return fig
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def plot_violin_1d(
|
|
149
|
+
values_list: list[np.ndarray],
|
|
150
|
+
clusters_list: list[np.ndarray],
|
|
151
|
+
title: str = "Distribution by Cluster",
|
|
152
|
+
ylabel: str = "Value",
|
|
153
|
+
figsize: tuple = (12, 8),
|
|
154
|
+
):
|
|
155
|
+
"""
|
|
156
|
+
Plot 1D violin plot with cluster distribution
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
values_list: List of 1D value arrays (one per file)
|
|
160
|
+
clusters_list: List of cluster arrays (one per file)
|
|
161
|
+
title: Plot title
|
|
162
|
+
ylabel: Y-axis label
|
|
163
|
+
figsize: Figure size
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
matplotlib Figure
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
# Combine all data
|
|
170
|
+
all_values = np.concatenate(values_list)
|
|
171
|
+
all_clusters = np.concatenate(clusters_list)
|
|
172
|
+
|
|
173
|
+
# Show all clusters except noise (-1)
|
|
174
|
+
cluster_ids = sorted([c for c in np.unique(all_clusters) if c >= 0])
|
|
175
|
+
|
|
176
|
+
# Prepare violin plot data
|
|
177
|
+
data = []
|
|
178
|
+
labels = []
|
|
179
|
+
|
|
180
|
+
# Add "All" first
|
|
181
|
+
data.append(all_values)
|
|
182
|
+
labels.append("All")
|
|
183
|
+
|
|
184
|
+
# Then add each cluster
|
|
185
|
+
for cluster_id in cluster_ids:
|
|
186
|
+
cluster_mask = all_clusters == cluster_id
|
|
187
|
+
cluster_values = all_values[cluster_mask]
|
|
188
|
+
if len(cluster_values) > 0:
|
|
189
|
+
data.append(cluster_values)
|
|
190
|
+
labels.append(f"Cluster {cluster_id}")
|
|
191
|
+
|
|
192
|
+
if len(data) == 0:
|
|
193
|
+
raise ValueError("No data for specified clusters")
|
|
194
|
+
|
|
195
|
+
# Create plot
|
|
196
|
+
# Lazy import: seaborn is slow to load (~500ms), defer until needed
|
|
197
|
+
import seaborn as sns # noqa: PLC0415
|
|
198
|
+
|
|
199
|
+
fig = plt.figure(figsize=figsize)
|
|
200
|
+
sns.set_style("whitegrid")
|
|
201
|
+
ax = plt.subplot(111)
|
|
202
|
+
|
|
203
|
+
# Prepare colors: gray for "All", then cluster colors
|
|
204
|
+
palette = ["gray"] # Color for "All"
|
|
205
|
+
for cluster_id in cluster_ids:
|
|
206
|
+
color = _get_cluster_color(cluster_id)
|
|
207
|
+
palette.append(color)
|
|
208
|
+
|
|
209
|
+
sns.violinplot(data=data, ax=ax, inner="box", cut=0, zorder=1, alpha=0.5, palette=palette)
|
|
210
|
+
|
|
211
|
+
# Scatter: first is "All" with gray, then clusters
|
|
212
|
+
for i, d in enumerate(data):
|
|
213
|
+
x = np.random.normal(i, 0.05, size=len(d))
|
|
214
|
+
if i == 0:
|
|
215
|
+
color = "gray" # All
|
|
216
|
+
else:
|
|
217
|
+
color = _get_cluster_color(cluster_ids[i - 1])
|
|
218
|
+
ax.scatter(x, d, alpha=0.8, s=5, color=color, zorder=2)
|
|
219
|
+
|
|
220
|
+
ax.set_xticks(np.arange(0, len(labels)))
|
|
221
|
+
ax.set_xticklabels(labels)
|
|
222
|
+
ax.set_ylabel(ylabel)
|
|
223
|
+
ax.set_title(title)
|
|
224
|
+
ax.grid(axis="y", linestyle="--", alpha=0.7)
|
|
225
|
+
plt.tight_layout()
|
|
226
|
+
|
|
227
|
+
return fig
|