wsi-toolbox 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wsi_toolbox/__init__.py +122 -0
- wsi_toolbox/app.py +874 -0
- wsi_toolbox/cli.py +599 -0
- wsi_toolbox/commands/__init__.py +66 -0
- wsi_toolbox/commands/clustering.py +198 -0
- wsi_toolbox/commands/data_loader.py +219 -0
- wsi_toolbox/commands/dzi.py +160 -0
- wsi_toolbox/commands/patch_embedding.py +196 -0
- wsi_toolbox/commands/pca.py +206 -0
- wsi_toolbox/commands/preview.py +394 -0
- wsi_toolbox/commands/show.py +171 -0
- wsi_toolbox/commands/umap_embedding.py +174 -0
- wsi_toolbox/commands/wsi.py +223 -0
- wsi_toolbox/common.py +148 -0
- wsi_toolbox/models.py +30 -0
- wsi_toolbox/utils/__init__.py +109 -0
- wsi_toolbox/utils/analysis.py +174 -0
- wsi_toolbox/utils/hdf5_paths.py +232 -0
- wsi_toolbox/utils/plot.py +227 -0
- wsi_toolbox/utils/progress.py +207 -0
- wsi_toolbox/utils/seed.py +26 -0
- wsi_toolbox/utils/st.py +55 -0
- wsi_toolbox/utils/white.py +121 -0
- wsi_toolbox/watcher.py +256 -0
- wsi_toolbox/wsi_files.py +619 -0
- wsi_toolbox-0.2.0.dist-info/METADATA +253 -0
- wsi_toolbox-0.2.0.dist-info/RECORD +30 -0
- wsi_toolbox-0.2.0.dist-info/WHEEL +4 -0
- wsi_toolbox-0.2.0.dist-info/entry_points.txt +3 -0
- wsi_toolbox-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Patch embedding extraction command
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import gc
|
|
6
|
+
|
|
7
|
+
import h5py
|
|
8
|
+
import numpy as np
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from ..common import create_default_model
|
|
12
|
+
from ..utils import safe_del
|
|
13
|
+
from . import _get, _progress, get_config
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PatchEmbeddingResult(BaseModel):
|
|
17
|
+
"""Result of patch embedding extraction"""
|
|
18
|
+
|
|
19
|
+
feature_dim: int = 0
|
|
20
|
+
patch_count: int = 0
|
|
21
|
+
model: str = ""
|
|
22
|
+
with_latent: bool = False
|
|
23
|
+
skipped: bool = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PatchEmbeddingCommand:
|
|
27
|
+
"""
|
|
28
|
+
Extract embeddings from patches using foundation models
|
|
29
|
+
|
|
30
|
+
Usage:
|
|
31
|
+
# Set global config once
|
|
32
|
+
commands.set_default_model_preset('gigapath')
|
|
33
|
+
commands.set_default_device('cuda')
|
|
34
|
+
|
|
35
|
+
# Create and run command
|
|
36
|
+
cmd = PatchEmbeddingCommand(batch_size=256, with_latent=False)
|
|
37
|
+
result = cmd(hdf5_path='data.h5')
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
batch_size: int = 256,
|
|
43
|
+
with_latent: bool = False,
|
|
44
|
+
overwrite: bool = False,
|
|
45
|
+
model_name: str | None = None,
|
|
46
|
+
device: str | None = None,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Initialize patch embedding extractor
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
batch_size: Batch size for inference
|
|
53
|
+
with_latent: Whether to extract latent features
|
|
54
|
+
overwrite: Whether to overwrite existing features
|
|
55
|
+
model_name: Model name (None to use global default)
|
|
56
|
+
device: Device (None to use global default)
|
|
57
|
+
|
|
58
|
+
Note:
|
|
59
|
+
progress and verbose are controlled by global config
|
|
60
|
+
"""
|
|
61
|
+
self.batch_size = batch_size
|
|
62
|
+
self.with_latent = with_latent
|
|
63
|
+
self.overwrite = overwrite
|
|
64
|
+
self.model_name = _get("model_name", model_name)
|
|
65
|
+
self.device = _get("device", device)
|
|
66
|
+
|
|
67
|
+
# Validate model
|
|
68
|
+
if self.model_name not in ["uni", "gigapath", "virchow2"]:
|
|
69
|
+
raise ValueError(f"Invalid model: {self.model_name}")
|
|
70
|
+
|
|
71
|
+
# Dataset paths
|
|
72
|
+
self.feature_name = f"{self.model_name}/features"
|
|
73
|
+
self.latent_feature_name = f"{self.model_name}/latent_features"
|
|
74
|
+
|
|
75
|
+
def __call__(self, hdf5_path: str) -> PatchEmbeddingResult:
|
|
76
|
+
"""
|
|
77
|
+
Execute embedding extraction
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
hdf5_path: Path to HDF5 file
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
PatchEmbeddingResult: Result metadata (feature_dim, patch_count, skipped, etc.)
|
|
84
|
+
"""
|
|
85
|
+
# Lazy import: torch is slow to load (~800ms), defer until needed
|
|
86
|
+
import torch # noqa: PLC0415
|
|
87
|
+
|
|
88
|
+
# Load model (uses globally registered model generator)
|
|
89
|
+
model = create_default_model()
|
|
90
|
+
model = model.eval().to(self.device)
|
|
91
|
+
|
|
92
|
+
# Normalization parameters
|
|
93
|
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
|
|
94
|
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
|
|
95
|
+
|
|
96
|
+
done = False
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
with h5py.File(hdf5_path, "r+") as f:
|
|
100
|
+
latent_size = model.patch_embed.proj.kernel_size[0]
|
|
101
|
+
|
|
102
|
+
# Check if already exists
|
|
103
|
+
if not self.overwrite:
|
|
104
|
+
if self.with_latent:
|
|
105
|
+
if (self.feature_name in f) and (self.latent_feature_name in f):
|
|
106
|
+
if get_config().verbose:
|
|
107
|
+
print("Already extracted. Skipped.")
|
|
108
|
+
done = True
|
|
109
|
+
return PatchEmbeddingResult(skipped=True)
|
|
110
|
+
if (self.feature_name in f) or (self.latent_feature_name in f):
|
|
111
|
+
raise RuntimeError(f"Either {self.feature_name} or {self.latent_feature_name} exists.")
|
|
112
|
+
else:
|
|
113
|
+
if self.feature_name in f:
|
|
114
|
+
if get_config().verbose:
|
|
115
|
+
print("Already extracted. Skipped.")
|
|
116
|
+
done = True
|
|
117
|
+
return PatchEmbeddingResult(skipped=True)
|
|
118
|
+
|
|
119
|
+
# Delete if overwrite
|
|
120
|
+
if self.overwrite:
|
|
121
|
+
safe_del(f, self.feature_name)
|
|
122
|
+
safe_del(f, self.latent_feature_name)
|
|
123
|
+
|
|
124
|
+
# Get patch count
|
|
125
|
+
patch_count = f["metadata/patch_count"][()]
|
|
126
|
+
|
|
127
|
+
# Create batch indices
|
|
128
|
+
batch_idx = [(i, min(i + self.batch_size, patch_count)) for i in range(0, patch_count, self.batch_size)]
|
|
129
|
+
|
|
130
|
+
# Create datasets
|
|
131
|
+
f.create_dataset(self.feature_name, shape=(patch_count, model.num_features), dtype=np.float32)
|
|
132
|
+
if self.with_latent:
|
|
133
|
+
f.create_dataset(
|
|
134
|
+
self.latent_feature_name,
|
|
135
|
+
shape=(patch_count, latent_size**2, model.num_features),
|
|
136
|
+
dtype=np.float16,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Process batches
|
|
140
|
+
tq = _progress(batch_idx)
|
|
141
|
+
for i0, i1 in tq:
|
|
142
|
+
# Load batch
|
|
143
|
+
x = f["patches"][i0:i1]
|
|
144
|
+
x = (torch.from_numpy(x) / 255).permute(0, 3, 1, 2) # BHWC->BCHW
|
|
145
|
+
x = x.to(self.device)
|
|
146
|
+
x = (x - mean) / std
|
|
147
|
+
|
|
148
|
+
# Forward pass
|
|
149
|
+
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
150
|
+
h_tensor = model.forward_features(x)
|
|
151
|
+
|
|
152
|
+
# Extract features
|
|
153
|
+
h = h_tensor.cpu().detach().numpy() # [B, T+L, H]
|
|
154
|
+
latent_index = h.shape[1] - latent_size**2
|
|
155
|
+
cls_feature = h[:, 0, ...]
|
|
156
|
+
latent_feature = h[:, latent_index:, ...]
|
|
157
|
+
|
|
158
|
+
# Save features
|
|
159
|
+
f[self.feature_name][i0:i1] = cls_feature
|
|
160
|
+
if self.with_latent:
|
|
161
|
+
f[self.latent_feature_name][i0:i1] = latent_feature.astype(np.float16)
|
|
162
|
+
|
|
163
|
+
# Cleanup
|
|
164
|
+
del x, h_tensor
|
|
165
|
+
torch.cuda.empty_cache()
|
|
166
|
+
|
|
167
|
+
tq.set_description(f"Processing {i0}-{i1} (total={patch_count})")
|
|
168
|
+
tq.refresh()
|
|
169
|
+
|
|
170
|
+
if get_config().verbose:
|
|
171
|
+
print(f"Embeddings dimension: {f[self.feature_name].shape}")
|
|
172
|
+
|
|
173
|
+
done = True
|
|
174
|
+
return PatchEmbeddingResult(
|
|
175
|
+
feature_dim=model.num_features,
|
|
176
|
+
patch_count=patch_count,
|
|
177
|
+
model=self.model_name,
|
|
178
|
+
with_latent=self.with_latent,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
finally:
|
|
182
|
+
if done and get_config().verbose:
|
|
183
|
+
print(f"Wrote {self.feature_name}")
|
|
184
|
+
elif not done:
|
|
185
|
+
# Cleanup on error
|
|
186
|
+
with h5py.File(hdf5_path, "a") as f:
|
|
187
|
+
safe_del(f, self.feature_name)
|
|
188
|
+
if self.with_latent:
|
|
189
|
+
safe_del(f, self.latent_feature_name)
|
|
190
|
+
if get_config().verbose:
|
|
191
|
+
print(f"ABORTED! Deleted {self.feature_name}")
|
|
192
|
+
|
|
193
|
+
# Cleanup
|
|
194
|
+
del model, mean, std
|
|
195
|
+
torch.cuda.empty_cache()
|
|
196
|
+
gc.collect()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PCA scoring command for feature analysis
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import h5py
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from ..utils.hdf5_paths import build_cluster_path, build_namespace, ensure_groups
|
|
10
|
+
from . import _get, _progress, get_config
|
|
11
|
+
from .data_loader import MultipleContext
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sigmoid(x):
|
|
15
|
+
"""Apply sigmoid function"""
|
|
16
|
+
return 1 / (1 + np.exp(-x))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PCAResult(BaseModel):
|
|
20
|
+
"""Result of PCA operation"""
|
|
21
|
+
|
|
22
|
+
n_samples: int
|
|
23
|
+
n_components: int
|
|
24
|
+
namespace: str
|
|
25
|
+
target_path: str
|
|
26
|
+
skipped: bool = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PCACommand:
|
|
30
|
+
"""
|
|
31
|
+
Compute PCA scores from features
|
|
32
|
+
|
|
33
|
+
Input:
|
|
34
|
+
- features (from <model>/features)
|
|
35
|
+
- namespace + filters (recursive hierarchy)
|
|
36
|
+
- n_components: 1, 2, or 3
|
|
37
|
+
- scaler: minmax or std
|
|
38
|
+
|
|
39
|
+
Output:
|
|
40
|
+
- PCA scores written to deepest level
|
|
41
|
+
- metadata saved as HDF5 attributes
|
|
42
|
+
|
|
43
|
+
Example hierarchy:
|
|
44
|
+
uni/default/pca2
|
|
45
|
+
↑ with attributes: n_components=2, scaler="minmax"
|
|
46
|
+
uni/default/filter/1+2+3/pca1
|
|
47
|
+
↑ filtered, with PCA scores
|
|
48
|
+
|
|
49
|
+
Usage:
|
|
50
|
+
# Basic PCA
|
|
51
|
+
cmd = PCACommand(n_components=2)
|
|
52
|
+
result = cmd('data.h5') # → uni/default/pca2
|
|
53
|
+
|
|
54
|
+
# Filtered PCA
|
|
55
|
+
cmd = PCACommand(parent_filters=[[1,2,3]])
|
|
56
|
+
result = cmd('data.h5') # → uni/default/filter/1+2+3/pca2
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
n_components: int = 2,
|
|
62
|
+
namespace: str | None = None,
|
|
63
|
+
parent_filters: list[list[int]] | None = None,
|
|
64
|
+
scaler: str = "minmax",
|
|
65
|
+
overwrite: bool = False,
|
|
66
|
+
model_name: str | None = None,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Args:
|
|
70
|
+
n_components: Number of PCA components (1, 2, or 3)
|
|
71
|
+
namespace: Explicit namespace (None = auto-generate)
|
|
72
|
+
parent_filters: Hierarchical filters, e.g., [[1,2,3], [4,5]]
|
|
73
|
+
scaler: Scaling method ("minmax" or "std")
|
|
74
|
+
overwrite: Overwrite existing PCA scores
|
|
75
|
+
model_name: Model name (None = use global default)
|
|
76
|
+
"""
|
|
77
|
+
self.n_components = n_components
|
|
78
|
+
self.namespace = namespace
|
|
79
|
+
self.parent_filters = parent_filters or []
|
|
80
|
+
self.scaler = scaler
|
|
81
|
+
self.overwrite = overwrite
|
|
82
|
+
self.model_name = _get("model_name", model_name)
|
|
83
|
+
|
|
84
|
+
# Validate
|
|
85
|
+
if self.model_name not in ["uni", "gigapath", "virchow2"]:
|
|
86
|
+
raise ValueError(f"Invalid model: {self.model_name}")
|
|
87
|
+
if self.n_components not in [1, 2, 3]:
|
|
88
|
+
raise ValueError(f"Invalid n_components: {self.n_components}")
|
|
89
|
+
if self.scaler not in ["minmax", "std"]:
|
|
90
|
+
raise ValueError(f"Invalid scaler: {self.scaler}")
|
|
91
|
+
|
|
92
|
+
# Internal state
|
|
93
|
+
self.hdf5_paths = []
|
|
94
|
+
self.pca_scores = None
|
|
95
|
+
|
|
96
|
+
def __call__(self, hdf5_paths: str | list[str]) -> PCAResult:
|
|
97
|
+
"""
|
|
98
|
+
Execute PCA computation
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
hdf5_paths: Single HDF5 path or list of paths
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
PCAResult
|
|
105
|
+
"""
|
|
106
|
+
# Normalize to list
|
|
107
|
+
if isinstance(hdf5_paths, str):
|
|
108
|
+
hdf5_paths = [hdf5_paths]
|
|
109
|
+
self.hdf5_paths = hdf5_paths
|
|
110
|
+
|
|
111
|
+
# Determine namespace
|
|
112
|
+
if self.namespace is None:
|
|
113
|
+
self.namespace = build_namespace(hdf5_paths)
|
|
114
|
+
elif "+" in self.namespace:
|
|
115
|
+
raise ValueError("Namespace cannot contain '+' (reserved for multi-file auto-generated namespaces)")
|
|
116
|
+
|
|
117
|
+
# Build target path
|
|
118
|
+
target_path = build_cluster_path(
|
|
119
|
+
self.model_name, self.namespace, filters=self.parent_filters, dataset=f"pca{self.n_components}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Check if already exists
|
|
123
|
+
if not self.overwrite:
|
|
124
|
+
with h5py.File(hdf5_paths[0], "r") as f:
|
|
125
|
+
if target_path in f:
|
|
126
|
+
scores = f[target_path][:]
|
|
127
|
+
n_samples = np.sum(~np.isnan(scores[:, 0]) if scores.ndim > 1 else ~np.isnan(scores))
|
|
128
|
+
if get_config().verbose:
|
|
129
|
+
print(f"PCA scores already exist at {target_path}")
|
|
130
|
+
return PCAResult(
|
|
131
|
+
n_samples=n_samples,
|
|
132
|
+
n_components=self.n_components,
|
|
133
|
+
namespace=self.namespace,
|
|
134
|
+
target_path=target_path,
|
|
135
|
+
skipped=True,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Execute with progress tracking
|
|
139
|
+
with _progress(total=3, desc="PCA") as pbar:
|
|
140
|
+
# Load data
|
|
141
|
+
pbar.set_description("Loading features")
|
|
142
|
+
ctx = MultipleContext(hdf5_paths, self.model_name, self.namespace, self.parent_filters)
|
|
143
|
+
features = ctx.load_features(source="features")
|
|
144
|
+
pbar.update(1)
|
|
145
|
+
|
|
146
|
+
# Compute PCA
|
|
147
|
+
pbar.set_description("Computing PCA")
|
|
148
|
+
self.pca_scores = self._compute_pca(features)
|
|
149
|
+
pbar.update(1)
|
|
150
|
+
|
|
151
|
+
# Write results
|
|
152
|
+
pbar.set_description("Writing results")
|
|
153
|
+
self._write_results(ctx, target_path)
|
|
154
|
+
pbar.update(1)
|
|
155
|
+
|
|
156
|
+
# Verbose output after progress bar closes
|
|
157
|
+
if get_config().verbose:
|
|
158
|
+
print(f"Computed PCA: {len(features)} samples → {self.n_components}D")
|
|
159
|
+
print(f"Wrote {target_path} to {len(hdf5_paths)} file(s)")
|
|
160
|
+
|
|
161
|
+
return PCAResult(
|
|
162
|
+
n_samples=len(features), n_components=self.n_components, namespace=self.namespace, target_path=target_path
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _compute_pca(self, features: np.ndarray) -> np.ndarray:
|
|
166
|
+
"""Compute PCA and apply scaling"""
|
|
167
|
+
# Lazy import: sklearn is slow to load (~600ms), defer until needed
|
|
168
|
+
from sklearn.decomposition import PCA # noqa: PLC0415
|
|
169
|
+
from sklearn.preprocessing import MinMaxScaler, StandardScaler # noqa: PLC0415
|
|
170
|
+
|
|
171
|
+
pca = PCA(n_components=self.n_components)
|
|
172
|
+
pca_values = pca.fit_transform(features)
|
|
173
|
+
|
|
174
|
+
if self.scaler == "minmax":
|
|
175
|
+
scaler = MinMaxScaler()
|
|
176
|
+
pca_values = scaler.fit_transform(pca_values)
|
|
177
|
+
elif self.scaler == "std":
|
|
178
|
+
scaler = StandardScaler()
|
|
179
|
+
pca_values = scaler.fit_transform(pca_values)
|
|
180
|
+
pca_values = sigmoid(pca_values)
|
|
181
|
+
|
|
182
|
+
return pca_values
|
|
183
|
+
|
|
184
|
+
def _write_results(self, ctx: MultipleContext, target_path: str):
|
|
185
|
+
"""Write PCA results to HDF5 files"""
|
|
186
|
+
for file_slice in ctx:
|
|
187
|
+
file_scores = file_slice.slice(self.pca_scores)
|
|
188
|
+
|
|
189
|
+
with h5py.File(file_slice.hdf5_path, "a") as f:
|
|
190
|
+
ensure_groups(f, target_path)
|
|
191
|
+
|
|
192
|
+
if target_path in f:
|
|
193
|
+
del f[target_path]
|
|
194
|
+
|
|
195
|
+
# Fill with NaN for filtered patches
|
|
196
|
+
if self.n_components == 1:
|
|
197
|
+
full_scores = np.full(len(file_slice.mask), np.nan)
|
|
198
|
+
full_scores[file_slice.mask] = file_scores.flatten()
|
|
199
|
+
else:
|
|
200
|
+
full_scores = np.full((len(file_slice.mask), self.n_components), np.nan)
|
|
201
|
+
full_scores[file_slice.mask] = file_scores
|
|
202
|
+
|
|
203
|
+
ds = f.create_dataset(target_path, data=full_scores)
|
|
204
|
+
ds.attrs["n_components"] = self.n_components
|
|
205
|
+
ds.attrs["scaler"] = self.scaler
|
|
206
|
+
ds.attrs["model"] = self.model_name
|