wsi-toolbox 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,174 @@
1
+ import multiprocessing
2
+ from typing import Callable
3
+
4
+ import numpy as np
5
+
6
+
7
+ def reorder_clusters_by_pca(clusters: np.ndarray, pca_values: np.ndarray) -> np.ndarray:
8
+ """
9
+ Reorder cluster IDs based on PCA distribution for consistent visualization.
10
+
11
+ The goal is to ensure that when clusters are plotted in a violin plot (left to right),
12
+ the distribution rises gradually from left and steeply on the right.
13
+
14
+ Algorithm:
15
+ 1. Sort clusters by their mean PCA1 value
16
+ 2. Check if median of sorted means is below or above the midpoint
17
+ 3. If median > midpoint, flip the order so lower cluster IDs have lower PCA values
18
+
19
+ This ensures consistent ordering regardless of PCA sign ambiguity.
20
+
21
+ Args:
22
+ clusters: Cluster labels array [N]
23
+ pca_values: PCA1 values array [N] (first principal component)
24
+
25
+ Returns:
26
+ Reordered cluster labels with same shape
27
+ """
28
+ unique_clusters = [c for c in np.unique(clusters) if c >= 0]
29
+ if len(unique_clusters) <= 1:
30
+ return clusters
31
+
32
+ # 1. Compute mean PCA1 for each cluster
33
+ cluster_means = {}
34
+ for c in unique_clusters:
35
+ cluster_means[c] = np.mean(pca_values[clusters == c])
36
+
37
+ # 2. Sort clusters by mean PCA1
38
+ sorted_clusters = sorted(unique_clusters, key=lambda c: cluster_means[c])
39
+ sorted_means = [cluster_means[c] for c in sorted_clusters]
40
+
41
+ # 3. Check distribution: flip if median is on the higher side
42
+ midpoint = (sorted_means[0] + sorted_means[-1]) / 2
43
+ median_mean = np.median(sorted_means)
44
+
45
+ if median_mean > midpoint:
46
+ sorted_clusters = sorted_clusters[::-1]
47
+
48
+ # 4. Build remapping
49
+ old_to_new = {old: new for new, old in enumerate(sorted_clusters)}
50
+
51
+ # 5. Apply remapping (preserve -1 for filtered)
52
+ return np.array([old_to_new.get(c, c) for c in clusters])
53
+
54
+
55
+ def find_optimal_components(features, threshold=0.95):
56
+ # Lazy import: sklearn is slow to load (~600ms), defer until needed
57
+ from sklearn.decomposition import PCA # noqa: PLC0415
58
+
59
+ pca = PCA()
60
+ pca.fit(features)
61
+ explained_variance = pca.explained_variance_ratio_
62
+ # 累積寄与率が95%を超える次元数を選択する例
63
+ cumulative_variance = np.cumsum(explained_variance)
64
+ optimal_n = np.argmax(cumulative_variance >= threshold) + 1
65
+ return min(optimal_n, len(features) - 1)
66
+
67
+
68
+ def process_edges_batch(batch_indices, all_indices, h, use_umap_embs, pca=None):
69
+ """Process a batch of nodes and their edges"""
70
+ edges = []
71
+ weights = []
72
+
73
+ for i in batch_indices:
74
+ for j in all_indices[i]:
75
+ if i == j: # skip self loop
76
+ continue
77
+
78
+ if use_umap_embs:
79
+ distance = np.linalg.norm(h[i] - h[j])
80
+ weight = np.exp(-distance)
81
+ else:
82
+ explained_variance_ratio = pca.explained_variance_ratio_
83
+ weighted_diff = (h[i] - h[j]) * np.sqrt(explained_variance_ratio[: len(h[i])])
84
+ distance = np.linalg.norm(weighted_diff)
85
+ weight = np.exp(-distance / distance.mean())
86
+
87
+ edges.append((i, j))
88
+ weights.append(weight)
89
+
90
+ return edges, weights
91
+
92
+
93
+ def leiden_cluster(
94
+ features: np.ndarray,
95
+ resolution: float = 1.0,
96
+ n_jobs: int = -1,
97
+ on_progress: Callable[[str], None] | None = None,
98
+ ) -> np.ndarray:
99
+ """
100
+ Perform Leiden clustering on feature embeddings.
101
+
102
+ Args:
103
+ features: Feature matrix (n_samples, n_features)
104
+ resolution: Leiden clustering resolution parameter
105
+ n_jobs: Number of parallel jobs (-1 = all CPUs)
106
+ on_progress: Optional callback for progress updates, receives message string
107
+
108
+ Returns:
109
+ np.ndarray: Cluster labels for each sample
110
+ """
111
+ # Lazy import: sklearn/igraph/networkx are slow to load, defer until needed
112
+ import igraph as ig # noqa: PLC0415
113
+ import leidenalg as la # noqa: PLC0415
114
+ import networkx as nx # noqa: PLC0415
115
+ from joblib import Parallel, delayed # noqa: PLC0415
116
+ from sklearn.decomposition import PCA # noqa: PLC0415
117
+ from sklearn.neighbors import NearestNeighbors # noqa: PLC0415
118
+
119
+ if n_jobs < 0:
120
+ n_jobs = multiprocessing.cpu_count()
121
+ n_samples = features.shape[0]
122
+
123
+ def _progress(msg: str):
124
+ if on_progress:
125
+ on_progress(msg)
126
+
127
+ # 1. PCA
128
+ _progress("Processing PCA")
129
+ n_components = find_optimal_components(features)
130
+ pca = PCA(n_components)
131
+ target_features = pca.fit_transform(features)
132
+
133
+ # 2. KNN
134
+ _progress("Processing KNN")
135
+ k = int(np.sqrt(len(target_features)))
136
+ nn = NearestNeighbors(n_neighbors=k).fit(target_features)
137
+ distances, indices = nn.kneighbors(target_features)
138
+
139
+ # 3. Build graph
140
+ _progress("Building graph")
141
+ G = nx.Graph()
142
+ G.add_nodes_from(range(n_samples))
143
+
144
+ batch_size = max(1, n_samples // n_jobs)
145
+ batches = [list(range(i, min(i + batch_size, n_samples))) for i in range(0, n_samples, batch_size)]
146
+ results = Parallel(n_jobs=n_jobs)(
147
+ [delayed(process_edges_batch)(batch, indices, target_features, False, pca) for batch in batches]
148
+ )
149
+
150
+ for batch_edges, batch_weights in results:
151
+ for (i, j), weight in zip(batch_edges, batch_weights):
152
+ G.add_edge(i, j, weight=weight)
153
+
154
+ # 4. Leiden clustering
155
+ _progress("Leiden clustering")
156
+ edges = list(G.edges())
157
+ weights = [G[u][v]["weight"] for u, v in edges]
158
+ ig_graph = ig.Graph(n=n_samples, edges=edges, edge_attrs={"weight": weights})
159
+
160
+ partition = la.find_partition(
161
+ ig_graph,
162
+ la.RBConfigurationVertexPartition,
163
+ weights="weight",
164
+ resolution_parameter=resolution,
165
+ )
166
+
167
+ # 5. Finalize
168
+ _progress("Finalizing")
169
+ clusters = np.full(n_samples, -1)
170
+ for i, community in enumerate(partition):
171
+ for node in community:
172
+ clusters[node] = i
173
+
174
+ return clusters
@@ -0,0 +1,232 @@
1
+ """
2
+ HDF5 path utilities for consistent namespace and filter handling
3
+ """
4
+
5
+ from pathlib import Path
6
+
7
+ import h5py
8
+
9
+ # Reserved namespace names that cannot be used
10
+ # These conflict with existing HDF5 structure (e.g., model/features, model/metadata)
11
+ RESERVED_NAMESPACES = frozenset({"features", "metadata", "latent_features"})
12
+
13
+
14
+ def validate_namespace(namespace: str) -> bool:
15
+ """
16
+ Validate namespace string
17
+
18
+ Args:
19
+ namespace: Namespace to validate
20
+
21
+ Returns:
22
+ True if valid, False if invalid
23
+ """
24
+ if not namespace:
25
+ return False
26
+
27
+ if namespace in RESERVED_NAMESPACES:
28
+ return False
29
+
30
+ return True
31
+
32
+
33
+ def normalize_filename(path: str) -> str:
34
+ """
35
+ Normalize filename for use in namespace
36
+
37
+ Args:
38
+ path: File path
39
+
40
+ Returns:
41
+ Normalized name (stem only, forbidden chars replaced)
42
+ """
43
+ name = Path(path).stem
44
+ # Replace forbidden characters
45
+ name = name.replace("+", "_") # + is reserved for separator
46
+ name = name.replace("/", "_") # path separator
47
+ return name
48
+
49
+
50
+ def build_namespace(input_paths: list[str]) -> str:
51
+ """
52
+ Build namespace from input file paths
53
+
54
+ Note: No validation here - auto-generated namespaces always contain +
55
+ Validation happens at build_cluster_path() which is the final path assembly
56
+
57
+ Args:
58
+ input_paths: List of HDF5 file paths
59
+
60
+ Returns:
61
+ Namespace string
62
+ - Single file: "default"
63
+ - Multiple files: "file1+file2+..." (sorted, normalized)
64
+ """
65
+ if len(input_paths) == 1:
66
+ return "default"
67
+
68
+ # Normalize and sort filenames
69
+ names = sorted([normalize_filename(p) for p in input_paths])
70
+ return "+".join(names)
71
+
72
+
73
+ def build_cluster_path(
74
+ model_name: str,
75
+ namespace: str = "default",
76
+ filters: list[list[int]] | None = None,
77
+ dataset: str = "clusters",
78
+ ) -> str:
79
+ """
80
+ Build HDF5 path for clustering data
81
+
82
+ Args:
83
+ model_name: Model name (e.g., "uni", "gigapath")
84
+ namespace: Namespace (e.g., "default", "001+002")
85
+ filters: Nested list of cluster filters, e.g., [[1,2,3], [0,1]]
86
+ dataset: Dataset name ("clusters", "umap", "pca1", "pca2", "pca3")
87
+
88
+ Returns:
89
+ Full HDF5 path
90
+
91
+ Raises:
92
+ ValueError: If namespace is invalid or reserved
93
+
94
+ Examples:
95
+ >>> build_cluster_path("uni", "default")
96
+ 'uni/default/clusters'
97
+
98
+ >>> build_cluster_path("uni", "default", [[1,2,3]])
99
+ 'uni/default/filter/1+2+3/clusters'
100
+
101
+ >>> build_cluster_path("uni", "default", [[1,2,3], [0,1]])
102
+ 'uni/default/filter/1+2+3/filter/0+1/clusters'
103
+
104
+ >>> build_cluster_path("uni", "001+002", [[5]])
105
+ 'uni/001+002/filter/5/clusters'
106
+ """
107
+ # Validate namespace
108
+ if not validate_namespace(namespace):
109
+ raise ValueError(f"Invalid namespace '{namespace}'. Reserved names: {', '.join(sorted(RESERVED_NAMESPACES))}")
110
+
111
+ path = f"{model_name}/{namespace}"
112
+
113
+ if filters:
114
+ for filter_ids in filters:
115
+ filter_str = "+".join(map(str, sorted(filter_ids)))
116
+ path += f"/filter/{filter_str}"
117
+
118
+ path += f"/{dataset}"
119
+ return path
120
+
121
+
122
+ def parse_cluster_path(path: str) -> dict:
123
+ """
124
+ Parse cluster path into components
125
+
126
+ Args:
127
+ path: HDF5 path (e.g., "uni/default/filter/1+2+3/clusters")
128
+
129
+ Returns:
130
+ Dict with keys: model_name, namespace, filters, dataset
131
+
132
+ Examples:
133
+ >>> parse_cluster_path("uni/default/clusters")
134
+ {'model_name': 'uni', 'namespace': 'default', 'filters': [], 'dataset': 'clusters'}
135
+
136
+ >>> parse_cluster_path("uni/default/filter/1+2+3/clusters")
137
+ {'model_name': 'uni', 'namespace': 'default', 'filters': [[1,2,3]], 'dataset': 'clusters'}
138
+ """
139
+ parts = path.split("/")
140
+
141
+ result = {"model_name": parts[0], "namespace": parts[1], "filters": [], "dataset": parts[-1]}
142
+
143
+ # Parse filter hierarchy
144
+ i = 2
145
+ while i < len(parts) - 1:
146
+ if parts[i] == "filter":
147
+ filter_str = parts[i + 1]
148
+ filter_ids = [int(x) for x in filter_str.split("+")]
149
+ result["filters"].append(filter_ids)
150
+ i += 2
151
+ else:
152
+ i += 1
153
+
154
+ return result
155
+
156
+
157
+ def list_namespaces(h5_file, model_name: str) -> list[str]:
158
+ """
159
+ List all namespaces in HDF5 file for given model
160
+
161
+ Args:
162
+ h5_file: h5py.File object (opened)
163
+ model_name: Model name
164
+
165
+ Returns:
166
+ List of namespace strings
167
+ """
168
+ if model_name not in h5_file:
169
+ return []
170
+
171
+ namespaces = []
172
+ for key in h5_file[model_name].keys():
173
+ if isinstance(h5_file[f"{model_name}/{key}"], h5py.Group):
174
+ # Check if it contains 'clusters' dataset
175
+ if "clusters" in h5_file[f"{model_name}/{key}"]:
176
+ namespaces.append(key)
177
+
178
+ return namespaces
179
+
180
+
181
+ def list_filters(h5_file, model_name: str, namespace: str) -> list[str]:
182
+ """
183
+ List all filter paths under a namespace
184
+
185
+ Args:
186
+ h5_file: h5py.File object (opened)
187
+ model_name: Model name
188
+ namespace: Namespace
189
+
190
+ Returns:
191
+ List of filter strings (e.g., ["1+2+3", "5"])
192
+ """
193
+ base_path = f"{model_name}/{namespace}/filter"
194
+ if base_path not in h5_file:
195
+ return []
196
+
197
+ filters = []
198
+
199
+ def visit_filters(name, obj):
200
+ if isinstance(obj, h5py.Group) and "clusters" in obj:
201
+ # Extract filter string from full path
202
+ rel_path = name.replace(base_path + "/", "")
203
+ # Remove '/filter/' segments to get just the IDs
204
+ filter_str = rel_path.replace("/filter/", "/")
205
+ filters.append(filter_str)
206
+
207
+ h5_file[base_path].visititems(visit_filters)
208
+
209
+ return filters
210
+
211
+
212
+ def ensure_groups(h5file: h5py.File, path: str) -> None:
213
+ """
214
+ Ensure all parent groups exist for a given path.
215
+
216
+ Args:
217
+ h5file: Open h5py.File object
218
+ path: Full path to dataset (e.g., "model/namespace/clusters")
219
+
220
+ Example:
221
+ >>> with h5py.File("data.h5", "a") as f:
222
+ ... ensure_groups(f, "uni/default/filter/1+2/clusters")
223
+ ... f.create_dataset("uni/default/filter/1+2/clusters", data=clusters)
224
+ """
225
+ parts = path.split("/")
226
+ group_parts = parts[:-1] # Exclude the dataset name
227
+
228
+ current = ""
229
+ for part in group_parts:
230
+ current = f"{current}/{part}" if current else part
231
+ if current not in h5file:
232
+ h5file.create_group(current)
@@ -0,0 +1,227 @@
1
+ """
2
+ Plotting utilities for 2D scatter plots and 1D violin plots
3
+ """
4
+
5
+ import numpy as np
6
+ from matplotlib import pyplot as plt
7
+
8
+ from ..common import _get_cluster_color
9
+
10
+
11
+ def plot_scatter_2d(
12
+ coords_list: list[np.ndarray],
13
+ clusters_list: list[np.ndarray],
14
+ filenames: list[str],
15
+ title: str = "2D Projection",
16
+ figsize: tuple = (12, 8),
17
+ xlabel: str = "Dimension 1",
18
+ ylabel: str = "Dimension 2",
19
+ ):
20
+ """
21
+ Plot 2D scatter plot from single or multiple files
22
+
23
+ Unified plotting logic that works for both single and multiple files.
24
+
25
+ Args:
26
+ coords_list: List of coordinate arrays (one per file)
27
+ clusters_list: List of cluster arrays (one per file)
28
+ filenames: List of file names for legend
29
+ title: Plot title
30
+ figsize: Figure size
31
+ xlabel: X-axis label
32
+ ylabel: Y-axis label
33
+
34
+ Returns:
35
+ matplotlib Figure
36
+ """
37
+
38
+ markers = ["o", "s", "^", "D", "v", "<", ">", "p", "*", "h"]
39
+
40
+ # Get all unique clusters (same namespace = same clusters)
41
+ all_unique_clusters = sorted(np.unique(np.concatenate(clusters_list)))
42
+ cluster_to_color = {cluster_id: _get_cluster_color(cluster_id) for cluster_id in all_unique_clusters}
43
+
44
+ fig, ax = plt.subplots(figsize=figsize)
45
+
46
+ # Single file: simpler legend (no file markers)
47
+ if len(coords_list) == 1:
48
+ for cluster_id in all_unique_clusters:
49
+ mask = clusters_list[0] == cluster_id
50
+ if np.sum(mask) > 0:
51
+ if cluster_id == -1:
52
+ color = "black"
53
+ label = "Noise"
54
+ size = 12
55
+ else:
56
+ color = cluster_to_color[cluster_id]
57
+ label = f"Cluster {cluster_id}"
58
+ size = 7
59
+ ax.scatter(
60
+ coords_list[0][mask, 0],
61
+ coords_list[0][mask, 1],
62
+ s=size,
63
+ c=[color],
64
+ label=label,
65
+ alpha=0.8,
66
+ )
67
+ else:
68
+ # Multiple files: show both cluster colors and file markers
69
+ # Create handles for cluster legend (colors)
70
+ cluster_handles = []
71
+ for cluster_id in all_unique_clusters:
72
+ if cluster_id < 0: # Skip noise
73
+ continue
74
+ handle = plt.Line2D(
75
+ [0],
76
+ [0],
77
+ marker="o",
78
+ color="w",
79
+ markerfacecolor=cluster_to_color[cluster_id],
80
+ markersize=8,
81
+ label=f"Cluster {cluster_id}",
82
+ )
83
+ cluster_handles.append(handle)
84
+
85
+ # Create handles for file legend (markers)
86
+ file_handles = []
87
+ for i, filename in enumerate(filenames):
88
+ marker = markers[i % len(markers)]
89
+ handle = plt.Line2D(
90
+ [0], [0], marker=marker, color="w", markerfacecolor="gray", markersize=8, label=filename
91
+ )
92
+ file_handles.append(handle)
93
+
94
+ # Plot all data: cluster-first, then file-specific markers
95
+ for cluster_id in all_unique_clusters:
96
+ for i, (coords, clusters, filename) in enumerate(zip(coords_list, clusters_list, filenames)):
97
+ mask = clusters == cluster_id
98
+ if np.sum(mask) > 0: # Only plot if this file has patches in this cluster
99
+ marker = markers[i % len(markers)]
100
+ ax.scatter(
101
+ coords[mask, 0],
102
+ coords[mask, 1],
103
+ marker=marker,
104
+ c=[cluster_to_color[cluster_id]],
105
+ s=10,
106
+ alpha=0.6,
107
+ )
108
+
109
+ # Add legends for multiple files
110
+ legend1 = ax.legend(handles=cluster_handles, title="Clusters", loc="upper left", bbox_to_anchor=(1.02, 1))
111
+ ax.add_artist(legend1)
112
+ ax.legend(handles=file_handles, title="Sources", loc="upper left", bbox_to_anchor=(1.02, 0.5))
113
+
114
+ # Draw cluster numbers at centroids
115
+ all_coords_combined = np.concatenate(coords_list)
116
+ all_clusters_combined = np.concatenate(clusters_list)
117
+ for cluster_id in all_unique_clusters:
118
+ if cluster_id < 0: # Skip noise cluster
119
+ continue
120
+ cluster_points = all_coords_combined[all_clusters_combined == cluster_id]
121
+ if len(cluster_points) < 1:
122
+ continue
123
+ centroid_x = np.mean(cluster_points[:, 0])
124
+ centroid_y = np.mean(cluster_points[:, 1])
125
+ ax.text(
126
+ centroid_x,
127
+ centroid_y,
128
+ str(cluster_id),
129
+ fontsize=12,
130
+ fontweight="bold",
131
+ ha="center",
132
+ va="center",
133
+ bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
134
+ )
135
+
136
+ # Single file: show legend normally
137
+ if len(coords_list) == 1:
138
+ ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
139
+
140
+ ax.set_title(title)
141
+ ax.set_xlabel(xlabel)
142
+ ax.set_ylabel(ylabel)
143
+ plt.tight_layout()
144
+
145
+ return fig
146
+
147
+
148
+ def plot_violin_1d(
149
+ values_list: list[np.ndarray],
150
+ clusters_list: list[np.ndarray],
151
+ title: str = "Distribution by Cluster",
152
+ ylabel: str = "Value",
153
+ figsize: tuple = (12, 8),
154
+ ):
155
+ """
156
+ Plot 1D violin plot with cluster distribution
157
+
158
+ Args:
159
+ values_list: List of 1D value arrays (one per file)
160
+ clusters_list: List of cluster arrays (one per file)
161
+ title: Plot title
162
+ ylabel: Y-axis label
163
+ figsize: Figure size
164
+
165
+ Returns:
166
+ matplotlib Figure
167
+ """
168
+
169
+ # Combine all data
170
+ all_values = np.concatenate(values_list)
171
+ all_clusters = np.concatenate(clusters_list)
172
+
173
+ # Show all clusters except noise (-1)
174
+ cluster_ids = sorted([c for c in np.unique(all_clusters) if c >= 0])
175
+
176
+ # Prepare violin plot data
177
+ data = []
178
+ labels = []
179
+
180
+ # Add "All" first
181
+ data.append(all_values)
182
+ labels.append("All")
183
+
184
+ # Then add each cluster
185
+ for cluster_id in cluster_ids:
186
+ cluster_mask = all_clusters == cluster_id
187
+ cluster_values = all_values[cluster_mask]
188
+ if len(cluster_values) > 0:
189
+ data.append(cluster_values)
190
+ labels.append(f"Cluster {cluster_id}")
191
+
192
+ if len(data) == 0:
193
+ raise ValueError("No data for specified clusters")
194
+
195
+ # Create plot
196
+ # Lazy import: seaborn is slow to load (~500ms), defer until needed
197
+ import seaborn as sns # noqa: PLC0415
198
+
199
+ fig = plt.figure(figsize=figsize)
200
+ sns.set_style("whitegrid")
201
+ ax = plt.subplot(111)
202
+
203
+ # Prepare colors: gray for "All", then cluster colors
204
+ palette = ["gray"] # Color for "All"
205
+ for cluster_id in cluster_ids:
206
+ color = _get_cluster_color(cluster_id)
207
+ palette.append(color)
208
+
209
+ sns.violinplot(data=data, ax=ax, inner="box", cut=0, zorder=1, alpha=0.5, palette=palette)
210
+
211
+ # Scatter: first is "All" with gray, then clusters
212
+ for i, d in enumerate(data):
213
+ x = np.random.normal(i, 0.05, size=len(d))
214
+ if i == 0:
215
+ color = "gray" # All
216
+ else:
217
+ color = _get_cluster_color(cluster_ids[i - 1])
218
+ ax.scatter(x, d, alpha=0.8, s=5, color=color, zorder=2)
219
+
220
+ ax.set_xticks(np.arange(0, len(labels)))
221
+ ax.set_xticklabels(labels)
222
+ ax.set_ylabel(ylabel)
223
+ ax.set_title(title)
224
+ ax.grid(axis="y", linestyle="--", alpha=0.7)
225
+ plt.tight_layout()
226
+
227
+ return fig