linkml-store 0.2.5__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of linkml-store might be problematic. Click here for more details.
- linkml_store/api/client.py +9 -6
- linkml_store/api/collection.py +118 -5
- linkml_store/api/database.py +45 -14
- linkml_store/api/stores/duckdb/duckdb_collection.py +176 -8
- linkml_store/api/stores/duckdb/duckdb_database.py +52 -19
- linkml_store/api/stores/filesystem/__init__.py +1 -1
- linkml_store/api/stores/mongodb/mongodb_collection.py +186 -0
- linkml_store/api/stores/mongodb/mongodb_database.py +8 -3
- linkml_store/api/stores/solr/solr_collection.py +7 -1
- linkml_store/cli.py +202 -21
- linkml_store/index/implementations/llm_indexer.py +14 -6
- linkml_store/index/indexer.py +7 -4
- linkml_store/inference/implementations/llm_inference_engine.py +13 -9
- linkml_store/inference/implementations/rag_inference_engine.py +13 -10
- linkml_store/inference/implementations/sklearn_inference_engine.py +7 -1
- linkml_store/inference/inference_config.py +1 -0
- linkml_store/utils/dat_parser.py +95 -0
- linkml_store/utils/enrichment_analyzer.py +217 -0
- linkml_store/utils/format_utils.py +183 -3
- linkml_store/utils/llm_utils.py +3 -1
- linkml_store/utils/pandas_utils.py +1 -1
- linkml_store/utils/sql_utils.py +7 -1
- linkml_store/utils/vector_utils.py +4 -11
- {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/METADATA +4 -3
- {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/RECORD +28 -26
- {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/WHEEL +1 -1
- {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/LICENSE +0 -0
- {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/entry_points.txt +0 -0
linkml_store/cli.py
CHANGED
|
@@ -3,6 +3,7 @@ import sys
|
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from pathlib import Path
|
|
6
|
+
from tokenize import group
|
|
6
7
|
from typing import Optional, Tuple, Any
|
|
7
8
|
|
|
8
9
|
import click
|
|
@@ -37,6 +38,11 @@ index_type_option = click.option(
|
|
|
37
38
|
show_default=True,
|
|
38
39
|
help="Type of index to create. Values: simple, llm",
|
|
39
40
|
)
|
|
41
|
+
json_select_query_option = click.option(
|
|
42
|
+
"--json-select-query",
|
|
43
|
+
"-J",
|
|
44
|
+
help="JSON SELECT query",
|
|
45
|
+
)
|
|
40
46
|
|
|
41
47
|
logger = logging.getLogger(__name__)
|
|
42
48
|
|
|
@@ -136,7 +142,7 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
|
|
|
136
142
|
logger.setLevel(logging.ERROR)
|
|
137
143
|
ctx.ensure_object(dict)
|
|
138
144
|
if input:
|
|
139
|
-
database = "duckdb"
|
|
145
|
+
database = "duckdb" # default: store in duckdb
|
|
140
146
|
if input.startswith("http"):
|
|
141
147
|
parts = input.split("/")
|
|
142
148
|
collection = parts[-1]
|
|
@@ -144,8 +150,7 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
|
|
|
144
150
|
else:
|
|
145
151
|
stem = underscore(Path(input).stem)
|
|
146
152
|
collection = stem
|
|
147
|
-
logger.info(f"Using input file: {input}, "
|
|
148
|
-
f"default storage is {database} and collection is {collection}")
|
|
153
|
+
logger.info(f"Using input file: {input}, " f"default storage is {database} and collection is {collection}")
|
|
149
154
|
config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
|
|
150
155
|
if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
|
|
151
156
|
config = DEFAULT_LOCAL_CONF_PATH
|
|
@@ -186,12 +191,24 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
|
|
|
186
191
|
|
|
187
192
|
|
|
188
193
|
@cli.command()
|
|
189
|
-
@click.
|
|
194
|
+
@click.pass_context
|
|
195
|
+
def drop(ctx):
|
|
196
|
+
"""
|
|
197
|
+
Drop database and all its collections.
|
|
198
|
+
"""
|
|
199
|
+
database = ctx.obj["settings"].database
|
|
200
|
+
database.drop()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@cli.command()
|
|
204
|
+
@click.argument("files", type=click.Path(), nargs=-1)
|
|
190
205
|
@click.option("--replace/--no-replace", default=False, show_default=True, help="Replace existing objects")
|
|
191
206
|
@click.option("--format", "-f", type=format_choice, help="Input format")
|
|
192
207
|
@click.option("--object", "-i", multiple=True, help="Input object as YAML")
|
|
208
|
+
@click.option("--source-field", help="If provided, inject file path source as this field")
|
|
209
|
+
@json_select_query_option
|
|
193
210
|
@click.pass_context
|
|
194
|
-
def insert(ctx, files, replace, object, format):
|
|
211
|
+
def insert(ctx, files, replace, object, format, source_field, json_select_query):
|
|
195
212
|
"""Insert objects from files (JSON, YAML, TSV) into the specified collection.
|
|
196
213
|
|
|
197
214
|
Using a configuration:
|
|
@@ -207,11 +224,17 @@ def insert(ctx, files, replace, object, format):
|
|
|
207
224
|
raise ValueError("Collection must be specified.")
|
|
208
225
|
if not files and not object:
|
|
209
226
|
files = ["-"]
|
|
227
|
+
load_objects_args = {}
|
|
228
|
+
if json_select_query:
|
|
229
|
+
load_objects_args["select_query"] = json_select_query
|
|
210
230
|
for file_path in files:
|
|
211
231
|
if format:
|
|
212
|
-
objects = load_objects(file_path, format=format)
|
|
232
|
+
objects = load_objects(file_path, format=format, **load_objects_args)
|
|
213
233
|
else:
|
|
214
|
-
objects = load_objects(file_path)
|
|
234
|
+
objects = load_objects(file_path, **load_objects_args)
|
|
235
|
+
if source_field:
|
|
236
|
+
for obj in objects:
|
|
237
|
+
obj[source_field] = str(file_path)
|
|
215
238
|
logger.info(f"Inserting {len(objects)} objects from {file_path} into collection '{collection.alias}'.")
|
|
216
239
|
if replace:
|
|
217
240
|
collection.replace(objects)
|
|
@@ -222,6 +245,8 @@ def insert(ctx, files, replace, object, format):
|
|
|
222
245
|
for object_str in object:
|
|
223
246
|
logger.info(f"Parsing: {object_str}")
|
|
224
247
|
objects = yaml.safe_load(object_str)
|
|
248
|
+
if not isinstance(objects, list):
|
|
249
|
+
objects = [objects]
|
|
225
250
|
if replace:
|
|
226
251
|
collection.replace(objects)
|
|
227
252
|
else:
|
|
@@ -234,21 +259,41 @@ def insert(ctx, files, replace, object, format):
|
|
|
234
259
|
@click.argument("files", type=click.Path(exists=True), nargs=-1)
|
|
235
260
|
@click.option("--format", "-f", type=format_choice, help="Input format")
|
|
236
261
|
@click.option("--object", "-i", multiple=True, help="Input object as YAML")
|
|
262
|
+
@json_select_query_option
|
|
237
263
|
@click.pass_context
|
|
238
|
-
def store(ctx, files, object, format):
|
|
264
|
+
def store(ctx, files, object, format, json_select_query):
|
|
239
265
|
"""Store objects from files (JSON, YAML, TSV) into the database.
|
|
240
266
|
|
|
241
|
-
Note: this is similar to insert, but a collection does not need to be specified
|
|
267
|
+
Note: this is similar to insert, but a collection does not need to be specified.
|
|
268
|
+
|
|
269
|
+
For example, assume that `my-collection` is a dict with multiple keys,
|
|
270
|
+
and we want one collection per key:
|
|
271
|
+
|
|
272
|
+
linkml-store -d my.ddb store my-collection.yaml
|
|
273
|
+
|
|
274
|
+
Loading JSON (e.g OBO-JSON), with a --json-select-query:
|
|
275
|
+
|
|
276
|
+
linkml-store -d cl.ddb store -J graphs cl.obo.json
|
|
277
|
+
|
|
278
|
+
Loading XML (e.g OWL-XML), with a --json-select-query:
|
|
279
|
+
|
|
280
|
+
linkml-store -d cl.ddb store -J Ontology cl.owx
|
|
281
|
+
|
|
282
|
+
Because the XML uses a top level Ontology, with multiple
|
|
283
|
+
|
|
242
284
|
"""
|
|
243
285
|
settings = ctx.obj["settings"]
|
|
244
286
|
db = settings.database
|
|
245
287
|
if not files and not object:
|
|
246
288
|
files = ["-"]
|
|
289
|
+
load_objects_args = {}
|
|
290
|
+
if json_select_query:
|
|
291
|
+
load_objects_args["select_query"] = json_select_query
|
|
247
292
|
for file_path in files:
|
|
248
293
|
if format:
|
|
249
|
-
objects = load_objects(file_path, format=format)
|
|
294
|
+
objects = load_objects(file_path, format=format, **load_objects_args)
|
|
250
295
|
else:
|
|
251
|
-
objects = load_objects(file_path)
|
|
296
|
+
objects = load_objects(file_path, **load_objects_args)
|
|
252
297
|
logger.info(f"Inserting {len(objects)} objects from {file_path} into database '{db}'.")
|
|
253
298
|
for obj in objects:
|
|
254
299
|
db.store(obj)
|
|
@@ -422,15 +467,32 @@ def list_collections(ctx, **kwargs):
|
|
|
422
467
|
|
|
423
468
|
@cli.command()
|
|
424
469
|
@click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
|
|
425
|
-
@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
|
|
470
|
+
@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return per facet")
|
|
471
|
+
@click.option("--facet-min-count", "-M", type=click.INT, help="Minimum count for a facet to be included")
|
|
426
472
|
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
427
473
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
428
|
-
@click.option("--columns", "-S", help="Columns to facet on")
|
|
474
|
+
@click.option("--columns", "-S", help="Columns to facet on. Comma-separated, join combined facets with +")
|
|
429
475
|
@click.option("--wide/--no-wide", "-U/--no-U", default=False, show_default=True, help="Wide table")
|
|
430
476
|
@click.pass_context
|
|
431
|
-
def fq(ctx, where, limit, columns, output_type, wide, output):
|
|
477
|
+
def fq(ctx, where, limit, columns, output_type, wide, output, **kwargs):
|
|
432
478
|
"""
|
|
433
|
-
Query
|
|
479
|
+
Query facet counts from the specified collection.
|
|
480
|
+
|
|
481
|
+
Assuming your .linkml.yaml includes an entry mapping `phenopackets` to a
|
|
482
|
+
mongodb
|
|
483
|
+
|
|
484
|
+
Facet counts (all columns)
|
|
485
|
+
|
|
486
|
+
linkml-store -d phenopackets fq
|
|
487
|
+
|
|
488
|
+
Nested columns:
|
|
489
|
+
|
|
490
|
+
linkml-store -d phenopackets fq subject.timeAtLastEncounter.age
|
|
491
|
+
|
|
492
|
+
Compound keys:
|
|
493
|
+
|
|
494
|
+
linkml-store -d phenopackets fq subject.sex+subject.timeAtLastEncounter.age
|
|
495
|
+
|
|
434
496
|
"""
|
|
435
497
|
collection = ctx.obj["settings"].collection
|
|
436
498
|
where_clause = yaml.safe_load(where) if where else None
|
|
@@ -439,7 +501,7 @@ def fq(ctx, where, limit, columns, output_type, wide, output):
|
|
|
439
501
|
columns = [col.strip() for col in columns]
|
|
440
502
|
columns = [(tuple(col.split("+")) if "+" in col else col) for col in columns]
|
|
441
503
|
logger.info(f"Faceting on columns: {columns}")
|
|
442
|
-
results = collection.query_facets(where_clause, facet_columns=columns,
|
|
504
|
+
results = collection.query_facets(where_clause, facet_columns=columns, facet_limit=limit, **kwargs)
|
|
443
505
|
logger.info(f"Facet results: {results}")
|
|
444
506
|
|
|
445
507
|
def _untuple(key):
|
|
@@ -471,6 +533,56 @@ def fq(ctx, where, limit, columns, output_type, wide, output):
|
|
|
471
533
|
click.echo(output_data)
|
|
472
534
|
|
|
473
535
|
|
|
536
|
+
@cli.command()
|
|
537
|
+
@click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
|
|
538
|
+
@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return per facet")
|
|
539
|
+
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
540
|
+
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
541
|
+
@click.option("--columns", "-S", help="Columns to facet on. Comma-separated, join combined facets with +")
|
|
542
|
+
@click.pass_context
|
|
543
|
+
def groupby(ctx, where, limit, columns, output_type, output, **kwargs):
|
|
544
|
+
"""
|
|
545
|
+
Group by columns in the specified collection.
|
|
546
|
+
|
|
547
|
+
Assume a simple triple model:
|
|
548
|
+
|
|
549
|
+
linkml-store -d cl.ddb -c triple insert cl.owl
|
|
550
|
+
|
|
551
|
+
This makes a flat subject/predicate/object table
|
|
552
|
+
|
|
553
|
+
This can be grouped, e.g by subject:
|
|
554
|
+
|
|
555
|
+
linkml-store -d cl.ddb -c triple groupby -s subject
|
|
556
|
+
|
|
557
|
+
Or subject and predicate:
|
|
558
|
+
|
|
559
|
+
linkml-store -d cl.ddb -c triple groupby -s '[subject,predicate]'
|
|
560
|
+
|
|
561
|
+
"""
|
|
562
|
+
collection = ctx.obj["settings"].collection
|
|
563
|
+
where_clause = yaml.safe_load(where) if where else None
|
|
564
|
+
columns = columns.split(",") if columns else None
|
|
565
|
+
if columns:
|
|
566
|
+
columns = [col.strip() for col in columns]
|
|
567
|
+
columns = [(tuple(col.split("+")) if "+" in col else col) for col in columns]
|
|
568
|
+
logger.info(f"Group by: {columns}")
|
|
569
|
+
result = collection.group_by(
|
|
570
|
+
group_by_fields=columns,
|
|
571
|
+
where_clause=where_clause,
|
|
572
|
+
agg_map={},
|
|
573
|
+
limit=limit,
|
|
574
|
+
**kwargs,
|
|
575
|
+
)
|
|
576
|
+
logger.info(f"Group by results: {result}")
|
|
577
|
+
output_data = render_output(result.rows, output_type)
|
|
578
|
+
if output:
|
|
579
|
+
with open(output, "w") as f:
|
|
580
|
+
f.write(output_data)
|
|
581
|
+
click.echo(f"Query results saved to {output}")
|
|
582
|
+
else:
|
|
583
|
+
click.echo(output_data)
|
|
584
|
+
|
|
585
|
+
|
|
474
586
|
def _get_index(index_type=None, **kwargs) -> Indexer:
|
|
475
587
|
if index_type is None or index_type == "simple":
|
|
476
588
|
return SimpleIndexer(name="test", **kwargs)
|
|
@@ -519,10 +631,12 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
|
|
|
519
631
|
value_key = tuple([row.get(att) for att in value_atts])
|
|
520
632
|
pivoted[index_key][column_key] = value_key
|
|
521
633
|
pivoted_objs = []
|
|
634
|
+
|
|
522
635
|
def detuple(t: Tuple) -> Any:
|
|
523
636
|
if len(t) == 1:
|
|
524
637
|
return t[0]
|
|
525
638
|
return str(t)
|
|
639
|
+
|
|
526
640
|
for index_key, data in pivoted.items():
|
|
527
641
|
obj = {att: key for att, key in zip(index_atts, index_key)}
|
|
528
642
|
for column_key, value_key in data.items():
|
|
@@ -531,6 +645,57 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
|
|
|
531
645
|
write_output(pivoted_objs, output_type, target=output)
|
|
532
646
|
|
|
533
647
|
|
|
648
|
+
@cli.command()
|
|
649
|
+
@click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
|
|
650
|
+
@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
|
|
651
|
+
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
652
|
+
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
653
|
+
@click.option("--sample-field", "-I", help="Field to use as the sample identifier")
|
|
654
|
+
@click.option("--classification-field", "-L", help="Field to use as for classification")
|
|
655
|
+
@click.option(
|
|
656
|
+
"--p-value-threshold",
|
|
657
|
+
"-P",
|
|
658
|
+
type=click.FLOAT,
|
|
659
|
+
default=0.05,
|
|
660
|
+
show_default=True,
|
|
661
|
+
help="P-value threshold for enrichment",
|
|
662
|
+
)
|
|
663
|
+
@click.option(
|
|
664
|
+
"--multiple-testing-correction",
|
|
665
|
+
"-M",
|
|
666
|
+
type=click.STRING,
|
|
667
|
+
default="bh",
|
|
668
|
+
show_default=True,
|
|
669
|
+
help="Multiple test correction method",
|
|
670
|
+
)
|
|
671
|
+
@click.argument("samples", type=click.STRING, nargs=-1)
|
|
672
|
+
@click.pass_context
|
|
673
|
+
def enrichment(ctx, where, limit, output_type, output, sample_field, classification_field, samples, **kwargs):
|
|
674
|
+
from linkml_store.utils.enrichment_analyzer import EnrichmentAnalyzer
|
|
675
|
+
|
|
676
|
+
collection = ctx.obj["settings"].collection
|
|
677
|
+
where_clause = yaml.safe_load(where) if where else None
|
|
678
|
+
column_atts = [sample_field, classification_field]
|
|
679
|
+
results = collection.find(where_clause, select_cols=column_atts, limit=-1)
|
|
680
|
+
df = results.rows_dataframe
|
|
681
|
+
ea = EnrichmentAnalyzer(df, sample_key=sample_field, classification_key=classification_field)
|
|
682
|
+
if not samples:
|
|
683
|
+
samples = df[sample_field].unique()
|
|
684
|
+
enrichment_results = []
|
|
685
|
+
for sample in samples:
|
|
686
|
+
enriched = ea.find_enriched_categories(sample, **kwargs)
|
|
687
|
+
for e in enriched:
|
|
688
|
+
obj = {"sample": sample, **e.model_dump()}
|
|
689
|
+
enrichment_results.append(obj)
|
|
690
|
+
output_data = render_output(enrichment_results, output_type)
|
|
691
|
+
if output:
|
|
692
|
+
with open(output, "w") as f:
|
|
693
|
+
f.write(output_data)
|
|
694
|
+
click.echo(f"Search results saved to {output}")
|
|
695
|
+
else:
|
|
696
|
+
click.echo(output_data)
|
|
697
|
+
|
|
698
|
+
|
|
534
699
|
@cli.command()
|
|
535
700
|
@click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
|
|
536
701
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
@@ -538,7 +703,7 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
|
|
|
538
703
|
@click.option(
|
|
539
704
|
"--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
|
|
540
705
|
)
|
|
541
|
-
@click.option("--training-collection", type=click.STRING,help="Collection to use for training")
|
|
706
|
+
@click.option("--training-collection", type=click.STRING, help="Collection to use for training")
|
|
542
707
|
@click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
|
|
543
708
|
@click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
|
|
544
709
|
@click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
|
|
@@ -753,12 +918,28 @@ def indexes(ctx):
|
|
|
753
918
|
@cli.command()
|
|
754
919
|
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
755
920
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
921
|
+
@click.option(
|
|
922
|
+
"--collection-only/--no-collection-only",
|
|
923
|
+
default=False,
|
|
924
|
+
show_default=True,
|
|
925
|
+
help="Only validate specified collection",
|
|
926
|
+
)
|
|
927
|
+
@click.option(
|
|
928
|
+
"--ensure-referential-integrity/--no-ensure-referential-integrity",
|
|
929
|
+
default=True,
|
|
930
|
+
show_default=True,
|
|
931
|
+
help="Ensure referential integrity",
|
|
932
|
+
)
|
|
756
933
|
@click.pass_context
|
|
757
|
-
def validate(ctx, output_type, output):
|
|
934
|
+
def validate(ctx, output_type, output, collection_only, **kwargs):
|
|
758
935
|
"""Validate objects in the specified collection."""
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
936
|
+
if collection_only:
|
|
937
|
+
collection = ctx.obj["settings"].collection
|
|
938
|
+
logger.info(f"Validating collection {collection.alias}")
|
|
939
|
+
validation_results = [json_dumper.to_dict(x) for x in collection.iter_validate_collection(**kwargs)]
|
|
940
|
+
else:
|
|
941
|
+
db = ctx.obj["settings"].database
|
|
942
|
+
validation_results = [json_dumper.to_dict(x) for x in db.validate_database(**kwargs)]
|
|
762
943
|
output_data = render_output(validation_results, output_type)
|
|
763
944
|
if output:
|
|
764
945
|
with open(output, "w") as f:
|
|
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
|
3
3
|
from typing import TYPE_CHECKING, List, Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
import openai
|
|
6
7
|
|
|
7
8
|
from linkml_store.api.config import CollectionConfig
|
|
8
9
|
from linkml_store.index.indexer import INDEX_ITEM, Indexer
|
|
@@ -11,6 +12,7 @@ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
import llm
|
|
13
14
|
|
|
15
|
+
CHUNK_SIZE = 1000
|
|
14
16
|
|
|
15
17
|
logger = logging.getLogger(__name__)
|
|
16
18
|
|
|
@@ -25,7 +27,7 @@ class LLMIndexer(Indexer):
|
|
|
25
27
|
>>> vector = indexer.text_to_vector("hello")
|
|
26
28
|
"""
|
|
27
29
|
|
|
28
|
-
embedding_model_name: str = "ada-002"
|
|
30
|
+
embedding_model_name: str = "text-embedding-ada-002"
|
|
29
31
|
_embedding_model: "llm.EmbeddingModel" = None
|
|
30
32
|
cached_embeddings_database: str = None
|
|
31
33
|
cached_embeddings_collection: str = None
|
|
@@ -52,7 +54,9 @@ class LLMIndexer(Indexer):
|
|
|
52
54
|
"""
|
|
53
55
|
return self.texts_to_vectors([text], cache=cache, **kwargs)[0]
|
|
54
56
|
|
|
55
|
-
def texts_to_vectors(
|
|
57
|
+
def texts_to_vectors(
|
|
58
|
+
self, texts: List[str], cache: bool = None, token_limit_penalty=0, **kwargs
|
|
59
|
+
) -> List[INDEX_ITEM]:
|
|
56
60
|
"""
|
|
57
61
|
Use LLM to embed.
|
|
58
62
|
|
|
@@ -60,18 +64,22 @@ class LLMIndexer(Indexer):
|
|
|
60
64
|
>>> vectors = indexer.texts_to_vectors(["hello", "goodbye"])
|
|
61
65
|
|
|
62
66
|
:param texts:
|
|
67
|
+
:param cache:
|
|
68
|
+
:param token_limit_penalty:
|
|
63
69
|
:return:
|
|
64
70
|
"""
|
|
65
71
|
from tiktoken import encoding_for_model
|
|
72
|
+
|
|
66
73
|
logging.info(f"Converting {len(texts)} texts to vectors")
|
|
67
74
|
model = self.embedding_model
|
|
68
75
|
# TODO: make this more accurate
|
|
69
|
-
token_limit = get_token_limit(model.model_id) -
|
|
70
|
-
|
|
76
|
+
token_limit = get_token_limit(model.model_id) - token_limit_penalty
|
|
77
|
+
logging.info(f"Token limit for {model.model_id}: {token_limit}")
|
|
78
|
+
encoding = encoding_for_model(self.embedding_model_name)
|
|
71
79
|
|
|
72
80
|
def truncate_text(text: str) -> str:
|
|
73
81
|
# split into tokens every 1000 chars:
|
|
74
|
-
parts = [text[i : i +
|
|
82
|
+
parts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
|
|
75
83
|
truncated = render_formatted_text(
|
|
76
84
|
lambda x: "".join(x),
|
|
77
85
|
parts,
|
|
@@ -140,5 +148,5 @@ class LLMIndexer(Indexer):
|
|
|
140
148
|
embeddings_collection.commit()
|
|
141
149
|
else:
|
|
142
150
|
logger.info(f"Embedding {len(texts)} texts")
|
|
143
|
-
embeddings = model.embed_multi(texts)
|
|
151
|
+
embeddings = list(model.embed_multi(texts, batch_size=1))
|
|
144
152
|
return [np.array(v, dtype=float) for v in embeddings]
|
linkml_store/index/indexer.py
CHANGED
|
@@ -154,8 +154,11 @@ class Indexer(BaseModel):
|
|
|
154
154
|
return str(obj)
|
|
155
155
|
|
|
156
156
|
def search(
|
|
157
|
-
self,
|
|
158
|
-
|
|
157
|
+
self,
|
|
158
|
+
query: str,
|
|
159
|
+
vectors: List[Tuple[str, INDEX_ITEM]],
|
|
160
|
+
limit: Optional[int] = None,
|
|
161
|
+
mmr_relevance_factor: Optional[float] = None,
|
|
159
162
|
) -> List[Tuple[float, Any]]:
|
|
160
163
|
"""
|
|
161
164
|
Use the indexer to search against a database of vectors.
|
|
@@ -175,8 +178,8 @@ class Indexer(BaseModel):
|
|
|
175
178
|
vlist = [v for _, v in vectors]
|
|
176
179
|
idlist = [id for id, _ in vectors]
|
|
177
180
|
sorted_indices = mmr_diversified_search(
|
|
178
|
-
query_vector, vlist,
|
|
179
|
-
|
|
181
|
+
query_vector, vlist, relevance_factor=mmr_relevance_factor, top_n=limit
|
|
182
|
+
)
|
|
180
183
|
results = []
|
|
181
184
|
# TODO: this is inefficient when limit is high
|
|
182
185
|
for i in range(limit):
|
|
@@ -79,21 +79,24 @@ class LLMInferenceEngine(InferenceEngine):
|
|
|
79
79
|
def _schema_str(self) -> str:
|
|
80
80
|
db = self.training_data.base_collection.parent
|
|
81
81
|
from linkml_runtime.dumpers import json_dumper
|
|
82
|
+
|
|
82
83
|
schema_dict = json_dumper.to_dict(db.schema_view.schema)
|
|
83
84
|
return yaml.dump(schema_dict)
|
|
84
85
|
|
|
85
|
-
def derive(
|
|
86
|
+
def derive(
|
|
87
|
+
self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
|
|
88
|
+
) -> Optional[LLMInference]:
|
|
86
89
|
import llm
|
|
87
90
|
|
|
88
91
|
model: llm.Model = self.model
|
|
89
|
-
#model_name = self.config.llm_config.model_name
|
|
90
|
-
#feature_attributes = self.config.feature_attributes
|
|
92
|
+
# model_name = self.config.llm_config.model_name
|
|
93
|
+
# feature_attributes = self.config.feature_attributes
|
|
91
94
|
target_attributes = self.config.target_attributes
|
|
92
95
|
query_text = self.object_to_text(object)
|
|
93
96
|
|
|
94
97
|
if not target_attributes:
|
|
95
98
|
target_attributes = [k for k, v in object.items() if v is None or v == ""]
|
|
96
|
-
#if not feature_attributes:
|
|
99
|
+
# if not feature_attributes:
|
|
97
100
|
# feature_attributes = [k for k, v in object.items() if v is not None and v != ""]
|
|
98
101
|
|
|
99
102
|
system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
|
|
@@ -107,7 +110,9 @@ class LLMInferenceEngine(InferenceEngine):
|
|
|
107
110
|
"```yaml\n"
|
|
108
111
|
f"{stub}\n"
|
|
109
112
|
"```\n"
|
|
110
|
-
"---\nQuery:\n"
|
|
113
|
+
"---\nQuery:\n"
|
|
114
|
+
f"## INCOMPLETE OBJECT:\n{query_text}\n"
|
|
115
|
+
"## OUTPUT:\n"
|
|
111
116
|
)
|
|
112
117
|
logger.info(f"Prompt: {prompt}")
|
|
113
118
|
response = model.prompt(prompt, system=system_prompt)
|
|
@@ -130,9 +135,8 @@ class LLMInferenceEngine(InferenceEngine):
|
|
|
130
135
|
"\nThis was invalid.\n",
|
|
131
136
|
"Validation errors:\n",
|
|
132
137
|
] + [self.object_to_text(e) for e in errs]
|
|
133
|
-
return self.derive(object, iteration=iteration+1, additional_prompt_texts=extra_texts)
|
|
134
|
-
return LLMInference(predicted_object=predicted_object, iterations=iteration+1, query=object)
|
|
135
|
-
|
|
138
|
+
return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
|
|
139
|
+
return LLMInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
|
|
136
140
|
|
|
137
141
|
def export_model(
|
|
138
142
|
self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
|
|
@@ -149,4 +153,4 @@ class LLMInferenceEngine(InferenceEngine):
|
|
|
149
153
|
|
|
150
154
|
@classmethod
|
|
151
155
|
def load_model(cls, file_path: Union[str, Path]) -> "LLMInferenceEngine":
|
|
152
|
-
raise NotImplementedError("Does not make sense for this engine")
|
|
156
|
+
raise NotImplementedError("Does not make sense for this engine")
|
|
@@ -111,7 +111,9 @@ class RAGInferenceEngine(InferenceEngine):
|
|
|
111
111
|
def object_to_text(self, object: OBJECT) -> str:
|
|
112
112
|
return yaml.dump(object)
|
|
113
113
|
|
|
114
|
-
def derive(
|
|
114
|
+
def derive(
|
|
115
|
+
self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
|
|
116
|
+
) -> Optional[RAGInference]:
|
|
115
117
|
import llm
|
|
116
118
|
from tiktoken import encoding_for_model
|
|
117
119
|
|
|
@@ -131,8 +133,9 @@ class RAGInferenceEngine(InferenceEngine):
|
|
|
131
133
|
if not self.rag_collection.indexers:
|
|
132
134
|
raise ValueError("RAG collection must have an indexer attached")
|
|
133
135
|
logger.info(f"Searching {self.rag_collection.alias} for examples for: {query_text}")
|
|
134
|
-
rs = self.rag_collection.search(
|
|
135
|
-
|
|
136
|
+
rs = self.rag_collection.search(
|
|
137
|
+
query_text, limit=num_examples, index_name="llm", mmr_relevance_factor=mmr_relevance_factor
|
|
138
|
+
)
|
|
136
139
|
examples = rs.rows
|
|
137
140
|
logger.info(f"Found {len(examples)} examples")
|
|
138
141
|
if not examples:
|
|
@@ -153,11 +156,11 @@ class RAGInferenceEngine(InferenceEngine):
|
|
|
153
156
|
input_obj_text = self.object_to_text(input_obj)
|
|
154
157
|
if input_obj_text == query_text:
|
|
155
158
|
continue
|
|
156
|
-
#raise ValueError(
|
|
159
|
+
# raise ValueError(
|
|
157
160
|
# f"Query object {query_text} is the same as example object {input_obj_text}\n"
|
|
158
161
|
# "This indicates possible test data leakage\n."
|
|
159
162
|
# "TODO: allow an option that allows user to treat this as a basic lookup\n"
|
|
160
|
-
#)
|
|
163
|
+
# )
|
|
161
164
|
output_obj = select_nested(example, target_attributes)
|
|
162
165
|
prompt_clause = (
|
|
163
166
|
"---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
|
|
@@ -176,9 +179,9 @@ class RAGInferenceEngine(InferenceEngine):
|
|
|
176
179
|
except KeyError:
|
|
177
180
|
encoding = encoding_for_model("gpt-4")
|
|
178
181
|
token_limit = get_token_limit(model_name)
|
|
179
|
-
prompt = render_formatted_text(
|
|
180
|
-
|
|
181
|
-
|
|
182
|
+
prompt = render_formatted_text(
|
|
183
|
+
make_text, values=prompt_clauses, encoding=encoding, token_limit=token_limit, additional_text=system_prompt
|
|
184
|
+
)
|
|
182
185
|
logger.info(f"Prompt: {prompt}")
|
|
183
186
|
response = model.prompt(prompt, system=system_prompt)
|
|
184
187
|
yaml_str = response.text()
|
|
@@ -199,8 +202,8 @@ class RAGInferenceEngine(InferenceEngine):
|
|
|
199
202
|
"\nThis was invalid.\n",
|
|
200
203
|
"Validation errors:\n",
|
|
201
204
|
] + [self.object_to_text(e) for e in errs]
|
|
202
|
-
return self.derive(object, iteration=iteration+1, additional_prompt_texts=extra_texts)
|
|
203
|
-
return RAGInference(predicted_object=predicted_object, iterations=iteration+1, query=object)
|
|
205
|
+
return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
|
|
206
|
+
return RAGInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
|
|
204
207
|
|
|
205
208
|
def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
|
|
206
209
|
if "```" in yaml_str:
|
|
@@ -94,6 +94,8 @@ class SklearnInferenceEngine(InferenceEngine):
|
|
|
94
94
|
if not feature_cols:
|
|
95
95
|
feature_cols = df.columns.difference(target_cols).tolist()
|
|
96
96
|
self.config.feature_attributes = feature_cols
|
|
97
|
+
if not feature_cols:
|
|
98
|
+
raise ValueError("No features found in the data")
|
|
97
99
|
target_col = target_cols[0]
|
|
98
100
|
logger.info(f"Feature columns: {feature_cols}")
|
|
99
101
|
X = df[feature_cols].copy()
|
|
@@ -102,6 +104,8 @@ class SklearnInferenceEngine(InferenceEngine):
|
|
|
102
104
|
|
|
103
105
|
# find list of features to skip (categorical with > N categories)
|
|
104
106
|
skip_features = []
|
|
107
|
+
if not len(X.columns):
|
|
108
|
+
raise ValueError("No features to train on")
|
|
105
109
|
for col in X.columns:
|
|
106
110
|
unique_values = self._get_unique_values(X[col])
|
|
107
111
|
if len(unique_values) > self.maximum_proportion_distinct_features * len(X[col]):
|
|
@@ -115,6 +119,8 @@ class SklearnInferenceEngine(InferenceEngine):
|
|
|
115
119
|
|
|
116
120
|
# Encode features
|
|
117
121
|
encoded_features = []
|
|
122
|
+
if not len(X.columns):
|
|
123
|
+
raise ValueError(f"No features to train on from after skipping {skip_features}")
|
|
118
124
|
for col in X.columns:
|
|
119
125
|
logger.info(f"Checking whether to encode: {col}")
|
|
120
126
|
col_encoder = self._get_encoder(X[col])
|
|
@@ -153,7 +159,7 @@ class SklearnInferenceEngine(InferenceEngine):
|
|
|
153
159
|
y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
|
|
154
160
|
self.transformed_targets = y_encoder.classes_
|
|
155
161
|
|
|
156
|
-
# print(f"Fitting model with features: {X.columns}")
|
|
162
|
+
# print(f"Fitting model with features: {X.columns}, y={y}, X={X}")
|
|
157
163
|
clf = DecisionTreeClassifier(random_state=42)
|
|
158
164
|
clf.fit(X, y)
|
|
159
165
|
self.classifier = clf
|
|
@@ -59,6 +59,7 @@ class Inference(BaseModel, extra="forbid"):
|
|
|
59
59
|
"""
|
|
60
60
|
Result of an inference derivation.
|
|
61
61
|
"""
|
|
62
|
+
|
|
62
63
|
query: Optional[OBJECT] = Field(default=None, description="The query object.")
|
|
63
64
|
predicted_object: OBJECT = Field(..., description="The predicted object.")
|
|
64
65
|
confidence: Optional[float] = Field(default=None, description="The confidence of the prediction.", le=1.0, ge=0.0)
|