wsi-toolbox 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,196 @@
1
+ """
2
+ Patch embedding extraction command
3
+ """
4
+
5
+ import gc
6
+
7
+ import h5py
8
+ import numpy as np
9
+ from pydantic import BaseModel
10
+
11
+ from ..common import create_default_model
12
+ from ..utils import safe_del
13
+ from . import _get, _progress, get_config
14
+
15
+
16
+ class PatchEmbeddingResult(BaseModel):
17
+ """Result of patch embedding extraction"""
18
+
19
+ feature_dim: int = 0
20
+ patch_count: int = 0
21
+ model: str = ""
22
+ with_latent: bool = False
23
+ skipped: bool = False
24
+
25
+
26
+ class PatchEmbeddingCommand:
27
+ """
28
+ Extract embeddings from patches using foundation models
29
+
30
+ Usage:
31
+ # Set global config once
32
+ commands.set_default_model_preset('gigapath')
33
+ commands.set_default_device('cuda')
34
+
35
+ # Create and run command
36
+ cmd = PatchEmbeddingCommand(batch_size=256, with_latent=False)
37
+ result = cmd(hdf5_path='data.h5')
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ batch_size: int = 256,
43
+ with_latent: bool = False,
44
+ overwrite: bool = False,
45
+ model_name: str | None = None,
46
+ device: str | None = None,
47
+ ):
48
+ """
49
+ Initialize patch embedding extractor
50
+
51
+ Args:
52
+ batch_size: Batch size for inference
53
+ with_latent: Whether to extract latent features
54
+ overwrite: Whether to overwrite existing features
55
+ model_name: Model name (None to use global default)
56
+ device: Device (None to use global default)
57
+
58
+ Note:
59
+ progress and verbose are controlled by global config
60
+ """
61
+ self.batch_size = batch_size
62
+ self.with_latent = with_latent
63
+ self.overwrite = overwrite
64
+ self.model_name = _get("model_name", model_name)
65
+ self.device = _get("device", device)
66
+
67
+ # Validate model
68
+ if self.model_name not in ["uni", "gigapath", "virchow2"]:
69
+ raise ValueError(f"Invalid model: {self.model_name}")
70
+
71
+ # Dataset paths
72
+ self.feature_name = f"{self.model_name}/features"
73
+ self.latent_feature_name = f"{self.model_name}/latent_features"
74
+
75
+ def __call__(self, hdf5_path: str) -> PatchEmbeddingResult:
76
+ """
77
+ Execute embedding extraction
78
+
79
+ Args:
80
+ hdf5_path: Path to HDF5 file
81
+
82
+ Returns:
83
+ PatchEmbeddingResult: Result metadata (feature_dim, patch_count, skipped, etc.)
84
+ """
85
+ # Lazy import: torch is slow to load (~800ms), defer until needed
86
+ import torch # noqa: PLC0415
87
+
88
+ # Load model (uses globally registered model generator)
89
+ model = create_default_model()
90
+ model = model.eval().to(self.device)
91
+
92
+ # Normalization parameters
93
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
94
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
95
+
96
+ done = False
97
+
98
+ try:
99
+ with h5py.File(hdf5_path, "r+") as f:
100
+ latent_size = model.patch_embed.proj.kernel_size[0]
101
+
102
+ # Check if already exists
103
+ if not self.overwrite:
104
+ if self.with_latent:
105
+ if (self.feature_name in f) and (self.latent_feature_name in f):
106
+ if get_config().verbose:
107
+ print("Already extracted. Skipped.")
108
+ done = True
109
+ return PatchEmbeddingResult(skipped=True)
110
+ if (self.feature_name in f) or (self.latent_feature_name in f):
111
+ raise RuntimeError(f"Either {self.feature_name} or {self.latent_feature_name} exists.")
112
+ else:
113
+ if self.feature_name in f:
114
+ if get_config().verbose:
115
+ print("Already extracted. Skipped.")
116
+ done = True
117
+ return PatchEmbeddingResult(skipped=True)
118
+
119
+ # Delete if overwrite
120
+ if self.overwrite:
121
+ safe_del(f, self.feature_name)
122
+ safe_del(f, self.latent_feature_name)
123
+
124
+ # Get patch count
125
+ patch_count = f["metadata/patch_count"][()]
126
+
127
+ # Create batch indices
128
+ batch_idx = [(i, min(i + self.batch_size, patch_count)) for i in range(0, patch_count, self.batch_size)]
129
+
130
+ # Create datasets
131
+ f.create_dataset(self.feature_name, shape=(patch_count, model.num_features), dtype=np.float32)
132
+ if self.with_latent:
133
+ f.create_dataset(
134
+ self.latent_feature_name,
135
+ shape=(patch_count, latent_size**2, model.num_features),
136
+ dtype=np.float16,
137
+ )
138
+
139
+ # Process batches
140
+ tq = _progress(batch_idx)
141
+ for i0, i1 in tq:
142
+ # Load batch
143
+ x = f["patches"][i0:i1]
144
+ x = (torch.from_numpy(x) / 255).permute(0, 3, 1, 2) # BHWC->BCHW
145
+ x = x.to(self.device)
146
+ x = (x - mean) / std
147
+
148
+ # Forward pass
149
+ with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
150
+ h_tensor = model.forward_features(x)
151
+
152
+ # Extract features
153
+ h = h_tensor.cpu().detach().numpy() # [B, T+L, H]
154
+ latent_index = h.shape[1] - latent_size**2
155
+ cls_feature = h[:, 0, ...]
156
+ latent_feature = h[:, latent_index:, ...]
157
+
158
+ # Save features
159
+ f[self.feature_name][i0:i1] = cls_feature
160
+ if self.with_latent:
161
+ f[self.latent_feature_name][i0:i1] = latent_feature.astype(np.float16)
162
+
163
+ # Cleanup
164
+ del x, h_tensor
165
+ torch.cuda.empty_cache()
166
+
167
+ tq.set_description(f"Processing {i0}-{i1} (total={patch_count})")
168
+ tq.refresh()
169
+
170
+ if get_config().verbose:
171
+ print(f"Embeddings dimension: {f[self.feature_name].shape}")
172
+
173
+ done = True
174
+ return PatchEmbeddingResult(
175
+ feature_dim=model.num_features,
176
+ patch_count=patch_count,
177
+ model=self.model_name,
178
+ with_latent=self.with_latent,
179
+ )
180
+
181
+ finally:
182
+ if done and get_config().verbose:
183
+ print(f"Wrote {self.feature_name}")
184
+ elif not done:
185
+ # Cleanup on error
186
+ with h5py.File(hdf5_path, "a") as f:
187
+ safe_del(f, self.feature_name)
188
+ if self.with_latent:
189
+ safe_del(f, self.latent_feature_name)
190
+ if get_config().verbose:
191
+ print(f"ABORTED! Deleted {self.feature_name}")
192
+
193
+ # Cleanup
194
+ del model, mean, std
195
+ torch.cuda.empty_cache()
196
+ gc.collect()
@@ -0,0 +1,206 @@
1
+ """
2
+ PCA scoring command for feature analysis
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from pydantic import BaseModel
8
+
9
+ from ..utils.hdf5_paths import build_cluster_path, build_namespace, ensure_groups
10
+ from . import _get, _progress, get_config
11
+ from .data_loader import MultipleContext
12
+
13
+
14
+ def sigmoid(x):
15
+ """Apply sigmoid function"""
16
+ return 1 / (1 + np.exp(-x))
17
+
18
+
19
+ class PCAResult(BaseModel):
20
+ """Result of PCA operation"""
21
+
22
+ n_samples: int
23
+ n_components: int
24
+ namespace: str
25
+ target_path: str
26
+ skipped: bool = False
27
+
28
+
29
+ class PCACommand:
30
+ """
31
+ Compute PCA scores from features
32
+
33
+ Input:
34
+ - features (from <model>/features)
35
+ - namespace + filters (recursive hierarchy)
36
+ - n_components: 1, 2, or 3
37
+ - scaler: minmax or std
38
+
39
+ Output:
40
+ - PCA scores written to deepest level
41
+ - metadata saved as HDF5 attributes
42
+
43
+ Example hierarchy:
44
+ uni/default/pca2
45
+ ↑ with attributes: n_components=2, scaler="minmax"
46
+ uni/default/filter/1+2+3/pca1
47
+ ↑ filtered, with PCA scores
48
+
49
+ Usage:
50
+ # Basic PCA
51
+ cmd = PCACommand(n_components=2)
52
+ result = cmd('data.h5') # → uni/default/pca2
53
+
54
+ # Filtered PCA
55
+ cmd = PCACommand(parent_filters=[[1,2,3]])
56
+ result = cmd('data.h5') # → uni/default/filter/1+2+3/pca2
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ n_components: int = 2,
62
+ namespace: str | None = None,
63
+ parent_filters: list[list[int]] | None = None,
64
+ scaler: str = "minmax",
65
+ overwrite: bool = False,
66
+ model_name: str | None = None,
67
+ ):
68
+ """
69
+ Args:
70
+ n_components: Number of PCA components (1, 2, or 3)
71
+ namespace: Explicit namespace (None = auto-generate)
72
+ parent_filters: Hierarchical filters, e.g., [[1,2,3], [4,5]]
73
+ scaler: Scaling method ("minmax" or "std")
74
+ overwrite: Overwrite existing PCA scores
75
+ model_name: Model name (None = use global default)
76
+ """
77
+ self.n_components = n_components
78
+ self.namespace = namespace
79
+ self.parent_filters = parent_filters or []
80
+ self.scaler = scaler
81
+ self.overwrite = overwrite
82
+ self.model_name = _get("model_name", model_name)
83
+
84
+ # Validate
85
+ if self.model_name not in ["uni", "gigapath", "virchow2"]:
86
+ raise ValueError(f"Invalid model: {self.model_name}")
87
+ if self.n_components not in [1, 2, 3]:
88
+ raise ValueError(f"Invalid n_components: {self.n_components}")
89
+ if self.scaler not in ["minmax", "std"]:
90
+ raise ValueError(f"Invalid scaler: {self.scaler}")
91
+
92
+ # Internal state
93
+ self.hdf5_paths = []
94
+ self.pca_scores = None
95
+
96
+ def __call__(self, hdf5_paths: str | list[str]) -> PCAResult:
97
+ """
98
+ Execute PCA computation
99
+
100
+ Args:
101
+ hdf5_paths: Single HDF5 path or list of paths
102
+
103
+ Returns:
104
+ PCAResult
105
+ """
106
+ # Normalize to list
107
+ if isinstance(hdf5_paths, str):
108
+ hdf5_paths = [hdf5_paths]
109
+ self.hdf5_paths = hdf5_paths
110
+
111
+ # Determine namespace
112
+ if self.namespace is None:
113
+ self.namespace = build_namespace(hdf5_paths)
114
+ elif "+" in self.namespace:
115
+ raise ValueError("Namespace cannot contain '+' (reserved for multi-file auto-generated namespaces)")
116
+
117
+ # Build target path
118
+ target_path = build_cluster_path(
119
+ self.model_name, self.namespace, filters=self.parent_filters, dataset=f"pca{self.n_components}"
120
+ )
121
+
122
+ # Check if already exists
123
+ if not self.overwrite:
124
+ with h5py.File(hdf5_paths[0], "r") as f:
125
+ if target_path in f:
126
+ scores = f[target_path][:]
127
+ n_samples = np.sum(~np.isnan(scores[:, 0]) if scores.ndim > 1 else ~np.isnan(scores))
128
+ if get_config().verbose:
129
+ print(f"PCA scores already exist at {target_path}")
130
+ return PCAResult(
131
+ n_samples=n_samples,
132
+ n_components=self.n_components,
133
+ namespace=self.namespace,
134
+ target_path=target_path,
135
+ skipped=True,
136
+ )
137
+
138
+ # Execute with progress tracking
139
+ with _progress(total=3, desc="PCA") as pbar:
140
+ # Load data
141
+ pbar.set_description("Loading features")
142
+ ctx = MultipleContext(hdf5_paths, self.model_name, self.namespace, self.parent_filters)
143
+ features = ctx.load_features(source="features")
144
+ pbar.update(1)
145
+
146
+ # Compute PCA
147
+ pbar.set_description("Computing PCA")
148
+ self.pca_scores = self._compute_pca(features)
149
+ pbar.update(1)
150
+
151
+ # Write results
152
+ pbar.set_description("Writing results")
153
+ self._write_results(ctx, target_path)
154
+ pbar.update(1)
155
+
156
+ # Verbose output after progress bar closes
157
+ if get_config().verbose:
158
+ print(f"Computed PCA: {len(features)} samples → {self.n_components}D")
159
+ print(f"Wrote {target_path} to {len(hdf5_paths)} file(s)")
160
+
161
+ return PCAResult(
162
+ n_samples=len(features), n_components=self.n_components, namespace=self.namespace, target_path=target_path
163
+ )
164
+
165
+ def _compute_pca(self, features: np.ndarray) -> np.ndarray:
166
+ """Compute PCA and apply scaling"""
167
+ # Lazy import: sklearn is slow to load (~600ms), defer until needed
168
+ from sklearn.decomposition import PCA # noqa: PLC0415
169
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler # noqa: PLC0415
170
+
171
+ pca = PCA(n_components=self.n_components)
172
+ pca_values = pca.fit_transform(features)
173
+
174
+ if self.scaler == "minmax":
175
+ scaler = MinMaxScaler()
176
+ pca_values = scaler.fit_transform(pca_values)
177
+ elif self.scaler == "std":
178
+ scaler = StandardScaler()
179
+ pca_values = scaler.fit_transform(pca_values)
180
+ pca_values = sigmoid(pca_values)
181
+
182
+ return pca_values
183
+
184
+ def _write_results(self, ctx: MultipleContext, target_path: str):
185
+ """Write PCA results to HDF5 files"""
186
+ for file_slice in ctx:
187
+ file_scores = file_slice.slice(self.pca_scores)
188
+
189
+ with h5py.File(file_slice.hdf5_path, "a") as f:
190
+ ensure_groups(f, target_path)
191
+
192
+ if target_path in f:
193
+ del f[target_path]
194
+
195
+ # Fill with NaN for filtered patches
196
+ if self.n_components == 1:
197
+ full_scores = np.full(len(file_slice.mask), np.nan)
198
+ full_scores[file_slice.mask] = file_scores.flatten()
199
+ else:
200
+ full_scores = np.full((len(file_slice.mask), self.n_components), np.nan)
201
+ full_scores[file_slice.mask] = file_scores
202
+
203
+ ds = f.create_dataset(target_path, data=full_scores)
204
+ ds.attrs["n_components"] = self.n_components
205
+ ds.attrs["scaler"] = self.scaler
206
+ ds.attrs["model"] = self.model_name