wsi-toolbox 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,174 @@
1
+ """
2
+ UMAP embedding command for dimensionality reduction
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from pydantic import BaseModel
8
+
9
+ from ..utils.hdf5_paths import build_cluster_path, build_namespace, ensure_groups
10
+ from . import _get, _progress, get_config
11
+ from .data_loader import MultipleContext
12
+
13
+
14
+ class UmapResult(BaseModel):
15
+ """Result of UMAP embedding operation"""
16
+
17
+ n_samples: int
18
+ n_components: int
19
+ namespace: str
20
+ target_path: str
21
+ skipped: bool = False
22
+
23
+
24
+ class UmapCommand:
25
+ """
26
+ Compute UMAP embeddings from features
27
+
28
+ Usage:
29
+ # Basic UMAP
30
+ cmd = UmapCommand()
31
+ result = cmd('data.h5') # → uni/default/umap
32
+
33
+ # Multi-file UMAP
34
+ cmd = UmapCommand()
35
+ result = cmd(['001.h5', '002.h5']) # → uni/001+002/umap
36
+
37
+ # UMAP for filtered data
38
+ cmd = UmapCommand(parent_filters=[[1,2,3]])
39
+ result = cmd('data.h5') # → uni/default/filter/1+2+3/umap
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ namespace: str | None = None,
45
+ parent_filters: list[list[int]] | None = None,
46
+ n_components: int = 2,
47
+ n_neighbors: int = 15,
48
+ min_dist: float = 0.1,
49
+ metric: str = "euclidean",
50
+ overwrite: bool = False,
51
+ model_name: str | None = None,
52
+ ):
53
+ """
54
+ Initialize UMAP command
55
+
56
+ Args:
57
+ namespace: Explicit namespace (None = auto-generate from input paths)
58
+ parent_filters: Hierarchical filters, e.g., [[1,2,3]]
59
+ n_components: Number of UMAP dimensions (default: 2)
60
+ n_neighbors: UMAP n_neighbors parameter (default: 15)
61
+ min_dist: UMAP min_dist parameter (default: 0.1)
62
+ metric: UMAP metric (default: "euclidean")
63
+ overwrite: Whether to overwrite existing UMAP coordinates
64
+ model_name: Model name (None to use global default)
65
+ """
66
+ self.namespace = namespace
67
+ self.parent_filters = parent_filters or []
68
+ self.n_components = n_components
69
+ self.n_neighbors = n_neighbors
70
+ self.min_dist = min_dist
71
+ self.metric = metric
72
+ self.overwrite = overwrite
73
+ self.model_name = _get("model_name", model_name)
74
+
75
+ # Validate model
76
+ if self.model_name not in ["uni", "gigapath", "virchow2"]:
77
+ raise ValueError(f"Invalid model: {self.model_name}")
78
+
79
+ # Internal state
80
+ self.hdf5_paths = []
81
+ self.umap_embeddings = None
82
+
83
+ def __call__(self, hdf5_paths: str | list[str]) -> UmapResult:
84
+ """Execute UMAP embedding"""
85
+ import umap # noqa: PLC0415 - lazy load, umap is slow to import
86
+
87
+ # Normalize to list
88
+ if isinstance(hdf5_paths, str):
89
+ hdf5_paths = [hdf5_paths]
90
+ self.hdf5_paths = hdf5_paths
91
+
92
+ # Determine namespace
93
+ if self.namespace is None:
94
+ self.namespace = build_namespace(hdf5_paths)
95
+ elif "+" in self.namespace:
96
+ raise ValueError("Namespace cannot contain '+' (reserved for multi-file auto-generated namespaces)")
97
+
98
+ # Build target path
99
+ target_path = build_cluster_path(self.model_name, self.namespace, filters=self.parent_filters, dataset="umap")
100
+
101
+ # Check if already exists
102
+ if not self.overwrite:
103
+ with h5py.File(hdf5_paths[0], "r") as f:
104
+ if target_path in f:
105
+ umap_coords = f[target_path][:]
106
+ n_samples = np.sum(~np.isnan(umap_coords[:, 0]))
107
+ if get_config().verbose:
108
+ print(f"UMAP already exists at {target_path}")
109
+ return UmapResult(
110
+ n_samples=n_samples,
111
+ n_components=self.n_components,
112
+ namespace=self.namespace,
113
+ target_path=target_path,
114
+ skipped=True,
115
+ )
116
+
117
+ # Execute with progress tracking
118
+ with _progress(total=3, desc="UMAP") as pbar:
119
+ # Load features
120
+ pbar.set_description("Loading features")
121
+ ctx = MultipleContext(hdf5_paths, self.model_name, self.namespace, self.parent_filters)
122
+ features = ctx.load_features(source="features")
123
+ pbar.update(1)
124
+
125
+ # Compute UMAP
126
+ pbar.set_description("Computing UMAP")
127
+ reducer = umap.UMAP(
128
+ n_components=self.n_components,
129
+ n_neighbors=self.n_neighbors,
130
+ min_dist=self.min_dist,
131
+ metric=self.metric,
132
+ )
133
+ self.umap_embeddings = reducer.fit_transform(features)
134
+ pbar.update(1)
135
+
136
+ # Write results
137
+ pbar.set_description("Writing results")
138
+ self._write_results(ctx, target_path)
139
+ pbar.update(1)
140
+
141
+ # Verbose output after progress bar closes
142
+ if get_config().verbose:
143
+ print(f"Computing UMAP: {len(features)} samples → {self.n_components}D")
144
+ print(f"Wrote {target_path} to {len(hdf5_paths)} file(s)")
145
+
146
+ return UmapResult(
147
+ n_samples=len(features), n_components=self.n_components, namespace=self.namespace, target_path=target_path
148
+ )
149
+
150
+ def _write_results(self, ctx: MultipleContext, target_path: str):
151
+ """Write UMAP coordinates to HDF5 files"""
152
+ for file_slice in ctx:
153
+ umap_coords = file_slice.slice(self.umap_embeddings)
154
+
155
+ with h5py.File(file_slice.hdf5_path, "a") as f:
156
+ ensure_groups(f, target_path)
157
+
158
+ if target_path in f:
159
+ del f[target_path]
160
+
161
+ # Fill with NaN for filtered patches
162
+ full_umap = np.full((len(file_slice.mask), self.n_components), np.nan, dtype=umap_coords.dtype)
163
+ full_umap[file_slice.mask] = umap_coords
164
+
165
+ ds = f.create_dataset(target_path, data=full_umap)
166
+ ds.attrs["n_components"] = self.n_components
167
+ ds.attrs["n_neighbors"] = self.n_neighbors
168
+ ds.attrs["min_dist"] = self.min_dist
169
+ ds.attrs["metric"] = self.metric
170
+ ds.attrs["model"] = self.model_name
171
+
172
+ def get_embeddings(self):
173
+ """Get computed UMAP embeddings"""
174
+ return self.umap_embeddings
@@ -0,0 +1,223 @@
1
+ """
2
+ WSI to HDF5 conversion command
3
+ """
4
+
5
+ import os
6
+ from typing import Callable
7
+
8
+ import cv2
9
+ import h5py
10
+ import numpy as np
11
+ from pydantic import BaseModel
12
+
13
+ from ..utils.white import create_white_detector
14
+ from ..wsi_files import create_wsi_file
15
+ from . import _progress, get_config
16
+
17
+
18
+ class Wsi2HDF5Result(BaseModel):
19
+ """Result of WSI to HDF5 conversion"""
20
+
21
+ mpp: float
22
+ original_mpp: float
23
+ scale: int
24
+ patch_count: int
25
+ patch_size: int
26
+ cols: int
27
+ rows: int
28
+ output_path: str
29
+
30
+
31
+ class Wsi2HDF5Command:
32
+ """
33
+ Convert WSI image to HDF5 format with patch extraction
34
+
35
+ Usage:
36
+ # Set global config once
37
+ commands.set_default_progress('tqdm')
38
+
39
+ # Create and run command
40
+ cmd = Wsi2HDF5Command(patch_size=256, engine='auto')
41
+ result = cmd(input_path='image.ndpi', output_path='output.h5')
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ patch_size: int = 256,
47
+ engine: str = "auto",
48
+ mpp: float = 0.5,
49
+ rotate: bool = False,
50
+ white_detector: Callable[[np.ndarray], bool] = None,
51
+ ):
52
+ """
53
+ Initialize WSI to HDF5 converter
54
+
55
+ Args:
56
+ patch_size: Size of patches to extract
57
+ engine: WSI reader engine ('auto', 'openslide', 'tifffile', 'standard')
58
+ mpp: Microns per pixel (for standard images)
59
+ rotate: Whether to rotate patches 180 degrees
60
+ white_detector: Function that takes (H, W, 3) array and returns bool.
61
+ If None, uses legacy is_white_patch with default params.
62
+
63
+ Note:
64
+ progress and verbose are controlled by global config:
65
+ - commands.set_default_progress('tqdm')
66
+ - commands.set_verbose(True/False)
67
+ """
68
+ self.patch_size = patch_size
69
+ self.engine = engine
70
+ self.mpp = mpp
71
+ self.rotate = rotate
72
+
73
+ # Set white detection function
74
+ if white_detector is None:
75
+ # Default: use ptp method with default threshold
76
+ self._is_white_patch = create_white_detector("ptp")
77
+ else:
78
+ self._is_white_patch = white_detector
79
+
80
+ def __call__(self, input_path: str, output_path: str) -> Wsi2HDF5Result:
81
+ """
82
+ Execute WSI to HDF5 conversion
83
+
84
+ Args:
85
+ input_path: Path to input WSI file
86
+ output_path: Path to output HDF5 file
87
+
88
+ Returns:
89
+ Wsi2HDF5Result: Metadata including mpp, scale, patch_count
90
+ """
91
+ # Create WSI reader
92
+ wsi = create_wsi_file(input_path, engine=self.engine, mpp=self.mpp)
93
+
94
+ # Calculate scale based on mpp
95
+ original_mpp = wsi.get_mpp()
96
+
97
+ if 0.360 < original_mpp < 0.660:
98
+ # mpp ≃ 0.5 mpp
99
+ scale = 1
100
+ elif original_mpp < 0.360:
101
+ scale = 2
102
+ else:
103
+ raise RuntimeError(f"Invalid mpp: {original_mpp:.6f}")
104
+
105
+ mpp = original_mpp * scale
106
+
107
+ # Get image dimensions
108
+ W, H = wsi.get_original_size()
109
+ S = self.patch_size # Scaled patch size
110
+ T = S * scale # Original patch size
111
+
112
+ x_patch_count = W // T
113
+ y_patch_count = H // T
114
+ width = (W // T) * T
115
+ row_count = H // T
116
+ coordinates = []
117
+
118
+ try:
119
+ # Create HDF5 file
120
+ with h5py.File(output_path, "w") as f:
121
+ # Write metadata (both as datasets and attrs for migration)
122
+ f.create_dataset("metadata/original_mpp", data=original_mpp)
123
+ f.create_dataset("metadata/original_width", data=W)
124
+ f.create_dataset("metadata/original_height", data=H)
125
+ f.create_dataset("metadata/image_level", data=0)
126
+ f.create_dataset("metadata/mpp", data=mpp)
127
+ f.create_dataset("metadata/scale", data=scale)
128
+ f.create_dataset("metadata/patch_size", data=S)
129
+ f.create_dataset("metadata/cols", data=x_patch_count)
130
+ f.create_dataset("metadata/rows", data=y_patch_count)
131
+
132
+ # Also save as attrs for future migration
133
+ f.attrs["original_mpp"] = original_mpp
134
+ f.attrs["original_width"] = W
135
+ f.attrs["original_height"] = H
136
+ f.attrs["image_level"] = 0
137
+ f.attrs["mpp"] = mpp
138
+ f.attrs["scale"] = scale
139
+ f.attrs["patch_size"] = S
140
+ f.attrs["cols"] = x_patch_count
141
+ f.attrs["rows"] = y_patch_count
142
+
143
+ # Create patches dataset
144
+ total_patches = f.create_dataset(
145
+ "patches",
146
+ shape=(x_patch_count * y_patch_count, S, S, 3),
147
+ dtype=np.uint8,
148
+ chunks=(1, S, S, 3),
149
+ compression="gzip",
150
+ compression_opts=9,
151
+ )
152
+
153
+ # Extract patches row by row
154
+ cursor = 0
155
+ tq = _progress(range(row_count))
156
+ for row in tq:
157
+ # Read one row
158
+ image = wsi.read_region((0, row * T, width, T))
159
+ image = cv2.resize(image, (width // scale, S), interpolation=cv2.INTER_LANCZOS4)
160
+
161
+ # Reshape into patches
162
+ patches = image.reshape(1, S, x_patch_count, S, 3) # (y, h, x, w, 3)
163
+ patches = patches.transpose(0, 2, 1, 3, 4) # (y, x, h, w, 3)
164
+ patches = patches[0]
165
+
166
+ # Filter white patches and collect valid ones
167
+ batch = []
168
+ for col, patch in enumerate(patches):
169
+ if self._is_white_patch(patch):
170
+ continue
171
+
172
+ if self.rotate:
173
+ patch = cv2.rotate(patch, cv2.ROTATE_180)
174
+ coordinates.append(((x_patch_count - 1 - col) * S, (y_patch_count - 1 - row) * S))
175
+ else:
176
+ coordinates.append((col * S, row * S))
177
+
178
+ batch.append(patch)
179
+
180
+ # Write batch
181
+ batch = np.array(batch)
182
+ total_patches[cursor : cursor + len(batch), ...] = batch
183
+ cursor += len(batch)
184
+
185
+ tq.set_description(f"Selected {len(batch)}/{len(patches)} patches (row {row}/{y_patch_count})")
186
+ tq.refresh()
187
+
188
+ # Resize to actual patch count and save coordinates
189
+ patch_count = len(coordinates)
190
+ f.create_dataset("coordinates", data=coordinates)
191
+ f["patches"].resize((patch_count, S, S, 3))
192
+ f.create_dataset("metadata/patch_count", data=patch_count)
193
+ f.attrs["patch_count"] = patch_count
194
+
195
+ except BaseException:
196
+ # Clean up incomplete file on error (including Ctrl-C)
197
+ if os.path.exists(output_path):
198
+ os.remove(output_path)
199
+ raise
200
+
201
+ # Verbose output after progress bar closes
202
+ if get_config().verbose:
203
+ print(f"Original mpp: {original_mpp:.6f}")
204
+ print(f"Image mpp: {mpp:.6f}")
205
+ print(f"Target resolutions: {W} x {H}")
206
+ print(f"Obtained resolutions: {x_patch_count * S} x {y_patch_count * S}")
207
+ print(f"Scale: {scale}")
208
+ print(f"Patch size: {T}")
209
+ print(f"Scaled patch size: {S}")
210
+ print(f"Row count: {y_patch_count}")
211
+ print(f"Col count: {x_patch_count}")
212
+ print(f"{patch_count} patches were selected.")
213
+
214
+ return Wsi2HDF5Result(
215
+ mpp=mpp,
216
+ original_mpp=original_mpp,
217
+ scale=scale,
218
+ patch_count=patch_count,
219
+ patch_size=S,
220
+ cols=x_patch_count,
221
+ rows=y_patch_count,
222
+ output_path=output_path,
223
+ )
wsi_toolbox/common.py ADDED
@@ -0,0 +1,148 @@
1
+ """
2
+ Global configuration and settings for WSI-toolbox
3
+ """
4
+
5
+ from functools import partial
6
+ from typing import Callable
7
+
8
+ from matplotlib import pyplot as plt
9
+ from pydantic import BaseModel, Field
10
+
11
+ from .models import MODEL_NAMES, create_foundation_model
12
+ from .utils.progress import Progress
13
+
14
+
15
+ # === Global Configuration (Pydantic) ===
16
+ class Config(BaseModel):
17
+ """Global configuration for commands"""
18
+
19
+ progress: str = Field(default="tqdm", description="Progress bar backend")
20
+ model_name: str = Field(default="uni", description="Default model name")
21
+ model_generator: Callable | None = Field(default=None, description="Model generator function")
22
+ verbose: bool = Field(default=True, description="Verbose output")
23
+ device: str = Field(default="cuda", description="Device for computation")
24
+ cluster_cmap: str = Field(default="tab20", description="Cluster colormap name")
25
+
26
+ class Config:
27
+ arbitrary_types_allowed = True
28
+
29
+
30
+ # Global config instance
31
+ _config = Config()
32
+
33
+
34
+ def get_config() -> Config:
35
+ """Get global configuration instance"""
36
+ return _config
37
+
38
+
39
+ def set_default_progress(backend: str):
40
+ """Set global default progress backend ('tqdm', 'streamlit', etc.)"""
41
+ _config.progress = backend
42
+
43
+
44
+ def set_default_model(name: str, generator: Callable, label: str | None = None):
45
+ """Set custom model generator as default
46
+
47
+ Args:
48
+ name: Model name (used for file paths, etc.)
49
+ generator: Callable that returns a model instance (e.g., lambda: MyModel())
50
+ label: Display label (defaults to name if not provided)
51
+
52
+ Example:
53
+ >>> set_default_model('resnet', lambda: torchvision.models.resnet50())
54
+ >>> set_default_model('custom', create_my_model, label='My Custom Model')
55
+ """
56
+ _config.model_name = name
57
+ _config.model_generator = generator
58
+
59
+
60
+ def set_default_model_preset(preset_name: str):
61
+ """Set default model from preset ('uni', 'gigapath', 'virchow2')
62
+
63
+ Args:
64
+ preset_name: One of 'uni', 'gigapath', 'virchow2'
65
+ """
66
+ if preset_name not in MODEL_NAMES:
67
+ raise ValueError(f"Invalid preset: {preset_name}. Must be one of {MODEL_NAMES}")
68
+
69
+ _config.model_name = preset_name
70
+ _config.model_generator = partial(create_foundation_model, preset_name)
71
+
72
+
73
+ def create_default_model():
74
+ """Create a new model instance using the registered generator.
75
+
76
+ Returns:
77
+ torch.nn.Module: Fresh model instance
78
+
79
+ Raises:
80
+ RuntimeError: If no model generator is registered
81
+
82
+ Example:
83
+ >>> set_default_model_preset('uni')
84
+ >>> model = create_default_model() # Creates new UNI model instance
85
+ """
86
+ if _config.model_generator is None:
87
+ raise RuntimeError(
88
+ "No model generator registered. Call set_default_model() or set_default_model_preset() first."
89
+ )
90
+ return _config.model_generator()
91
+
92
+
93
+ def set_default_device(device: str):
94
+ """Set global default device ('cuda', 'cpu')"""
95
+ _config.device = device
96
+
97
+
98
+ def set_verbose(verbose: bool):
99
+ """Set global verbosity"""
100
+ _config.verbose = verbose
101
+
102
+
103
+ def set_default_cluster_cmap(cmap_name: str):
104
+ """Set global cluster colormap ('tab20', 'tab10', 'Set1', etc.)"""
105
+ _config.cluster_cmap = cmap_name
106
+
107
+
108
+ def _get_cluster_color(cluster_id: int):
109
+ """
110
+ Get color for cluster ID using global colormap
111
+
112
+ Args:
113
+ cluster_id: Cluster ID
114
+
115
+ Returns:
116
+ Color in matplotlib format (array or string)
117
+ """
118
+
119
+ cmap = plt.get_cmap(_config.cluster_cmap)
120
+ return cmap(cluster_id % 20) # Modulo to handle colormaps with limited colors
121
+
122
+
123
+ def _get(key: str, value):
124
+ """Get value or fall back to global default"""
125
+ if value is not None:
126
+ return value
127
+ return getattr(_config, key)
128
+
129
+
130
+ def _progress(iterable=None, total=None, desc="", **kwargs):
131
+ """Create a progress bar using global config backend"""
132
+ return Progress(iterable=iterable, backend=_config.progress, total=total, desc=desc, **kwargs)
133
+
134
+
135
+ __all__ = [
136
+ "Config",
137
+ "get_config",
138
+ "set_default_progress",
139
+ "set_default_model",
140
+ "set_default_model_preset",
141
+ "create_default_model",
142
+ "set_default_device",
143
+ "set_verbose",
144
+ "set_default_cluster_cmap",
145
+ "_get_cluster_color",
146
+ "_get",
147
+ "_progress",
148
+ ]
wsi_toolbox/models.py ADDED
@@ -0,0 +1,30 @@
1
+ MODEL_NAMES = ["uni", "gigapath", "virchow2"]
2
+
3
+
4
+ def create_foundation_model(model_name: str):
5
+ """
6
+ Create a foundation model instance by preset name.
7
+
8
+ Args:
9
+ model_name: One of 'uni', 'gigapath', 'virchow2'
10
+
11
+ Returns:
12
+ torch.nn.Module: Model instance (not moved to device, not in eval mode)
13
+ """
14
+ # Lazy import: timm/torch are slow to load (~2s), defer until model creation
15
+ import timm # noqa: PLC0415
16
+ import torch # noqa: PLC0415
17
+ from timm.layers import SwiGLUPacked # noqa: PLC0415
18
+
19
+ if model_name == "uni":
20
+ return timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, dynamic_img_size=True, init_values=1e-5)
21
+
22
+ if model_name == "gigapath":
23
+ return timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True, dynamic_img_size=True)
24
+
25
+ if model_name == "virchow2":
26
+ return timm.create_model(
27
+ "hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU
28
+ )
29
+
30
+ raise ValueError(f"Invalid model_name: {model_name}. Must be one of {MODEL_NAMES}")
@@ -0,0 +1,109 @@
1
+ import sys
2
+
3
+ import numpy as np
4
+ from matplotlib import colors as mcolors
5
+ from matplotlib import pyplot as plt
6
+ from matplotlib.offsetbox import AnnotationBbox, OffsetImage
7
+ from PIL import Image, ImageDraw
8
+ from PIL.Image import Image as ImageType
9
+
10
+
11
+ def yes_no_prompt(question):
12
+ print(f"{question} [Y/n]: ", end="")
13
+ response = input().lower()
14
+ return response == "" or response.startswith("y")
15
+
16
+
17
+ def get_platform_font():
18
+ if sys.platform == "win32":
19
+ # Windows
20
+ font_path = "C:\\Windows\\Fonts\\msgothic.ttc" # MSゴシック
21
+ elif sys.platform == "darwin":
22
+ # macOS
23
+ font_path = "/System/Library/Fonts/Supplemental/Arial.ttf"
24
+ else:
25
+ # Linux
26
+ # font_path = '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf' # TODO: propagation
27
+ font_path = "/usr/share/fonts/TTF/DejaVuSans.ttf"
28
+ return font_path
29
+
30
+
31
+ def create_frame(size, color, text, font):
32
+ frame = Image.new("RGBA", (size, size), (0, 0, 0, 0))
33
+ draw = ImageDraw.Draw(frame)
34
+ draw.rectangle((0, 0, size, size), outline=color, width=4)
35
+ text_color = "white" if mcolors.rgb_to_hsv(mcolors.hex2color(color))[2] < 0.9 else "black"
36
+ bbox = np.array(draw.textbbox((0, 0), text, font=font))
37
+ draw.rectangle((4, 4, bbox[2] + 4, bbox[3] + 4), fill=color)
38
+ draw.text((1, 1), text, font=font, fill=text_color)
39
+ return frame
40
+
41
+
42
+ def safe_del(hdf_file, key_path):
43
+ """
44
+ Safely delete a dataset from HDF5 file if it exists
45
+
46
+ Args:
47
+ hdf_file: h5py.File object
48
+ key_path: Dataset path to delete
49
+ """
50
+ if key_path in hdf_file:
51
+ del hdf_file[key_path]
52
+
53
+
54
+ def hover_images_on_scatters(scatters, imagess, ax=None, offset=(150, 30)):
55
+ if ax is None:
56
+ ax = plt.gca()
57
+ fig = ax.figure
58
+
59
+ def as_image(image_or_path):
60
+ if isinstance(image_or_path, np.ndarray):
61
+ return image_or_path
62
+ if isinstance(image_or_path, ImageType):
63
+ return image_or_path
64
+ if isinstance(image_or_path, str):
65
+ return Image.open(image_or_path)
66
+ raise RuntimeError("Invalid param", image_or_path)
67
+
68
+ imagebox = OffsetImage(as_image(imagess[0][0]), zoom=0.5)
69
+ imagebox.image.axes = ax
70
+ annot = AnnotationBbox(
71
+ imagebox,
72
+ xy=(0, 0),
73
+ # xybox=(256, 256),
74
+ # xycoords='data',
75
+ boxcoords="offset points",
76
+ # boxcoords=('axes fraction', 'data'),
77
+ pad=0.1,
78
+ arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.3"),
79
+ zorder=100,
80
+ )
81
+ annot.set_visible(False)
82
+ ax.add_artist(annot)
83
+
84
+ def hover(event):
85
+ vis = annot.get_visible()
86
+ if event.inaxes != ax:
87
+ return
88
+ for n, (sc, ii) in enumerate(zip(scatters, imagess)):
89
+ cont, index = sc.contains(event)
90
+ if cont:
91
+ i = index["ind"][0]
92
+ pos = sc.get_offsets()[i]
93
+ annot.xy = pos
94
+ annot.xybox = pos + np.array(offset)
95
+ image = as_image(ii[i])
96
+ # text = unique_code[n]
97
+ # annot.set_text(text)
98
+ # annot.get_bbox_patch().set_facecolor(cmap(int(text)/10))
99
+ imagebox.set_data(image)
100
+ annot.set_visible(True)
101
+ fig.canvas.draw_idle()
102
+ return
103
+
104
+ if vis:
105
+ annot.set_visible(False)
106
+ fig.canvas.draw_idle()
107
+ return
108
+
109
+ fig.canvas.mpl_connect("motion_notify_event", hover)