wsi-toolbox 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,335 @@
1
+ """
2
+ Preview generation commands using Template Method Pattern
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from PIL import Image, ImageFont
8
+ from matplotlib import pyplot as plt, colors as mcolors
9
+ from sklearn.preprocessing import MinMaxScaler
10
+ from sklearn.decomposition import PCA
11
+
12
+ from ..utils import create_frame, get_platform_font
13
+ from . import _get, _progress
14
+
15
+
16
+ class BasePreviewCommand:
17
+ """
18
+ Base class for preview commands using Template Method Pattern
19
+
20
+ Subclasses must implement:
21
+ - _prepare(f, **kwargs): Prepare data (frames, scores, etc.)
22
+ - _get_frame(index, data, f): Get frame for specific patch
23
+ """
24
+
25
+ def __init__(self, size: int = 64, font_size: int = 16,
26
+ model_name: str | None = None):
27
+ """
28
+ Initialize preview command
29
+
30
+ Args:
31
+ size: Thumbnail patch size
32
+ font_size: Font size for labels
33
+ model_name: Model name (None to use global default)
34
+ """
35
+ self.size = size
36
+ self.font_size = font_size
37
+ self.model_name = _get('model_name', model_name)
38
+
39
+ def __call__(self, hdf5_path: str, **kwargs) -> Image.Image:
40
+ """
41
+ Template method - common workflow for all preview commands
42
+
43
+ Args:
44
+ hdf5_path: Path to HDF5 file
45
+ **kwargs: Subclass-specific arguments
46
+
47
+ Returns:
48
+ PIL.Image: Thumbnail image
49
+ """
50
+ S = self.size
51
+
52
+ with h5py.File(hdf5_path, 'r') as f:
53
+ # Load metadata
54
+ cols, rows, patch_count, patch_size = self._load_metadata(f)
55
+
56
+ # Subclass-specific preparation
57
+ data = self._prepare(f, **kwargs)
58
+
59
+ # Create canvas
60
+ canvas = Image.new('RGB', (cols * S, rows * S), (0, 0, 0))
61
+
62
+ # Render all patches (common loop)
63
+ tq = _progress(range(patch_count))
64
+ for i in tq:
65
+ coord = f['coordinates'][i]
66
+ patch_array = f['patches'][i]
67
+
68
+ # Get subclass-specific frame
69
+ frame = self._get_frame(i, data, f)
70
+
71
+ # Render patch
72
+ x, y = coord // patch_size * S
73
+ patch = Image.fromarray(patch_array).resize((S, S))
74
+ if frame:
75
+ patch.paste(frame, (0, 0), frame)
76
+ canvas.paste(patch, (x, y, x + S, y + S))
77
+
78
+ return canvas
79
+
80
+ def _load_metadata(self, f: h5py.File):
81
+ """Load common metadata"""
82
+ cols = f['metadata/cols'][()]
83
+ rows = f['metadata/rows'][()]
84
+ patch_count = f['metadata/patch_count'][()]
85
+ patch_size = f['metadata/patch_size'][()]
86
+ return cols, rows, patch_count, patch_size
87
+
88
+ def _prepare(self, f: h5py.File, **kwargs):
89
+ """
90
+ Prepare data for rendering (implemented by subclass)
91
+
92
+ Args:
93
+ f: HDF5 file handle
94
+ **kwargs: Subclass-specific arguments
95
+
96
+ Returns:
97
+ Any data structure needed for _get_frame()
98
+ """
99
+ raise NotImplementedError
100
+
101
+ def _get_frame(self, index: int, data, f: h5py.File):
102
+ """
103
+ Get frame for specific patch (implemented by subclass)
104
+
105
+ Args:
106
+ index: Patch index
107
+ data: Data prepared by _prepare()
108
+ f: HDF5 file handle
109
+
110
+ Returns:
111
+ PIL.Image or None: Frame overlay
112
+ """
113
+ raise NotImplementedError
114
+
115
+
116
+ class PreviewClustersCommand(BasePreviewCommand):
117
+ """
118
+ Generate thumbnail with cluster visualization
119
+
120
+ Usage:
121
+ cmd = PreviewClustersCommand(size=64)
122
+ image = cmd(hdf5_path='data.h5', cluster_name='test')
123
+ """
124
+
125
+ def _prepare(self, f: h5py.File, cluster_name: str = ''):
126
+ """
127
+ Prepare cluster frames
128
+
129
+ Args:
130
+ f: HDF5 file handle
131
+ cluster_name: Cluster name suffix
132
+
133
+ Returns:
134
+ dict with 'clusters' and 'frames'
135
+ """
136
+ # Load clusters
137
+ cluster_path = f'{self.model_name}/clusters'
138
+ if cluster_name:
139
+ cluster_path += f'_{cluster_name}'
140
+ if cluster_path not in f:
141
+ raise RuntimeError(f'{cluster_path} does not exist in HDF5 file')
142
+
143
+ clusters = f[cluster_path][:]
144
+
145
+ # Prepare frames for each cluster
146
+ font = ImageFont.truetype(font=get_platform_font(), size=self.font_size)
147
+ cmap = plt.get_cmap('tab20')
148
+ frames = {}
149
+
150
+ for cluster in np.unique(clusters).tolist() + [-1]:
151
+ color = mcolors.rgb2hex(cmap(cluster)[:3]) if cluster >= 0 else '#111'
152
+ frames[cluster] = create_frame(self.size, color, f'{cluster}', font)
153
+
154
+ return {'clusters': clusters, 'frames': frames}
155
+
156
+ def _get_frame(self, index: int, data, f: h5py.File):
157
+ """Get frame for cluster at index"""
158
+ cluster = data['clusters'][index]
159
+ return data['frames'][cluster] if cluster >= 0 else None
160
+
161
+
162
+ class PreviewScoresCommand(BasePreviewCommand):
163
+ """
164
+ Generate thumbnail with score visualization
165
+
166
+ Usage:
167
+ cmd = PreviewScoresCommand(size=64)
168
+ image = cmd(hdf5_path='data.h5', score_name='pca')
169
+ """
170
+
171
+ def _prepare(self, f: h5py.File, score_name: str):
172
+ """
173
+ Prepare score visualization data
174
+
175
+ Args:
176
+ f: HDF5 file handle
177
+ score_name: Score dataset name
178
+
179
+ Returns:
180
+ dict with 'scores', 'cmap', and 'font'
181
+ """
182
+ # Load scores
183
+ score_path = f'{self.model_name}/scores_{score_name}'
184
+ scores = f[score_path][()]
185
+
186
+ # Prepare font and colormap
187
+ font = ImageFont.truetype(font=get_platform_font(), size=self.font_size)
188
+ cmap = plt.get_cmap('viridis')
189
+
190
+ return {'scores': scores, 'cmap': cmap, 'font': font}
191
+
192
+ def _get_frame(self, index: int, data, f: h5py.File):
193
+ """Get frame for score at index"""
194
+ score = data['scores'][index]
195
+
196
+ if np.isnan(score):
197
+ return None
198
+
199
+ color = mcolors.rgb2hex(data['cmap'](score)[:3])
200
+ return create_frame(self.size, color, f'{score:.3f}', data['font'])
201
+
202
+
203
+ class PreviewLatentPCACommand(BasePreviewCommand):
204
+ """
205
+ Generate thumbnail with latent PCA visualization
206
+
207
+ Usage:
208
+ cmd = PreviewLatentPCACommand(size=64)
209
+ image = cmd(hdf5_path='data.h5', alpha=0.5)
210
+ """
211
+
212
+ def _prepare(self, f: h5py.File, alpha: float = 0.5):
213
+ """
214
+ Prepare latent PCA visualization data
215
+
216
+ Args:
217
+ f: HDF5 file handle
218
+ alpha: Transparency of overlay (0.0-1.0)
219
+
220
+ Returns:
221
+ dict with 'overlays' and 'alpha_mask'
222
+ """
223
+ # Load latent features
224
+ h = f[f'{self.model_name}/latent_features'][()] # B, L(16x16), EMB(1024)
225
+ h = h.astype(np.float32)
226
+ s = h.shape
227
+
228
+ # Estimate original latent size
229
+ latent_size = int(np.sqrt(s[1])) # l = sqrt(L)
230
+ # Validate dyadicity
231
+ assert latent_size**2 == s[1]
232
+ if self.size % latent_size != 0:
233
+ print(f'WARNING: {self.size} is not divisible by {latent_size}')
234
+
235
+ # Apply PCA
236
+ pca = PCA(n_components=3)
237
+ latent_pca = pca.fit_transform(h.reshape(s[0] * s[1], s[-1])) # B*L, 3
238
+
239
+ # Normalize to [0, 1]
240
+ scaler = MinMaxScaler()
241
+ latent_pca = scaler.fit_transform(latent_pca)
242
+
243
+ # Reshape and convert to RGB
244
+ latent_pca = latent_pca.reshape(s[0], latent_size, latent_size, 3)
245
+ overlays = (latent_pca * 255).astype(np.uint8) # B, l, l, 3
246
+
247
+ # Create alpha mask
248
+ alpha_mask = Image.new('L', (self.size, self.size), int(alpha * 255))
249
+
250
+ return {'overlays': overlays, 'alpha_mask': alpha_mask, 'latent_size': latent_size}
251
+
252
+ def _get_frame(self, index: int, data, f: h5py.File):
253
+ """
254
+ Get latent PCA overlay as a frame for patch at index
255
+
256
+ Args:
257
+ index: Patch index
258
+ data: Data prepared by _prepare()
259
+ f: HDF5 file handle
260
+
261
+ Returns:
262
+ PIL.Image: RGBA overlay image
263
+ """
264
+ # Get overlay for this patch
265
+ overlay = Image.fromarray(data['overlays'][index]).convert('RGBA')
266
+ overlay = overlay.resize((self.size, self.size), Image.NEAREST)
267
+
268
+ # Apply alpha mask to make it an overlay
269
+ overlay.putalpha(data['alpha_mask'])
270
+
271
+ return overlay
272
+
273
+
274
+ class PreviewLatentClusterCommand(BasePreviewCommand):
275
+ """
276
+ Generate thumbnail with latent cluster visualization
277
+
278
+ Usage:
279
+ cmd = PreviewLatentClusterCommand(size=64)
280
+ image = cmd(hdf5_path='data.h5', alpha=0.5)
281
+ """
282
+
283
+ def _prepare(self, f: h5py.File, alpha: float = 0.5):
284
+ """
285
+ Prepare latent cluster visualization data
286
+
287
+ Args:
288
+ f: HDF5 file handle
289
+ alpha: Transparency of overlay (0.0-1.0)
290
+
291
+ Returns:
292
+ dict with 'overlays' and 'alpha_mask'
293
+ """
294
+ # Load latent clusters
295
+ clusters = f[f'{self.model_name}/latent_clusters'][()] # B, L(16x16)
296
+ s = clusters.shape
297
+
298
+ # Estimate original latent size
299
+ latent_size = int(np.sqrt(s[1])) # l = sqrt(L)
300
+ # Validate dyadicity
301
+ assert latent_size**2 == s[1]
302
+ if self.size % latent_size != 0:
303
+ print(f'WARNING: {self.size} is not divisible by {latent_size}')
304
+
305
+ # Apply colormap
306
+ cmap = plt.get_cmap('tab20')
307
+ latent_map = cmap(clusters)
308
+ latent_map = latent_map.reshape(s[0], latent_size, latent_size, 4)
309
+ overlays = (latent_map * 255).astype(np.uint8) # B, l, l, 4
310
+
311
+ # Create alpha mask
312
+ alpha_mask = Image.new('L', (self.size, self.size), int(alpha * 255))
313
+
314
+ return {'overlays': overlays, 'alpha_mask': alpha_mask, 'latent_size': latent_size}
315
+
316
+ def _get_frame(self, index: int, data, f: h5py.File):
317
+ """
318
+ Get latent cluster overlay as a frame for patch at index
319
+
320
+ Args:
321
+ index: Patch index
322
+ data: Data prepared by _prepare()
323
+ f: HDF5 file handle
324
+
325
+ Returns:
326
+ PIL.Image: RGBA overlay image
327
+ """
328
+ # Get overlay for this patch
329
+ overlay = Image.fromarray(data['overlays'][index]).convert('RGBA')
330
+ overlay = overlay.resize((self.size, self.size), Image.NEAREST)
331
+
332
+ # Apply alpha mask to make it an overlay
333
+ overlay.putalpha(data['alpha_mask'])
334
+
335
+ return overlay
@@ -0,0 +1,196 @@
1
+ """
2
+ WSI to HDF5 conversion command
3
+ """
4
+
5
+ import cv2
6
+ import h5py
7
+ import numpy as np
8
+ from pydantic import BaseModel
9
+
10
+ from ..wsi_files import create_wsi_file
11
+ from ..utils.helpers import is_white_patch
12
+ from . import _config, _progress
13
+
14
+
15
+ class Wsi2HDF5Result(BaseModel):
16
+ """Result of WSI to HDF5 conversion"""
17
+ mpp: float
18
+ original_mpp: float
19
+ scale: int
20
+ patch_count: int
21
+ patch_size: int
22
+ cols: int
23
+ rows: int
24
+ output_path: str
25
+
26
+
27
+ class Wsi2HDF5Command:
28
+ """
29
+ Convert WSI image to HDF5 format with patch extraction
30
+
31
+ Usage:
32
+ # Set global config once
33
+ commands.set_default_progress('tqdm')
34
+
35
+ # Create and run command
36
+ cmd = Wsi2HDF5Command(patch_size=256, engine='auto')
37
+ result = cmd(input_path='image.ndpi', output_path='output.h5')
38
+ """
39
+
40
+ def __init__(self,
41
+ patch_size: int = 256,
42
+ engine: str = 'auto',
43
+ mpp: float = 0,
44
+ rotate: bool = True):
45
+ """
46
+ Initialize WSI to HDF5 converter
47
+
48
+ Args:
49
+ patch_size: Size of patches to extract
50
+ engine: WSI reader engine ('auto', 'openslide', 'tifffile', 'standard')
51
+ mpp: Microns per pixel (for standard images)
52
+ rotate: Whether to rotate patches 180 degrees
53
+
54
+ Note:
55
+ progress and verbose are controlled by global config:
56
+ - commands.set_default_progress('tqdm')
57
+ - commands.set_verbose(True/False)
58
+ """
59
+ self.patch_size = patch_size
60
+ self.engine = engine
61
+ self.mpp = mpp
62
+ self.rotate = rotate
63
+
64
+ def __call__(self, input_path: str, output_path: str) -> Wsi2HDF5Result:
65
+ """
66
+ Execute WSI to HDF5 conversion
67
+
68
+ Args:
69
+ input_path: Path to input WSI file
70
+ output_path: Path to output HDF5 file
71
+
72
+ Returns:
73
+ Wsi2HDF5Result: Metadata including mpp, scale, patch_count
74
+ """
75
+ # Create WSI reader
76
+ wsi = create_wsi_file(input_path, engine=self.engine, mpp=self.mpp)
77
+
78
+ # Calculate scale based on mpp
79
+ original_mpp = wsi.get_mpp()
80
+
81
+ if 0.360 < original_mpp < 0.500:
82
+ scale = 1
83
+ elif original_mpp < 0.360:
84
+ scale = 2
85
+ else:
86
+ raise RuntimeError(f'Invalid mpp: {original_mpp:.6f}')
87
+
88
+ mpp = original_mpp * scale
89
+
90
+ # Get image dimensions
91
+ W, H = wsi.get_original_size()
92
+ S = self.patch_size # Scaled patch size
93
+ T = S * scale # Original patch size
94
+
95
+ x_patch_count = W // T
96
+ y_patch_count = H // T
97
+ width = (W // T) * T
98
+ row_count = H // T
99
+
100
+ if _config.verbose and _config.progress == 'tqdm':
101
+ print(f'Original mpp: {original_mpp:.6f}')
102
+ print(f'Image mpp: {mpp:.6f}')
103
+ print(f'Target resolutions: {W} x {H}')
104
+ print(f'Obtained resolutions: {x_patch_count*S} x {y_patch_count*S}')
105
+ print(f'Scale: {scale}')
106
+ print(f'Patch size: {T}')
107
+ print(f'Scaled patch size: {S}')
108
+ print(f'Row count: {y_patch_count}')
109
+ print(f'Col count: {x_patch_count}')
110
+
111
+ coordinates = []
112
+
113
+ # Create HDF5 file
114
+ with h5py.File(output_path, 'w') as f:
115
+ # Write metadata
116
+ f.create_dataset('metadata/original_mpp', data=original_mpp)
117
+ f.create_dataset('metadata/original_width', data=W)
118
+ f.create_dataset('metadata/original_height', data=H)
119
+ f.create_dataset('metadata/image_level', data=0)
120
+ f.create_dataset('metadata/mpp', data=mpp)
121
+ f.create_dataset('metadata/scale', data=scale)
122
+ f.create_dataset('metadata/patch_size', data=S)
123
+ f.create_dataset('metadata/cols', data=x_patch_count)
124
+ f.create_dataset('metadata/rows', data=y_patch_count)
125
+
126
+ # Create patches dataset
127
+ total_patches = f.create_dataset(
128
+ 'patches',
129
+ shape=(x_patch_count * y_patch_count, S, S, 3),
130
+ dtype=np.uint8,
131
+ chunks=(1, S, S, 3),
132
+ compression='gzip',
133
+ compression_opts=9
134
+ )
135
+
136
+ # Extract patches row by row
137
+ cursor = 0
138
+ tq = _progress(range(row_count))
139
+ for row in tq:
140
+ # Read one row
141
+ image = wsi.read_region((0, row * T, width, T))
142
+ image = cv2.resize(image, (width // scale, S),
143
+ interpolation=cv2.INTER_LANCZOS4)
144
+
145
+ # Reshape into patches
146
+ patches = image.reshape(1, S, x_patch_count, S, 3) # (y, h, x, w, 3)
147
+ patches = patches.transpose(0, 2, 1, 3, 4) # (y, x, h, w, 3)
148
+ patches = patches[0]
149
+
150
+ # Filter white patches and collect valid ones
151
+ batch = []
152
+ for col, patch in enumerate(patches):
153
+ if is_white_patch(patch):
154
+ continue
155
+
156
+ if self.rotate:
157
+ patch = cv2.rotate(patch, cv2.ROTATE_180)
158
+ coordinates.append((
159
+ (x_patch_count - 1 - col) * S,
160
+ (y_patch_count - 1 - row) * S
161
+ ))
162
+ else:
163
+ coordinates.append((col * S, row * S))
164
+
165
+ batch.append(patch)
166
+
167
+ # Write batch
168
+ batch = np.array(batch)
169
+ total_patches[cursor:cursor + len(batch), ...] = batch
170
+ cursor += len(batch)
171
+
172
+ tq.set_description(
173
+ f'Selected {len(batch)}/{len(patches)} patches '
174
+ f'(row {row}/{y_patch_count})'
175
+ )
176
+ tq.refresh()
177
+
178
+ # Resize to actual patch count and save coordinates
179
+ patch_count = len(coordinates)
180
+ f.create_dataset('coordinates', data=coordinates)
181
+ f['patches'].resize((patch_count, S, S, 3))
182
+ f.create_dataset('metadata/patch_count', data=patch_count)
183
+
184
+ if _config.verbose and _config.progress == 'tqdm':
185
+ print(f'{patch_count} patches were selected.')
186
+
187
+ return Wsi2HDF5Result(
188
+ mpp=mpp,
189
+ original_mpp=original_mpp,
190
+ scale=scale,
191
+ patch_count=patch_count,
192
+ patch_size=S,
193
+ cols=x_patch_count,
194
+ rows=y_patch_count,
195
+ output_path=output_path
196
+ )