wsi-toolbox 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wsi_toolbox/__init__.py +119 -0
- wsi_toolbox/app.py +753 -0
- wsi_toolbox/cli.py +485 -0
- wsi_toolbox/commands/__init__.py +92 -0
- wsi_toolbox/commands/clustering.py +214 -0
- wsi_toolbox/commands/dzi_export.py +202 -0
- wsi_toolbox/commands/patch_embedding.py +199 -0
- wsi_toolbox/commands/preview.py +335 -0
- wsi_toolbox/commands/wsi.py +196 -0
- wsi_toolbox/exp.py +466 -0
- wsi_toolbox/models.py +38 -0
- wsi_toolbox/utils/__init__.py +153 -0
- wsi_toolbox/utils/analysis.py +127 -0
- wsi_toolbox/utils/cli.py +25 -0
- wsi_toolbox/utils/helpers.py +57 -0
- wsi_toolbox/utils/progress.py +206 -0
- wsi_toolbox/utils/seed.py +21 -0
- wsi_toolbox/utils/st.py +53 -0
- wsi_toolbox/watcher.py +261 -0
- wsi_toolbox/wsi_files.py +187 -0
- wsi_toolbox-0.1.0.dist-info/METADATA +269 -0
- wsi_toolbox-0.1.0.dist-info/RECORD +25 -0
- wsi_toolbox-0.1.0.dist-info/WHEEL +4 -0
- wsi_toolbox-0.1.0.dist-info/entry_points.txt +2 -0
- wsi_toolbox-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
import numpy as np
|
|
3
|
+
from sklearn.decomposition import PCA
|
|
4
|
+
from sklearn.neighbors import NearestNeighbors
|
|
5
|
+
import networkx as nx
|
|
6
|
+
import igraph as ig
|
|
7
|
+
import leidenalg as la
|
|
8
|
+
from joblib import Parallel, delayed
|
|
9
|
+
|
|
10
|
+
from .progress import tqdm_or_st
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def find_optimal_components(features, threshold=0.95):
|
|
15
|
+
pca = PCA()
|
|
16
|
+
pca.fit(features)
|
|
17
|
+
explained_variance = pca.explained_variance_ratio_
|
|
18
|
+
# 累積寄与率が95%を超える次元数を選択する例
|
|
19
|
+
cumulative_variance = np.cumsum(explained_variance)
|
|
20
|
+
optimal_n = np.argmax(cumulative_variance >= threshold) + 1
|
|
21
|
+
return min(optimal_n, len(features) - 1)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def process_edges_batch(batch_indices, all_indices, h, use_umap_embs, pca=None):
|
|
25
|
+
"""Process a batch of nodes and their edges"""
|
|
26
|
+
edges = []
|
|
27
|
+
weights = []
|
|
28
|
+
|
|
29
|
+
for i in batch_indices:
|
|
30
|
+
for j in all_indices[i]:
|
|
31
|
+
if i == j: # skip self loop
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
if use_umap_embs:
|
|
35
|
+
distance = np.linalg.norm(h[i] - h[j])
|
|
36
|
+
weight = np.exp(-distance)
|
|
37
|
+
else:
|
|
38
|
+
explained_variance_ratio = pca.explained_variance_ratio_
|
|
39
|
+
weighted_diff = (h[i] - h[j]) * np.sqrt(explained_variance_ratio[:len(h[i])])
|
|
40
|
+
distance = np.linalg.norm(weighted_diff)
|
|
41
|
+
weight = np.exp(-distance / distance.mean())
|
|
42
|
+
|
|
43
|
+
edges.append((i, j))
|
|
44
|
+
weights.append(weight)
|
|
45
|
+
|
|
46
|
+
return edges, weights
|
|
47
|
+
|
|
48
|
+
def leiden_cluster(features, umap_emb_func=None, resolution=1.0, n_jobs=-1, progress='tqdm'):
|
|
49
|
+
if n_jobs < 0:
|
|
50
|
+
n_jobs = multiprocessing.cpu_count()
|
|
51
|
+
use_umap_embs = umap_emb_func is not None
|
|
52
|
+
n_samples = features.shape[0]
|
|
53
|
+
|
|
54
|
+
progress_count = 5 # (UMAP), PCA, KNN, edges, leiden, Finalize
|
|
55
|
+
if use_umap_embs:
|
|
56
|
+
progress_count += 1
|
|
57
|
+
tq = tqdm_or_st(total=progress_count, backend=progress)
|
|
58
|
+
|
|
59
|
+
# 1. UMAP cluster if needed
|
|
60
|
+
if use_umap_embs:
|
|
61
|
+
tq.set_description(f'UMAP projection...')
|
|
62
|
+
umap_embeddings = umap_emb_func()
|
|
63
|
+
tq.update(1)
|
|
64
|
+
else:
|
|
65
|
+
umap_embeddings = None
|
|
66
|
+
|
|
67
|
+
# 2. pre-PCA
|
|
68
|
+
tq.set_description(f'Processing PCA...')
|
|
69
|
+
n_components = find_optimal_components(features)
|
|
70
|
+
pca = PCA(n_components)
|
|
71
|
+
target_features = pca.fit_transform(features)
|
|
72
|
+
tq.update(1)
|
|
73
|
+
|
|
74
|
+
# 3. KNN
|
|
75
|
+
tq.set_description(f'Processing KNN...')
|
|
76
|
+
k = int(np.sqrt(len(target_features)))
|
|
77
|
+
nn = NearestNeighbors(n_neighbors=k).fit(target_features)
|
|
78
|
+
distances, indices = nn.kneighbors(target_features)
|
|
79
|
+
tq.update(1)
|
|
80
|
+
|
|
81
|
+
# 4. Build graph
|
|
82
|
+
tq.set_description(f'Processing edges...')
|
|
83
|
+
G = nx.Graph()
|
|
84
|
+
G.add_nodes_from(range(n_samples))
|
|
85
|
+
|
|
86
|
+
h = umap_embeddings if use_umap_embs else target_features
|
|
87
|
+
batch_size = max(1, n_samples // n_jobs)
|
|
88
|
+
batches = [list(range(i, min(i + batch_size, n_samples)))
|
|
89
|
+
for i in range(0, n_samples, batch_size)]
|
|
90
|
+
results = Parallel(n_jobs=n_jobs)([
|
|
91
|
+
delayed(process_edges_batch)(batch, indices, h, use_umap_embs, pca)
|
|
92
|
+
for batch in batches
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
for batch_edges, batch_weights in results:
|
|
97
|
+
for (i, j), weight in zip(batch_edges, batch_weights):
|
|
98
|
+
G.add_edge(i, j, weight=weight)
|
|
99
|
+
tq.update(1)
|
|
100
|
+
|
|
101
|
+
# 5. Leiden clustering
|
|
102
|
+
tq.set_description(f'Leiden clustering...')
|
|
103
|
+
edges = list(G.edges())
|
|
104
|
+
weights = [G[u][v]['weight'] for u, v in edges]
|
|
105
|
+
ig_graph = ig.Graph(n=n_samples, edges=edges, edge_attrs={'weight': weights})
|
|
106
|
+
|
|
107
|
+
partition = la.find_partition(
|
|
108
|
+
ig_graph,
|
|
109
|
+
la.RBConfigurationVertexPartition,
|
|
110
|
+
weights='weight',
|
|
111
|
+
resolution_parameter=resolution, # maybe most adaptive
|
|
112
|
+
# resolution_parameter=1.0, # maybe most adaptive
|
|
113
|
+
# resolution_parameter=0.5, # more coarse cluster
|
|
114
|
+
)
|
|
115
|
+
tq.update(1)
|
|
116
|
+
|
|
117
|
+
# 6. Finalize
|
|
118
|
+
tq.set_description(f'Finalize...')
|
|
119
|
+
clusters = np.full(n_samples, -1) # Initialize all as noise
|
|
120
|
+
for i, community in enumerate(partition):
|
|
121
|
+
for node in community:
|
|
122
|
+
clusters[node] = i
|
|
123
|
+
tq.update(1)
|
|
124
|
+
tq.close()
|
|
125
|
+
|
|
126
|
+
return clusters
|
|
127
|
+
|
wsi_toolbox/utils/cli.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import re
|
|
4
|
+
from string import capwords
|
|
5
|
+
import inspect
|
|
6
|
+
import asyncio
|
|
7
|
+
from typing import Callable, Type
|
|
8
|
+
import argparse
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
from pydantic_autocli import AutoCLI
|
|
12
|
+
|
|
13
|
+
from .seed import fix_global_seed, get_global_seed
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseMLArgs(BaseModel):
|
|
17
|
+
seed: int = get_global_seed()
|
|
18
|
+
|
|
19
|
+
class BaseMLCLI(AutoCLI):
|
|
20
|
+
class CommonArgs(BaseMLArgs):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def _pre_common(self, a:BaseMLArgs):
|
|
24
|
+
fix_global_seed(a.seed)
|
|
25
|
+
super()._pre_common(a)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper utility functions for WSI processing
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import h5py
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_white_patch(patch, rgb_std_threshold=7.0, white_ratio=0.7):
|
|
10
|
+
"""
|
|
11
|
+
Check if a patch is mostly white/blank
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
patch: RGB patch (H, W, 3)
|
|
15
|
+
rgb_std_threshold: Threshold for RGB standard deviation
|
|
16
|
+
white_ratio: Ratio threshold for white pixels
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
bool: True if patch is considered white/blank
|
|
20
|
+
"""
|
|
21
|
+
# white: RGB std < 7.0
|
|
22
|
+
rgb_std_pixels = np.std(patch, axis=2) < rgb_std_threshold
|
|
23
|
+
white_pixels = np.sum(rgb_std_pixels)
|
|
24
|
+
total_pixels = patch.shape[0] * patch.shape[1]
|
|
25
|
+
white_ratio_calculated = white_pixels / total_pixels
|
|
26
|
+
# print('whi' if white_ratio_calculated > white_ratio else 'use',
|
|
27
|
+
# 'std{:.3f}'.format(np.sum(rgb_std_pixels)/total_pixels)
|
|
28
|
+
# )
|
|
29
|
+
return white_ratio_calculated > white_ratio
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def cosine_distance(x, y):
|
|
33
|
+
"""
|
|
34
|
+
Calculate cosine distance with exponential weighting
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
x: First vector
|
|
38
|
+
y: Second vector
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
tuple: (distance, weight)
|
|
42
|
+
"""
|
|
43
|
+
distance = np.linalg.norm(x - y)
|
|
44
|
+
weight = np.exp(-distance / distance.mean())
|
|
45
|
+
return distance, weight
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def safe_del(hdf_file, key_path):
|
|
49
|
+
"""
|
|
50
|
+
Safely delete a dataset from HDF5 file if it exists
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
hdf_file: h5py.File object
|
|
54
|
+
key_path: Dataset path to delete
|
|
55
|
+
"""
|
|
56
|
+
if key_path in hdf_file:
|
|
57
|
+
del hdf_file[key_path]
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Iterable, TypeVar, Optional, Union, Any
|
|
3
|
+
|
|
4
|
+
T = TypeVar('T')
|
|
5
|
+
|
|
6
|
+
class StreamlitProgress:
|
|
7
|
+
"""tqdmと同じインターフェースを持つStreamlitのプログレスバー"""
|
|
8
|
+
|
|
9
|
+
def __init__(self, iterable: Optional[Iterable[T]] = None, total: Optional[int] = None,
|
|
10
|
+
desc: str = "", **kwargs):
|
|
11
|
+
self.iterable = iterable
|
|
12
|
+
self.total = total if total is not None else (len(iterable) if iterable is not None and hasattr(iterable, "__len__") else None)
|
|
13
|
+
self.desc = desc
|
|
14
|
+
self.n = 0
|
|
15
|
+
self.kwargs = kwargs
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import streamlit as st
|
|
19
|
+
# 説明テキスト用のコンテナ
|
|
20
|
+
self.text_container = st.empty()
|
|
21
|
+
if desc:
|
|
22
|
+
self.text_container.text(desc)
|
|
23
|
+
# プログレスバー
|
|
24
|
+
self.progress_bar = st.progress(0)
|
|
25
|
+
# 後置テキスト用のコンテナ
|
|
26
|
+
self.postfix_container = st.empty()
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError("streamlitがインストールされていません。")
|
|
29
|
+
|
|
30
|
+
def update(self, n: int = 1) -> None:
|
|
31
|
+
"""進捗を更新する"""
|
|
32
|
+
self.n += n
|
|
33
|
+
if self.total:
|
|
34
|
+
self.progress_bar.progress(min(self.n / self.total, 1.0))
|
|
35
|
+
|
|
36
|
+
def set_description(self, desc: str = None, refresh: bool = True) -> None:
|
|
37
|
+
"""説明テキストを更新する"""
|
|
38
|
+
if desc is not None:
|
|
39
|
+
self.desc = desc
|
|
40
|
+
# self.text_container.text(desc)
|
|
41
|
+
self.text_container.markdown('<p style="font-size:14px; color:gray;">' + desc +'</p>', unsafe_allow_html=True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def set_postfix(self, ordered_dict=None, **kwargs) -> None:
|
|
45
|
+
"""後置テキストを設定する"""
|
|
46
|
+
# ordered_dictとkwargsを組み合わせる
|
|
47
|
+
postfix_dict = {}
|
|
48
|
+
if ordered_dict:
|
|
49
|
+
postfix_dict.update(ordered_dict)
|
|
50
|
+
if kwargs:
|
|
51
|
+
postfix_dict.update(kwargs)
|
|
52
|
+
|
|
53
|
+
if postfix_dict:
|
|
54
|
+
# 辞書を文字列に変換して表示
|
|
55
|
+
postfix_str = ', '.join(f'{k}={v}' for k, v in postfix_dict.items())
|
|
56
|
+
self.postfix_container.text(f"状態: {postfix_str}")
|
|
57
|
+
|
|
58
|
+
def close(self) -> None:
|
|
59
|
+
"""プログレスバーを完了状態にする"""
|
|
60
|
+
if self.total:
|
|
61
|
+
self.progress_bar.progress(1.0)
|
|
62
|
+
self.text_container.empty()
|
|
63
|
+
|
|
64
|
+
def refresh(self):
|
|
65
|
+
""" 不要なので何もしない """
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
def __iter__(self):
|
|
69
|
+
"""イテレータとして使用できるようにする"""
|
|
70
|
+
if self.iterable is None:
|
|
71
|
+
raise ValueError("このプログレスバーはイテレータとして使用できません")
|
|
72
|
+
|
|
73
|
+
for obj in self.iterable:
|
|
74
|
+
yield obj
|
|
75
|
+
self.update(1)
|
|
76
|
+
|
|
77
|
+
self.close()
|
|
78
|
+
|
|
79
|
+
def __enter__(self):
|
|
80
|
+
"""コンテキストマネージャとして使用できるようにする"""
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
84
|
+
"""コンテキスト終了時に呼ばれる"""
|
|
85
|
+
self.close()
|
|
86
|
+
|
|
87
|
+
def tqdm_or_st(iterable: Optional[Iterable[T]] = None,
|
|
88
|
+
backend: str = 'tqdm',
|
|
89
|
+
**kwargs) -> Union['tqdm', StreamlitProgress]:
|
|
90
|
+
"""
|
|
91
|
+
指定されたバックエンドのプログレスバーを返す
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
iterable: 進捗を表示するイテレータ
|
|
95
|
+
backend: バックエンド ("tqdm", "streamlit")
|
|
96
|
+
**kwargs: tqdmやStreamlitProgressに渡す引数
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
tqdm または StreamlitProgress オブジェクト
|
|
100
|
+
"""
|
|
101
|
+
# if backend == "auto":
|
|
102
|
+
# try:
|
|
103
|
+
# import streamlit as st
|
|
104
|
+
# if st._is_running_with_streamlit:
|
|
105
|
+
# backend = "streamlit"
|
|
106
|
+
# else:
|
|
107
|
+
# backend = "tqdm"
|
|
108
|
+
# except (ImportError, AttributeError):
|
|
109
|
+
# backend = "tqdm"
|
|
110
|
+
|
|
111
|
+
assert backend in ['tqdm', 'streamlit']
|
|
112
|
+
|
|
113
|
+
if backend == "tqdm":
|
|
114
|
+
try:
|
|
115
|
+
from tqdm import tqdm
|
|
116
|
+
return tqdm(iterable, **kwargs)
|
|
117
|
+
except ImportError:
|
|
118
|
+
print("tqdmが見つからないため、Streamlitバックエンドを試行します...")
|
|
119
|
+
backend = "streamlit"
|
|
120
|
+
|
|
121
|
+
# Streamlitを使用
|
|
122
|
+
if backend == "streamlit":
|
|
123
|
+
try:
|
|
124
|
+
return StreamlitProgress(iterable, **kwargs)
|
|
125
|
+
except ImportError:
|
|
126
|
+
print("Streamlitが見つかりません。プログレスバーなしで実行します。")
|
|
127
|
+
# フォールバック: 何もしないダミープログレスバー
|
|
128
|
+
try:
|
|
129
|
+
from tqdm import tqdm
|
|
130
|
+
return tqdm(iterable, disable=True, **kwargs)
|
|
131
|
+
except ImportError:
|
|
132
|
+
# tqdmもないので、単なるイテレータを返す
|
|
133
|
+
class DummyTqdm:
|
|
134
|
+
def __init__(self, iterable=None, **kwargs):
|
|
135
|
+
self.iterable = iterable
|
|
136
|
+
def update(self, n=1): pass
|
|
137
|
+
def close(self): pass
|
|
138
|
+
def set_description(self, desc=None, refresh=True): pass
|
|
139
|
+
def set_postfix(self, ordered_dict=None, **kwargs): pass
|
|
140
|
+
def __iter__(self):
|
|
141
|
+
if self.iterable is None: raise ValueError("イテレータがありません")
|
|
142
|
+
for x in self.iterable: yield x
|
|
143
|
+
def __enter__(self): return self
|
|
144
|
+
def __exit__(self, *args, **kwargs): pass
|
|
145
|
+
return DummyTqdm(iterable, **kwargs)
|
|
146
|
+
|
|
147
|
+
# 基本的な使用例
|
|
148
|
+
def basic_example():
|
|
149
|
+
"""基本的な使用例"""
|
|
150
|
+
items = list(range(10))
|
|
151
|
+
|
|
152
|
+
# tqdmと同じ使い方
|
|
153
|
+
for item in tqdm_or_st(items, desc="基本的な例", backend="tqdm"):
|
|
154
|
+
time.sleep(0.1)
|
|
155
|
+
print(f"処理中: {item}")
|
|
156
|
+
|
|
157
|
+
# Streamlitの使用例
|
|
158
|
+
def streamlit_example():
|
|
159
|
+
"""Streamlitでの使用例 (Streamlitアプリ内で実行する必要があります)"""
|
|
160
|
+
import streamlit as st
|
|
161
|
+
|
|
162
|
+
st.title("処理の進捗表示")
|
|
163
|
+
|
|
164
|
+
items = list(range(10))
|
|
165
|
+
results = []
|
|
166
|
+
|
|
167
|
+
# 自動的にStreamlitを検出
|
|
168
|
+
for item in tqdm_or_st(items, desc="処理中...", backend="auto"):
|
|
169
|
+
time.sleep(0.2)
|
|
170
|
+
results.append(item * 2)
|
|
171
|
+
|
|
172
|
+
st.write("結果:", results)
|
|
173
|
+
|
|
174
|
+
# コンテキストマネージャとしての使用例
|
|
175
|
+
def context_manager_example():
|
|
176
|
+
"""コンテキストマネージャとしての使用例"""
|
|
177
|
+
total_steps = 5
|
|
178
|
+
|
|
179
|
+
# with文で使用
|
|
180
|
+
with tqdm_or_st(total=total_steps, desc="手動更新", backend="tqdm") as pbar:
|
|
181
|
+
for i in range(total_steps):
|
|
182
|
+
time.sleep(0.2)
|
|
183
|
+
|
|
184
|
+
# 説明を更新
|
|
185
|
+
if i == 2:
|
|
186
|
+
pbar.set_description(f"ステップ {i+1}/{total_steps}")
|
|
187
|
+
|
|
188
|
+
# 追加情報を表示
|
|
189
|
+
pbar.set_postfix(progress=f"{(i+1)/total_steps:.0%}")
|
|
190
|
+
|
|
191
|
+
# 進捗を更新
|
|
192
|
+
pbar.update(1)
|
|
193
|
+
|
|
194
|
+
# テスト用のメイン関数
|
|
195
|
+
def main():
|
|
196
|
+
print("基本的な使用例:")
|
|
197
|
+
basic_example()
|
|
198
|
+
|
|
199
|
+
print("\nコンテキストマネージャとしての使用例:")
|
|
200
|
+
context_manager_example()
|
|
201
|
+
|
|
202
|
+
print("\nStreamlitの例はStreamlitアプリ内で実行してください")
|
|
203
|
+
# streamlit_example() # Streamlitアプリ内でのみ実行可能
|
|
204
|
+
|
|
205
|
+
if __name__ == "__main__":
|
|
206
|
+
main()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
__GLOBAL_SEED = 42
|
|
6
|
+
|
|
7
|
+
def get_global_seed():
|
|
8
|
+
return __GLOBAL_SEED
|
|
9
|
+
|
|
10
|
+
def fix_global_seed(seed=None):
|
|
11
|
+
if seed is None:
|
|
12
|
+
seed = get_global_seed()
|
|
13
|
+
global __GLOBAL_SEED
|
|
14
|
+
random.seed(seed)
|
|
15
|
+
np.random.seed(seed)
|
|
16
|
+
torch.manual_seed(seed)
|
|
17
|
+
torch.random.manual_seed(seed)
|
|
18
|
+
torch.cuda.manual_seed(seed)
|
|
19
|
+
torch.backends.cudnn.deterministic = True
|
|
20
|
+
torch.use_deterministic_algorithms = True
|
|
21
|
+
__GLOBAL_SEED = seed
|
wsi_toolbox/utils/st.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import streamlit as st
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
|
|
4
|
+
HORIZONTAL_STYLE = """
|
|
5
|
+
<style class="hide-element">
|
|
6
|
+
/* Hides the style container and removes the extra spacing */
|
|
7
|
+
.element-container:has(.hide-element) {
|
|
8
|
+
display: none;
|
|
9
|
+
}
|
|
10
|
+
/*
|
|
11
|
+
The selector for >.element-container is necessary to avoid selecting the whole
|
|
12
|
+
body of the streamlit app, which is also a stVerticalBlock.
|
|
13
|
+
*/
|
|
14
|
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) {
|
|
15
|
+
display: flex;
|
|
16
|
+
flex-direction: row !important;
|
|
17
|
+
flex-wrap: wrap;
|
|
18
|
+
gap: 0.5rem;
|
|
19
|
+
align-items: baseline;
|
|
20
|
+
}
|
|
21
|
+
/* Buttons and their parent container all have a width of 704px, which we need to override */
|
|
22
|
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div {
|
|
23
|
+
width: max-content !important;
|
|
24
|
+
}
|
|
25
|
+
/* Selectbox container */
|
|
26
|
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[data-testid="stSelectbox"] {
|
|
27
|
+
display: flex !important;
|
|
28
|
+
flex-direction: row !important;
|
|
29
|
+
align-items: center !important;
|
|
30
|
+
gap: 0.5rem !important;
|
|
31
|
+
}
|
|
32
|
+
/* Selectbox label */
|
|
33
|
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[data-testid="stWidgetLabel"] {
|
|
34
|
+
margin-bottom: 0 !important;
|
|
35
|
+
padding-right: 0.5rem !important;
|
|
36
|
+
}
|
|
37
|
+
/* Selectbox input container */
|
|
38
|
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[data-baseweb="select"] {
|
|
39
|
+
min-width: 120px !important;
|
|
40
|
+
}
|
|
41
|
+
/* Selectbox dropdown */
|
|
42
|
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div[role="listbox"] {
|
|
43
|
+
min-width: 120px !important;
|
|
44
|
+
}
|
|
45
|
+
</style>
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def st_horizontal():
|
|
50
|
+
st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True)
|
|
51
|
+
with st.container():
|
|
52
|
+
st.markdown('<span class="hide-element horizontal-marker"></span>', unsafe_allow_html=True)
|
|
53
|
+
yield
|