linkml-store 0.2.6__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (35) hide show
  1. linkml_store/api/client.py +2 -3
  2. linkml_store/api/collection.py +63 -8
  3. linkml_store/api/database.py +20 -3
  4. linkml_store/api/stores/duckdb/duckdb_collection.py +168 -4
  5. linkml_store/api/stores/duckdb/duckdb_database.py +5 -5
  6. linkml_store/api/stores/filesystem/__init__.py +1 -1
  7. linkml_store/api/stores/filesystem/filesystem_database.py +1 -1
  8. linkml_store/api/stores/mongodb/mongodb_collection.py +132 -15
  9. linkml_store/api/stores/mongodb/mongodb_database.py +2 -1
  10. linkml_store/api/stores/neo4j/neo4j_database.py +1 -1
  11. linkml_store/api/stores/solr/solr_collection.py +107 -18
  12. linkml_store/cli.py +201 -21
  13. linkml_store/index/implementations/llm_indexer.py +13 -6
  14. linkml_store/index/indexer.py +9 -5
  15. linkml_store/inference/implementations/llm_inference_engine.py +15 -13
  16. linkml_store/inference/implementations/rag_inference_engine.py +13 -10
  17. linkml_store/inference/implementations/sklearn_inference_engine.py +7 -1
  18. linkml_store/inference/inference_config.py +2 -1
  19. linkml_store/inference/inference_engine.py +1 -1
  20. linkml_store/plotting/__init__.py +5 -0
  21. linkml_store/plotting/cli.py +172 -0
  22. linkml_store/plotting/heatmap.py +356 -0
  23. linkml_store/utils/dat_parser.py +95 -0
  24. linkml_store/utils/enrichment_analyzer.py +217 -0
  25. linkml_store/utils/format_utils.py +124 -3
  26. linkml_store/utils/llm_utils.py +4 -2
  27. linkml_store/utils/object_utils.py +9 -3
  28. linkml_store/utils/pandas_utils.py +1 -1
  29. linkml_store/utils/sql_utils.py +1 -1
  30. linkml_store/utils/vector_utils.py +3 -10
  31. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/METADATA +3 -1
  32. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/RECORD +35 -30
  33. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/WHEEL +1 -1
  34. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/LICENSE +0 -0
  35. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,7 @@ class Neo4jDatabase(Database):
27
27
  if handle is None:
28
28
  handle = "bolt://localhost:7687/neo4j"
29
29
  if handle.startswith("neo4j:"):
30
- handle = handle.replace("neo4j:", "bolt:")
30
+ handle = handle.replace("neo4j:", "bolt:", 1)
31
31
  super().__init__(handle=handle, **kwargs)
32
32
 
33
33
  @property
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
  from copy import copy
5
- from typing import Any, Dict, List, Optional, Union
5
+ from typing import Any, Dict, List, Optional, Union, Tuple
6
6
 
7
7
  import requests
8
8
 
@@ -56,32 +56,121 @@ class SolrCollection(Collection):
56
56
  response.raise_for_status()
57
57
 
58
58
  data = response.json()
59
+ logger.debug(f"Response: {data}")
59
60
  num_rows = data["response"]["numFound"]
60
61
  rows = data["response"]["docs"]
61
62
 
62
63
  return QueryResult(query=query, num_rows=num_rows, rows=rows)
63
64
 
64
65
  def query_facets(
65
- self, where: Optional[Dict] = None, facet_columns: List[str] = None, facet_limit=DEFAULT_FACET_LIMIT, **kwargs
66
- ) -> Dict[str, Dict[str, int]]:
66
+ self,
67
+ where: Optional[Dict] = None,
68
+ facet_columns: List[Union[str, Tuple[str, ...]]] = None,
69
+ facet_limit=DEFAULT_FACET_LIMIT,
70
+ facet_min_count: int = 1,
71
+ **kwargs,
72
+ ) -> Dict[Union[str, Tuple[str, ...]], List[Tuple[Any, int]]]:
73
+ """
74
+ Query facet counts for fields or field combinations.
75
+
76
+ :param where: Filter conditions
77
+ :param facet_columns: List of fields to facet on. Elements can be:
78
+ - Simple strings for single field facets
79
+ - Tuples of strings for field combinations (pivot facets)
80
+ :param facet_limit: Maximum number of facet values to return
81
+ :param facet_min_count: Minimum count for facet values to be included
82
+ :return: Dictionary mapping fields or field tuples to lists of (value, count) tuples
83
+ """
67
84
  solr_query = self._build_solr_query(where)
68
- solr_query["facet"] = "true"
69
- solr_query["facet.field"] = facet_columns
70
- solr_query["facet.limit"] = facet_limit
71
-
72
- logger.info(f"Querying Solr collection {self.alias} for facets with query: {solr_query}")
73
-
74
- response = requests.get(f"{self._collection_base}/select", params=solr_query)
75
- response.raise_for_status()
76
-
77
- data = response.json()
78
- facet_counts = data["facet_counts"]["facet_fields"]
79
-
85
+
86
+ # Separate single fields and tuple fields
87
+ single_fields = []
88
+ tuple_fields = []
89
+
90
+ if facet_columns:
91
+ for field in facet_columns:
92
+ if isinstance(field, str):
93
+ single_fields.append(field)
94
+ elif isinstance(field, tuple):
95
+ tuple_fields.append(field)
96
+
97
+ # Process regular facets
80
98
  results = {}
81
- for facet_field, counts in facet_counts.items():
82
- results[facet_field] = list(zip(counts[::2], counts[1::2]))
83
-
99
+ if single_fields:
100
+ solr_query["facet"] = "true"
101
+ solr_query["facet.field"] = single_fields
102
+ solr_query["facet.limit"] = facet_limit
103
+ solr_query["facet.mincount"] = facet_min_count
104
+
105
+ logger.info(f"Querying Solr collection {self.alias} for facets with query: {solr_query}")
106
+ response = requests.get(f"{self._collection_base}/select", params=solr_query)
107
+ response.raise_for_status()
108
+
109
+ data = response.json()
110
+ facet_counts = data["facet_counts"]["facet_fields"]
111
+
112
+ for facet_field, counts in facet_counts.items():
113
+ results[facet_field] = list(zip(counts[::2], counts[1::2]))
114
+
115
+ # Process pivot facets for tuple fields
116
+ if tuple_fields:
117
+ # TODO: Add a warning if Solr < 4.0, when this was introduced
118
+ for field_tuple in tuple_fields:
119
+ # Create a query for this specific field tuple
120
+ pivot_query = self._build_solr_query(where)
121
+ pivot_query["facet"] = "true"
122
+
123
+ # Create pivot facet
124
+ field_str = ','.join(field_tuple)
125
+ pivot_query["facet.pivot"] = field_str
126
+ pivot_query["facet.pivot.mincount"] = facet_min_count
127
+ pivot_query["facet.limit"] = facet_limit
128
+
129
+ logger.info(f"Querying Solr collection {self.alias} for pivot facets with query: {pivot_query}")
130
+ response = requests.get(f"{self._collection_base}/select", params=pivot_query)
131
+ response.raise_for_status()
132
+
133
+ data = response.json()
134
+ pivot_facets = data.get("facet_counts", {}).get("facet_pivot", {})
135
+
136
+ # Process pivot facets into the same format as MongoDB results
137
+ field_str = ','.join(field_tuple)
138
+ pivot_data = pivot_facets.get(field_str, [])
139
+
140
+ # Build a list of tuples (field values, count)
141
+ pivot_results = []
142
+ self._process_pivot_facets(pivot_data, [], pivot_results, field_tuple)
143
+
144
+ results[field_tuple] = pivot_results
145
+
84
146
  return results
147
+
148
+ def _process_pivot_facets(self, pivot_data, current_values, results, field_tuple):
149
+ """
150
+ Recursively process pivot facet results to extract combinations of field values.
151
+
152
+ :param pivot_data: The pivot facet data from Solr
153
+ :param current_values: The current path of values in the recursion
154
+ :param results: The result list to populate
155
+ :param field_tuple: The original field tuple for reference
156
+ """
157
+ for item in pivot_data:
158
+ # Add the current field value
159
+ value = item.get("value")
160
+ count = item.get("count", 0)
161
+
162
+ # Update the current path with this value
163
+ values = current_values + [value]
164
+
165
+ # If we have all the fields from the tuple, add a result
166
+ if len(values) == len(field_tuple):
167
+ # Create a tuple of values corresponding to the field tuple
168
+ results.append((tuple(values), count))
169
+
170
+ # Process child pivot fields recursively
171
+ pivot = item.get("pivot", [])
172
+ if pivot and len(values) < len(field_tuple):
173
+ self._process_pivot_facets(pivot, values, results, field_tuple)
85
174
 
86
175
  def _build_solr_query(
87
176
  self, query: Union[Query, Dict], search_term="*:*", extra: Optional[Dict] = None
linkml_store/cli.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
3
  import warnings
4
4
  from collections import defaultdict
5
5
  from pathlib import Path
6
- from typing import Optional, Tuple, Any
6
+ from typing import Any, Optional, Tuple
7
7
 
8
8
  import click
9
9
  import yaml
@@ -37,6 +37,11 @@ index_type_option = click.option(
37
37
  show_default=True,
38
38
  help="Type of index to create. Values: simple, llm",
39
39
  )
40
+ json_select_query_option = click.option(
41
+ "--json-select-query",
42
+ "-J",
43
+ help="JSON SELECT query",
44
+ )
40
45
 
41
46
  logger = logging.getLogger(__name__)
42
47
 
@@ -136,7 +141,7 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
136
141
  logger.setLevel(logging.ERROR)
137
142
  ctx.ensure_object(dict)
138
143
  if input:
139
- database = "duckdb" # default: store in duckdb
144
+ database = "duckdb" # default: store in duckdb
140
145
  if input.startswith("http"):
141
146
  parts = input.split("/")
142
147
  collection = parts[-1]
@@ -144,8 +149,7 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
144
149
  else:
145
150
  stem = underscore(Path(input).stem)
146
151
  collection = stem
147
- logger.info(f"Using input file: {input}, "
148
- f"default storage is {database} and collection is {collection}")
152
+ logger.info(f"Using input file: {input}, " f"default storage is {database} and collection is {collection}")
149
153
  config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
150
154
  if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
151
155
  config = DEFAULT_LOCAL_CONF_PATH
@@ -185,13 +189,25 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
185
189
  settings.collection_name = collection.alias
186
190
 
187
191
 
192
+ @cli.command()
193
+ @click.pass_context
194
+ def drop(ctx):
195
+ """
196
+ Drop database and all its collections.
197
+ """
198
+ database = ctx.obj["settings"].database
199
+ database.drop()
200
+
201
+
188
202
  @cli.command()
189
203
  @click.argument("files", type=click.Path(), nargs=-1)
190
204
  @click.option("--replace/--no-replace", default=False, show_default=True, help="Replace existing objects")
191
205
  @click.option("--format", "-f", type=format_choice, help="Input format")
192
206
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
207
+ @click.option("--source-field", help="If provided, inject file path source as this field")
208
+ @json_select_query_option
193
209
  @click.pass_context
194
- def insert(ctx, files, replace, object, format):
210
+ def insert(ctx, files, replace, object, format, source_field, json_select_query):
195
211
  """Insert objects from files (JSON, YAML, TSV) into the specified collection.
196
212
 
197
213
  Using a configuration:
@@ -207,11 +223,17 @@ def insert(ctx, files, replace, object, format):
207
223
  raise ValueError("Collection must be specified.")
208
224
  if not files and not object:
209
225
  files = ["-"]
226
+ load_objects_args = {}
227
+ if json_select_query:
228
+ load_objects_args["select_query"] = json_select_query
210
229
  for file_path in files:
211
230
  if format:
212
- objects = load_objects(file_path, format=format)
231
+ objects = load_objects(file_path, format=format, **load_objects_args)
213
232
  else:
214
- objects = load_objects(file_path)
233
+ objects = load_objects(file_path, **load_objects_args)
234
+ if source_field:
235
+ for obj in objects:
236
+ obj[source_field] = str(file_path)
215
237
  logger.info(f"Inserting {len(objects)} objects from {file_path} into collection '{collection.alias}'.")
216
238
  if replace:
217
239
  collection.replace(objects)
@@ -222,6 +244,8 @@ def insert(ctx, files, replace, object, format):
222
244
  for object_str in object:
223
245
  logger.info(f"Parsing: {object_str}")
224
246
  objects = yaml.safe_load(object_str)
247
+ if not isinstance(objects, list):
248
+ objects = [objects]
225
249
  if replace:
226
250
  collection.replace(objects)
227
251
  else:
@@ -234,21 +258,41 @@ def insert(ctx, files, replace, object, format):
234
258
  @click.argument("files", type=click.Path(exists=True), nargs=-1)
235
259
  @click.option("--format", "-f", type=format_choice, help="Input format")
236
260
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
261
+ @json_select_query_option
237
262
  @click.pass_context
238
- def store(ctx, files, object, format):
263
+ def store(ctx, files, object, format, json_select_query):
239
264
  """Store objects from files (JSON, YAML, TSV) into the database.
240
265
 
241
- Note: this is similar to insert, but a collection does not need to be specified
266
+ Note: this is similar to insert, but a collection does not need to be specified.
267
+
268
+ For example, assume that `my-collection` is a dict with multiple keys,
269
+ and we want one collection per key:
270
+
271
+ linkml-store -d my.ddb store my-collection.yaml
272
+
273
+ Loading JSON (e.g OBO-JSON), with a --json-select-query:
274
+
275
+ linkml-store -d cl.ddb store -J graphs cl.obo.json
276
+
277
+ Loading XML (e.g OWL-XML), with a --json-select-query:
278
+
279
+ linkml-store -d cl.ddb store -J Ontology cl.owx
280
+
281
+ Because the XML uses a top level Ontology, with multiple
282
+
242
283
  """
243
284
  settings = ctx.obj["settings"]
244
285
  db = settings.database
245
286
  if not files and not object:
246
287
  files = ["-"]
288
+ load_objects_args = {}
289
+ if json_select_query:
290
+ load_objects_args["select_query"] = json_select_query
247
291
  for file_path in files:
248
292
  if format:
249
- objects = load_objects(file_path, format=format)
293
+ objects = load_objects(file_path, format=format, **load_objects_args)
250
294
  else:
251
- objects = load_objects(file_path)
295
+ objects = load_objects(file_path, **load_objects_args)
252
296
  logger.info(f"Inserting {len(objects)} objects from {file_path} into database '{db}'.")
253
297
  for obj in objects:
254
298
  db.store(obj)
@@ -422,15 +466,32 @@ def list_collections(ctx, **kwargs):
422
466
 
423
467
  @cli.command()
424
468
  @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
425
- @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
469
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return per facet")
470
+ @click.option("--facet-min-count", "-M", type=click.INT, help="Minimum count for a facet to be included")
426
471
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
427
472
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
428
- @click.option("--columns", "-S", help="Columns to facet on")
473
+ @click.option("--columns", "-S", help="Columns to facet on. Comma-separated, join combined facets with +")
429
474
  @click.option("--wide/--no-wide", "-U/--no-U", default=False, show_default=True, help="Wide table")
430
475
  @click.pass_context
431
- def fq(ctx, where, limit, columns, output_type, wide, output):
476
+ def fq(ctx, where, limit, columns, output_type, wide, output, **kwargs):
432
477
  """
433
- Query facets from the specified collection.
478
+ Query facet counts from the specified collection.
479
+
480
+ Assuming your .linkml.yaml includes an entry mapping `phenopackets` to a
481
+ mongodb
482
+
483
+ Facet counts (all columns)
484
+
485
+ linkml-store -d phenopackets fq
486
+
487
+ Nested columns:
488
+
489
+ linkml-store -d phenopackets fq subject.timeAtLastEncounter.age
490
+
491
+ Compound keys:
492
+
493
+ linkml-store -d phenopackets fq subject.sex+subject.timeAtLastEncounter.age
494
+
434
495
  """
435
496
  collection = ctx.obj["settings"].collection
436
497
  where_clause = yaml.safe_load(where) if where else None
@@ -439,7 +500,7 @@ def fq(ctx, where, limit, columns, output_type, wide, output):
439
500
  columns = [col.strip() for col in columns]
440
501
  columns = [(tuple(col.split("+")) if "+" in col else col) for col in columns]
441
502
  logger.info(f"Faceting on columns: {columns}")
442
- results = collection.query_facets(where_clause, facet_columns=columns, limit=limit)
503
+ results = collection.query_facets(where_clause, facet_columns=columns, facet_limit=limit, **kwargs)
443
504
  logger.info(f"Facet results: {results}")
444
505
 
445
506
  def _untuple(key):
@@ -471,6 +532,56 @@ def fq(ctx, where, limit, columns, output_type, wide, output):
471
532
  click.echo(output_data)
472
533
 
473
534
 
535
+ @cli.command()
536
+ @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
537
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return per facet")
538
+ @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
539
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
540
+ @click.option("--columns", "-S", help="Columns to facet on. Comma-separated, join combined facets with +")
541
+ @click.pass_context
542
+ def groupby(ctx, where, limit, columns, output_type, output, **kwargs):
543
+ """
544
+ Group by columns in the specified collection.
545
+
546
+ Assume a simple triple model:
547
+
548
+ linkml-store -d cl.ddb -c triple insert cl.owl
549
+
550
+ This makes a flat subject/predicate/object table
551
+
552
+ This can be grouped, e.g by subject:
553
+
554
+ linkml-store -d cl.ddb -c triple groupby -s subject
555
+
556
+ Or subject and predicate:
557
+
558
+ linkml-store -d cl.ddb -c triple groupby -s '[subject,predicate]'
559
+
560
+ """
561
+ collection = ctx.obj["settings"].collection
562
+ where_clause = yaml.safe_load(where) if where else None
563
+ columns = columns.split(",") if columns else None
564
+ if columns:
565
+ columns = [col.strip() for col in columns]
566
+ columns = [(tuple(col.split("+")) if "+" in col else col) for col in columns]
567
+ logger.info(f"Group by: {columns}")
568
+ result = collection.group_by(
569
+ group_by_fields=columns,
570
+ where_clause=where_clause,
571
+ agg_map={},
572
+ limit=limit,
573
+ **kwargs,
574
+ )
575
+ logger.info(f"Group by results: {result}")
576
+ output_data = render_output(result.rows, output_type)
577
+ if output:
578
+ with open(output, "w") as f:
579
+ f.write(output_data)
580
+ click.echo(f"Query results saved to {output}")
581
+ else:
582
+ click.echo(output_data)
583
+
584
+
474
585
  def _get_index(index_type=None, **kwargs) -> Indexer:
475
586
  if index_type is None or index_type == "simple":
476
587
  return SimpleIndexer(name="test", **kwargs)
@@ -519,10 +630,12 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
519
630
  value_key = tuple([row.get(att) for att in value_atts])
520
631
  pivoted[index_key][column_key] = value_key
521
632
  pivoted_objs = []
633
+
522
634
  def detuple(t: Tuple) -> Any:
523
635
  if len(t) == 1:
524
636
  return t[0]
525
637
  return str(t)
638
+
526
639
  for index_key, data in pivoted.items():
527
640
  obj = {att: key for att, key in zip(index_atts, index_key)}
528
641
  for column_key, value_key in data.items():
@@ -531,6 +644,57 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
531
644
  write_output(pivoted_objs, output_type, target=output)
532
645
 
533
646
 
647
+ @cli.command()
648
+ @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
649
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
650
+ @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
651
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
652
+ @click.option("--sample-field", "-I", help="Field to use as the sample identifier")
653
+ @click.option("--classification-field", "-L", help="Field to use as for classification")
654
+ @click.option(
655
+ "--p-value-threshold",
656
+ "-P",
657
+ type=click.FLOAT,
658
+ default=0.05,
659
+ show_default=True,
660
+ help="P-value threshold for enrichment",
661
+ )
662
+ @click.option(
663
+ "--multiple-testing-correction",
664
+ "-M",
665
+ type=click.STRING,
666
+ default="bh",
667
+ show_default=True,
668
+ help="Multiple test correction method",
669
+ )
670
+ @click.argument("samples", type=click.STRING, nargs=-1)
671
+ @click.pass_context
672
+ def enrichment(ctx, where, limit, output_type, output, sample_field, classification_field, samples, **kwargs):
673
+ from linkml_store.utils.enrichment_analyzer import EnrichmentAnalyzer
674
+
675
+ collection = ctx.obj["settings"].collection
676
+ where_clause = yaml.safe_load(where) if where else None
677
+ column_atts = [sample_field, classification_field]
678
+ results = collection.find(where_clause, select_cols=column_atts, limit=-1)
679
+ df = results.rows_dataframe
680
+ ea = EnrichmentAnalyzer(df, sample_key=sample_field, classification_key=classification_field)
681
+ if not samples:
682
+ samples = df[sample_field].unique()
683
+ enrichment_results = []
684
+ for sample in samples:
685
+ enriched = ea.find_enriched_categories(sample, **kwargs)
686
+ for e in enriched:
687
+ obj = {"sample": sample, **e.model_dump()}
688
+ enrichment_results.append(obj)
689
+ output_data = render_output(enrichment_results, output_type)
690
+ if output:
691
+ with open(output, "w") as f:
692
+ f.write(output_data)
693
+ click.echo(f"Search results saved to {output}")
694
+ else:
695
+ click.echo(output_data)
696
+
697
+
534
698
  @cli.command()
535
699
  @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
536
700
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
@@ -538,7 +702,7 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
538
702
  @click.option(
539
703
  "--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
540
704
  )
541
- @click.option("--training-collection", type=click.STRING,help="Collection to use for training")
705
+ @click.option("--training-collection", type=click.STRING, help="Collection to use for training")
542
706
  @click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
543
707
  @click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
544
708
  @click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
@@ -753,12 +917,28 @@ def indexes(ctx):
753
917
  @cli.command()
754
918
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
755
919
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
920
+ @click.option(
921
+ "--collection-only/--no-collection-only",
922
+ default=False,
923
+ show_default=True,
924
+ help="Only validate specified collection",
925
+ )
926
+ @click.option(
927
+ "--ensure-referential-integrity/--no-ensure-referential-integrity",
928
+ default=True,
929
+ show_default=True,
930
+ help="Ensure referential integrity",
931
+ )
756
932
  @click.pass_context
757
- def validate(ctx, output_type, output):
933
+ def validate(ctx, output_type, output, collection_only, **kwargs):
758
934
  """Validate objects in the specified collection."""
759
- collection = ctx.obj["settings"].collection
760
- logger.info(f"Validating collection {collection.alias}")
761
- validation_results = [json_dumper.to_dict(x) for x in collection.iter_validate_collection()]
935
+ if collection_only:
936
+ collection = ctx.obj["settings"].collection
937
+ logger.info(f"Validating collection {collection.alias}")
938
+ validation_results = [json_dumper.to_dict(x) for x in collection.iter_validate_collection(**kwargs)]
939
+ else:
940
+ db = ctx.obj["settings"].database
941
+ validation_results = [json_dumper.to_dict(x) for x in db.validate_database(**kwargs)]
762
942
  output_data = render_output(validation_results, output_type)
763
943
  if output:
764
944
  with open(output, "w") as f:
@@ -11,6 +11,7 @@ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
11
11
  if TYPE_CHECKING:
12
12
  import llm
13
13
 
14
+ CHUNK_SIZE = 1000
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
@@ -25,7 +26,7 @@ class LLMIndexer(Indexer):
25
26
  >>> vector = indexer.text_to_vector("hello")
26
27
  """
27
28
 
28
- embedding_model_name: str = "ada-002"
29
+ embedding_model_name: str = "text-embedding-ada-002"
29
30
  _embedding_model: "llm.EmbeddingModel" = None
30
31
  cached_embeddings_database: str = None
31
32
  cached_embeddings_collection: str = None
@@ -52,7 +53,9 @@ class LLMIndexer(Indexer):
52
53
  """
53
54
  return self.texts_to_vectors([text], cache=cache, **kwargs)[0]
54
55
 
55
- def texts_to_vectors(self, texts: List[str], cache: bool = None, **kwargs) -> List[INDEX_ITEM]:
56
+ def texts_to_vectors(
57
+ self, texts: List[str], cache: bool = None, token_limit_penalty=0, **kwargs
58
+ ) -> List[INDEX_ITEM]:
56
59
  """
57
60
  Use LLM to embed.
58
61
 
@@ -60,18 +63,22 @@ class LLMIndexer(Indexer):
60
63
  >>> vectors = indexer.texts_to_vectors(["hello", "goodbye"])
61
64
 
62
65
  :param texts:
66
+ :param cache:
67
+ :param token_limit_penalty:
63
68
  :return:
64
69
  """
65
70
  from tiktoken import encoding_for_model
71
+
66
72
  logging.info(f"Converting {len(texts)} texts to vectors")
67
73
  model = self.embedding_model
68
74
  # TODO: make this more accurate
69
- token_limit = get_token_limit(model.model_id) - 200
70
- encoding = encoding_for_model("gpt-4o")
75
+ token_limit = get_token_limit(model.model_id) - token_limit_penalty
76
+ logging.info(f"Token limit for {model.model_id}: {token_limit}")
77
+ encoding = encoding_for_model(self.embedding_model_name)
71
78
 
72
79
  def truncate_text(text: str) -> str:
73
80
  # split into tokens every 1000 chars:
74
- parts = [text[i : i + 1000] for i in range(0, len(text), 1000)]
81
+ parts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
75
82
  truncated = render_formatted_text(
76
83
  lambda x: "".join(x),
77
84
  parts,
@@ -140,5 +147,5 @@ class LLMIndexer(Indexer):
140
147
  embeddings_collection.commit()
141
148
  else:
142
149
  logger.info(f"Embedding {len(texts)} texts")
143
- embeddings = model.embed_multi(texts)
150
+ embeddings = list(model.embed_multi(texts, batch_size=1))
144
151
  return [np.array(v, dtype=float) for v in embeddings]
@@ -3,9 +3,10 @@ from enum import Enum
3
3
  from typing import Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
- from linkml_store.utils.vector_utils import pairwise_cosine_similarity, mmr_diversified_search
7
6
  from pydantic import BaseModel
8
7
 
8
+ from linkml_store.utils.vector_utils import mmr_diversified_search, pairwise_cosine_similarity
9
+
9
10
  INDEX_ITEM = np.ndarray
10
11
 
11
12
  logger = logging.getLogger(__name__)
@@ -154,8 +155,11 @@ class Indexer(BaseModel):
154
155
  return str(obj)
155
156
 
156
157
  def search(
157
- self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None,
158
- mmr_relevance_factor: Optional[float] = None
158
+ self,
159
+ query: str,
160
+ vectors: List[Tuple[str, INDEX_ITEM]],
161
+ limit: Optional[int] = None,
162
+ mmr_relevance_factor: Optional[float] = None,
159
163
  ) -> List[Tuple[float, Any]]:
160
164
  """
161
165
  Use the indexer to search against a database of vectors.
@@ -175,8 +179,8 @@ class Indexer(BaseModel):
175
179
  vlist = [v for _, v in vectors]
176
180
  idlist = [id for id, _ in vectors]
177
181
  sorted_indices = mmr_diversified_search(
178
- query_vector, vlist,
179
- relevance_factor=mmr_relevance_factor, top_n=limit)
182
+ query_vector, vlist, relevance_factor=mmr_relevance_factor, top_n=limit
183
+ )
180
184
  results = []
181
185
  # TODO: this is inefficient when limit is high
182
186
  for i in range(limit):