linkml-store 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- linkml_store/__init__.py +7 -0
- linkml_store/api/__init__.py +8 -0
- linkml_store/api/client.py +414 -0
- linkml_store/api/collection.py +1280 -0
- linkml_store/api/config.py +187 -0
- linkml_store/api/database.py +862 -0
- linkml_store/api/queries.py +69 -0
- linkml_store/api/stores/__init__.py +0 -0
- linkml_store/api/stores/chromadb/__init__.py +7 -0
- linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
- linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
- linkml_store/api/stores/dremio/__init__.py +10 -0
- linkml_store/api/stores/dremio/dremio_collection.py +555 -0
- linkml_store/api/stores/dremio/dremio_database.py +1052 -0
- linkml_store/api/stores/dremio/mappings.py +105 -0
- linkml_store/api/stores/dremio_rest/__init__.py +11 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
- linkml_store/api/stores/duckdb/__init__.py +16 -0
- linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
- linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
- linkml_store/api/stores/duckdb/mappings.py +8 -0
- linkml_store/api/stores/filesystem/__init__.py +15 -0
- linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
- linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
- linkml_store/api/stores/hdf5/__init__.py +7 -0
- linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
- linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
- linkml_store/api/stores/ibis/__init__.py +5 -0
- linkml_store/api/stores/ibis/ibis_collection.py +488 -0
- linkml_store/api/stores/ibis/ibis_database.py +328 -0
- linkml_store/api/stores/mongodb/__init__.py +25 -0
- linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
- linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
- linkml_store/api/stores/neo4j/__init__.py +0 -0
- linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
- linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
- linkml_store/api/stores/solr/__init__.py +3 -0
- linkml_store/api/stores/solr/solr_collection.py +224 -0
- linkml_store/api/stores/solr/solr_database.py +83 -0
- linkml_store/api/stores/solr/solr_utils.py +0 -0
- linkml_store/api/types.py +4 -0
- linkml_store/cli.py +1147 -0
- linkml_store/constants.py +7 -0
- linkml_store/graphs/__init__.py +0 -0
- linkml_store/graphs/graph_map.py +24 -0
- linkml_store/index/__init__.py +53 -0
- linkml_store/index/implementations/__init__.py +0 -0
- linkml_store/index/implementations/llm_indexer.py +174 -0
- linkml_store/index/implementations/simple_indexer.py +43 -0
- linkml_store/index/indexer.py +211 -0
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/evaluation.py +195 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/llm_inference_engine.py +154 -0
- linkml_store/inference/implementations/rag_inference_engine.py +276 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
- linkml_store/inference/inference_config.py +66 -0
- linkml_store/inference/inference_engine.py +209 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/plotting/__init__.py +5 -0
- linkml_store/plotting/cli.py +826 -0
- linkml_store/plotting/dimensionality_reduction.py +453 -0
- linkml_store/plotting/embedding_plot.py +489 -0
- linkml_store/plotting/facet_chart.py +73 -0
- linkml_store/plotting/heatmap.py +383 -0
- linkml_store/utils/__init__.py +0 -0
- linkml_store/utils/change_utils.py +17 -0
- linkml_store/utils/dat_parser.py +95 -0
- linkml_store/utils/embedding_matcher.py +424 -0
- linkml_store/utils/embedding_utils.py +299 -0
- linkml_store/utils/enrichment_analyzer.py +217 -0
- linkml_store/utils/file_utils.py +37 -0
- linkml_store/utils/format_utils.py +550 -0
- linkml_store/utils/io.py +38 -0
- linkml_store/utils/llm_utils.py +122 -0
- linkml_store/utils/mongodb_utils.py +145 -0
- linkml_store/utils/neo4j_utils.py +42 -0
- linkml_store/utils/object_utils.py +190 -0
- linkml_store/utils/pandas_utils.py +93 -0
- linkml_store/utils/patch_utils.py +126 -0
- linkml_store/utils/query_utils.py +89 -0
- linkml_store/utils/schema_utils.py +23 -0
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/sql_utils.py +177 -0
- linkml_store/utils/stats_utils.py +53 -0
- linkml_store/utils/vector_utils.py +158 -0
- linkml_store/webapi/__init__.py +0 -0
- linkml_store/webapi/html/__init__.py +3 -0
- linkml_store/webapi/html/base.html.j2 +24 -0
- linkml_store/webapi/html/collection_details.html.j2 +15 -0
- linkml_store/webapi/html/database_details.html.j2 +16 -0
- linkml_store/webapi/html/databases.html.j2 +14 -0
- linkml_store/webapi/html/generic.html.j2 +43 -0
- linkml_store/webapi/main.py +855 -0
- linkml_store-0.3.0.dist-info/METADATA +226 -0
- linkml_store-0.3.0.dist-info/RECORD +101 -0
- linkml_store-0.3.0.dist-info/WHEEL +4 -0
- linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
- linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
"""Plotting utilities for embedding visualizations."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Dict, List, Optional, Union, Literal, Tuple
|
|
5
|
+
import numpy as np
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
import plotly.graph_objects as go
|
|
8
|
+
import plotly.express as px
|
|
9
|
+
from plotly.subplots import make_subplots
|
|
10
|
+
|
|
11
|
+
from linkml_store.utils.embedding_utils import EmbeddingData
|
|
12
|
+
from linkml_store.plotting.dimensionality_reduction import ReductionResult
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EmbeddingPlotConfig:
|
|
19
|
+
"""Configuration for embedding plots."""
|
|
20
|
+
|
|
21
|
+
# Visual encoding
|
|
22
|
+
color_field: Optional[str] = None
|
|
23
|
+
shape_field: Optional[str] = "collection"
|
|
24
|
+
size_field: Optional[str] = None
|
|
25
|
+
hover_fields: List[str] = field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
# Plot styling
|
|
28
|
+
title: str = "Embedding Visualization"
|
|
29
|
+
width: int = 800
|
|
30
|
+
height: int = 600
|
|
31
|
+
point_size: int = 8
|
|
32
|
+
opacity: float = 0.7
|
|
33
|
+
|
|
34
|
+
# Color schemes
|
|
35
|
+
color_discrete_map: Optional[Dict] = None
|
|
36
|
+
color_continuous_scale: str = "Viridis"
|
|
37
|
+
|
|
38
|
+
# Shape mapping
|
|
39
|
+
shape_map: Optional[Dict] = None
|
|
40
|
+
marker_symbols: List[str] = field(default_factory=lambda: [
|
|
41
|
+
"circle", "square", "diamond", "cross", "x",
|
|
42
|
+
"triangle-up", "triangle-down", "pentagon", "hexagon", "star"
|
|
43
|
+
])
|
|
44
|
+
|
|
45
|
+
# Display options
|
|
46
|
+
show_legend: bool = True
|
|
47
|
+
show_axes: bool = True
|
|
48
|
+
dark_mode: bool = False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def plot_embeddings(
|
|
52
|
+
embedding_data: EmbeddingData,
|
|
53
|
+
reduction_result: ReductionResult,
|
|
54
|
+
config: Optional[EmbeddingPlotConfig] = None,
|
|
55
|
+
output_file: Optional[str] = None
|
|
56
|
+
) -> go.Figure:
|
|
57
|
+
"""
|
|
58
|
+
Create interactive plot of embeddings.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
embedding_data: Embedding data with metadata
|
|
62
|
+
reduction_result: Dimensionality reduction results
|
|
63
|
+
config: Plot configuration
|
|
64
|
+
output_file: Optional path to save HTML file
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Plotly figure object
|
|
68
|
+
"""
|
|
69
|
+
if config is None:
|
|
70
|
+
config = EmbeddingPlotConfig()
|
|
71
|
+
|
|
72
|
+
# Prepare data for plotting
|
|
73
|
+
plot_data = _prepare_plot_data(embedding_data, reduction_result, config)
|
|
74
|
+
|
|
75
|
+
# Create figure
|
|
76
|
+
fig = _create_scatter_plot(plot_data, config)
|
|
77
|
+
|
|
78
|
+
# Apply styling
|
|
79
|
+
fig = _style_figure(fig, config, reduction_result)
|
|
80
|
+
|
|
81
|
+
# Save if requested
|
|
82
|
+
if output_file:
|
|
83
|
+
fig.write_html(output_file)
|
|
84
|
+
logger.info(f"Saved plot to {output_file}")
|
|
85
|
+
|
|
86
|
+
return fig
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _prepare_plot_data(
|
|
90
|
+
embedding_data: EmbeddingData,
|
|
91
|
+
reduction_result: ReductionResult,
|
|
92
|
+
config: EmbeddingPlotConfig
|
|
93
|
+
) -> Dict:
|
|
94
|
+
"""Prepare data for plotting."""
|
|
95
|
+
data = {
|
|
96
|
+
"x": reduction_result.coordinates[:, 0],
|
|
97
|
+
"y": reduction_result.coordinates[:, 1],
|
|
98
|
+
"ids": embedding_data.object_ids,
|
|
99
|
+
"collection": embedding_data.collection_names,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
# Add metadata fields
|
|
103
|
+
for field in config.hover_fields:
|
|
104
|
+
values = embedding_data.get_metadata_values(field)
|
|
105
|
+
if values:
|
|
106
|
+
data[field] = values
|
|
107
|
+
|
|
108
|
+
# Add color field
|
|
109
|
+
if config.color_field:
|
|
110
|
+
if config.color_field == "collection":
|
|
111
|
+
data["color"] = embedding_data.collection_names
|
|
112
|
+
else:
|
|
113
|
+
data["color"] = embedding_data.get_metadata_values(config.color_field)
|
|
114
|
+
|
|
115
|
+
# Add shape field
|
|
116
|
+
if config.shape_field:
|
|
117
|
+
if config.shape_field == "collection":
|
|
118
|
+
data["shape"] = embedding_data.collection_names
|
|
119
|
+
else:
|
|
120
|
+
data["shape"] = embedding_data.get_metadata_values(config.shape_field)
|
|
121
|
+
|
|
122
|
+
# Add size field
|
|
123
|
+
if config.size_field:
|
|
124
|
+
data["size"] = embedding_data.get_metadata_values(config.size_field)
|
|
125
|
+
else:
|
|
126
|
+
data["size"] = [config.point_size] * embedding_data.n_samples
|
|
127
|
+
|
|
128
|
+
return data
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _create_scatter_plot(
|
|
132
|
+
plot_data: Dict,
|
|
133
|
+
config: EmbeddingPlotConfig
|
|
134
|
+
) -> go.Figure:
|
|
135
|
+
"""Create the scatter plot."""
|
|
136
|
+
fig = go.Figure()
|
|
137
|
+
|
|
138
|
+
# Determine if we need to create separate traces for different shapes
|
|
139
|
+
if "shape" in plot_data and config.shape_field:
|
|
140
|
+
unique_shapes = list(set(plot_data["shape"]))
|
|
141
|
+
|
|
142
|
+
# Create shape mapping
|
|
143
|
+
if config.shape_map:
|
|
144
|
+
shape_map = config.shape_map
|
|
145
|
+
else:
|
|
146
|
+
shape_map = {
|
|
147
|
+
shape: config.marker_symbols[i % len(config.marker_symbols)]
|
|
148
|
+
for i, shape in enumerate(unique_shapes)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
# Create trace for each shape category
|
|
152
|
+
for shape_value in unique_shapes:
|
|
153
|
+
mask = [s == shape_value for s in plot_data["shape"]]
|
|
154
|
+
trace_data = {
|
|
155
|
+
key: [v for i, v in enumerate(values) if mask[i]]
|
|
156
|
+
for key, values in plot_data.items()
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
# Prepare hover text
|
|
160
|
+
hover_text = _create_hover_text(trace_data, config)
|
|
161
|
+
|
|
162
|
+
# Determine color
|
|
163
|
+
if "color" in trace_data:
|
|
164
|
+
color_values = trace_data["color"]
|
|
165
|
+
# Check if categorical or continuous
|
|
166
|
+
if color_values and all(isinstance(v, (int, float)) for v in color_values if v is not None):
|
|
167
|
+
marker_color = color_values
|
|
168
|
+
marker_colorscale = config.color_continuous_scale
|
|
169
|
+
else:
|
|
170
|
+
# Categorical - don't use color map for plotly, let it auto-assign
|
|
171
|
+
# Just use the raw categorical values
|
|
172
|
+
marker_color = color_values
|
|
173
|
+
marker_colorscale = None
|
|
174
|
+
else:
|
|
175
|
+
marker_color = None
|
|
176
|
+
marker_colorscale = None
|
|
177
|
+
|
|
178
|
+
# When using separate traces, we can't use categorical colors directly
|
|
179
|
+
# Use a single color per trace instead
|
|
180
|
+
if marker_colorscale is None and marker_color is not None:
|
|
181
|
+
# For categorical, just use the trace name for automatic coloring
|
|
182
|
+
marker_dict = dict(
|
|
183
|
+
symbol=shape_map.get(shape_value, "circle"),
|
|
184
|
+
size=trace_data.get("size", config.point_size),
|
|
185
|
+
opacity=config.opacity,
|
|
186
|
+
line=dict(width=0.5, color="white")
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
# For continuous colors
|
|
190
|
+
marker_dict = dict(
|
|
191
|
+
symbol=shape_map.get(shape_value, "circle"),
|
|
192
|
+
size=trace_data.get("size", config.point_size),
|
|
193
|
+
color=marker_color,
|
|
194
|
+
colorscale=marker_colorscale,
|
|
195
|
+
opacity=config.opacity,
|
|
196
|
+
line=dict(width=0.5, color="white")
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
trace = go.Scatter(
|
|
200
|
+
x=trace_data["x"],
|
|
201
|
+
y=trace_data["y"],
|
|
202
|
+
mode="markers",
|
|
203
|
+
name=str(shape_value),
|
|
204
|
+
text=hover_text,
|
|
205
|
+
hovertemplate="%{text}<extra></extra>",
|
|
206
|
+
marker=marker_dict
|
|
207
|
+
)
|
|
208
|
+
fig.add_trace(trace)
|
|
209
|
+
else:
|
|
210
|
+
# Single trace for all points
|
|
211
|
+
hover_text = _create_hover_text(plot_data, config)
|
|
212
|
+
|
|
213
|
+
# Handle colors
|
|
214
|
+
if "color" in plot_data:
|
|
215
|
+
color_values = plot_data["color"]
|
|
216
|
+
if all(isinstance(v, (int, float)) for v in color_values if v is not None):
|
|
217
|
+
marker_color = color_values
|
|
218
|
+
marker_colorscale = config.color_continuous_scale
|
|
219
|
+
showscale = True
|
|
220
|
+
else:
|
|
221
|
+
if config.color_discrete_map:
|
|
222
|
+
marker_color = [config.color_discrete_map.get(c, c) for c in color_values]
|
|
223
|
+
else:
|
|
224
|
+
marker_color = color_values
|
|
225
|
+
marker_colorscale = None
|
|
226
|
+
showscale = False
|
|
227
|
+
else:
|
|
228
|
+
marker_color = "blue"
|
|
229
|
+
marker_colorscale = None
|
|
230
|
+
showscale = False
|
|
231
|
+
|
|
232
|
+
trace = go.Scatter(
|
|
233
|
+
x=plot_data["x"],
|
|
234
|
+
y=plot_data["y"],
|
|
235
|
+
mode="markers",
|
|
236
|
+
text=hover_text,
|
|
237
|
+
hovertemplate="%{text}<extra></extra>",
|
|
238
|
+
marker=dict(
|
|
239
|
+
size=plot_data.get("size", config.point_size),
|
|
240
|
+
color=marker_color,
|
|
241
|
+
colorscale=marker_colorscale,
|
|
242
|
+
showscale=showscale,
|
|
243
|
+
opacity=config.opacity,
|
|
244
|
+
line=dict(width=0.5, color="white")
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
fig.add_trace(trace)
|
|
248
|
+
|
|
249
|
+
return fig
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _create_hover_text(
|
|
253
|
+
data: Dict,
|
|
254
|
+
config: EmbeddingPlotConfig
|
|
255
|
+
) -> List[str]:
|
|
256
|
+
"""Create hover text for each point."""
|
|
257
|
+
hover_texts = []
|
|
258
|
+
n_points = len(data["x"])
|
|
259
|
+
|
|
260
|
+
for i in range(n_points):
|
|
261
|
+
lines = []
|
|
262
|
+
|
|
263
|
+
# Add ID
|
|
264
|
+
if "ids" in data:
|
|
265
|
+
lines.append(f"<b>ID:</b> {data['ids'][i]}")
|
|
266
|
+
|
|
267
|
+
# Add collection
|
|
268
|
+
if "collection" in data:
|
|
269
|
+
lines.append(f"<b>Collection:</b> {data['collection'][i]}")
|
|
270
|
+
|
|
271
|
+
# Add hover fields
|
|
272
|
+
for field in config.hover_fields:
|
|
273
|
+
if field in data and data[field][i] is not None:
|
|
274
|
+
lines.append(f"<b>{field}:</b> {data[field][i]}")
|
|
275
|
+
|
|
276
|
+
# Add coordinates
|
|
277
|
+
lines.append(f"<b>Coordinates:</b> ({data['x'][i]:.3f}, {data['y'][i]:.3f})")
|
|
278
|
+
|
|
279
|
+
hover_texts.append("<br>".join(lines))
|
|
280
|
+
|
|
281
|
+
return hover_texts
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _style_figure(
|
|
285
|
+
fig: go.Figure,
|
|
286
|
+
config: EmbeddingPlotConfig,
|
|
287
|
+
reduction_result: ReductionResult
|
|
288
|
+
) -> go.Figure:
|
|
289
|
+
"""Apply styling to the figure."""
|
|
290
|
+
# Create subtitle with method info
|
|
291
|
+
subtitle = f"Method: {reduction_result.method.upper()}"
|
|
292
|
+
if reduction_result.explained_variance:
|
|
293
|
+
subtitle += f" | Explained variance: {reduction_result.explained_variance:.2%}"
|
|
294
|
+
|
|
295
|
+
full_title = f"{config.title}<br><sub>{subtitle}</sub>"
|
|
296
|
+
|
|
297
|
+
# Update layout
|
|
298
|
+
layout_updates = dict(
|
|
299
|
+
title=full_title,
|
|
300
|
+
width=config.width,
|
|
301
|
+
height=config.height,
|
|
302
|
+
showlegend=config.show_legend,
|
|
303
|
+
hovermode="closest",
|
|
304
|
+
xaxis=dict(
|
|
305
|
+
title="Component 1",
|
|
306
|
+
showgrid=True,
|
|
307
|
+
gridwidth=1,
|
|
308
|
+
gridcolor="LightGray" if not config.dark_mode else "DarkGray",
|
|
309
|
+
showline=config.show_axes,
|
|
310
|
+
zeroline=True,
|
|
311
|
+
),
|
|
312
|
+
yaxis=dict(
|
|
313
|
+
title="Component 2",
|
|
314
|
+
showgrid=True,
|
|
315
|
+
gridwidth=1,
|
|
316
|
+
gridcolor="LightGray" if not config.dark_mode else "DarkGray",
|
|
317
|
+
showline=config.show_axes,
|
|
318
|
+
zeroline=True,
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if config.dark_mode:
|
|
323
|
+
layout_updates.update(
|
|
324
|
+
template="plotly_dark",
|
|
325
|
+
paper_bgcolor="black",
|
|
326
|
+
plot_bgcolor="black",
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
layout_updates.update(
|
|
330
|
+
template="plotly_white",
|
|
331
|
+
paper_bgcolor="white",
|
|
332
|
+
plot_bgcolor="white",
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
fig.update_layout(**layout_updates)
|
|
336
|
+
|
|
337
|
+
return fig
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def plot_embeddings_comparison(
|
|
341
|
+
embedding_datasets: Dict[str, Tuple[EmbeddingData, ReductionResult]],
|
|
342
|
+
config: Optional[EmbeddingPlotConfig] = None,
|
|
343
|
+
output_file: Optional[str] = None
|
|
344
|
+
) -> go.Figure:
|
|
345
|
+
"""
|
|
346
|
+
Create comparison plot of multiple embedding datasets.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
embedding_datasets: Dictionary of (name -> (embedding_data, reduction_result))
|
|
350
|
+
config: Plot configuration
|
|
351
|
+
output_file: Optional path to save HTML file
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
Plotly figure with subplots
|
|
355
|
+
"""
|
|
356
|
+
if config is None:
|
|
357
|
+
config = EmbeddingPlotConfig()
|
|
358
|
+
|
|
359
|
+
n_datasets = len(embedding_datasets)
|
|
360
|
+
n_cols = min(n_datasets, 3)
|
|
361
|
+
n_rows = (n_datasets + n_cols - 1) // n_cols
|
|
362
|
+
|
|
363
|
+
# Create subplots
|
|
364
|
+
subplot_titles = list(embedding_datasets.keys())
|
|
365
|
+
fig = make_subplots(
|
|
366
|
+
rows=n_rows,
|
|
367
|
+
cols=n_cols,
|
|
368
|
+
subplot_titles=subplot_titles,
|
|
369
|
+
horizontal_spacing=0.1,
|
|
370
|
+
vertical_spacing=0.15
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Add each dataset as a subplot
|
|
374
|
+
for idx, (name, (emb_data, red_result)) in enumerate(embedding_datasets.items()):
|
|
375
|
+
row = idx // n_cols + 1
|
|
376
|
+
col = idx % n_cols + 1
|
|
377
|
+
|
|
378
|
+
plot_data = _prepare_plot_data(emb_data, red_result, config)
|
|
379
|
+
|
|
380
|
+
# Create scatter trace
|
|
381
|
+
hover_text = _create_hover_text(plot_data, config)
|
|
382
|
+
|
|
383
|
+
trace = go.Scatter(
|
|
384
|
+
x=plot_data["x"],
|
|
385
|
+
y=plot_data["y"],
|
|
386
|
+
mode="markers",
|
|
387
|
+
name=name,
|
|
388
|
+
text=hover_text,
|
|
389
|
+
hovertemplate="%{text}<extra></extra>",
|
|
390
|
+
marker=dict(
|
|
391
|
+
size=plot_data.get("size", config.point_size),
|
|
392
|
+
color=plot_data.get("color", "blue"),
|
|
393
|
+
opacity=config.opacity,
|
|
394
|
+
),
|
|
395
|
+
showlegend=(idx == 0) # Only show legend for first subplot
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
fig.add_trace(trace, row=row, col=col)
|
|
399
|
+
|
|
400
|
+
# Update axes labels
|
|
401
|
+
fig.update_xaxes(title_text="Component 1", row=row, col=col)
|
|
402
|
+
fig.update_yaxes(title_text="Component 2", row=row, col=col)
|
|
403
|
+
|
|
404
|
+
# Update overall layout
|
|
405
|
+
fig.update_layout(
|
|
406
|
+
title=config.title,
|
|
407
|
+
width=config.width * n_cols // 2,
|
|
408
|
+
height=config.height * n_rows // 2,
|
|
409
|
+
showlegend=config.show_legend,
|
|
410
|
+
hovermode="closest"
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Save if requested
|
|
414
|
+
if output_file:
|
|
415
|
+
fig.write_html(output_file)
|
|
416
|
+
logger.info(f"Saved comparison plot to {output_file}")
|
|
417
|
+
|
|
418
|
+
return fig
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def plot_embedding_clusters(
|
|
422
|
+
embedding_data: EmbeddingData,
|
|
423
|
+
reduction_result: ReductionResult,
|
|
424
|
+
cluster_labels: np.ndarray,
|
|
425
|
+
config: Optional[EmbeddingPlotConfig] = None,
|
|
426
|
+
output_file: Optional[str] = None
|
|
427
|
+
) -> go.Figure:
|
|
428
|
+
"""
|
|
429
|
+
Plot embeddings with cluster assignments.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
embedding_data: Embedding data
|
|
433
|
+
reduction_result: Dimensionality reduction results
|
|
434
|
+
cluster_labels: Cluster labels for each point
|
|
435
|
+
config: Plot configuration
|
|
436
|
+
output_file: Optional path to save HTML file
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
Plotly figure with clusters
|
|
440
|
+
"""
|
|
441
|
+
if config is None:
|
|
442
|
+
config = EmbeddingPlotConfig()
|
|
443
|
+
|
|
444
|
+
# Override color field to use clusters
|
|
445
|
+
config.color_field = "cluster"
|
|
446
|
+
|
|
447
|
+
# Add cluster labels to metadata
|
|
448
|
+
for i, label in enumerate(cluster_labels):
|
|
449
|
+
if i < len(embedding_data.metadata):
|
|
450
|
+
embedding_data.metadata[i]["cluster"] = f"Cluster {label}"
|
|
451
|
+
|
|
452
|
+
# Create plot
|
|
453
|
+
fig = plot_embeddings(embedding_data, reduction_result, config)
|
|
454
|
+
|
|
455
|
+
# Add cluster centroids if possible
|
|
456
|
+
unique_labels = np.unique(cluster_labels)
|
|
457
|
+
if len(unique_labels) < 50: # Only show centroids for reasonable number of clusters
|
|
458
|
+
for label in unique_labels:
|
|
459
|
+
mask = cluster_labels == label
|
|
460
|
+
centroid_x = np.mean(reduction_result.coordinates[mask, 0])
|
|
461
|
+
centroid_y = np.mean(reduction_result.coordinates[mask, 1])
|
|
462
|
+
|
|
463
|
+
fig.add_trace(go.Scatter(
|
|
464
|
+
x=[centroid_x],
|
|
465
|
+
y=[centroid_y],
|
|
466
|
+
mode="markers+text",
|
|
467
|
+
marker=dict(
|
|
468
|
+
size=15,
|
|
469
|
+
symbol="star",
|
|
470
|
+
color="black",
|
|
471
|
+
line=dict(width=2, color="white")
|
|
472
|
+
),
|
|
473
|
+
text=f"C{label}",
|
|
474
|
+
textposition="top center",
|
|
475
|
+
showlegend=False,
|
|
476
|
+
hovertemplate=f"Cluster {label} centroid<extra></extra>"
|
|
477
|
+
))
|
|
478
|
+
|
|
479
|
+
# Update title
|
|
480
|
+
fig.update_layout(
|
|
481
|
+
title=f"{config.title}<br><sub>Clusters: {len(unique_labels)}</sub>"
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Save if requested
|
|
485
|
+
if output_file:
|
|
486
|
+
fig.write_html(output_file)
|
|
487
|
+
logger.info(f"Saved cluster plot to {output_file}")
|
|
488
|
+
|
|
489
|
+
return fig
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
def create_faceted_horizontal_barchart(data, output_path,
|
|
7
|
+
figsize_per_facet=(8, 4)):
|
|
8
|
+
"""
|
|
9
|
+
Create horizontal bar charts for each facet in the data.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
data: Dictionary where keys are facet names and values are lists of items
|
|
13
|
+
output_path: Path to save the PNG file
|
|
14
|
+
figsize_per_facet: (width, height) for each subplot
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# Calculate figure size based on number of facets
|
|
18
|
+
new_data = {}
|
|
19
|
+
for k, fm in data.items():
|
|
20
|
+
new_fm = {str(k)[0:50]: v for k, v in list(fm.items())[0:10] if isinstance(k, (str, int, float))}
|
|
21
|
+
new_data[k] = new_fm
|
|
22
|
+
data = new_data
|
|
23
|
+
n_facets = len(data)
|
|
24
|
+
fig_width = figsize_per_facet[0]
|
|
25
|
+
fig_height = figsize_per_facet[1] * n_facets
|
|
26
|
+
|
|
27
|
+
# Create subplots
|
|
28
|
+
fig, axes = plt.subplots(n_facets, 1, figsize=(fig_width, fig_height))
|
|
29
|
+
|
|
30
|
+
# Handle case where there's only one facet
|
|
31
|
+
if n_facets == 1:
|
|
32
|
+
axes = [axes]
|
|
33
|
+
|
|
34
|
+
# Process each facet
|
|
35
|
+
for i, (facet_name, items) in enumerate(data.items()):
|
|
36
|
+
ax = axes[i]
|
|
37
|
+
|
|
38
|
+
# Count occurrences of each item
|
|
39
|
+
counts = Counter(items)
|
|
40
|
+
|
|
41
|
+
# Sort by count (descending) for better visualization
|
|
42
|
+
sorted_items = sorted(counts.items(), key=lambda x: x[1], reverse=True)
|
|
43
|
+
|
|
44
|
+
# Separate labels and values
|
|
45
|
+
labels = [item[0] for item in sorted_items]
|
|
46
|
+
values = [item[1] for item in sorted_items]
|
|
47
|
+
|
|
48
|
+
# Create horizontal bar chart
|
|
49
|
+
y_pos = np.arange(len(labels))
|
|
50
|
+
bars = ax.barh(y_pos, values, alpha=0.7)
|
|
51
|
+
|
|
52
|
+
# Customize the subplot
|
|
53
|
+
ax.set_yticks(y_pos)
|
|
54
|
+
ax.set_yticklabels(labels)
|
|
55
|
+
ax.set_xlabel('Count')
|
|
56
|
+
ax.set_title(f'{facet_name.replace("_", " ").title()}')
|
|
57
|
+
ax.grid(axis='x', alpha=0.3)
|
|
58
|
+
|
|
59
|
+
# Add value labels on bars
|
|
60
|
+
for j, (bar, value) in enumerate(zip(bars, values)):
|
|
61
|
+
ax.text(bar.get_width() + 0.01 * max(values), bar.get_y() + bar.get_height()/2,
|
|
62
|
+
str(value), ha='left', va='center', fontsize=9)
|
|
63
|
+
|
|
64
|
+
# Invert y-axis so highest counts are at top
|
|
65
|
+
ax.invert_yaxis()
|
|
66
|
+
|
|
67
|
+
# Adjust layout to prevent overlap
|
|
68
|
+
plt.tight_layout()
|
|
69
|
+
|
|
70
|
+
# Save the plot
|
|
71
|
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
72
|
+
# print(f"Faceted bar chart saved to: {output_path}")
|
|
73
|
+
|