linkml-store 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- linkml_store/__init__.py +7 -0
- linkml_store/api/__init__.py +8 -0
- linkml_store/api/client.py +414 -0
- linkml_store/api/collection.py +1280 -0
- linkml_store/api/config.py +187 -0
- linkml_store/api/database.py +862 -0
- linkml_store/api/queries.py +69 -0
- linkml_store/api/stores/__init__.py +0 -0
- linkml_store/api/stores/chromadb/__init__.py +7 -0
- linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
- linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
- linkml_store/api/stores/dremio/__init__.py +10 -0
- linkml_store/api/stores/dremio/dremio_collection.py +555 -0
- linkml_store/api/stores/dremio/dremio_database.py +1052 -0
- linkml_store/api/stores/dremio/mappings.py +105 -0
- linkml_store/api/stores/dremio_rest/__init__.py +11 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
- linkml_store/api/stores/duckdb/__init__.py +16 -0
- linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
- linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
- linkml_store/api/stores/duckdb/mappings.py +8 -0
- linkml_store/api/stores/filesystem/__init__.py +15 -0
- linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
- linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
- linkml_store/api/stores/hdf5/__init__.py +7 -0
- linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
- linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
- linkml_store/api/stores/ibis/__init__.py +5 -0
- linkml_store/api/stores/ibis/ibis_collection.py +488 -0
- linkml_store/api/stores/ibis/ibis_database.py +328 -0
- linkml_store/api/stores/mongodb/__init__.py +25 -0
- linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
- linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
- linkml_store/api/stores/neo4j/__init__.py +0 -0
- linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
- linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
- linkml_store/api/stores/solr/__init__.py +3 -0
- linkml_store/api/stores/solr/solr_collection.py +224 -0
- linkml_store/api/stores/solr/solr_database.py +83 -0
- linkml_store/api/stores/solr/solr_utils.py +0 -0
- linkml_store/api/types.py +4 -0
- linkml_store/cli.py +1147 -0
- linkml_store/constants.py +7 -0
- linkml_store/graphs/__init__.py +0 -0
- linkml_store/graphs/graph_map.py +24 -0
- linkml_store/index/__init__.py +53 -0
- linkml_store/index/implementations/__init__.py +0 -0
- linkml_store/index/implementations/llm_indexer.py +174 -0
- linkml_store/index/implementations/simple_indexer.py +43 -0
- linkml_store/index/indexer.py +211 -0
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/evaluation.py +195 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/llm_inference_engine.py +154 -0
- linkml_store/inference/implementations/rag_inference_engine.py +276 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
- linkml_store/inference/inference_config.py +66 -0
- linkml_store/inference/inference_engine.py +209 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/plotting/__init__.py +5 -0
- linkml_store/plotting/cli.py +826 -0
- linkml_store/plotting/dimensionality_reduction.py +453 -0
- linkml_store/plotting/embedding_plot.py +489 -0
- linkml_store/plotting/facet_chart.py +73 -0
- linkml_store/plotting/heatmap.py +383 -0
- linkml_store/utils/__init__.py +0 -0
- linkml_store/utils/change_utils.py +17 -0
- linkml_store/utils/dat_parser.py +95 -0
- linkml_store/utils/embedding_matcher.py +424 -0
- linkml_store/utils/embedding_utils.py +299 -0
- linkml_store/utils/enrichment_analyzer.py +217 -0
- linkml_store/utils/file_utils.py +37 -0
- linkml_store/utils/format_utils.py +550 -0
- linkml_store/utils/io.py +38 -0
- linkml_store/utils/llm_utils.py +122 -0
- linkml_store/utils/mongodb_utils.py +145 -0
- linkml_store/utils/neo4j_utils.py +42 -0
- linkml_store/utils/object_utils.py +190 -0
- linkml_store/utils/pandas_utils.py +93 -0
- linkml_store/utils/patch_utils.py +126 -0
- linkml_store/utils/query_utils.py +89 -0
- linkml_store/utils/schema_utils.py +23 -0
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/sql_utils.py +177 -0
- linkml_store/utils/stats_utils.py +53 -0
- linkml_store/utils/vector_utils.py +158 -0
- linkml_store/webapi/__init__.py +0 -0
- linkml_store/webapi/html/__init__.py +3 -0
- linkml_store/webapi/html/base.html.j2 +24 -0
- linkml_store/webapi/html/collection_details.html.j2 +15 -0
- linkml_store/webapi/html/database_details.html.j2 +16 -0
- linkml_store/webapi/html/databases.html.j2 +14 -0
- linkml_store/webapi/html/generic.html.j2 +43 -0
- linkml_store/webapi/main.py +855 -0
- linkml_store-0.3.0.dist-info/METADATA +226 -0
- linkml_store-0.3.0.dist-info/RECORD +101 -0
- linkml_store-0.3.0.dist-info/WHEEL +4 -0
- linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
- linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,826 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for the plotting package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
|
|
11
|
+
from linkml_store.plotting.heatmap import heatmap_from_file, export_heatmap_data
|
|
12
|
+
from linkml_store.utils.format_utils import Format, load_objects
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@click.group()
|
|
19
|
+
def plot_cli():
|
|
20
|
+
"""Plotting utilities for LinkML data."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@plot_cli.command()
|
|
25
|
+
@click.argument("input_file", required=False)
|
|
26
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
27
|
+
@click.option("--y-column", "-y", required=True, help="Column to use for y-axis")
|
|
28
|
+
@click.option("--value-column", "-v", help="Column containing values (if not provided, counts will be used)")
|
|
29
|
+
@click.option("--minimum-value", "-m", type=float, help="Minimum value to include in the heatmap")
|
|
30
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
31
|
+
@click.option("--width", "-w", type=int, default=10, show_default=True, help="Width of the figure in inches")
|
|
32
|
+
@click.option("--height", "-h", type=int, default=8, show_default=True, help="Height of the figure in inches")
|
|
33
|
+
@click.option("--cmap", "-c", default="YlGnBu", show_default=True, help="Colormap to use")
|
|
34
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
35
|
+
@click.option("--format", "-f", help="Input file format")
|
|
36
|
+
@click.option("--dpi", type=int, default=300, show_default=True, help="DPI for output image")
|
|
37
|
+
@click.option("--square/--no-square", default=False, show_default=True, help="Make cells square")
|
|
38
|
+
@click.option("--annotate/--no-annotate", default=True, show_default=True, help="Annotate cells with values")
|
|
39
|
+
@click.option("--font-size", type=int, default=10, show_default=True, help="Font size for annotations and labels")
|
|
40
|
+
@click.option("--robust/--no-robust", default=False, show_default=True, help="Use robust quantiles for colormap scaling")
|
|
41
|
+
@click.option("--remove-duplicates/--no-remove-duplicates", default=False, show_default=True,
|
|
42
|
+
help="Remove duplicate x,y combinations (default) or keep all occurrences")
|
|
43
|
+
@click.option("--cluster", type=click.Choice(["none", "both", "x", "y"]), default="none", show_default=True,
|
|
44
|
+
help="Cluster axes: none (default), both, x-axis only, or y-axis only")
|
|
45
|
+
@click.option("--cluster-method", type=click.Choice(["complete", "average", "single", "ward"]), default="complete", show_default=True,
|
|
46
|
+
help="Linkage method for hierarchical clustering")
|
|
47
|
+
@click.option("--cluster-metric", type=click.Choice(["euclidean", "correlation", "cosine", "cityblock"]), default="euclidean", show_default=True,
|
|
48
|
+
help="Distance metric for clustering")
|
|
49
|
+
@click.option("--export-data", "-e", help="Export the heatmap data to this file")
|
|
50
|
+
@click.option("--export-format", "-E", type=click.Choice([f.value for f in Format]), default="csv", show_default=True,
|
|
51
|
+
help="Format for exported data")
|
|
52
|
+
def heatmap(
|
|
53
|
+
input_file: Optional[str],
|
|
54
|
+
x_column: str,
|
|
55
|
+
y_column: str,
|
|
56
|
+
value_column: Optional[str],
|
|
57
|
+
minimum_value: Optional[float],
|
|
58
|
+
title: Optional[str],
|
|
59
|
+
width: int,
|
|
60
|
+
height: int,
|
|
61
|
+
cmap: str,
|
|
62
|
+
output: str,
|
|
63
|
+
format: Optional[str],
|
|
64
|
+
dpi: int,
|
|
65
|
+
square: bool,
|
|
66
|
+
annotate: bool,
|
|
67
|
+
font_size: int,
|
|
68
|
+
robust: bool,
|
|
69
|
+
remove_duplicates: bool,
|
|
70
|
+
cluster: str,
|
|
71
|
+
cluster_method: str,
|
|
72
|
+
cluster_metric: str,
|
|
73
|
+
export_data: Optional[str],
|
|
74
|
+
export_format: Union[str, Format],
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Create a heatmap from a tabular data file.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
# From a file
|
|
81
|
+
linkml-store plot heatmap data.csv -x species -y country -o heatmap.png
|
|
82
|
+
|
|
83
|
+
# From stdin
|
|
84
|
+
cat data.csv | linkml-store plot heatmap -x species -y country -o heatmap.png
|
|
85
|
+
|
|
86
|
+
This will create a heatmap showing the frequency counts of species by country.
|
|
87
|
+
If you want to use a specific value column instead of counts:
|
|
88
|
+
|
|
89
|
+
linkml-store plot heatmap data.csv -x species -y country -v population -o heatmap.png
|
|
90
|
+
"""
|
|
91
|
+
# Handle file path - if None, use stdin
|
|
92
|
+
if input_file is None:
|
|
93
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
94
|
+
|
|
95
|
+
# Convert 'none' to False for clustering parameter
|
|
96
|
+
use_cluster = False if cluster == "none" else cluster
|
|
97
|
+
|
|
98
|
+
# Create heatmap visualization
|
|
99
|
+
fig, ax = heatmap_from_file(
|
|
100
|
+
file_path=input_file,
|
|
101
|
+
x_column=x_column,
|
|
102
|
+
y_column=y_column,
|
|
103
|
+
value_column=value_column,
|
|
104
|
+
minimum_value=minimum_value,
|
|
105
|
+
title=title,
|
|
106
|
+
figsize=(width, height),
|
|
107
|
+
cmap=cmap,
|
|
108
|
+
output_file=output,
|
|
109
|
+
format=format,
|
|
110
|
+
dpi=dpi,
|
|
111
|
+
square=square,
|
|
112
|
+
annot=annotate,
|
|
113
|
+
font_size=font_size,
|
|
114
|
+
robust=robust,
|
|
115
|
+
remove_duplicates=remove_duplicates,
|
|
116
|
+
cluster=use_cluster,
|
|
117
|
+
cluster_method=cluster_method,
|
|
118
|
+
cluster_metric=cluster_metric,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Export data if requested
|
|
122
|
+
if export_data:
|
|
123
|
+
# For export, reuse the data already loaded for the heatmap instead of loading again
|
|
124
|
+
# This avoids the "I/O operation on closed file" error when input_file is stdin
|
|
125
|
+
import pandas as pd
|
|
126
|
+
from matplotlib.axes import Axes
|
|
127
|
+
|
|
128
|
+
# Extract the data directly from the plot
|
|
129
|
+
if hasattr(ax, 'get_figure') and hasattr(ax, 'get_children'):
|
|
130
|
+
# Extract the heatmap data from the plot itself
|
|
131
|
+
heatmap_data = {}
|
|
132
|
+
for child in ax.get_children():
|
|
133
|
+
if isinstance(child, plt.matplotlib.collections.QuadMesh):
|
|
134
|
+
# Get the colormap data
|
|
135
|
+
data_values = child.get_array()
|
|
136
|
+
rows = ax.get_yticks()
|
|
137
|
+
cols = ax.get_xticks()
|
|
138
|
+
row_labels = [item.get_text() for item in ax.get_yticklabels()]
|
|
139
|
+
col_labels = [item.get_text() for item in ax.get_xticklabels()]
|
|
140
|
+
|
|
141
|
+
# Create a dataframe from the plot data
|
|
142
|
+
heatmap_df = pd.DataFrame(
|
|
143
|
+
index=[label for label in row_labels if label],
|
|
144
|
+
columns=[label for label in col_labels if label]
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Fill in the values (if we can)
|
|
148
|
+
if len(data_values) == len(row_labels) * len(col_labels):
|
|
149
|
+
for i, row in enumerate(row_labels):
|
|
150
|
+
for j, col in enumerate(col_labels):
|
|
151
|
+
if row and col: # Skip empty labels
|
|
152
|
+
idx = i * len(col_labels) + j
|
|
153
|
+
if idx < len(data_values):
|
|
154
|
+
heatmap_df.at[row, col] = data_values[idx]
|
|
155
|
+
|
|
156
|
+
# Reset index to make the y_column a regular column
|
|
157
|
+
result_df = heatmap_df.reset_index()
|
|
158
|
+
result_df.rename(columns={'index': y_column}, inplace=True)
|
|
159
|
+
|
|
160
|
+
# Export the data
|
|
161
|
+
from linkml_store.utils.format_utils import write_output
|
|
162
|
+
records = result_df.to_dict(orient='records')
|
|
163
|
+
write_output(records, format=export_format, target=export_data)
|
|
164
|
+
click.echo(f"Heatmap data exported to {export_data}")
|
|
165
|
+
break
|
|
166
|
+
else:
|
|
167
|
+
# If we couldn't extract data from the plot, inform the user
|
|
168
|
+
click.echo("Warning: Could not export data from the plot")
|
|
169
|
+
else:
|
|
170
|
+
click.echo("Warning: Could not export data from the plot")
|
|
171
|
+
|
|
172
|
+
click.echo(f"Heatmap created at {output}")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@plot_cli.command()
|
|
176
|
+
@click.argument("input_file", required=False)
|
|
177
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
178
|
+
@click.option("--bins", "-b", type=int, default=10, show_default=True, help="Number of bins for the histogram")
|
|
179
|
+
@click.option("--value-column", "-v", help="Column containing values (if not provided, counts will be used)")
|
|
180
|
+
@click.option("--x-log-scale/--no-x-log-scale", default=False, show_default=True, help="Use log scale for the x-axis")
|
|
181
|
+
@click.option("--y-log-scale/--no-y-log-scale", default=False, show_default=True, help="Use log scale for the y-axis")
|
|
182
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
183
|
+
@click.option("--width", "-w", type=int, default=10, show_default=True, help="Width of the figure in inches")
|
|
184
|
+
@click.option("--height", "-h", type=int, default=8, show_default=True, help="Height of the figure in inches")
|
|
185
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
186
|
+
def histogram(
|
|
187
|
+
input_file: Optional[str],
|
|
188
|
+
x_column: str,
|
|
189
|
+
bins: int,
|
|
190
|
+
value_column: Optional[str],
|
|
191
|
+
x_log_scale: bool,
|
|
192
|
+
y_log_scale: bool,
|
|
193
|
+
title: Optional[str],
|
|
194
|
+
width: int,
|
|
195
|
+
height: int,
|
|
196
|
+
output: str,
|
|
197
|
+
):
|
|
198
|
+
"""
|
|
199
|
+
Create a histogram from a tabular data file.
|
|
200
|
+
"""
|
|
201
|
+
# Handle file path - if None, use stdin
|
|
202
|
+
if input_file is None:
|
|
203
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
204
|
+
|
|
205
|
+
objs = load_objects(input_file)
|
|
206
|
+
import pandas as pd
|
|
207
|
+
df = pd.DataFrame(objs)
|
|
208
|
+
|
|
209
|
+
# if the x column is a list, then translate it to the length of the list
|
|
210
|
+
if isinstance(df[x_column].iloc[0], (list, tuple)):
|
|
211
|
+
df[x_column] = df[x_column].apply(lambda x: len(x) if isinstance(x, (list, tuple)) else x)
|
|
212
|
+
|
|
213
|
+
# Debug: Check your DataFrame first
|
|
214
|
+
print("DataFrame shape:", df.shape)
|
|
215
|
+
print("DataFrame head:")
|
|
216
|
+
print(df.head())
|
|
217
|
+
print("\nColumn names:", df.columns.tolist())
|
|
218
|
+
print("Data types:")
|
|
219
|
+
print(df.dtypes)
|
|
220
|
+
print("\nSize column info:")
|
|
221
|
+
print("Unique values:", df[x_column].nunique())
|
|
222
|
+
print("Sample values:", df[x_column].unique()[:10])
|
|
223
|
+
print("Any null values?", df[x_column].isnull().sum())
|
|
224
|
+
|
|
225
|
+
import matplotlib.pyplot as plt
|
|
226
|
+
# Count the frequency of each size value
|
|
227
|
+
size_counts = df[x_column].value_counts().sort_index()
|
|
228
|
+
|
|
229
|
+
# Create the bar chart
|
|
230
|
+
plt.figure(figsize=(10, 6))
|
|
231
|
+
if bins == 0:
|
|
232
|
+
min_val = int(df[x_column].min())
|
|
233
|
+
max_val = int(df[x_column].max())
|
|
234
|
+
bin_edges = range(min_val, max_val + 2) # +2 to include the last value
|
|
235
|
+
plt.hist(df[x_column], bins=bin_edges, alpha=0.7, edgecolor='black', linewidth=0.5)
|
|
236
|
+
else:
|
|
237
|
+
plt.hist(df[x_column], bins=bins, alpha=0.7, edgecolor='black', linewidth=0.5)
|
|
238
|
+
plt.xlabel(x_column.replace('_', ' ').title())
|
|
239
|
+
plt.ylabel('Frequency')
|
|
240
|
+
plt.title(title or f'Distribution of {x_column}')
|
|
241
|
+
|
|
242
|
+
if x_log_scale:
|
|
243
|
+
plt.xscale('log')
|
|
244
|
+
if y_log_scale:
|
|
245
|
+
plt.yscale('log')
|
|
246
|
+
|
|
247
|
+
# Add some stats to the plot
|
|
248
|
+
mean_val = df[x_column].mean()
|
|
249
|
+
median_val = df[x_column].median()
|
|
250
|
+
plt.axvline(mean_val, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_val:.1f}')
|
|
251
|
+
plt.axvline(median_val, color='orange', linestyle='--', alpha=0.7, label=f'Median: {median_val:.1f}')
|
|
252
|
+
plt.legend()
|
|
253
|
+
|
|
254
|
+
# Rotate x-axis labels if there are many unique sizes
|
|
255
|
+
if len(size_counts) > 10:
|
|
256
|
+
plt.xticks(rotation=45)
|
|
257
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
258
|
+
plt.close()
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@plot_cli.command()
|
|
262
|
+
@click.argument("input_file", required=False)
|
|
263
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
264
|
+
@click.option("--y-column", "-y", required=True, help="Column to use for y-axis")
|
|
265
|
+
@click.option("--x-log-scale/--no-x-log-scale", default=False, show_default=True, help="Use log scale for the x-axis")
|
|
266
|
+
@click.option("--y-log-scale/--no-y-log-scale", default=False, show_default=True, help="Use log scale for the y-axis")
|
|
267
|
+
@click.option("--value-column", "-v", help="Column containing values (if not provided, counts will be used)")
|
|
268
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
269
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
270
|
+
def boxplot_old(
|
|
271
|
+
input_file: Optional[str],
|
|
272
|
+
x_column: str,
|
|
273
|
+
y_column: str,
|
|
274
|
+
x_log_scale: bool,
|
|
275
|
+
y_log_scale: bool,
|
|
276
|
+
value_column: Optional[str],
|
|
277
|
+
title: Optional[str],
|
|
278
|
+
output: str,
|
|
279
|
+
):
|
|
280
|
+
"""
|
|
281
|
+
Create a boxplot from a tabular data file.
|
|
282
|
+
"""
|
|
283
|
+
# Handle file path - if None, use stdin
|
|
284
|
+
if input_file is None:
|
|
285
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
286
|
+
|
|
287
|
+
objs = load_objects(input_file)
|
|
288
|
+
|
|
289
|
+
import pandas as pd
|
|
290
|
+
df = pd.DataFrame(objs)
|
|
291
|
+
|
|
292
|
+
# if y column is a list, explode it
|
|
293
|
+
if isinstance(df[y_column].iloc[0], (list, tuple)):
|
|
294
|
+
df[y_column] = df[y_column].apply(lambda x: x[0] if isinstance(x, (list, tuple)) else x)
|
|
295
|
+
print("MADE A LIST INTO A SINGLE VALUE", df[y_column].head())
|
|
296
|
+
if isinstance(df[x_column].iloc[0], (list, tuple)):
|
|
297
|
+
df[x_column] = df[x_column].apply(lambda x: x[0] if isinstance(x, (list, tuple)) else x)
|
|
298
|
+
print("MADE A LIST INTO A SINGLE VALUE", df[x_column].head())
|
|
299
|
+
|
|
300
|
+
import seaborn as sns
|
|
301
|
+
import matplotlib.pyplot as plt
|
|
302
|
+
|
|
303
|
+
plt.figure(figsize=(10, 6))
|
|
304
|
+
ax = sns.boxplot(data=df, x=x_column, y=y_column,
|
|
305
|
+
# Outlier customization
|
|
306
|
+
flierprops={'marker': 'o', # circle markers
|
|
307
|
+
'markerfacecolor': 'red', # fill color
|
|
308
|
+
'markersize': 5, # size
|
|
309
|
+
'alpha': 0.7}) # transparency
|
|
310
|
+
|
|
311
|
+
if x_log_scale:
|
|
312
|
+
plt.xscale('log')
|
|
313
|
+
if y_log_scale:
|
|
314
|
+
plt.yscale('log')
|
|
315
|
+
|
|
316
|
+
plt.xticks(rotation=45)
|
|
317
|
+
plt.title(title)
|
|
318
|
+
plt.ylabel(y_column.replace('_', ' ').title())
|
|
319
|
+
plt.tight_layout()
|
|
320
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
321
|
+
plt.close()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@plot_cli.command()
|
|
325
|
+
@click.argument("input_file", required=False)
|
|
326
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
327
|
+
@click.option("--y-column", "-y", required=False, help="Column to use for y-axis. If not specified, will count")
|
|
328
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
329
|
+
@click.option("--width", "-w", type=int, default=10, show_default=True, help="Width of the figure in inches")
|
|
330
|
+
@click.option("--height", "-h", type=int, default=8, show_default=True, help="Height of the figure in inches")
|
|
331
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
332
|
+
def barchart(input_file: Optional[str], x_column: str, y_column: str, title: Optional[str], width: int, height: int, output: str):
|
|
333
|
+
"""
|
|
334
|
+
Create a barchart from a tabular data file.
|
|
335
|
+
"""
|
|
336
|
+
# Handle file path - if None, use stdin
|
|
337
|
+
if input_file is None:
|
|
338
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
339
|
+
|
|
340
|
+
objs = load_objects(input_file)
|
|
341
|
+
import pandas as pd
|
|
342
|
+
df = pd.DataFrame(objs)
|
|
343
|
+
import matplotlib.pyplot as plt
|
|
344
|
+
|
|
345
|
+
if not y_column:
|
|
346
|
+
df[x_column].value_counts().plot(kind='bar', figsize=(width, height))
|
|
347
|
+
else:
|
|
348
|
+
df.groupby(x_column)[y_column].value_counts().unstack().plot(kind='bar', figsize=(width, height))
|
|
349
|
+
plt.title(title)
|
|
350
|
+
plt.ylabel(x_column.replace('_', ' ').title())
|
|
351
|
+
plt.xticks(rotation=45)
|
|
352
|
+
plt.tight_layout()
|
|
353
|
+
plt.tight_layout()
|
|
354
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
355
|
+
plt.close()
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@plot_cli.command()
|
|
359
|
+
@click.argument("input_file", required=False)
|
|
360
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
361
|
+
@click.option("--y-column", "-y", required=True, help="Column to use for y-axis")
|
|
362
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
363
|
+
@click.option("--width", "-w", type=int, default=10, show_default=True, help="Width of the figure in inches")
|
|
364
|
+
@click.option("--height", "-h", type=int, default=8, show_default=True, help="Height of the figure in inches")
|
|
365
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
366
|
+
def diverging_barchart(input_file: Optional[str], x_column: str, y_column: str, title: Optional[str], width: int, height: int, output: str):
|
|
367
|
+
"""
|
|
368
|
+
Create a diverging barchart from a tabular data file.
|
|
369
|
+
|
|
370
|
+
The x-axis is the score, and the y-axis is the y_column.
|
|
371
|
+
The bars are colored red if the score is negative, and green if the score is positive.
|
|
372
|
+
The bars are annotated with the score value.
|
|
373
|
+
The bars are sorted by the score value.
|
|
374
|
+
The bars are centered on the score value.
|
|
375
|
+
"""
|
|
376
|
+
# Handle file path - if None, use stdin
|
|
377
|
+
if input_file is None:
|
|
378
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
379
|
+
|
|
380
|
+
objs = load_objects(input_file)
|
|
381
|
+
import pandas as pd
|
|
382
|
+
df = pd.DataFrame(objs)
|
|
383
|
+
import pandas as pd
|
|
384
|
+
import matplotlib.pyplot as plt
|
|
385
|
+
import seaborn as sns
|
|
386
|
+
import numpy as np
|
|
387
|
+
|
|
388
|
+
# Calculate appropriate figure height based on number of rows
|
|
389
|
+
num_rows = len(df)
|
|
390
|
+
calculated_height = max(height, num_rows * 0.4) # At least 0.4 inches per row
|
|
391
|
+
|
|
392
|
+
plt.figure(figsize=(width, calculated_height))
|
|
393
|
+
|
|
394
|
+
# Create color palette based on actual values
|
|
395
|
+
colors = ['#d62728' if x < 0 else '#2ca02c' for x in df[x_column]]
|
|
396
|
+
|
|
397
|
+
# Create the plot using seaborn with explicit color mapping
|
|
398
|
+
ax = sns.barplot(data=df, y=y_column, x=x_column, palette=colors, order=df[y_column])
|
|
399
|
+
|
|
400
|
+
# Add vertical line at x=0
|
|
401
|
+
plt.axvline(x=0, color='black', linestyle='-', linewidth=2, alpha=0.8, zorder=10)
|
|
402
|
+
|
|
403
|
+
# Customize
|
|
404
|
+
plt.xlabel('Score', fontsize=12, fontweight='bold')
|
|
405
|
+
plt.ylabel('Tasks', fontsize=12, fontweight='bold')
|
|
406
|
+
|
|
407
|
+
# Use provided title or default
|
|
408
|
+
plot_title = title if title else 'Task Scores Distribution'
|
|
409
|
+
plt.title(plot_title, fontsize=14, fontweight='bold', pad=20)
|
|
410
|
+
|
|
411
|
+
# Set x-axis limits based on actual data range
|
|
412
|
+
x_min, x_max = df[x_column].min(), df[x_column].max()
|
|
413
|
+
margin = max(0.1, (x_max - x_min) * 0.1) # 10% margin or 0.1, whichever is larger
|
|
414
|
+
plt.xlim(x_min - margin, x_max + margin)
|
|
415
|
+
|
|
416
|
+
# Ensure y-axis labels are not truncated
|
|
417
|
+
plt.subplots_adjust(left=0.3) # Increase left margin
|
|
418
|
+
|
|
419
|
+
# Add score annotations using the actual bar positions
|
|
420
|
+
bars = ax.patches
|
|
421
|
+
for i, (bar, score) in enumerate(zip(bars, df[x_column])):
|
|
422
|
+
# Get the actual y-position of the bar center
|
|
423
|
+
y_pos = bar.get_y() + bar.get_height() / 2
|
|
424
|
+
|
|
425
|
+
# Position text based on score value
|
|
426
|
+
x_offset = 0.02 * (x_max - x_min) # 2% of data range
|
|
427
|
+
x_pos = score + (x_offset if score >= 0 else -x_offset)
|
|
428
|
+
|
|
429
|
+
plt.text(x_pos, y_pos, f'{score:.2f}',
|
|
430
|
+
va='center', ha='left' if score >= 0 else 'right',
|
|
431
|
+
fontsize=max(8, min(10, 120 / num_rows)),
|
|
432
|
+
zorder=11) # Ensure text is on top
|
|
433
|
+
|
|
434
|
+
ax.tick_params(axis='y', labelsize=10)
|
|
435
|
+
# Make y-axis labels smaller if there are many rows
|
|
436
|
+
#if num_rows > 20:
|
|
437
|
+
# ax.tick_params(axis='y', labelsize=max(6, min(10, 200 / num_rows)))
|
|
438
|
+
|
|
439
|
+
plt.tight_layout()
|
|
440
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
441
|
+
plt.close()
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
# lineplot
|
|
445
|
+
@plot_cli.command()
|
|
446
|
+
@click.argument("input_file", required=False)
|
|
447
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
448
|
+
@click.option("--group-by", "-g", required=True, help="Column to group by")
|
|
449
|
+
@click.option("--period", "-p", help="Period to group by (e.g. M, Y, Q, W, D)")
|
|
450
|
+
@click.option("--exclude", "-E", help="Exclude group-by values (comma-separated)")
|
|
451
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
452
|
+
@click.option("--minimum-entries", "-m", type=int, default=1, help="Exclude groups with fewer than this number of entries")
|
|
453
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
454
|
+
def lineplot(input_file: Optional[str], x_column: str, group_by: str, period: str, exclude: Optional[str], minimum_entries: int, title: Optional[str], output: str):
|
|
455
|
+
"""
|
|
456
|
+
Create a lineplot from a tabular data file.
|
|
457
|
+
"""
|
|
458
|
+
# Handle file path - if None, use stdin
|
|
459
|
+
if input_file is None:
|
|
460
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
461
|
+
|
|
462
|
+
objs = load_objects(input_file)
|
|
463
|
+
import pandas as pd
|
|
464
|
+
df = pd.DataFrame(objs)
|
|
465
|
+
|
|
466
|
+
if exclude:
|
|
467
|
+
exclude_values = exclude.split(',')
|
|
468
|
+
df = df[~df[group_by].isin(exclude_values)]
|
|
469
|
+
|
|
470
|
+
if minimum_entries and minimum_entries > 1:
|
|
471
|
+
df = df.groupby(group_by).filter(lambda x: len(x) >= minimum_entries)
|
|
472
|
+
|
|
473
|
+
# assume datetime
|
|
474
|
+
if period:
|
|
475
|
+
df[x_column] = pd.to_datetime(df[x_column], errors='coerce') # Convert invalid to NaT
|
|
476
|
+
if df[x_column].dt.tz is not None:
|
|
477
|
+
df[x_column] = df[x_column].dt.tz_localize(None)
|
|
478
|
+
# Drop NaT values
|
|
479
|
+
df = df.dropna(subset=[x_column])
|
|
480
|
+
df['period'] = df[x_column].dt.to_period(period)
|
|
481
|
+
# Convert period back to timestamp for plotting
|
|
482
|
+
grouped_data = df.groupby(['period', group_by]).size().reset_index(name='count')
|
|
483
|
+
grouped_data['period'] = grouped_data['period'].dt.to_timestamp()
|
|
484
|
+
else:
|
|
485
|
+
grouped_data = df.groupby([group_by]).size().reset_index(name='count')
|
|
486
|
+
grouped_data['period'] = grouped_data[group_by]
|
|
487
|
+
|
|
488
|
+
import matplotlib.pyplot as plt
|
|
489
|
+
import seaborn as sns
|
|
490
|
+
|
|
491
|
+
plt.figure(figsize=(12, 8)) # Increased size for better readability
|
|
492
|
+
|
|
493
|
+
# Use a colorblind-friendly palette
|
|
494
|
+
colors = sns.color_palette("colorblind", n_colors=len(grouped_data[group_by].unique()))
|
|
495
|
+
|
|
496
|
+
# Define line styles for additional differentiation
|
|
497
|
+
line_styles = ['-', '--', '-.', ':', '-', '--', '-.', ':']
|
|
498
|
+
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p']
|
|
499
|
+
|
|
500
|
+
# Create the plot with different styles for each line
|
|
501
|
+
unique_groups = grouped_data[group_by].unique()
|
|
502
|
+
|
|
503
|
+
for i, group in enumerate(unique_groups):
|
|
504
|
+
group_data = grouped_data[grouped_data[group_by] == group]
|
|
505
|
+
plt.plot(group_data['period'], group_data['count'],
|
|
506
|
+
label=group,
|
|
507
|
+
color=colors[i % len(colors)],
|
|
508
|
+
linestyle=line_styles[i % len(line_styles)],
|
|
509
|
+
marker=markers[i % len(markers)],
|
|
510
|
+
markersize=6,
|
|
511
|
+
linewidth=2.5,
|
|
512
|
+
markevery=max(1, len(group_data) // 10)) # Show markers at reasonable intervals
|
|
513
|
+
|
|
514
|
+
# Add direct labels at the end of each line
|
|
515
|
+
for i, group in enumerate(unique_groups):
|
|
516
|
+
group_data = grouped_data[grouped_data[group_by] == group].sort_values('period')
|
|
517
|
+
if len(group_data) > 0:
|
|
518
|
+
last_point = group_data.iloc[-1]
|
|
519
|
+
plt.annotate(group,
|
|
520
|
+
xy=(last_point['period'], last_point['count']),
|
|
521
|
+
xytext=(10, 0),
|
|
522
|
+
textcoords='offset points',
|
|
523
|
+
fontsize=10,
|
|
524
|
+
fontweight='bold',
|
|
525
|
+
va='center',
|
|
526
|
+
color=colors[i % len(colors)])
|
|
527
|
+
|
|
528
|
+
plt.title(title, fontsize=14, fontweight='bold')
|
|
529
|
+
plt.ylabel('Count', fontsize=12)
|
|
530
|
+
plt.xlabel(x_column.replace('_', ' ').title(), fontsize=12)
|
|
531
|
+
|
|
532
|
+
# Improve legend
|
|
533
|
+
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
|
|
534
|
+
|
|
535
|
+
# Add grid for better readability
|
|
536
|
+
plt.grid(True, alpha=0.3)
|
|
537
|
+
|
|
538
|
+
# Adjust layout to prevent label cutoff
|
|
539
|
+
plt.tight_layout()
|
|
540
|
+
plt.subplots_adjust(right=0.8) # Make room for end labels
|
|
541
|
+
|
|
542
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
543
|
+
plt.close()
|
|
544
|
+
|
|
545
|
+
def calculate_correlation(df: pd.DataFrame, x_column: str, y_column: str) -> float:
|
|
546
|
+
"""Calculate the correlation coefficient between two columns."""
|
|
547
|
+
return df[x_column].corr(df[y_column])
|
|
548
|
+
|
|
549
|
+
# scatterplot
|
|
550
|
+
@plot_cli.command()
|
|
551
|
+
@click.argument("input_file", required=False)
|
|
552
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
553
|
+
@click.option("--y-column", "-y", required=True, help="Column to use for y-axis")
|
|
554
|
+
@click.option("--include-correlation", "-c", is_flag=True, help="Include correlation coefficient in the plot, and add a line of best fit")
|
|
555
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
556
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
557
|
+
def scatterplot(input_file: Optional[str], x_column: str, y_column: str, include_correlation: bool, title: Optional[str], output: str):
|
|
558
|
+
"""
|
|
559
|
+
Create a scatterplot from a tabular data file.
|
|
560
|
+
"""
|
|
561
|
+
# Handle file path - if None, use stdin
|
|
562
|
+
if input_file is None:
|
|
563
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
564
|
+
|
|
565
|
+
objs = load_objects(input_file)
|
|
566
|
+
import pandas as pd
|
|
567
|
+
df = pd.DataFrame(objs)
|
|
568
|
+
import seaborn as sns
|
|
569
|
+
import matplotlib.pyplot as plt
|
|
570
|
+
correlation = calculate_correlation(df, x_column, y_column)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
plt.figure(figsize=(10, 6))
|
|
574
|
+
sns.scatterplot(data=df, x=x_column, y=y_column, label=f"Correlation: {correlation:.2f}")
|
|
575
|
+
|
|
576
|
+
sns.regplot(data=df, x=x_column, y=y_column, label=f"Correlation: {correlation:.2f}")
|
|
577
|
+
|
|
578
|
+
plt.title(title)
|
|
579
|
+
plt.ylabel(y_column.replace('_', ' ').title())
|
|
580
|
+
plt.xlabel(x_column.replace('_', ' ').title())
|
|
581
|
+
plt.tight_layout()
|
|
582
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
583
|
+
plt.close()
|
|
584
|
+
|
|
585
|
+
@plot_cli.command()
|
|
586
|
+
@click.argument("input_file", required=False)
|
|
587
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
588
|
+
@click.option("--y-column", "-y", required=True, help="Column to use for y-axis")
|
|
589
|
+
@click.option("--group-by", "-g", required=False, help="Column to group by")
|
|
590
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
591
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
592
|
+
def barplot(input_file: Optional[str], x_column: str, y_column: str, group_by: str, title: Optional[str], output: str):
|
|
593
|
+
"""
|
|
594
|
+
Create a barplot from a tabular data file.
|
|
595
|
+
"""
|
|
596
|
+
import pandas as pd
|
|
597
|
+
import seaborn as sns
|
|
598
|
+
import matplotlib.pyplot as plt
|
|
599
|
+
|
|
600
|
+
# Handle file path - if None, use stdin
|
|
601
|
+
if input_file is None:
|
|
602
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
603
|
+
|
|
604
|
+
objs = load_objects(input_file)
|
|
605
|
+
df = pd.DataFrame(objs)
|
|
606
|
+
|
|
607
|
+
plt.figure(figsize=(10, 6))
|
|
608
|
+
sns.barplot(data=df, x=x_column, y=y_column)
|
|
609
|
+
plt.title(title)
|
|
610
|
+
# save the plot
|
|
611
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
612
|
+
plt.close()
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
@plot_cli.command()
|
|
616
|
+
@click.argument("input_file", required=False)
|
|
617
|
+
@click.option("--x-column", "-x", required=True, help="Column to use for x-axis")
|
|
618
|
+
@click.option("--y-column", "-y", required=True, help="Column to use for y-axis")
|
|
619
|
+
@click.option("--y-explode-lists", "-Y", is_flag=True, help="Explode list values in y-column into separate rows")
|
|
620
|
+
@click.option("--group-by", "-g", required=False, help="Column to group by")
|
|
621
|
+
@click.option("--width", "-w", type=int, default=10, show_default=True, help="Width of the figure in inches")
|
|
622
|
+
@click.option("--height", "-h", type=int, default=8, show_default=True, help="Height of the figure in inches")
|
|
623
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
624
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
625
|
+
def boxplot(input_file: Optional[str], x_column: str, y_column: str, y_explode_lists: bool, group_by: str, width: int, height: int, title: Optional[str], output: str):
|
|
626
|
+
"""
|
|
627
|
+
Create a boxplot from a tabular data file.
|
|
628
|
+
"""
|
|
629
|
+
import pandas as pd
|
|
630
|
+
import seaborn as sns
|
|
631
|
+
import matplotlib.pyplot as plt
|
|
632
|
+
|
|
633
|
+
# Handle file path - if None, use stdin
|
|
634
|
+
if input_file is None:
|
|
635
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
636
|
+
|
|
637
|
+
objs = load_objects(input_file)
|
|
638
|
+
df = pd.DataFrame(objs)
|
|
639
|
+
|
|
640
|
+
print("Y COLUMN", type(df[y_column].iloc[0]), isinstance(df[y_column].iloc[0], (list, tuple)), df[y_column].head())
|
|
641
|
+
print("X COLUMN", df[x_column].head())
|
|
642
|
+
|
|
643
|
+
# if y column is a list, join or explode it
|
|
644
|
+
if isinstance(df[y_column].iloc[0], (list, tuple)):
|
|
645
|
+
if y_explode_lists:
|
|
646
|
+
# Explode the list into separate rows
|
|
647
|
+
df = df.explode(y_column).reset_index(drop=True)
|
|
648
|
+
else:
|
|
649
|
+
# Join the list elements into a single string
|
|
650
|
+
df[y_column] = df[y_column].apply(lambda x: ",".join(x or []) if isinstance(x, (list, tuple)) else x)
|
|
651
|
+
if isinstance(df[x_column].iloc[0], (list, tuple)):
|
|
652
|
+
df[x_column] = df[x_column].apply(lambda x: ",".join(x or []) if isinstance(x, (list, tuple)) else x)
|
|
653
|
+
|
|
654
|
+
# sort the dataframe by the x_column
|
|
655
|
+
df = df.sort_values(by=x_column, ascending=False)
|
|
656
|
+
|
|
657
|
+
# Define the desired order for your ranges
|
|
658
|
+
# range_order = sorted(df[x_column].unique())
|
|
659
|
+
|
|
660
|
+
plt.figure(figsize=(width, height))
|
|
661
|
+
sns.catplot(data=df, x=x_column, y=y_column, hue=group_by, kind="box",
|
|
662
|
+
height=height, aspect=width/height)
|
|
663
|
+
#sns.boxplot(data=df, x=x_column, y=y_column, hue=group_by, order=range_order)
|
|
664
|
+
plt.title(title)
|
|
665
|
+
# save the plot
|
|
666
|
+
plt.savefig(output, bbox_inches="tight", dpi=150)
|
|
667
|
+
plt.close()
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
@plot_cli.command()
|
|
671
|
+
@click.argument("input_file", required=False)
|
|
672
|
+
@click.option("--title", "-t", help="Title for the heatmap")
|
|
673
|
+
@click.option("--width", "-w", type=int, default=10, show_default=True, help="Width of the figure in inches")
|
|
674
|
+
@click.option("--height", "-h", type=int, default=8, show_default=True, help="Height of the figure in inches")
|
|
675
|
+
@click.option("--output", "-o", required=True, help="Output file path")
|
|
676
|
+
def facet_chart(
|
|
677
|
+
input_file: Optional[str],
|
|
678
|
+
title: Optional[str],
|
|
679
|
+
width: int,
|
|
680
|
+
height: int,
|
|
681
|
+
output: str,
|
|
682
|
+
):
|
|
683
|
+
"""
|
|
684
|
+
Create a facet chart from a tabular data file.
|
|
685
|
+
"""
|
|
686
|
+
# Handle file path - if None, use stdin
|
|
687
|
+
if input_file is None:
|
|
688
|
+
input_file = "-" # format_utils treats "-" as stdin
|
|
689
|
+
|
|
690
|
+
objs = load_objects(input_file)
|
|
691
|
+
if len(objs) != 1:
|
|
692
|
+
raise ValueError("Facet chart requires exactly one object")
|
|
693
|
+
|
|
694
|
+
from linkml_store.plotting.facet_chart import create_faceted_horizontal_barchart
|
|
695
|
+
create_faceted_horizontal_barchart(objs[0], output)
|
|
696
|
+
click.echo(f"Facet chart saved to {output}")
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
@plot_cli.command()
|
|
700
|
+
@click.pass_context
|
|
701
|
+
@click.option("--collections", "-c", help="Comma-separated list of collection names", required=True)
|
|
702
|
+
@click.option("--method", "-m", type=click.Choice(["umap", "tsne", "pca"]), default="tsne", help="Reduction method")
|
|
703
|
+
@click.option("--index-name", "-i", help="Name of index to use (defaults to first available)")
|
|
704
|
+
@click.option("--color-field", help="Field to use for coloring points")
|
|
705
|
+
@click.option("--shape-field", default="collection", help="Field to use for point shapes")
|
|
706
|
+
@click.option("--size-field", help="Field to use for point sizes")
|
|
707
|
+
@click.option("--hover-fields", help="Comma-separated list of fields to show on hover")
|
|
708
|
+
@click.option("--limit-per-collection", "-l", type=int, help="Max embeddings per collection")
|
|
709
|
+
@click.option("--n-neighbors", type=int, default=15, help="UMAP n_neighbors parameter")
|
|
710
|
+
@click.option("--min-dist", type=float, default=0.1, help="UMAP min_dist parameter")
|
|
711
|
+
@click.option("--perplexity", type=float, default=30.0, help="t-SNE perplexity parameter")
|
|
712
|
+
@click.option("--random-state", type=int, default=42, help="Random seed for reproducibility")
|
|
713
|
+
@click.option("--width", type=int, default=800, help="Plot width in pixels")
|
|
714
|
+
@click.option("--height", type=int, default=600, help="Plot height in pixels")
|
|
715
|
+
@click.option("--dark-mode/--no-dark-mode", default=False, help="Use dark mode theme")
|
|
716
|
+
@click.option("--output", "-o", type=click.Path(), help="Output HTML file path")
|
|
717
|
+
def multi_collection_embeddings(ctx, collections, method, index_name, color_field, shape_field,
|
|
718
|
+
size_field, hover_fields, limit_per_collection, n_neighbors,
|
|
719
|
+
min_dist, perplexity, random_state, width, height, dark_mode, output):
|
|
720
|
+
"""
|
|
721
|
+
Create an interactive plot of embeddings from indexed collections.
|
|
722
|
+
|
|
723
|
+
Example:
|
|
724
|
+
linkml-store -d mydb.ddb plot multi-collection-embeddings --collections coll1,coll2 --method umap -o plot.html
|
|
725
|
+
"""
|
|
726
|
+
from linkml_store.utils.embedding_utils import extract_embeddings_from_multiple_collections
|
|
727
|
+
from linkml_store.plotting.dimensionality_reduction import reduce_dimensions
|
|
728
|
+
from linkml_store.plotting.embedding_plot import plot_embeddings as create_plot, EmbeddingPlotConfig
|
|
729
|
+
|
|
730
|
+
# Parse collections
|
|
731
|
+
collection_names = [c.strip() for c in collections.split(",")]
|
|
732
|
+
|
|
733
|
+
# Parse hover fields
|
|
734
|
+
hover_field_list = []
|
|
735
|
+
if hover_fields:
|
|
736
|
+
hover_field_list = [f.strip() for f in hover_fields.split(",")]
|
|
737
|
+
|
|
738
|
+
# Extract embeddings
|
|
739
|
+
db = ctx.obj["settings"].database
|
|
740
|
+
click.echo(f"Extracting embeddings from collections: {collection_names}")
|
|
741
|
+
|
|
742
|
+
embedding_data = extract_embeddings_from_multiple_collections(
|
|
743
|
+
database=db,
|
|
744
|
+
collection_names=collection_names,
|
|
745
|
+
index_name=index_name,
|
|
746
|
+
limit_per_collection=limit_per_collection,
|
|
747
|
+
include_metadata=True,
|
|
748
|
+
normalize=True
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
click.echo(f"Extracted {embedding_data.n_samples} embeddings with {embedding_data.n_dimensions} dimensions")
|
|
752
|
+
|
|
753
|
+
# Validate embeddings before reduction
|
|
754
|
+
from linkml_store.plotting.dimensionality_reduction import validate_embeddings
|
|
755
|
+
validation_results = validate_embeddings(embedding_data.vectors)
|
|
756
|
+
|
|
757
|
+
if validation_results["warnings"]:
|
|
758
|
+
click.echo("Embedding validation warnings:", err=True)
|
|
759
|
+
for warning in validation_results["warnings"]:
|
|
760
|
+
click.echo(f" - {warning}", err=True)
|
|
761
|
+
|
|
762
|
+
# Log detailed stats for debugging
|
|
763
|
+
logger.info(f"Embedding validation results: {validation_results}")
|
|
764
|
+
|
|
765
|
+
# Perform dimensionality reduction
|
|
766
|
+
click.echo(f"Performing {method.upper()} dimensionality reduction...")
|
|
767
|
+
|
|
768
|
+
# Set method-specific parameters
|
|
769
|
+
reduction_params = {
|
|
770
|
+
"n_components": 2,
|
|
771
|
+
"random_state": random_state
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
if method == "umap":
|
|
775
|
+
reduction_params.update({
|
|
776
|
+
"n_neighbors": min(n_neighbors, embedding_data.n_samples - 1),
|
|
777
|
+
"min_dist": min_dist
|
|
778
|
+
})
|
|
779
|
+
elif method == "tsne":
|
|
780
|
+
reduction_params["perplexity"] = min(perplexity, embedding_data.n_samples / 4)
|
|
781
|
+
|
|
782
|
+
try:
|
|
783
|
+
reduction_result = reduce_dimensions(
|
|
784
|
+
embedding_data.vectors,
|
|
785
|
+
method=method,
|
|
786
|
+
**reduction_params
|
|
787
|
+
)
|
|
788
|
+
except ImportError as e:
|
|
789
|
+
click.echo(f"Missing dependency: {e}", err=True)
|
|
790
|
+
click.echo("Install with: pip install umap-learn scikit-learn", err=True)
|
|
791
|
+
return
|
|
792
|
+
except Exception as e:
|
|
793
|
+
click.echo(f"Error during dimensionality reduction: {e}", err=True)
|
|
794
|
+
return
|
|
795
|
+
|
|
796
|
+
logger.info(f"Reduction result: {reduction_result}")
|
|
797
|
+
# Create plot configuration
|
|
798
|
+
plot_config = EmbeddingPlotConfig(
|
|
799
|
+
color_field=color_field,
|
|
800
|
+
shape_field=shape_field,
|
|
801
|
+
size_field=size_field,
|
|
802
|
+
hover_fields=hover_field_list,
|
|
803
|
+
title=f"Embedding Visualization ({', '.join(collection_names)})",
|
|
804
|
+
width=width,
|
|
805
|
+
height=height,
|
|
806
|
+
dark_mode=dark_mode
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
# Create plot
|
|
810
|
+
click.echo("Creating interactive plot...")
|
|
811
|
+
fig = create_plot(
|
|
812
|
+
embedding_data=embedding_data,
|
|
813
|
+
reduction_result=reduction_result,
|
|
814
|
+
config=plot_config,
|
|
815
|
+
output_file=output
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
if output:
|
|
819
|
+
click.echo(f"Plot saved to {output}")
|
|
820
|
+
else:
|
|
821
|
+
# If no output file, try to show in browser
|
|
822
|
+
fig.show()
|
|
823
|
+
click.echo("Plot opened in browser")
|
|
824
|
+
|
|
825
|
+
if __name__ == "__main__":
|
|
826
|
+
plot_cli()
|