linkml-store 0.2.5__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (28) hide show
  1. linkml_store/api/client.py +9 -6
  2. linkml_store/api/collection.py +118 -5
  3. linkml_store/api/database.py +45 -14
  4. linkml_store/api/stores/duckdb/duckdb_collection.py +176 -8
  5. linkml_store/api/stores/duckdb/duckdb_database.py +52 -19
  6. linkml_store/api/stores/filesystem/__init__.py +1 -1
  7. linkml_store/api/stores/mongodb/mongodb_collection.py +186 -0
  8. linkml_store/api/stores/mongodb/mongodb_database.py +8 -3
  9. linkml_store/api/stores/solr/solr_collection.py +7 -1
  10. linkml_store/cli.py +202 -21
  11. linkml_store/index/implementations/llm_indexer.py +14 -6
  12. linkml_store/index/indexer.py +7 -4
  13. linkml_store/inference/implementations/llm_inference_engine.py +13 -9
  14. linkml_store/inference/implementations/rag_inference_engine.py +13 -10
  15. linkml_store/inference/implementations/sklearn_inference_engine.py +7 -1
  16. linkml_store/inference/inference_config.py +1 -0
  17. linkml_store/utils/dat_parser.py +95 -0
  18. linkml_store/utils/enrichment_analyzer.py +217 -0
  19. linkml_store/utils/format_utils.py +183 -3
  20. linkml_store/utils/llm_utils.py +3 -1
  21. linkml_store/utils/pandas_utils.py +1 -1
  22. linkml_store/utils/sql_utils.py +7 -1
  23. linkml_store/utils/vector_utils.py +4 -11
  24. {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/METADATA +4 -3
  25. {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/RECORD +28 -26
  26. {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/WHEEL +1 -1
  27. {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/LICENSE +0 -0
  28. {linkml_store-0.2.5.dist-info → linkml_store-0.2.9.dist-info}/entry_points.txt +0 -0
linkml_store/cli.py CHANGED
@@ -3,6 +3,7 @@ import sys
3
3
  import warnings
4
4
  from collections import defaultdict
5
5
  from pathlib import Path
6
+ from tokenize import group
6
7
  from typing import Optional, Tuple, Any
7
8
 
8
9
  import click
@@ -37,6 +38,11 @@ index_type_option = click.option(
37
38
  show_default=True,
38
39
  help="Type of index to create. Values: simple, llm",
39
40
  )
41
+ json_select_query_option = click.option(
42
+ "--json-select-query",
43
+ "-J",
44
+ help="JSON SELECT query",
45
+ )
40
46
 
41
47
  logger = logging.getLogger(__name__)
42
48
 
@@ -136,7 +142,7 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
136
142
  logger.setLevel(logging.ERROR)
137
143
  ctx.ensure_object(dict)
138
144
  if input:
139
- database = "duckdb" # default: store in duckdb
145
+ database = "duckdb" # default: store in duckdb
140
146
  if input.startswith("http"):
141
147
  parts = input.split("/")
142
148
  collection = parts[-1]
@@ -144,8 +150,7 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
144
150
  else:
145
151
  stem = underscore(Path(input).stem)
146
152
  collection = stem
147
- logger.info(f"Using input file: {input}, "
148
- f"default storage is {database} and collection is {collection}")
153
+ logger.info(f"Using input file: {input}, " f"default storage is {database} and collection is {collection}")
149
154
  config = ClientConfig(databases={"duckdb": {"collections": {stem: {"source": {"local_path": input}}}}})
150
155
  if config is None and DEFAULT_LOCAL_CONF_PATH.exists():
151
156
  config = DEFAULT_LOCAL_CONF_PATH
@@ -186,12 +191,24 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
186
191
 
187
192
 
188
193
  @cli.command()
189
- @click.argument("files", type=click.Path(exists=True), nargs=-1)
194
+ @click.pass_context
195
+ def drop(ctx):
196
+ """
197
+ Drop database and all its collections.
198
+ """
199
+ database = ctx.obj["settings"].database
200
+ database.drop()
201
+
202
+
203
+ @cli.command()
204
+ @click.argument("files", type=click.Path(), nargs=-1)
190
205
  @click.option("--replace/--no-replace", default=False, show_default=True, help="Replace existing objects")
191
206
  @click.option("--format", "-f", type=format_choice, help="Input format")
192
207
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
208
+ @click.option("--source-field", help="If provided, inject file path source as this field")
209
+ @json_select_query_option
193
210
  @click.pass_context
194
- def insert(ctx, files, replace, object, format):
211
+ def insert(ctx, files, replace, object, format, source_field, json_select_query):
195
212
  """Insert objects from files (JSON, YAML, TSV) into the specified collection.
196
213
 
197
214
  Using a configuration:
@@ -207,11 +224,17 @@ def insert(ctx, files, replace, object, format):
207
224
  raise ValueError("Collection must be specified.")
208
225
  if not files and not object:
209
226
  files = ["-"]
227
+ load_objects_args = {}
228
+ if json_select_query:
229
+ load_objects_args["select_query"] = json_select_query
210
230
  for file_path in files:
211
231
  if format:
212
- objects = load_objects(file_path, format=format)
232
+ objects = load_objects(file_path, format=format, **load_objects_args)
213
233
  else:
214
- objects = load_objects(file_path)
234
+ objects = load_objects(file_path, **load_objects_args)
235
+ if source_field:
236
+ for obj in objects:
237
+ obj[source_field] = str(file_path)
215
238
  logger.info(f"Inserting {len(objects)} objects from {file_path} into collection '{collection.alias}'.")
216
239
  if replace:
217
240
  collection.replace(objects)
@@ -222,6 +245,8 @@ def insert(ctx, files, replace, object, format):
222
245
  for object_str in object:
223
246
  logger.info(f"Parsing: {object_str}")
224
247
  objects = yaml.safe_load(object_str)
248
+ if not isinstance(objects, list):
249
+ objects = [objects]
225
250
  if replace:
226
251
  collection.replace(objects)
227
252
  else:
@@ -234,21 +259,41 @@ def insert(ctx, files, replace, object, format):
234
259
  @click.argument("files", type=click.Path(exists=True), nargs=-1)
235
260
  @click.option("--format", "-f", type=format_choice, help="Input format")
236
261
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
262
+ @json_select_query_option
237
263
  @click.pass_context
238
- def store(ctx, files, object, format):
264
+ def store(ctx, files, object, format, json_select_query):
239
265
  """Store objects from files (JSON, YAML, TSV) into the database.
240
266
 
241
- Note: this is similar to insert, but a collection does not need to be specified
267
+ Note: this is similar to insert, but a collection does not need to be specified.
268
+
269
+ For example, assume that `my-collection` is a dict with multiple keys,
270
+ and we want one collection per key:
271
+
272
+ linkml-store -d my.ddb store my-collection.yaml
273
+
274
+ Loading JSON (e.g OBO-JSON), with a --json-select-query:
275
+
276
+ linkml-store -d cl.ddb store -J graphs cl.obo.json
277
+
278
+ Loading XML (e.g OWL-XML), with a --json-select-query:
279
+
280
+ linkml-store -d cl.ddb store -J Ontology cl.owx
281
+
282
+ Because the XML uses a top level Ontology, with multiple
283
+
242
284
  """
243
285
  settings = ctx.obj["settings"]
244
286
  db = settings.database
245
287
  if not files and not object:
246
288
  files = ["-"]
289
+ load_objects_args = {}
290
+ if json_select_query:
291
+ load_objects_args["select_query"] = json_select_query
247
292
  for file_path in files:
248
293
  if format:
249
- objects = load_objects(file_path, format=format)
294
+ objects = load_objects(file_path, format=format, **load_objects_args)
250
295
  else:
251
- objects = load_objects(file_path)
296
+ objects = load_objects(file_path, **load_objects_args)
252
297
  logger.info(f"Inserting {len(objects)} objects from {file_path} into database '{db}'.")
253
298
  for obj in objects:
254
299
  db.store(obj)
@@ -422,15 +467,32 @@ def list_collections(ctx, **kwargs):
422
467
 
423
468
  @cli.command()
424
469
  @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
425
- @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
470
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return per facet")
471
+ @click.option("--facet-min-count", "-M", type=click.INT, help="Minimum count for a facet to be included")
426
472
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
427
473
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
428
- @click.option("--columns", "-S", help="Columns to facet on")
474
+ @click.option("--columns", "-S", help="Columns to facet on. Comma-separated, join combined facets with +")
429
475
  @click.option("--wide/--no-wide", "-U/--no-U", default=False, show_default=True, help="Wide table")
430
476
  @click.pass_context
431
- def fq(ctx, where, limit, columns, output_type, wide, output):
477
+ def fq(ctx, where, limit, columns, output_type, wide, output, **kwargs):
432
478
  """
433
- Query facets from the specified collection.
479
+ Query facet counts from the specified collection.
480
+
481
+ Assuming your .linkml.yaml includes an entry mapping `phenopackets` to a
482
+ mongodb
483
+
484
+ Facet counts (all columns)
485
+
486
+ linkml-store -d phenopackets fq
487
+
488
+ Nested columns:
489
+
490
+ linkml-store -d phenopackets fq subject.timeAtLastEncounter.age
491
+
492
+ Compound keys:
493
+
494
+ linkml-store -d phenopackets fq subject.sex+subject.timeAtLastEncounter.age
495
+
434
496
  """
435
497
  collection = ctx.obj["settings"].collection
436
498
  where_clause = yaml.safe_load(where) if where else None
@@ -439,7 +501,7 @@ def fq(ctx, where, limit, columns, output_type, wide, output):
439
501
  columns = [col.strip() for col in columns]
440
502
  columns = [(tuple(col.split("+")) if "+" in col else col) for col in columns]
441
503
  logger.info(f"Faceting on columns: {columns}")
442
- results = collection.query_facets(where_clause, facet_columns=columns, limit=limit)
504
+ results = collection.query_facets(where_clause, facet_columns=columns, facet_limit=limit, **kwargs)
443
505
  logger.info(f"Facet results: {results}")
444
506
 
445
507
  def _untuple(key):
@@ -471,6 +533,56 @@ def fq(ctx, where, limit, columns, output_type, wide, output):
471
533
  click.echo(output_data)
472
534
 
473
535
 
536
+ @cli.command()
537
+ @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
538
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return per facet")
539
+ @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
540
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
541
+ @click.option("--columns", "-S", help="Columns to facet on. Comma-separated, join combined facets with +")
542
+ @click.pass_context
543
+ def groupby(ctx, where, limit, columns, output_type, output, **kwargs):
544
+ """
545
+ Group by columns in the specified collection.
546
+
547
+ Assume a simple triple model:
548
+
549
+ linkml-store -d cl.ddb -c triple insert cl.owl
550
+
551
+ This makes a flat subject/predicate/object table
552
+
553
+ This can be grouped, e.g by subject:
554
+
555
+ linkml-store -d cl.ddb -c triple groupby -s subject
556
+
557
+ Or subject and predicate:
558
+
559
+ linkml-store -d cl.ddb -c triple groupby -s '[subject,predicate]'
560
+
561
+ """
562
+ collection = ctx.obj["settings"].collection
563
+ where_clause = yaml.safe_load(where) if where else None
564
+ columns = columns.split(",") if columns else None
565
+ if columns:
566
+ columns = [col.strip() for col in columns]
567
+ columns = [(tuple(col.split("+")) if "+" in col else col) for col in columns]
568
+ logger.info(f"Group by: {columns}")
569
+ result = collection.group_by(
570
+ group_by_fields=columns,
571
+ where_clause=where_clause,
572
+ agg_map={},
573
+ limit=limit,
574
+ **kwargs,
575
+ )
576
+ logger.info(f"Group by results: {result}")
577
+ output_data = render_output(result.rows, output_type)
578
+ if output:
579
+ with open(output, "w") as f:
580
+ f.write(output_data)
581
+ click.echo(f"Query results saved to {output}")
582
+ else:
583
+ click.echo(output_data)
584
+
585
+
474
586
  def _get_index(index_type=None, **kwargs) -> Indexer:
475
587
  if index_type is None or index_type == "simple":
476
588
  return SimpleIndexer(name="test", **kwargs)
@@ -519,10 +631,12 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
519
631
  value_key = tuple([row.get(att) for att in value_atts])
520
632
  pivoted[index_key][column_key] = value_key
521
633
  pivoted_objs = []
634
+
522
635
  def detuple(t: Tuple) -> Any:
523
636
  if len(t) == 1:
524
637
  return t[0]
525
638
  return str(t)
639
+
526
640
  for index_key, data in pivoted.items():
527
641
  obj = {att: key for att, key in zip(index_atts, index_key)}
528
642
  for column_key, value_key in data.items():
@@ -531,6 +645,57 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
531
645
  write_output(pivoted_objs, output_type, target=output)
532
646
 
533
647
 
648
+ @cli.command()
649
+ @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
650
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
651
+ @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
652
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
653
+ @click.option("--sample-field", "-I", help="Field to use as the sample identifier")
654
+ @click.option("--classification-field", "-L", help="Field to use as for classification")
655
+ @click.option(
656
+ "--p-value-threshold",
657
+ "-P",
658
+ type=click.FLOAT,
659
+ default=0.05,
660
+ show_default=True,
661
+ help="P-value threshold for enrichment",
662
+ )
663
+ @click.option(
664
+ "--multiple-testing-correction",
665
+ "-M",
666
+ type=click.STRING,
667
+ default="bh",
668
+ show_default=True,
669
+ help="Multiple test correction method",
670
+ )
671
+ @click.argument("samples", type=click.STRING, nargs=-1)
672
+ @click.pass_context
673
+ def enrichment(ctx, where, limit, output_type, output, sample_field, classification_field, samples, **kwargs):
674
+ from linkml_store.utils.enrichment_analyzer import EnrichmentAnalyzer
675
+
676
+ collection = ctx.obj["settings"].collection
677
+ where_clause = yaml.safe_load(where) if where else None
678
+ column_atts = [sample_field, classification_field]
679
+ results = collection.find(where_clause, select_cols=column_atts, limit=-1)
680
+ df = results.rows_dataframe
681
+ ea = EnrichmentAnalyzer(df, sample_key=sample_field, classification_key=classification_field)
682
+ if not samples:
683
+ samples = df[sample_field].unique()
684
+ enrichment_results = []
685
+ for sample in samples:
686
+ enriched = ea.find_enriched_categories(sample, **kwargs)
687
+ for e in enriched:
688
+ obj = {"sample": sample, **e.model_dump()}
689
+ enrichment_results.append(obj)
690
+ output_data = render_output(enrichment_results, output_type)
691
+ if output:
692
+ with open(output, "w") as f:
693
+ f.write(output_data)
694
+ click.echo(f"Search results saved to {output}")
695
+ else:
696
+ click.echo(output_data)
697
+
698
+
534
699
  @cli.command()
535
700
  @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
536
701
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
@@ -538,7 +703,7 @@ def pivot(ctx, where, limit, index, columns, values, output_type, output):
538
703
  @click.option(
539
704
  "--feature-attributes", "-F", type=click.STRING, help="Feature attributes for inference (comma separated)"
540
705
  )
541
- @click.option("--training-collection", type=click.STRING,help="Collection to use for training")
706
+ @click.option("--training-collection", type=click.STRING, help="Collection to use for training")
542
707
  @click.option("--inference-config-file", "-Y", type=click.Path(), help="Path to inference configuration file")
543
708
  @click.option("--export-model", "-E", type=click.Path(), help="Export model to file")
544
709
  @click.option("--load-model", "-L", type=click.Path(), help="Load model from file")
@@ -753,12 +918,28 @@ def indexes(ctx):
753
918
  @cli.command()
754
919
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
755
920
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
921
+ @click.option(
922
+ "--collection-only/--no-collection-only",
923
+ default=False,
924
+ show_default=True,
925
+ help="Only validate specified collection",
926
+ )
927
+ @click.option(
928
+ "--ensure-referential-integrity/--no-ensure-referential-integrity",
929
+ default=True,
930
+ show_default=True,
931
+ help="Ensure referential integrity",
932
+ )
756
933
  @click.pass_context
757
- def validate(ctx, output_type, output):
934
+ def validate(ctx, output_type, output, collection_only, **kwargs):
758
935
  """Validate objects in the specified collection."""
759
- collection = ctx.obj["settings"].collection
760
- logger.info(f"Validating collection {collection.alias}")
761
- validation_results = [json_dumper.to_dict(x) for x in collection.iter_validate_collection()]
936
+ if collection_only:
937
+ collection = ctx.obj["settings"].collection
938
+ logger.info(f"Validating collection {collection.alias}")
939
+ validation_results = [json_dumper.to_dict(x) for x in collection.iter_validate_collection(**kwargs)]
940
+ else:
941
+ db = ctx.obj["settings"].database
942
+ validation_results = [json_dumper.to_dict(x) for x in db.validate_database(**kwargs)]
762
943
  output_data = render_output(validation_results, output_type)
763
944
  if output:
764
945
  with open(output, "w") as f:
@@ -3,6 +3,7 @@ from pathlib import Path
3
3
  from typing import TYPE_CHECKING, List, Optional
4
4
 
5
5
  import numpy as np
6
+ import openai
6
7
 
7
8
  from linkml_store.api.config import CollectionConfig
8
9
  from linkml_store.index.indexer import INDEX_ITEM, Indexer
@@ -11,6 +12,7 @@ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
11
12
  if TYPE_CHECKING:
12
13
  import llm
13
14
 
15
+ CHUNK_SIZE = 1000
14
16
 
15
17
  logger = logging.getLogger(__name__)
16
18
 
@@ -25,7 +27,7 @@ class LLMIndexer(Indexer):
25
27
  >>> vector = indexer.text_to_vector("hello")
26
28
  """
27
29
 
28
- embedding_model_name: str = "ada-002"
30
+ embedding_model_name: str = "text-embedding-ada-002"
29
31
  _embedding_model: "llm.EmbeddingModel" = None
30
32
  cached_embeddings_database: str = None
31
33
  cached_embeddings_collection: str = None
@@ -52,7 +54,9 @@ class LLMIndexer(Indexer):
52
54
  """
53
55
  return self.texts_to_vectors([text], cache=cache, **kwargs)[0]
54
56
 
55
- def texts_to_vectors(self, texts: List[str], cache: bool = None, **kwargs) -> List[INDEX_ITEM]:
57
+ def texts_to_vectors(
58
+ self, texts: List[str], cache: bool = None, token_limit_penalty=0, **kwargs
59
+ ) -> List[INDEX_ITEM]:
56
60
  """
57
61
  Use LLM to embed.
58
62
 
@@ -60,18 +64,22 @@ class LLMIndexer(Indexer):
60
64
  >>> vectors = indexer.texts_to_vectors(["hello", "goodbye"])
61
65
 
62
66
  :param texts:
67
+ :param cache:
68
+ :param token_limit_penalty:
63
69
  :return:
64
70
  """
65
71
  from tiktoken import encoding_for_model
72
+
66
73
  logging.info(f"Converting {len(texts)} texts to vectors")
67
74
  model = self.embedding_model
68
75
  # TODO: make this more accurate
69
- token_limit = get_token_limit(model.model_id) - 200
70
- encoding = encoding_for_model("gpt-4o")
76
+ token_limit = get_token_limit(model.model_id) - token_limit_penalty
77
+ logging.info(f"Token limit for {model.model_id}: {token_limit}")
78
+ encoding = encoding_for_model(self.embedding_model_name)
71
79
 
72
80
  def truncate_text(text: str) -> str:
73
81
  # split into tokens every 1000 chars:
74
- parts = [text[i : i + 1000] for i in range(0, len(text), 1000)]
82
+ parts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
75
83
  truncated = render_formatted_text(
76
84
  lambda x: "".join(x),
77
85
  parts,
@@ -140,5 +148,5 @@ class LLMIndexer(Indexer):
140
148
  embeddings_collection.commit()
141
149
  else:
142
150
  logger.info(f"Embedding {len(texts)} texts")
143
- embeddings = model.embed_multi(texts)
151
+ embeddings = list(model.embed_multi(texts, batch_size=1))
144
152
  return [np.array(v, dtype=float) for v in embeddings]
@@ -154,8 +154,11 @@ class Indexer(BaseModel):
154
154
  return str(obj)
155
155
 
156
156
  def search(
157
- self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None,
158
- mmr_relevance_factor: Optional[float] = None
157
+ self,
158
+ query: str,
159
+ vectors: List[Tuple[str, INDEX_ITEM]],
160
+ limit: Optional[int] = None,
161
+ mmr_relevance_factor: Optional[float] = None,
159
162
  ) -> List[Tuple[float, Any]]:
160
163
  """
161
164
  Use the indexer to search against a database of vectors.
@@ -175,8 +178,8 @@ class Indexer(BaseModel):
175
178
  vlist = [v for _, v in vectors]
176
179
  idlist = [id for id, _ in vectors]
177
180
  sorted_indices = mmr_diversified_search(
178
- query_vector, vlist,
179
- relevance_factor=mmr_relevance_factor, top_n=limit)
181
+ query_vector, vlist, relevance_factor=mmr_relevance_factor, top_n=limit
182
+ )
180
183
  results = []
181
184
  # TODO: this is inefficient when limit is high
182
185
  for i in range(limit):
@@ -79,21 +79,24 @@ class LLMInferenceEngine(InferenceEngine):
79
79
  def _schema_str(self) -> str:
80
80
  db = self.training_data.base_collection.parent
81
81
  from linkml_runtime.dumpers import json_dumper
82
+
82
83
  schema_dict = json_dumper.to_dict(db.schema_view.schema)
83
84
  return yaml.dump(schema_dict)
84
85
 
85
- def derive(self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None) -> Optional[LLMInference]:
86
+ def derive(
87
+ self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
88
+ ) -> Optional[LLMInference]:
86
89
  import llm
87
90
 
88
91
  model: llm.Model = self.model
89
- #model_name = self.config.llm_config.model_name
90
- #feature_attributes = self.config.feature_attributes
92
+ # model_name = self.config.llm_config.model_name
93
+ # feature_attributes = self.config.feature_attributes
91
94
  target_attributes = self.config.target_attributes
92
95
  query_text = self.object_to_text(object)
93
96
 
94
97
  if not target_attributes:
95
98
  target_attributes = [k for k, v in object.items() if v is None or v == ""]
96
- #if not feature_attributes:
99
+ # if not feature_attributes:
97
100
  # feature_attributes = [k for k, v in object.items() if v is not None and v != ""]
98
101
 
99
102
  system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
@@ -107,7 +110,9 @@ class LLMInferenceEngine(InferenceEngine):
107
110
  "```yaml\n"
108
111
  f"{stub}\n"
109
112
  "```\n"
110
- "---\nQuery:\n" f"## INCOMPLETE OBJECT:\n{query_text}\n" "## OUTPUT:\n"
113
+ "---\nQuery:\n"
114
+ f"## INCOMPLETE OBJECT:\n{query_text}\n"
115
+ "## OUTPUT:\n"
111
116
  )
112
117
  logger.info(f"Prompt: {prompt}")
113
118
  response = model.prompt(prompt, system=system_prompt)
@@ -130,9 +135,8 @@ class LLMInferenceEngine(InferenceEngine):
130
135
  "\nThis was invalid.\n",
131
136
  "Validation errors:\n",
132
137
  ] + [self.object_to_text(e) for e in errs]
133
- return self.derive(object, iteration=iteration+1, additional_prompt_texts=extra_texts)
134
- return LLMInference(predicted_object=predicted_object, iterations=iteration+1, query=object)
135
-
138
+ return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
139
+ return LLMInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
136
140
 
137
141
  def export_model(
138
142
  self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
@@ -149,4 +153,4 @@ class LLMInferenceEngine(InferenceEngine):
149
153
 
150
154
  @classmethod
151
155
  def load_model(cls, file_path: Union[str, Path]) -> "LLMInferenceEngine":
152
- raise NotImplementedError("Does not make sense for this engine")
156
+ raise NotImplementedError("Does not make sense for this engine")
@@ -111,7 +111,9 @@ class RAGInferenceEngine(InferenceEngine):
111
111
  def object_to_text(self, object: OBJECT) -> str:
112
112
  return yaml.dump(object)
113
113
 
114
- def derive(self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None) -> Optional[RAGInference]:
114
+ def derive(
115
+ self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None
116
+ ) -> Optional[RAGInference]:
115
117
  import llm
116
118
  from tiktoken import encoding_for_model
117
119
 
@@ -131,8 +133,9 @@ class RAGInferenceEngine(InferenceEngine):
131
133
  if not self.rag_collection.indexers:
132
134
  raise ValueError("RAG collection must have an indexer attached")
133
135
  logger.info(f"Searching {self.rag_collection.alias} for examples for: {query_text}")
134
- rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm",
135
- mmr_relevance_factor=mmr_relevance_factor)
136
+ rs = self.rag_collection.search(
137
+ query_text, limit=num_examples, index_name="llm", mmr_relevance_factor=mmr_relevance_factor
138
+ )
136
139
  examples = rs.rows
137
140
  logger.info(f"Found {len(examples)} examples")
138
141
  if not examples:
@@ -153,11 +156,11 @@ class RAGInferenceEngine(InferenceEngine):
153
156
  input_obj_text = self.object_to_text(input_obj)
154
157
  if input_obj_text == query_text:
155
158
  continue
156
- #raise ValueError(
159
+ # raise ValueError(
157
160
  # f"Query object {query_text} is the same as example object {input_obj_text}\n"
158
161
  # "This indicates possible test data leakage\n."
159
162
  # "TODO: allow an option that allows user to treat this as a basic lookup\n"
160
- #)
163
+ # )
161
164
  output_obj = select_nested(example, target_attributes)
162
165
  prompt_clause = (
163
166
  "---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
@@ -176,9 +179,9 @@ class RAGInferenceEngine(InferenceEngine):
176
179
  except KeyError:
177
180
  encoding = encoding_for_model("gpt-4")
178
181
  token_limit = get_token_limit(model_name)
179
- prompt = render_formatted_text(make_text, values=prompt_clauses,
180
- encoding=encoding, token_limit=token_limit,
181
- additional_text=system_prompt)
182
+ prompt = render_formatted_text(
183
+ make_text, values=prompt_clauses, encoding=encoding, token_limit=token_limit, additional_text=system_prompt
184
+ )
182
185
  logger.info(f"Prompt: {prompt}")
183
186
  response = model.prompt(prompt, system=system_prompt)
184
187
  yaml_str = response.text()
@@ -199,8 +202,8 @@ class RAGInferenceEngine(InferenceEngine):
199
202
  "\nThis was invalid.\n",
200
203
  "Validation errors:\n",
201
204
  ] + [self.object_to_text(e) for e in errs]
202
- return self.derive(object, iteration=iteration+1, additional_prompt_texts=extra_texts)
203
- return RAGInference(predicted_object=predicted_object, iterations=iteration+1, query=object)
205
+ return self.derive(object, iteration=iteration + 1, additional_prompt_texts=extra_texts)
206
+ return RAGInference(predicted_object=predicted_object, iterations=iteration + 1, query=object)
204
207
 
205
208
  def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
206
209
  if "```" in yaml_str:
@@ -94,6 +94,8 @@ class SklearnInferenceEngine(InferenceEngine):
94
94
  if not feature_cols:
95
95
  feature_cols = df.columns.difference(target_cols).tolist()
96
96
  self.config.feature_attributes = feature_cols
97
+ if not feature_cols:
98
+ raise ValueError("No features found in the data")
97
99
  target_col = target_cols[0]
98
100
  logger.info(f"Feature columns: {feature_cols}")
99
101
  X = df[feature_cols].copy()
@@ -102,6 +104,8 @@ class SklearnInferenceEngine(InferenceEngine):
102
104
 
103
105
  # find list of features to skip (categorical with > N categories)
104
106
  skip_features = []
107
+ if not len(X.columns):
108
+ raise ValueError("No features to train on")
105
109
  for col in X.columns:
106
110
  unique_values = self._get_unique_values(X[col])
107
111
  if len(unique_values) > self.maximum_proportion_distinct_features * len(X[col]):
@@ -115,6 +119,8 @@ class SklearnInferenceEngine(InferenceEngine):
115
119
 
116
120
  # Encode features
117
121
  encoded_features = []
122
+ if not len(X.columns):
123
+ raise ValueError(f"No features to train on from after skipping {skip_features}")
118
124
  for col in X.columns:
119
125
  logger.info(f"Checking whether to encode: {col}")
120
126
  col_encoder = self._get_encoder(X[col])
@@ -153,7 +159,7 @@ class SklearnInferenceEngine(InferenceEngine):
153
159
  y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
154
160
  self.transformed_targets = y_encoder.classes_
155
161
 
156
- # print(f"Fitting model with features: {X.columns}")
162
+ # print(f"Fitting model with features: {X.columns}, y={y}, X={X}")
157
163
  clf = DecisionTreeClassifier(random_state=42)
158
164
  clf.fit(X, y)
159
165
  self.classifier = clf
@@ -59,6 +59,7 @@ class Inference(BaseModel, extra="forbid"):
59
59
  """
60
60
  Result of an inference derivation.
61
61
  """
62
+
62
63
  query: Optional[OBJECT] = Field(default=None, description="The query object.")
63
64
  predicted_object: OBJECT = Field(..., description="The predicted object.")
64
65
  confidence: Optional[float] = Field(default=None, description="The confidence of the prediction.", le=1.0, ge=0.0)