wsi-toolbox 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,127 @@
1
+ import multiprocessing
2
+ import numpy as np
3
+ from sklearn.decomposition import PCA
4
+ from sklearn.neighbors import NearestNeighbors
5
+ import networkx as nx
6
+ import igraph as ig
7
+ import leidenalg as la
8
+ from joblib import Parallel, delayed
9
+
10
+ from .progress import tqdm_or_st
11
+
12
+
13
+
14
+ def find_optimal_components(features, threshold=0.95):
15
+ pca = PCA()
16
+ pca.fit(features)
17
+ explained_variance = pca.explained_variance_ratio_
18
+ # 累積寄与率が95%を超える次元数を選択する例
19
+ cumulative_variance = np.cumsum(explained_variance)
20
+ optimal_n = np.argmax(cumulative_variance >= threshold) + 1
21
+ return min(optimal_n, len(features) - 1)
22
+
23
+
24
+ def process_edges_batch(batch_indices, all_indices, h, use_umap_embs, pca=None):
25
+ """Process a batch of nodes and their edges"""
26
+ edges = []
27
+ weights = []
28
+
29
+ for i in batch_indices:
30
+ for j in all_indices[i]:
31
+ if i == j: # skip self loop
32
+ continue
33
+
34
+ if use_umap_embs:
35
+ distance = np.linalg.norm(h[i] - h[j])
36
+ weight = np.exp(-distance)
37
+ else:
38
+ explained_variance_ratio = pca.explained_variance_ratio_
39
+ weighted_diff = (h[i] - h[j]) * np.sqrt(explained_variance_ratio[:len(h[i])])
40
+ distance = np.linalg.norm(weighted_diff)
41
+ weight = np.exp(-distance / distance.mean())
42
+
43
+ edges.append((i, j))
44
+ weights.append(weight)
45
+
46
+ return edges, weights
47
+
48
+ def leiden_cluster(features, umap_emb_func=None, resolution=1.0, n_jobs=-1, progress='tqdm'):
49
+ if n_jobs < 0:
50
+ n_jobs = multiprocessing.cpu_count()
51
+ use_umap_embs = umap_emb_func is not None
52
+ n_samples = features.shape[0]
53
+
54
+ progress_count = 5 # (UMAP), PCA, KNN, edges, leiden, Finalize
55
+ if use_umap_embs:
56
+ progress_count += 1
57
+ tq = tqdm_or_st(total=progress_count, backend=progress)
58
+
59
+ # 1. UMAP cluster if needed
60
+ if use_umap_embs:
61
+ tq.set_description(f'UMAP projection...')
62
+ umap_embeddings = umap_emb_func()
63
+ tq.update(1)
64
+ else:
65
+ umap_embeddings = None
66
+
67
+ # 2. pre-PCA
68
+ tq.set_description(f'Processing PCA...')
69
+ n_components = find_optimal_components(features)
70
+ pca = PCA(n_components)
71
+ target_features = pca.fit_transform(features)
72
+ tq.update(1)
73
+
74
+ # 3. KNN
75
+ tq.set_description(f'Processing KNN...')
76
+ k = int(np.sqrt(len(target_features)))
77
+ nn = NearestNeighbors(n_neighbors=k).fit(target_features)
78
+ distances, indices = nn.kneighbors(target_features)
79
+ tq.update(1)
80
+
81
+ # 4. Build graph
82
+ tq.set_description(f'Processing edges...')
83
+ G = nx.Graph()
84
+ G.add_nodes_from(range(n_samples))
85
+
86
+ h = umap_embeddings if use_umap_embs else target_features
87
+ batch_size = max(1, n_samples // n_jobs)
88
+ batches = [list(range(i, min(i + batch_size, n_samples)))
89
+ for i in range(0, n_samples, batch_size)]
90
+ results = Parallel(n_jobs=n_jobs)([
91
+ delayed(process_edges_batch)(batch, indices, h, use_umap_embs, pca)
92
+ for batch in batches
93
+ ]
94
+ )
95
+
96
+ for batch_edges, batch_weights in results:
97
+ for (i, j), weight in zip(batch_edges, batch_weights):
98
+ G.add_edge(i, j, weight=weight)
99
+ tq.update(1)
100
+
101
+ # 5. Leiden clustering
102
+ tq.set_description(f'Leiden clustering...')
103
+ edges = list(G.edges())
104
+ weights = [G[u][v]['weight'] for u, v in edges]
105
+ ig_graph = ig.Graph(n=n_samples, edges=edges, edge_attrs={'weight': weights})
106
+
107
+ partition = la.find_partition(
108
+ ig_graph,
109
+ la.RBConfigurationVertexPartition,
110
+ weights='weight',
111
+ resolution_parameter=resolution, # maybe most adaptive
112
+ # resolution_parameter=1.0, # maybe most adaptive
113
+ # resolution_parameter=0.5, # more coarse cluster
114
+ )
115
+ tq.update(1)
116
+
117
+ # 6. Finalize
118
+ tq.set_description(f'Finalize...')
119
+ clusters = np.full(n_samples, -1) # Initialize all as noise
120
+ for i, community in enumerate(partition):
121
+ for node in community:
122
+ clusters[node] = i
123
+ tq.update(1)
124
+ tq.close()
125
+
126
+ return clusters
127
+
@@ -0,0 +1,25 @@
1
+ import os
2
+ import sys
3
+ import re
4
+ from string import capwords
5
+ import inspect
6
+ import asyncio
7
+ from typing import Callable, Type
8
+ import argparse
9
+
10
+ from pydantic import BaseModel, Field
11
+ from pydantic_autocli import AutoCLI
12
+
13
+ from .seed import fix_global_seed, get_global_seed
14
+
15
+
16
+ class BaseMLArgs(BaseModel):
17
+ seed: int = get_global_seed()
18
+
19
+ class BaseMLCLI(AutoCLI):
20
+ class CommonArgs(BaseMLArgs):
21
+ pass
22
+
23
+ def _pre_common(self, a:BaseMLArgs):
24
+ fix_global_seed(a.seed)
25
+ super()._pre_common(a)
@@ -0,0 +1,57 @@
1
+ """
2
+ Helper utility functions for WSI processing
3
+ """
4
+
5
+ import numpy as np
6
+ import h5py
7
+
8
+
9
+ def is_white_patch(patch, rgb_std_threshold=7.0, white_ratio=0.7):
10
+ """
11
+ Check if a patch is mostly white/blank
12
+
13
+ Args:
14
+ patch: RGB patch (H, W, 3)
15
+ rgb_std_threshold: Threshold for RGB standard deviation
16
+ white_ratio: Ratio threshold for white pixels
17
+
18
+ Returns:
19
+ bool: True if patch is considered white/blank
20
+ """
21
+ # white: RGB std < 7.0
22
+ rgb_std_pixels = np.std(patch, axis=2) < rgb_std_threshold
23
+ white_pixels = np.sum(rgb_std_pixels)
24
+ total_pixels = patch.shape[0] * patch.shape[1]
25
+ white_ratio_calculated = white_pixels / total_pixels
26
+ # print('whi' if white_ratio_calculated > white_ratio else 'use',
27
+ # 'std{:.3f}'.format(np.sum(rgb_std_pixels)/total_pixels)
28
+ # )
29
+ return white_ratio_calculated > white_ratio
30
+
31
+
32
+ def cosine_distance(x, y):
33
+ """
34
+ Calculate cosine distance with exponential weighting
35
+
36
+ Args:
37
+ x: First vector
38
+ y: Second vector
39
+
40
+ Returns:
41
+ tuple: (distance, weight)
42
+ """
43
+ distance = np.linalg.norm(x - y)
44
+ weight = np.exp(-distance / distance.mean())
45
+ return distance, weight
46
+
47
+
48
+ def safe_del(hdf_file, key_path):
49
+ """
50
+ Safely delete a dataset from HDF5 file if it exists
51
+
52
+ Args:
53
+ hdf_file: h5py.File object
54
+ key_path: Dataset path to delete
55
+ """
56
+ if key_path in hdf_file:
57
+ del hdf_file[key_path]
@@ -0,0 +1,206 @@
1
+ import time
2
+ from typing import Iterable, TypeVar, Optional, Union, Any
3
+
4
+ T = TypeVar('T')
5
+
6
+ class StreamlitProgress:
7
+ """tqdmと同じインターフェースを持つStreamlitのプログレスバー"""
8
+
9
+ def __init__(self, iterable: Optional[Iterable[T]] = None, total: Optional[int] = None,
10
+ desc: str = "", **kwargs):
11
+ self.iterable = iterable
12
+ self.total = total if total is not None else (len(iterable) if iterable is not None and hasattr(iterable, "__len__") else None)
13
+ self.desc = desc
14
+ self.n = 0
15
+ self.kwargs = kwargs
16
+
17
+ try:
18
+ import streamlit as st
19
+ # 説明テキスト用のコンテナ
20
+ self.text_container = st.empty()
21
+ if desc:
22
+ self.text_container.text(desc)
23
+ # プログレスバー
24
+ self.progress_bar = st.progress(0)
25
+ # 後置テキスト用のコンテナ
26
+ self.postfix_container = st.empty()
27
+ except ImportError:
28
+ raise ImportError("streamlitがインストールされていません。")
29
+
30
+ def update(self, n: int = 1) -> None:
31
+ """進捗を更新する"""
32
+ self.n += n
33
+ if self.total:
34
+ self.progress_bar.progress(min(self.n / self.total, 1.0))
35
+
36
+ def set_description(self, desc: str = None, refresh: bool = True) -> None:
37
+ """説明テキストを更新する"""
38
+ if desc is not None:
39
+ self.desc = desc
40
+ # self.text_container.text(desc)
41
+ self.text_container.markdown('<p style="font-size:14px; color:gray;">' + desc +'</p>', unsafe_allow_html=True)
42
+
43
+
44
+ def set_postfix(self, ordered_dict=None, **kwargs) -> None:
45
+ """後置テキストを設定する"""
46
+ # ordered_dictとkwargsを組み合わせる
47
+ postfix_dict = {}
48
+ if ordered_dict:
49
+ postfix_dict.update(ordered_dict)
50
+ if kwargs:
51
+ postfix_dict.update(kwargs)
52
+
53
+ if postfix_dict:
54
+ # 辞書を文字列に変換して表示
55
+ postfix_str = ', '.join(f'{k}={v}' for k, v in postfix_dict.items())
56
+ self.postfix_container.text(f"状態: {postfix_str}")
57
+
58
+ def close(self) -> None:
59
+ """プログレスバーを完了状態にする"""
60
+ if self.total:
61
+ self.progress_bar.progress(1.0)
62
+ self.text_container.empty()
63
+
64
+ def refresh(self):
65
+ """ 不要なので何もしない """
66
+ pass
67
+
68
+ def __iter__(self):
69
+ """イテレータとして使用できるようにする"""
70
+ if self.iterable is None:
71
+ raise ValueError("このプログレスバーはイテレータとして使用できません")
72
+
73
+ for obj in self.iterable:
74
+ yield obj
75
+ self.update(1)
76
+
77
+ self.close()
78
+
79
+ def __enter__(self):
80
+ """コンテキストマネージャとして使用できるようにする"""
81
+ return self
82
+
83
+ def __exit__(self, exc_type, exc_val, exc_tb):
84
+ """コンテキスト終了時に呼ばれる"""
85
+ self.close()
86
+
87
+ def tqdm_or_st(iterable: Optional[Iterable[T]] = None,
88
+ backend: str = 'tqdm',
89
+ **kwargs) -> Union['tqdm', StreamlitProgress]:
90
+ """
91
+ 指定されたバックエンドのプログレスバーを返す
92
+
93
+ Args:
94
+ iterable: 進捗を表示するイテレータ
95
+ backend: バックエンド ("tqdm", "streamlit")
96
+ **kwargs: tqdmやStreamlitProgressに渡す引数
97
+
98
+ Returns:
99
+ tqdm または StreamlitProgress オブジェクト
100
+ """
101
+ # if backend == "auto":
102
+ # try:
103
+ # import streamlit as st
104
+ # if st._is_running_with_streamlit:
105
+ # backend = "streamlit"
106
+ # else:
107
+ # backend = "tqdm"
108
+ # except (ImportError, AttributeError):
109
+ # backend = "tqdm"
110
+
111
+ assert backend in ['tqdm', 'streamlit']
112
+
113
+ if backend == "tqdm":
114
+ try:
115
+ from tqdm import tqdm
116
+ return tqdm(iterable, **kwargs)
117
+ except ImportError:
118
+ print("tqdmが見つからないため、Streamlitバックエンドを試行します...")
119
+ backend = "streamlit"
120
+
121
+ # Streamlitを使用
122
+ if backend == "streamlit":
123
+ try:
124
+ return StreamlitProgress(iterable, **kwargs)
125
+ except ImportError:
126
+ print("Streamlitが見つかりません。プログレスバーなしで実行します。")
127
+ # フォールバック: 何もしないダミープログレスバー
128
+ try:
129
+ from tqdm import tqdm
130
+ return tqdm(iterable, disable=True, **kwargs)
131
+ except ImportError:
132
+ # tqdmもないので、単なるイテレータを返す
133
+ class DummyTqdm:
134
+ def __init__(self, iterable=None, **kwargs):
135
+ self.iterable = iterable
136
+ def update(self, n=1): pass
137
+ def close(self): pass
138
+ def set_description(self, desc=None, refresh=True): pass
139
+ def set_postfix(self, ordered_dict=None, **kwargs): pass
140
+ def __iter__(self):
141
+ if self.iterable is None: raise ValueError("イテレータがありません")
142
+ for x in self.iterable: yield x
143
+ def __enter__(self): return self
144
+ def __exit__(self, *args, **kwargs): pass
145
+ return DummyTqdm(iterable, **kwargs)
146
+
147
+ # 基本的な使用例
148
+ def basic_example():
149
+ """基本的な使用例"""
150
+ items = list(range(10))
151
+
152
+ # tqdmと同じ使い方
153
+ for item in tqdm_or_st(items, desc="基本的な例", backend="tqdm"):
154
+ time.sleep(0.1)
155
+ print(f"処理中: {item}")
156
+
157
+ # Streamlitの使用例
158
+ def streamlit_example():
159
+ """Streamlitでの使用例 (Streamlitアプリ内で実行する必要があります)"""
160
+ import streamlit as st
161
+
162
+ st.title("処理の進捗表示")
163
+
164
+ items = list(range(10))
165
+ results = []
166
+
167
+ # 自動的にStreamlitを検出
168
+ for item in tqdm_or_st(items, desc="処理中...", backend="auto"):
169
+ time.sleep(0.2)
170
+ results.append(item * 2)
171
+
172
+ st.write("結果:", results)
173
+
174
+ # コンテキストマネージャとしての使用例
175
+ def context_manager_example():
176
+ """コンテキストマネージャとしての使用例"""
177
+ total_steps = 5
178
+
179
+ # with文で使用
180
+ with tqdm_or_st(total=total_steps, desc="手動更新", backend="tqdm") as pbar:
181
+ for i in range(total_steps):
182
+ time.sleep(0.2)
183
+
184
+ # 説明を更新
185
+ if i == 2:
186
+ pbar.set_description(f"ステップ {i+1}/{total_steps}")
187
+
188
+ # 追加情報を表示
189
+ pbar.set_postfix(progress=f"{(i+1)/total_steps:.0%}")
190
+
191
+ # 進捗を更新
192
+ pbar.update(1)
193
+
194
+ # テスト用のメイン関数
195
+ def main():
196
+ print("基本的な使用例:")
197
+ basic_example()
198
+
199
+ print("\nコンテキストマネージャとしての使用例:")
200
+ context_manager_example()
201
+
202
+ print("\nStreamlitの例はStreamlitアプリ内で実行してください")
203
+ # streamlit_example() # Streamlitアプリ内でのみ実行可能
204
+
205
+ if __name__ == "__main__":
206
+ main()
@@ -0,0 +1,21 @@
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+
5
+ __GLOBAL_SEED = 42
6
+
7
+ def get_global_seed():
8
+ return __GLOBAL_SEED
9
+
10
+ def fix_global_seed(seed=None):
11
+ if seed is None:
12
+ seed = get_global_seed()
13
+ global __GLOBAL_SEED
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.random.manual_seed(seed)
18
+ torch.cuda.manual_seed(seed)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.use_deterministic_algorithms = True
21
+ __GLOBAL_SEED = seed
@@ -0,0 +1,53 @@
1
+ import streamlit as st
2
+ from contextlib import contextmanager
3
+
4
+ HORIZONTAL_STYLE = """
5
+ <style class="hide-element">
6
+ /* Hides the style container and removes the extra spacing */
7
+ .element-container:has(.hide-element) {
8
+ display: none;
9
+ }
10
+ /*
11
+ The selector for >.element-container is necessary to avoid selecting the whole
12
+ body of the streamlit app, which is also a stVerticalBlock.
13
+ */
14
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) {
15
+ display: flex;
16
+ flex-direction: row !important;
17
+ flex-wrap: wrap;
18
+ gap: 0.5rem;
19
+ align-items: baseline;
20
+ }
21
+ /* Buttons and their parent container all have a width of 704px, which we need to override */
22
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div {
23
+ width: max-content !important;
24
+ }
25
+ /* Selectbox container */
26
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[data-testid="stSelectbox"] {
27
+ display: flex !important;
28
+ flex-direction: row !important;
29
+ align-items: center !important;
30
+ gap: 0.5rem !important;
31
+ }
32
+ /* Selectbox label */
33
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[data-testid="stWidgetLabel"] {
34
+ margin-bottom: 0 !important;
35
+ padding-right: 0.5rem !important;
36
+ }
37
+ /* Selectbox input container */
38
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[data-baseweb="select"] {
39
+ min-width: 120px !important;
40
+ }
41
+ /* Selectbox dropdown */
42
+ div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[role="listbox"] {
43
+ min-width: 120px !important;
44
+ }
45
+ </style>
46
+ """
47
+
48
+ @contextmanager
49
+ def st_horizontal():
50
+ st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True)
51
+ with st.container():
52
+ st.markdown('<span class="hide-element horizontal-marker"></span>', unsafe_allow_html=True)
53
+ yield