wsi-toolbox 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ """
2
+ Clustering command for WSI features
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from pydantic import BaseModel
8
+
9
+ from ..utils.analysis import leiden_cluster, reorder_clusters_by_pca
10
+ from ..utils.hdf5_paths import build_cluster_path, build_namespace, ensure_groups
11
+ from . import _get, _progress, get_config
12
+ from .data_loader import MultipleContext
13
+
14
+
15
+ class ClusteringResult(BaseModel):
16
+ """Result of clustering operation"""
17
+
18
+ cluster_count: int
19
+ feature_count: int
20
+ target_path: str
21
+ skipped: bool = False
22
+
23
+
24
+ class ClusteringCommand:
25
+ """
26
+ Perform Leiden clustering on features or UMAP coordinates
27
+
28
+ Input:
29
+ - features (from <model>/features)
30
+ - namespace + filters (recursive hierarchy)
31
+ - source: "features" or "umap"
32
+ - resolution: clustering resolution
33
+
34
+ Output:
35
+ - clusters written to deepest level
36
+ - metadata (resolution, source) saved as HDF5 attributes
37
+
38
+ Example hierarchy:
39
+ uni/default/filter/1+2+3/filter/4+5/clusters
40
+ ↑ with attributes: resolution=1.0, source="features"
41
+
42
+ Usage:
43
+ # Basic clustering
44
+ cmd = ClusteringCommand(resolution=1.0)
45
+ result = cmd('data.h5') # → uni/default/clusters
46
+
47
+ # Filtered clustering
48
+ cmd = ClusteringCommand(parent_filters=[[1,2,3], [4,5]])
49
+ result = cmd('data.h5') # → uni/default/filter/1+2+3/filter/4+5/clusters
50
+
51
+ # UMAP-based clustering
52
+ cmd = ClusteringCommand(source="umap")
53
+ result = cmd('data.h5') # → uses uni/default/umap
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ resolution: float = 1.0,
59
+ namespace: str | None = None,
60
+ parent_filters: list[list[int]] | None = None,
61
+ source: str = "features",
62
+ sort_clusters: bool = True,
63
+ overwrite: bool = False,
64
+ model_name: str | None = None,
65
+ ):
66
+ """
67
+ Args:
68
+ resolution: Leiden clustering resolution
69
+ namespace: Explicit namespace (None = auto-generate)
70
+ parent_filters: Hierarchical filters, e.g., [[1,2,3], [4,5]]
71
+ source: "features" or "umap"
72
+ sort_clusters: Reorder cluster IDs by PCA distribution (default: True)
73
+ overwrite: Overwrite existing clusters
74
+ model_name: Model name (None = use global default)
75
+ """
76
+ self.resolution = resolution
77
+ self.namespace = namespace
78
+ self.parent_filters = parent_filters or []
79
+ self.source = source
80
+ self.sort_clusters = sort_clusters
81
+ self.overwrite = overwrite
82
+ self.model_name = _get("model_name", model_name)
83
+
84
+ # Validate
85
+ if self.model_name not in ["uni", "gigapath", "virchow2"]:
86
+ raise ValueError(f"Invalid model: {self.model_name}")
87
+ if self.source not in ["features", "umap"]:
88
+ raise ValueError(f"Invalid source: {self.source}")
89
+
90
+ # Internal state
91
+ self.hdf5_paths = []
92
+ self.clusters = None
93
+
94
+ def __call__(self, hdf5_paths: str | list[str]) -> ClusteringResult:
95
+ """
96
+ Execute clustering
97
+
98
+ Args:
99
+ hdf5_paths: Single HDF5 path or list of paths
100
+
101
+ Returns:
102
+ ClusteringResult
103
+ """
104
+ # Normalize to list
105
+ if isinstance(hdf5_paths, str):
106
+ hdf5_paths = [hdf5_paths]
107
+ self.hdf5_paths = hdf5_paths
108
+
109
+ # Determine namespace
110
+ if self.namespace is None:
111
+ self.namespace = build_namespace(hdf5_paths)
112
+ elif "+" in self.namespace:
113
+ raise ValueError("Namespace cannot contain '+' (reserved for multi-file auto-generated namespaces)")
114
+
115
+ # Build target path
116
+ target_path = build_cluster_path(
117
+ self.model_name, self.namespace, filters=self.parent_filters, dataset="clusters"
118
+ )
119
+
120
+ # Check if already exists
121
+ if not self.overwrite:
122
+ with h5py.File(hdf5_paths[0], "r") as f:
123
+ if target_path in f:
124
+ clusters = f[target_path][:]
125
+ cluster_count = len([c for c in set(clusters) if c >= 0])
126
+ if get_config().verbose:
127
+ print(f"Clusters already exist at {target_path}")
128
+ return ClusteringResult(
129
+ cluster_count=cluster_count,
130
+ feature_count=np.sum(clusters >= 0),
131
+ target_path=target_path,
132
+ skipped=True,
133
+ )
134
+
135
+ # Execute with progress tracking
136
+ # Total: 1 (load) + 5 (clustering steps) + 1 (write) = 7
137
+ with _progress(total=7, desc="Clustering") as pbar:
138
+ # Load data
139
+ pbar.set_description("Loading data")
140
+ ctx = MultipleContext(hdf5_paths, self.model_name, self.namespace, self.parent_filters)
141
+ data = ctx.load_features(source=self.source)
142
+ pbar.update(1)
143
+
144
+ # Perform clustering using analysis module
145
+ def on_progress(msg: str):
146
+ pbar.set_description(msg)
147
+ pbar.update(1)
148
+
149
+ self.clusters = leiden_cluster(
150
+ data,
151
+ resolution=self.resolution,
152
+ on_progress=on_progress,
153
+ )
154
+
155
+ # Reorder cluster IDs by PCA distribution for consistent visualization
156
+ if self.sort_clusters:
157
+ pbar.set_description("Sorting clusters")
158
+ features = ctx.load_features(source="features")
159
+ from sklearn.decomposition import PCA # noqa: PLC0415
160
+
161
+ pca = PCA(n_components=1)
162
+ pca1 = pca.fit_transform(features).flatten()
163
+ self.clusters = reorder_clusters_by_pca(self.clusters, pca1)
164
+
165
+ cluster_count = len(set(self.clusters))
166
+
167
+ # Write results
168
+ pbar.set_description("Writing results")
169
+ self._write_results(ctx, target_path)
170
+ pbar.update(1)
171
+
172
+ # Verbose output after progress bar closes
173
+ if get_config().verbose:
174
+ print(f"Loaded {len(data)} samples from {self.source}")
175
+ print(f"Found {cluster_count} clusters")
176
+ print(f"Wrote {target_path} to {len(hdf5_paths)} file(s)")
177
+
178
+ return ClusteringResult(cluster_count=cluster_count, feature_count=len(data), target_path=target_path)
179
+
180
+ def _write_results(self, ctx: MultipleContext, target_path: str):
181
+ """Write clustering results to HDF5 files"""
182
+ for file_slice in ctx:
183
+ clusters = file_slice.slice(self.clusters)
184
+
185
+ with h5py.File(file_slice.hdf5_path, "a") as f:
186
+ ensure_groups(f, target_path)
187
+
188
+ if target_path in f:
189
+ del f[target_path]
190
+
191
+ # Fill with -1 for filtered patches
192
+ full_clusters = np.full(len(file_slice.mask), -1, dtype=clusters.dtype)
193
+ full_clusters[file_slice.mask] = clusters
194
+
195
+ ds = f.create_dataset(target_path, data=full_clusters)
196
+ ds.attrs["resolution"] = self.resolution
197
+ ds.attrs["source"] = self.source
198
+ ds.attrs["model"] = self.model_name
@@ -0,0 +1,219 @@
1
+ """
2
+ Multi-file context for HDF5 operations with namespace and filter support.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import h5py
8
+ import numpy as np
9
+
10
+ from ..utils.hdf5_paths import build_cluster_path
11
+
12
+
13
+ @dataclass
14
+ class FileSlice:
15
+ """Context for a single file in multi-file operations."""
16
+
17
+ hdf5_path: str
18
+ mask: np.ndarray
19
+ start: int
20
+ end: int
21
+
22
+ def slice(self, data: np.ndarray) -> np.ndarray:
23
+ """Extract this file's portion from concatenated data."""
24
+ return data[self.start : self.end]
25
+
26
+ @property
27
+ def count(self) -> int:
28
+ """Number of masked (active) samples in this file."""
29
+ return self.end - self.start
30
+
31
+
32
+ class MultipleContext:
33
+ """
34
+ Multi-file context for HDF5 operations with namespace + filters.
35
+
36
+ Handles the common pattern of:
37
+ 1. Loading existing clusters at each filter level
38
+ 2. Building cumulative mask
39
+ 3. Loading features/UMAP coordinates with the mask
40
+ 4. Iterating over files for writing results
41
+
42
+ Usage:
43
+ ctx = MultipleContext(hdf5_paths, model_name, namespace, filters)
44
+ data = ctx.load_features()
45
+
46
+ results = some_computation(data)
47
+
48
+ for file_slice in ctx:
49
+ file_results = file_slice.slice(results)
50
+ with h5py.File(file_slice.hdf5_path, "a") as f:
51
+ # write file_results with file_slice.mask
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ hdf5_paths: list[str],
57
+ model_name: str,
58
+ namespace: str,
59
+ parent_filters: list[list[int]] | None = None,
60
+ ):
61
+ """
62
+ Args:
63
+ hdf5_paths: List of HDF5 file paths
64
+ model_name: Model name (e.g., "uni")
65
+ namespace: Namespace (e.g., "default", "001+002")
66
+ parent_filters: Hierarchical filters, e.g., [[1,2,3], [4,5]]
67
+ """
68
+ self.hdf5_paths = hdf5_paths
69
+ self.model_name = model_name
70
+ self.namespace = namespace
71
+ self.parent_filters = parent_filters or []
72
+
73
+ # Populated after load_features()
74
+ self._masks: list[np.ndarray] | None = None
75
+ self._total_count: int = 0
76
+
77
+ def load_features(self, source: str = "features") -> np.ndarray:
78
+ """
79
+ Load features or UMAP coordinates with filtering.
80
+
81
+ Args:
82
+ source: "features" or "umap"
83
+
84
+ Returns:
85
+ Concatenated and normalized features/UMAP coordinates
86
+ """
87
+ data_list = []
88
+ self._masks = []
89
+
90
+ for hdf5_path in self.hdf5_paths:
91
+ with h5py.File(hdf5_path, "r") as f:
92
+ patch_count = f["metadata/patch_count"][()]
93
+
94
+ # Build cumulative mask from filters
95
+ mask = self._build_mask(f, patch_count)
96
+
97
+ # Validate mask length
98
+ if len(mask) != patch_count:
99
+ raise RuntimeError(f"Mask length mismatch in {hdf5_path}: expected {patch_count}, got {len(mask)}")
100
+
101
+ self._masks.append(mask)
102
+
103
+ # Load data based on source
104
+ if source == "umap":
105
+ umap_path = build_cluster_path(
106
+ self.model_name,
107
+ self.namespace,
108
+ filters=self.parent_filters if self.parent_filters else None,
109
+ dataset="umap",
110
+ )
111
+ if umap_path not in f:
112
+ raise RuntimeError(f"UMAP coordinates not found at {umap_path}. Run 'wsi-toolbox umap' first.")
113
+ data = f[umap_path][mask]
114
+ if np.any(np.isnan(data)):
115
+ raise RuntimeError(f"NaN values in UMAP coordinates at {umap_path}")
116
+ else:
117
+ feature_path = f"{self.model_name}/features"
118
+ if feature_path not in f:
119
+ raise RuntimeError(f"Features not found at {feature_path} in {hdf5_path}")
120
+ data = f[feature_path][mask]
121
+
122
+ data_list.append(data)
123
+
124
+ # Concatenate and normalize
125
+ # Lazy import: sklearn is slow to load (~600ms), defer until needed
126
+ from sklearn.preprocessing import StandardScaler # noqa: PLC0415
127
+
128
+ data = np.concatenate(data_list)
129
+ self._total_count = len(data)
130
+
131
+ scaler = StandardScaler()
132
+ data = scaler.fit_transform(data)
133
+
134
+ return data
135
+
136
+ def __iter__(self):
137
+ """Iterate over files yielding FileSlice for each."""
138
+ if self._masks is None:
139
+ raise RuntimeError("Call load_features() before iterating")
140
+
141
+ cursor = 0
142
+ for hdf5_path, mask in zip(self.hdf5_paths, self._masks):
143
+ count = np.sum(mask)
144
+ yield FileSlice(
145
+ hdf5_path=hdf5_path,
146
+ mask=mask,
147
+ start=cursor,
148
+ end=cursor + count,
149
+ )
150
+ cursor += count
151
+
152
+ def __len__(self) -> int:
153
+ """Total number of samples across all files."""
154
+ return self._total_count
155
+
156
+ @property
157
+ def masks(self) -> list[np.ndarray]:
158
+ """Get masks (for backward compatibility)."""
159
+ if self._masks is None:
160
+ raise RuntimeError("Call load_features() before accessing masks")
161
+ return self._masks
162
+
163
+ def _build_mask(self, f: h5py.File, patch_count: int) -> np.ndarray:
164
+ """
165
+ Build cumulative mask from hierarchical filters
166
+
167
+ Strategy: Only read the deepest cluster level
168
+ - If filters = [[1,2,3], [4,5]], only read clusters at filter/1+2+3/filter/4+5
169
+ - Those clusters are already filtered by [1,2,3], so we only need to filter by [4,5]
170
+ """
171
+ if not self.parent_filters:
172
+ # No filtering
173
+ return np.ones(patch_count, dtype=bool)
174
+
175
+ # Get the deepest cluster path (parent of where we'll write new clusters)
176
+ # If filters = [[1,2,3], [4,5]], we need clusters at filter/1+2+3/
177
+ parent_cluster_path = build_cluster_path(
178
+ self.model_name,
179
+ self.namespace,
180
+ filters=self.parent_filters[:-1] if len(self.parent_filters) > 1 else None,
181
+ dataset="clusters",
182
+ )
183
+
184
+ if parent_cluster_path not in f:
185
+ raise RuntimeError(
186
+ f"Parent clusters not found at {parent_cluster_path}. Run clustering at parent level first."
187
+ )
188
+
189
+ clusters = f[parent_cluster_path][:]
190
+
191
+ # Filter by the last filter only (because previous filters are already applied)
192
+ last_filter = self.parent_filters[-1]
193
+ mask = np.isin(clusters, last_filter)
194
+
195
+ return mask
196
+
197
+ def get_parent_cluster_info(self, hdf5_path: str) -> tuple[np.ndarray, np.ndarray]:
198
+ """
199
+ Get parent clusters and mask for a single file
200
+
201
+ Returns:
202
+ (clusters, mask): Parent cluster values and boolean mask
203
+ """
204
+ with h5py.File(hdf5_path, "r") as f:
205
+ patch_count = f["metadata/patch_count"][()]
206
+ mask = self._build_mask(f, patch_count)
207
+
208
+ if self.parent_filters:
209
+ parent_cluster_path = build_cluster_path(
210
+ self.model_name,
211
+ self.namespace,
212
+ filters=self.parent_filters[:-1] if len(self.parent_filters) > 1 else None,
213
+ dataset="clusters",
214
+ )
215
+ clusters = f[parent_cluster_path][:]
216
+ else:
217
+ clusters = None
218
+
219
+ return clusters, mask
@@ -0,0 +1,160 @@
1
+ """
2
+ DZI export command for Deep Zoom Image format
3
+ """
4
+
5
+ from pathlib import Path
6
+
7
+ from PIL import Image
8
+ from pydantic import BaseModel
9
+
10
+ from ..wsi_files import PyramidalWSIFile, WSIFile, create_wsi_file
11
+ from . import _progress, get_config
12
+
13
+
14
+ class DziResult(BaseModel):
15
+ """Result of DZI export"""
16
+
17
+ dzi_path: str
18
+ max_level: int
19
+ tile_size: int
20
+ overlap: int
21
+ width: int
22
+ height: int
23
+
24
+
25
+ class DziCommand:
26
+ """
27
+ Export WSI to DZI (Deep Zoom Image) format
28
+
29
+ Usage:
30
+ cmd = DziCommand(tile_size=256, overlap=0, jpeg_quality=90)
31
+ result = cmd(wsi_path='slide.svs', output_dir='output', name='slide')
32
+
33
+ # Or with existing WSIFile instance
34
+ wsi = create_wsi_file('slide.svs')
35
+ result = cmd(wsi_file=wsi, output_dir='output', name='slide')
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ tile_size: int = 256,
41
+ overlap: int = 0,
42
+ jpeg_quality: int = 90,
43
+ format: str = "jpeg",
44
+ ):
45
+ """
46
+ Initialize DZI export command
47
+
48
+ Args:
49
+ tile_size: Tile size in pixels (default: 256)
50
+ overlap: Overlap in pixels (default: 0)
51
+ jpeg_quality: JPEG compression quality (0-100)
52
+ format: Image format ("jpeg" or "png")
53
+ """
54
+ self.tile_size = tile_size
55
+ self.overlap = overlap
56
+ self.jpeg_quality = jpeg_quality
57
+ self.format = format
58
+
59
+ def __call__(
60
+ self,
61
+ wsi_path: str | None = None,
62
+ wsi_file: WSIFile | None = None,
63
+ output_dir: str = ".",
64
+ name: str = "slide",
65
+ ) -> DziResult:
66
+ """
67
+ Export WSI to DZI format
68
+
69
+ Args:
70
+ wsi_path: Path to WSI file (either this or wsi_file required)
71
+ wsi_file: WSIFile instance (either this or wsi_path required)
72
+ output_dir: Output directory
73
+ name: Base name for DZI files
74
+
75
+ Returns:
76
+ DziResult: Export metadata
77
+ """
78
+ # Get or create WSIFile
79
+ if wsi_file is None:
80
+ if wsi_path is None:
81
+ raise ValueError("Either wsi_path or wsi_file must be provided")
82
+ wsi_file = create_wsi_file(wsi_path)
83
+
84
+ # Check if pyramidal (DZI supported)
85
+ if not isinstance(wsi_file, PyramidalWSIFile):
86
+ raise TypeError(
87
+ f"DZI export requires PyramidalWSIFile, got {type(wsi_file).__name__}. "
88
+ "StandardImage does not support DZI export."
89
+ )
90
+
91
+ output_dir = Path(output_dir)
92
+ output_dir.mkdir(parents=True, exist_ok=True)
93
+
94
+ # Get dimensions
95
+ width, height = wsi_file.get_original_size()
96
+ max_level = wsi_file.get_dzi_max_level()
97
+
98
+ if get_config().verbose:
99
+ print(f"Original size: {width}x{height}")
100
+ print(f"Tile size: {self.tile_size}, Overlap: {self.overlap}")
101
+ print(f"Max zoom level: {max_level}")
102
+
103
+ # Setup directories
104
+ dzi_path = output_dir / f"{name}.dzi"
105
+ files_dir = output_dir / f"{name}_files"
106
+ files_dir.mkdir(exist_ok=True)
107
+
108
+ # Generate all levels
109
+ for level in range(max_level, -1, -1):
110
+ self._generate_level(wsi_file, files_dir, level)
111
+
112
+ # Write DZI XML
113
+ dzi_xml = wsi_file.get_dzi_xml(self.tile_size, self.overlap, self.format)
114
+ with open(dzi_path, "w", encoding="utf-8") as f:
115
+ f.write(dzi_xml)
116
+
117
+ if get_config().verbose:
118
+ print(f"DZI export complete: {dzi_path}")
119
+
120
+ return DziResult(
121
+ dzi_path=str(dzi_path),
122
+ max_level=max_level,
123
+ tile_size=self.tile_size,
124
+ overlap=self.overlap,
125
+ width=width,
126
+ height=height,
127
+ )
128
+
129
+ def _generate_level(
130
+ self,
131
+ wsi_file: PyramidalWSIFile,
132
+ files_dir: Path,
133
+ level: int,
134
+ ):
135
+ """Generate all tiles for a single level."""
136
+ level_dir = files_dir / str(level)
137
+ level_dir.mkdir(exist_ok=True)
138
+
139
+ level_width, level_height, cols, rows = wsi_file.get_dzi_level_info(level, self.tile_size)
140
+
141
+ if get_config().verbose:
142
+ print(f"Level {level}: {level_width}x{level_height}, {cols}x{rows} tiles")
143
+
144
+ ext = "png" if self.format == "png" else "jpeg"
145
+
146
+ tq = _progress(range(rows))
147
+ for row in tq:
148
+ tq.set_description(f"Level {level}: row {row + 1}/{rows}")
149
+ for col in range(cols):
150
+ tile_path = level_dir / f"{col}_{row}.{ext}"
151
+
152
+ # Get tile from WSIFile
153
+ tile_array = wsi_file.get_dzi_tile(level, col, row, self.tile_size, self.overlap)
154
+
155
+ # Save tile
156
+ img = Image.fromarray(tile_array)
157
+ if self.format == "png":
158
+ img.save(tile_path, "PNG")
159
+ else:
160
+ img.save(tile_path, "JPEG", quality=self.jpeg_quality)