linkml-store 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

@@ -226,6 +226,18 @@ class Collection(Generic[DatabaseType]):
226
226
  self._initialized = True
227
227
  patches = [{"op": "add", "path": "/0", "value": obj} for obj in objs]
228
228
  self._broadcast(patches, **kwargs)
229
+ self._post_modification_hook(**kwargs)
230
+
231
+ def _post_delete_hook(self, **kwargs):
232
+ self._post_modification_hook(**kwargs)
233
+
234
+ def _post_modification_hook(self, **kwargs):
235
+ for indexer in self.indexers.values():
236
+ ix_collection_name = self.get_index_collection_name(indexer)
237
+ ix_collection = self.parent.get_collection(ix_collection_name)
238
+ # Currently updating the source triggers complete reindexing
239
+ # TODO: make this more efficient by only deleting modified
240
+ ix_collection.delete_where({})
229
241
 
230
242
  def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> Optional[int]:
231
243
  """
@@ -458,6 +470,7 @@ class Collection(Generic[DatabaseType]):
458
470
  where: Optional[Any] = None,
459
471
  index_name: Optional[str] = None,
460
472
  limit: Optional[int] = None,
473
+ mmr_relevance_factor: Optional[float] = None,
461
474
  **kwargs,
462
475
  ) -> QueryResult:
463
476
  """
@@ -476,7 +489,7 @@ class Collection(Generic[DatabaseType]):
476
489
  Now let's index, using the simple trigram-based index
477
490
 
478
491
  >>> index = get_indexer("simple")
479
- >>> collection.attach_indexer(index)
492
+ >>> _ = collection.attach_indexer(index)
480
493
 
481
494
  Now let's find all objects:
482
495
 
@@ -514,12 +527,15 @@ class Collection(Generic[DatabaseType]):
514
527
  if ix_coll.size() == 0:
515
528
  logger.info(f"Index {index_name} is empty; indexing all objects")
516
529
  all_objs = self.find(limit=-1).rows
517
- self.index_objects(all_objs, index_name, replace=True, **kwargs)
530
+ if all_objs:
531
+ # print(f"Index {index_name} is empty; indexing all objects {len(all_objs)}")
532
+ self.index_objects(all_objs, index_name, replace=True, **kwargs)
533
+ assert ix_coll.size() > 0
518
534
  qr = ix_coll.find(where=where, limit=-1, **kwargs)
519
535
  index_col = ix.index_field
520
536
  # TODO: optimize this for large indexes
521
537
  vector_pairs = [(row, np.array(row[index_col], dtype=float)) for row in qr.rows]
522
- results = ix.search(query, vector_pairs, limit=limit)
538
+ results = ix.search(query, vector_pairs, limit=limit, mmr_relevance_factor=mmr_relevance_factor, **kwargs)
523
539
  for r in results:
524
540
  del r[1][index_col]
525
541
  new_qr = QueryResult(num_rows=len(results))
@@ -648,7 +664,31 @@ class Collection(Generic[DatabaseType]):
648
664
  """
649
665
  return self.find({}, limit=1).num_rows
650
666
 
651
- def attach_indexer(self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs):
667
+ def rows_iter(self) -> Iterable[OBJECT]:
668
+ """
669
+ Return an iterator over the objects in the collection.
670
+
671
+ :return:
672
+ """
673
+ yield from self.find({}, limit=-1).rows
674
+
675
+ def rows(self) -> List[OBJECT]:
676
+ """
677
+ Return a list of objects in the collection.
678
+
679
+ :return:
680
+ """
681
+ return list(self.rows_iter())
682
+
683
+ def ranked_rows(self) -> List[Tuple[float, OBJECT]]:
684
+ """
685
+ Return a list of objects in the collection, with scores.
686
+ """
687
+ return [(n, obj) for n, obj in enumerate(self.rows_iter())]
688
+
689
+ def attach_indexer(
690
+ self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs
691
+ ) -> Indexer:
652
692
  """
653
693
  Attach an index to the collection.
654
694
 
@@ -669,8 +709,8 @@ class Collection(Generic[DatabaseType]):
669
709
  >>> full_index.name = "full"
670
710
  >>> name_index = get_indexer("simple", text_template="{name}")
671
711
  >>> name_index.name = "name"
672
- >>> collection.attach_indexer(full_index)
673
- >>> collection.attach_indexer(name_index)
712
+ >>> _ = collection.attach_indexer(full_index)
713
+ >>> _ = collection.attach_indexer(name_index)
674
714
 
675
715
  Now let's find objects using the full index, using the string "France".
676
716
  We expect the country France to be the top hit, but the score will
@@ -713,6 +753,10 @@ class Collection(Generic[DatabaseType]):
713
753
  all_objs = self.find(limit=-1).rows
714
754
  logger.info(f"Auto-indexing {len(all_objs)} objects")
715
755
  self.index_objects(all_objs, index_name, replace=True, **kwargs)
756
+ return index
757
+
758
+ def get_index_collection_name(self, indexer: Indexer) -> str:
759
+ return self._index_collection_name(indexer.name)
716
760
 
717
761
  def _index_collection_name(self, index_name: str) -> str:
718
762
  """
@@ -268,7 +268,7 @@ class Database(ABC, Generic[CollectionType]):
268
268
  metadata: Optional[CollectionConfig] = None,
269
269
  recreate_if_exists=False,
270
270
  **kwargs,
271
- ) -> CollectionType:
271
+ ) -> Collection:
272
272
  """
273
273
  Create a new collection in the current database.
274
274
 
@@ -760,6 +760,12 @@ class Database(ABC, Generic[CollectionType]):
760
760
  """
761
761
  Export a database to a file or location.
762
762
 
763
+ >>> from linkml_store.api.client import Client
764
+ >>> client = Client()
765
+ >>> db = client.attach_database("duckdb", alias="test")
766
+ >>> db.import_database("tests/input/iris.csv", Format.CSV, collection_name="iris")
767
+ >>> db.export_database("/tmp/iris.yaml", Format.YAML)
768
+
763
769
  :param location: location of the file
764
770
  :param target_format: target format
765
771
  :param kwargs: additional arguments
@@ -40,7 +40,9 @@ class FacetCountResult(BaseModel):
40
40
 
41
41
  class QueryResult(BaseModel):
42
42
  """
43
- A query result
43
+ A query result.
44
+
45
+ TODO: make this a subclass of Collection
44
46
  """
45
47
 
46
48
  query: Optional[Query] = None
@@ -50,8 +50,9 @@ class DuckDBCollection(Collection):
50
50
  if not isinstance(objs, list):
51
51
  objs = [objs]
52
52
  cd = self.class_definition()
53
- if not cd:
53
+ if not cd or not cd.attributes:
54
54
  cd = self.induce_class_definition_from_objects(objs)
55
+ assert cd.attributes
55
56
  table = self._sqla_table(cd)
56
57
  engine = self.parent.engine
57
58
  with engine.connect() as conn:
@@ -61,7 +62,8 @@ class DuckDBCollection(Collection):
61
62
  stmt = stmt.compile(engine)
62
63
  conn.execute(stmt)
63
64
  conn.commit()
64
- return
65
+ self._post_delete_hook()
66
+ return None
65
67
 
66
68
  def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> Optional[int]:
67
69
  logger.info(f"Deleting from {self.target_class_name} where: {where}")
@@ -87,6 +89,7 @@ class DuckDBCollection(Collection):
87
89
  if deleted_rows_count == 0 and not missing_ok:
88
90
  raise ValueError(f"No rows found for {where}")
89
91
  conn.commit()
92
+ self._post_delete_hook()
90
93
  return deleted_rows_count if deleted_rows_count > -1 else None
91
94
 
92
95
  def query_facets(
linkml_store/cli.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import logging
2
2
  import sys
3
3
  import warnings
4
+ from collections import defaultdict
4
5
  from pathlib import Path
5
- from typing import Optional
6
+ from typing import Optional, Tuple, Any
6
7
 
7
8
  import click
8
9
  import yaml
@@ -76,6 +77,8 @@ class ContextSettings(BaseModel):
76
77
  if name is None:
77
78
  # if len(self.database.list_collections()) > 1:
78
79
  # raise ValueError("Collection must be specified if there are multiple collections.")
80
+ if not self.database:
81
+ return None
79
82
  if not self.database.list_collections():
80
83
  return None
81
84
  name = list(self.database.list_collections())[0]
@@ -218,7 +221,10 @@ def insert(ctx, files, object, format):
218
221
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
219
222
  @click.pass_context
220
223
  def store(ctx, files, object, format):
221
- """Store objects from files (JSON, YAML, TSV) into the specified collection."""
224
+ """Store objects from files (JSON, YAML, TSV) into the database.
225
+
226
+ Note: this is similar to insert, but a collection does not need to be specified
227
+ """
222
228
  settings = ctx.obj["settings"]
223
229
  db = settings.database
224
230
  if not files and not object:
@@ -410,14 +416,6 @@ def list_collections(ctx, **kwargs):
410
416
  def fq(ctx, where, limit, columns, output_type, wide, output):
411
417
  """
412
418
  Query facets from the specified collection.
413
-
414
- :param ctx:
415
- :param where:
416
- :param limit:
417
- :param columns:
418
- :param output_type:
419
- :param output:
420
- :return:
421
419
  """
422
420
  collection = ctx.obj["settings"].collection
423
421
  where_clause = yaml.safe_load(where) if where else None
@@ -483,6 +481,41 @@ def describe(ctx, where, output_type, output, limit):
483
481
  write_output(df.describe(include="all").transpose(), output_type, target=output)
484
482
 
485
483
 
484
+ @cli.command()
485
+ @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
486
+ @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
487
+ @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
488
+ @click.option("--output", "-o", type=click.Path(), help="Output file path")
489
+ @click.option("--index", "-I", help="Attributes to index on in pivot")
490
+ @click.option("--columns", "-A", help="Attributes to use as columns in pivot")
491
+ @click.option("--values", "-V", help="Attributes to use as values in pivot")
492
+ @click.pass_context
493
+ def pivot(ctx, where, limit, index, columns, values, output_type, output):
494
+ collection = ctx.obj["settings"].collection
495
+ where_clause = yaml.safe_load(where) if where else None
496
+ column_atts = columns.split(",") if columns else None
497
+ value_atts = values.split(",") if values else None
498
+ index_atts = index.split(",") if index else None
499
+ results = collection.find(where_clause, limit=limit)
500
+ pivoted = defaultdict(dict)
501
+ for row in results.rows:
502
+ index_key = tuple([row.get(att) for att in index_atts])
503
+ column_key = tuple([row.get(att) for att in column_atts])
504
+ value_key = tuple([row.get(att) for att in value_atts])
505
+ pivoted[index_key][column_key] = value_key
506
+ pivoted_objs = []
507
+ def detuple(t: Tuple) -> Any:
508
+ if len(t) == 1:
509
+ return t[0]
510
+ return str(t)
511
+ for index_key, data in pivoted.items():
512
+ obj = {att: key for att, key in zip(index_atts, index_key)}
513
+ for column_key, value_key in data.items():
514
+ obj[detuple(column_key)] = detuple(value_key)
515
+ pivoted_objs.append(obj)
516
+ write_output(pivoted_objs, output_type, target=output)
517
+
518
+
486
519
  @cli.command()
487
520
  @click.option("--output-type", "-O", type=format_choice, default=Format.YAML.value, help="Output format")
488
521
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
@@ -499,6 +532,7 @@ def describe(ctx, where, output_type, output, limit):
499
532
  "--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
500
533
  )
501
534
  @click.option("--evaluation-count", "-n", type=click.INT, help="Number of examples to evaluate over")
535
+ @click.option("--evaluation-match-function", help="Name of function to use for matching objects in eval")
502
536
  @click.option("--query", "-q", type=click.STRING, help="query term")
503
537
  @click.pass_context
504
538
  def infer(
@@ -506,6 +540,7 @@ def infer(
506
540
  inference_config_file,
507
541
  query,
508
542
  evaluation_count,
543
+ evaluation_match_function,
509
544
  training_test_data_split,
510
545
  predictor_type,
511
546
  target_attribute,
@@ -549,7 +584,10 @@ def infer(
549
584
  else:
550
585
  query_obj = None
551
586
  collection = ctx.obj["settings"].collection
552
- atts = collection.class_definition().attributes.keys()
587
+ if collection:
588
+ atts = collection.class_definition().attributes.keys()
589
+ else:
590
+ atts = []
553
591
  if feature_attributes:
554
592
  features = feature_attributes.split(",")
555
593
  features = [f.strip() for f in features]
@@ -575,7 +613,8 @@ def infer(
575
613
  if training_test_data_split:
576
614
  config.train_test_split = training_test_data_split
577
615
  predictor = get_inference_engine(predictor_type, config=config)
578
- predictor.load_and_split_data(collection)
616
+ if collection:
617
+ predictor.load_and_split_data(collection)
579
618
  predictor.initialize_model()
580
619
  if export_model:
581
620
  logger.info(f"Exporting model to {export_model} in {model_format}")
@@ -584,8 +623,14 @@ def infer(
584
623
  if not export_model and not evaluation_count:
585
624
  raise ValueError("Query or evaluate must be specified if not exporting model")
586
625
  if evaluation_count:
626
+ if evaluation_match_function == "score_text_overlap":
627
+ match_function_fn = score_text_overlap
628
+ elif evaluation_match_function is not None:
629
+ raise ValueError(f"Unknown match function: {evaluation_match_function}")
630
+ else:
631
+ match_function_fn = None
587
632
  outcome = evaluate_predictor(
588
- predictor, target_attributes, evaluation_count=evaluation_count, match_function=score_text_overlap
633
+ predictor, target_attributes, evaluation_count=evaluation_count, match_function=match_function_fn
589
634
  )
590
635
  print(f"Outcome: {outcome} // accuracy: {outcome.accuracy}")
591
636
  if query_obj:
@@ -1,11 +1,13 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import TYPE_CHECKING, List
3
+ from typing import TYPE_CHECKING, List, Optional
4
4
 
5
5
  import numpy as np
6
+ from tiktoken import encoding_for_model
6
7
 
7
8
  from linkml_store.api.config import CollectionConfig
8
9
  from linkml_store.index.indexer import INDEX_ITEM, Indexer
10
+ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
9
11
 
10
12
  if TYPE_CHECKING:
11
13
  import llm
@@ -29,6 +31,7 @@ class LLMIndexer(Indexer):
29
31
  cached_embeddings_database: str = None
30
32
  cached_embeddings_collection: str = None
31
33
  cache_queries: bool = False
34
+ truncation_method: Optional[str] = None
32
35
 
33
36
  @property
34
37
  def embedding_model(self):
@@ -62,6 +65,21 @@ class LLMIndexer(Indexer):
62
65
  """
63
66
  logging.info(f"Converting {len(texts)} texts to vectors")
64
67
  model = self.embedding_model
68
+ token_limit = get_token_limit(model.model_id)
69
+ encoding = encoding_for_model("gpt-4o")
70
+
71
+ def truncate_text(text: str) -> str:
72
+ # split into tokens every 1000 chars:
73
+ parts = [text[i : i + 1000] for i in range(0, len(text), 1000)]
74
+ return render_formatted_text(
75
+ lambda x: "".join(x),
76
+ parts,
77
+ encoding,
78
+ token_limit,
79
+ )
80
+
81
+ texts = [truncate_text(text) for text in texts]
82
+
65
83
  if self.cached_embeddings_database and (cache is None or cache or self.cache_queries):
66
84
  model_id = model.model_id
67
85
  if not model_id:
@@ -88,7 +106,7 @@ class LLMIndexer(Indexer):
88
106
  embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
89
107
  else:
90
108
  embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
91
- texts = list(texts)
109
+
92
110
  embeddings = list([None] * len(texts))
93
111
  uncached_texts = []
94
112
  n = 0
@@ -3,6 +3,7 @@ from enum import Enum
3
3
  from typing import Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from linkml_store.utils.vector_utils import pairwise_cosine_similarity, mmr_diversified_search
6
7
  from pydantic import BaseModel
7
8
 
8
9
  INDEX_ITEM = np.ndarray
@@ -19,23 +20,57 @@ class TemplateSyntaxEnum(str, Enum):
19
20
  fstring = "fstring"
20
21
 
21
22
 
22
- def cosine_similarity(vector1, vector2) -> float:
23
+ class Indexer(BaseModel):
23
24
  """
24
- Calculate the cosine similarity between two vectors
25
+ An indexer operates on a collection in order to search for objects.
25
26
 
26
- :param vector1:
27
- :param vector2:
28
- :return:
29
- """
30
- dot_product = np.dot(vector1, vector2)
31
- norm1 = np.linalg.norm(vector1)
32
- norm2 = np.linalg.norm(vector2)
33
- return dot_product / (norm1 * norm2)
27
+ You should use a subcllass of this; this can be looked up dynqamically:
34
28
 
29
+ >>> from linkml_store.index import get_indexer
30
+ >>> indexer = get_indexer("simple")
31
+
32
+ You can customize how objects are indexed by passing in a text template.
33
+ For example, if your collection has objects with "name" and "profession" attributes,
34
+ you can index them as "{name} {profession}".
35
+
36
+ >>> indexer = get_indexer("simple", text_template="{name} :: {profession}")
37
+
38
+ By default, python fstrings are assumed.
39
+
40
+ We can test this works using the :ref:`object_to_text` method (normally
41
+ you would never need to call this directly, but it's useful for testing):
42
+
43
+ >>> obj = {"name": "John", "profession": "doctor"}
44
+ >>> indexer.object_to_text(obj)
45
+ 'John :: doctor'
46
+
47
+ You can also use Jinja2 templates; this gives more flexibility and logic,
48
+ e.g. conditional formatting:
49
+
50
+ >>> tmpl = "{{name}}{% if profession %} :: {{profession}}{% endif %}"
51
+ >>> indexer = get_indexer("simple", text_template=tmpl, text_template_syntax=TemplateSyntaxEnum.jinja2)
52
+ >>> indexer.object_to_text(obj)
53
+ 'John :: doctor'
54
+ >>> indexer.object_to_text({"name": "John"})
55
+ 'John'
56
+
57
+ You can also specify which attributes to index:
58
+
59
+ >>> indexer = get_indexer("simple", index_attributes=["name"])
60
+ >>> indexer.object_to_text(obj)
61
+ 'John'
62
+
63
+ The purpose of an indexer is to translate a collection of objects into a collection of objects
64
+ such as vectors for purposes such as search. Unless you are implementing your own indexer, you
65
+ generally don't need to use the methods that return vectors, but we can examine their behavior
66
+ to get a sense of how they work.
67
+
68
+ >>> vectors = indexer.objects_to_vectors([{"name": "Aardvark"}, {"name": "Aardwolf"}, {"name": "Zesty"}])
69
+ >>> assert pairwise_cosine_similarity(vectors[0], vectors[1]) > pairwise_cosine_similarity(vectors[0], vectors[2])
70
+
71
+ Note you should consult the documentation for the specific indexer you are using for more details on
72
+ how text is converted to vectors.
35
73
 
36
- class Indexer(BaseModel):
37
- """
38
- An indexer operates on a collection in order to search for objects.
39
74
  """
40
75
 
41
76
  name: Optional[str] = None
@@ -119,10 +154,13 @@ class Indexer(BaseModel):
119
154
  return str(obj)
120
155
 
121
156
  def search(
122
- self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None
157
+ self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None,
158
+ mmr_relevance_factor: Optional[float] = None
123
159
  ) -> List[Tuple[float, Any]]:
124
160
  """
125
- Search the index for a query string
161
+ Use the indexer to search against a database of vectors.
162
+
163
+ Note: this is a low-level method, typically you would use the :ref:`search` method on a :ref:`Collection`.
126
164
 
127
165
  :param query: The query string to search for
128
166
  :param vectors: A list of indexed items, where each item is a tuple of (id, vector)
@@ -133,13 +171,29 @@ class Indexer(BaseModel):
133
171
  # Convert the query string to a vector
134
172
  query_vector = self.text_to_vector(query, cache=False)
135
173
 
174
+ if mmr_relevance_factor is not None:
175
+ vlist = [v for _, v in vectors]
176
+ idlist = [id for id, _ in vectors]
177
+ sorted_indices = mmr_diversified_search(
178
+ query_vector, vlist,
179
+ relevance_factor=mmr_relevance_factor, top_n=limit)
180
+ results = []
181
+ # TODO: this is inefficient when limit is high
182
+ for i in range(limit):
183
+ if i >= len(sorted_indices):
184
+ break
185
+ pos = sorted_indices[i]
186
+ score = pairwise_cosine_similarity(query_vector, vlist[pos])
187
+ results.append((score, idlist[pos]))
188
+ return results
189
+
136
190
  distances = []
137
191
 
138
192
  # Iterate over each indexed item
139
193
  for item_id, item_vector in vectors:
140
194
  # Calculate the Euclidean distance between the query vector and the item vector
141
195
  # distance = 1-np.linalg.norm(query_vector - item_vector)
142
- distance = cosine_similarity(query_vector, item_vector)
196
+ distance = pairwise_cosine_similarity(query_vector, item_vector)
143
197
  distances.append((distance, item_id))
144
198
 
145
199
  # Sort the distances in ascending order
@@ -20,6 +20,8 @@ def score_match(target: Optional[Any], candidate: Optional[Any], match_function:
20
20
  1.0
21
21
  >>> score_match("a", "b")
22
22
  0.0
23
+ >>> score_match("abcd", "abcde")
24
+ 0.0
23
25
  >>> score_match("a", None)
24
26
  0.0
25
27
  >>> score_match(None, "a")
@@ -52,7 +54,7 @@ def score_match(target: Optional[Any], candidate: Optional[Any], match_function:
52
54
 
53
55
  :param target:
54
56
  :param candidate:
55
- :param match_function:
57
+ :param match_function: defaults to struct
56
58
  :return:
57
59
  """
58
60
  if target == candidate:
@@ -99,7 +101,8 @@ def evaluate_predictor(
99
101
  :param predictor:
100
102
  :param target_attributes:
101
103
  :param feature_attributes:
102
- :param evaluation_count:
104
+ :param evaluation_count: max iterations
105
+ :param match_function: function to use for matching
103
106
  :return:
104
107
  """
105
108
  n = 0
@@ -113,8 +116,8 @@ def evaluate_predictor(
113
116
  else:
114
117
  test_obj = row
115
118
  result = predictor.derive(test_obj)
116
- logger.info(f"Predicted: {result.predicted_object} Expected: {expected_obj}")
117
119
  tp += score_match(result.predicted_object, expected_obj, match_function)
120
+ logger.info(f"TP={tp} MF={match_function} Predicted: {result.predicted_object} Expected: {expected_obj}")
118
121
  n += 1
119
122
  if evaluation_count is not None and n >= evaluation_count:
120
123
  break
@@ -125,6 +128,9 @@ def score_text_overlap(str1: Any, str2: Any) -> float:
125
128
  """
126
129
  Compute the overlap score between two strings.
127
130
 
131
+ >>> score_text_overlap("abc", "bcde")
132
+ 0.5
133
+
128
134
  :param str1:
129
135
  :param str2:
130
136
  :return:
@@ -1,17 +1,24 @@
1
+ import json
1
2
  import logging
2
3
  from dataclasses import dataclass
3
- from typing import Any, Optional
4
+ from pathlib import Path
5
+ from typing import ClassVar, List, Optional, TextIO, Union
4
6
 
5
7
  import yaml
6
8
  from llm import get_key
9
+ from pydantic import BaseModel
7
10
 
8
11
  from linkml_store.api.collection import OBJECT, Collection
9
12
  from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
10
- from linkml_store.inference.inference_engine import InferenceEngine
13
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
11
14
  from linkml_store.utils.object_utils import select_nested
12
15
 
13
16
  logger = logging.getLogger(__name__)
14
17
 
18
+ MAX_ITERATIONS = 5
19
+ DEFAULT_NUM_EXAMPLES = 20
20
+ DEFAULT_MMR_RELEVANCE_FACTOR = 0.8
21
+
15
22
  SYSTEM_PROMPT = """
16
23
  You are a {llm_config.role}, your task is to inference the YAML
17
24
  object output given the YAML object input. I will provide you
@@ -23,9 +30,14 @@ You should return ONLY valid YAML in your response.
23
30
  """
24
31
 
25
32
 
26
- # def select_object(obj: OBJECT, key_paths: List[str]) -> OBJECT:
27
- # return {k: obj.get(k, None) for k in keys}
28
- # return {k: object_path_get(obj, k, None) for k in key_paths}
33
+ class TrainedModel(BaseModel, extra="forbid"):
34
+ rag_collection_rows: List[OBJECT]
35
+ index_rows: List[OBJECT]
36
+ config: Optional[InferenceConfig] = None
37
+
38
+
39
+ class RAGInference(Inference):
40
+ iterations: int = 0
29
41
 
30
42
 
31
43
  @dataclass
@@ -54,14 +66,23 @@ class RAGInferenceEngine(InferenceEngine):
54
66
  >>> prediction.predicted_object
55
67
  {'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
56
68
 
69
+ The "model" can be saved for later use:
70
+
71
+ >>> ie.export_model("tests/output/countries.rag_model.json")
72
+
73
+ Note in this case the model is not the underlying LLM, but the "RAG Model" which is the vectorized
74
+ representation of training set objects.
75
+
57
76
  """
58
77
 
59
- classifier: Any = None
60
- encoders: dict = None
61
78
  _model: "llm.Model" = None # noqa: F821
62
79
 
63
80
  rag_collection: Collection = None
64
81
 
82
+ PERSIST_COLS: ClassVar[List[str]] = [
83
+ "config",
84
+ ]
85
+
65
86
  def __post_init__(self):
66
87
  if not self.config:
67
88
  self.config = InferenceConfig()
@@ -81,14 +102,16 @@ class RAGInferenceEngine(InferenceEngine):
81
102
  return self._model
82
103
 
83
104
  def initialize_model(self, **kwargs):
84
- rag_collection = self.training_data.collection
85
- rag_collection.attach_indexer("llm", auto_index=False)
86
- self.rag_collection = rag_collection
105
+ logger.info(f"Initializing model {self.model}")
106
+ if self.training_data:
107
+ rag_collection = self.training_data.collection
108
+ rag_collection.attach_indexer("llm", auto_index=False)
109
+ self.rag_collection = rag_collection
87
110
 
88
111
  def object_to_text(self, object: OBJECT) -> str:
89
112
  return yaml.dump(object)
90
113
 
91
- def derive(self, object: OBJECT) -> Optional[Inference]:
114
+ def derive(self, object: OBJECT, iteration=0, additional_prompt_texts: Optional[List[str]] = None) -> Optional[RAGInference]:
92
115
  import llm
93
116
  from tiktoken import encoding_for_model
94
117
 
@@ -98,48 +121,142 @@ class RAGInferenceEngine(InferenceEngine):
98
121
  model_name = self.config.llm_config.model_name
99
122
  feature_attributes = self.config.feature_attributes
100
123
  target_attributes = self.config.target_attributes
101
- num_examples = self.config.llm_config.number_of_few_shot_examples or 5
124
+ num_examples = self.config.llm_config.number_of_few_shot_examples or DEFAULT_NUM_EXAMPLES
102
125
  query_text = self.object_to_text(object)
103
- if not self.rag_collection.indexers:
104
- raise ValueError("RAG collection must have an indexer attached")
105
- rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
106
- examples = rs.rows
107
- if not examples:
108
- raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
126
+ mmr_relevance_factor = DEFAULT_MMR_RELEVANCE_FACTOR
127
+ if not self.rag_collection:
128
+ # TODO: zero-shot mode
129
+ examples = []
130
+ else:
131
+ if not self.rag_collection.indexers:
132
+ raise ValueError("RAG collection must have an indexer attached")
133
+ rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm",
134
+ mmr_relevance_factor=mmr_relevance_factor)
135
+ examples = rs.rows
136
+ if not examples:
137
+ raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
109
138
  prompt_clauses = []
139
+ query_obj = select_nested(object, feature_attributes)
140
+ query_text = self.object_to_text(query_obj)
110
141
  for example in examples:
111
- # input_obj = {k: example.get(k, None) for k in feature_attributes}
112
142
  input_obj = select_nested(example, feature_attributes)
113
- # output_obj = {k: example.get(k, None) for k in target_attributes}
143
+ input_obj_text = self.object_to_text(input_obj)
144
+ if input_obj_text == query_text:
145
+ raise ValueError(
146
+ f"Query object {query_text} is the same as example object {input_obj_text}\n"
147
+ "This indicates possible test data leakage\n."
148
+ "TODO: allow an option that allows user to treat this as a basic lookup\n"
149
+ )
114
150
  output_obj = select_nested(example, target_attributes)
115
151
  prompt_clause = (
116
- "---\nExample:\n"
117
- f"## INPUT:\n{self.object_to_text(input_obj)}\n"
118
- f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
152
+ "---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
119
153
  )
120
154
  prompt_clauses.append(prompt_clause)
121
- # query_obj = {k: object.get(k, None) for k in feature_attributes}
122
- query_obj = select_nested(object, feature_attributes)
123
- query_text = self.object_to_text(query_obj)
124
- prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
155
+
125
156
  system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
157
+ system_prompt += "\n".join(additional_prompt_texts or [])
158
+ prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
126
159
 
127
- def make_text(texts):
128
- return "\n".join(prompt_clauses) + prompt_end
160
+ def make_text(texts: List[str]):
161
+ return "\n".join(texts) + prompt_end
129
162
 
130
163
  try:
131
164
  encoding = encoding_for_model(model_name)
132
165
  except KeyError:
133
166
  encoding = encoding_for_model("gpt-4")
134
167
  token_limit = get_token_limit(model_name)
135
- prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
168
+ prompt = render_formatted_text(make_text, values=prompt_clauses,
169
+ encoding=encoding, token_limit=token_limit,
170
+ additional_text=system_prompt)
136
171
  logger.info(f"Prompt: {prompt}")
137
172
  response = model.prompt(prompt, system_prompt)
138
173
  yaml_str = response.text()
139
174
  logger.info(f"Response: {yaml_str}")
175
+ predicted_object = self._parse_yaml_payload(yaml_str, strict=True)
176
+ if self.config.validate_results:
177
+ base_collection = self.training_data.base_collection
178
+ errs = list(base_collection.iter_validate_collection([predicted_object]))
179
+ if errs:
180
+ print(f"{iteration} // FAILED TO VALIDATE: {yaml_str}")
181
+ print(f"PARSED: {predicted_object}")
182
+ print(f"ERRORS: {errs}")
183
+ if iteration > MAX_ITERATIONS:
184
+ raise ValueError(f"Validation errors: {errs}")
185
+ extra_texts = [
186
+ "Make sure results conform to the schema. Previously you provided:\n",
187
+ yaml_str,
188
+ "\nThis was invalid.\n",
189
+ "Validation errors:\n",
190
+ ] + [self.object_to_text(e) for e in errs]
191
+ return self.derive(object, iteration=iteration+1, additional_prompt_texts=extra_texts)
192
+ return RAGInference(predicted_object=predicted_object, iterations=iteration+1, query=object)
193
+
194
+ def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
195
+ if "```" in yaml_str:
196
+ yaml_str = yaml_str.split("```")[1].strip()
197
+ if yaml_str.startswith("yaml"):
198
+ yaml_str = yaml_str[4:].strip()
140
199
  try:
141
- predicted_object = yaml.safe_load(yaml_str)
142
- return Inference(predicted_object=predicted_object)
143
- except yaml.parser.ParserError as e:
144
- logger.error(f"Error parsing response: {yaml_str}\n{e}")
200
+ return yaml.safe_load(yaml_str)
201
+ except Exception as e:
202
+ if strict:
203
+ raise e
204
+ logger.error(f"Error parsing YAML: {yaml_str}\n{e}")
145
205
  return None
206
+
207
+ def export_model(
208
+ self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
209
+ ):
210
+ self.save_model(output)
211
+
212
+ def save_model(self, output: Union[str, Path]) -> None:
213
+ """
214
+ Save the trained model and related data to a file.
215
+
216
+ :param output: Path to save the model
217
+ """
218
+
219
+ # trigger index
220
+ _qr = self.rag_collection.search("*", limit=1)
221
+ assert len(_qr.ranked_rows) > 0
222
+
223
+ rows = self.rag_collection.find(limit=-1).rows
224
+
225
+ indexers = self.rag_collection.indexers
226
+ assert len(indexers) == 1
227
+ ix = self.rag_collection.indexers["llm"]
228
+ ix_coll = self.rag_collection.parent.get_collection(self.rag_collection.get_index_collection_name(ix))
229
+
230
+ ix_rows = ix_coll.find(limit=-1).rows
231
+ assert len(ix_rows) > 0
232
+ tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows, config=self.config)
233
+ # tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows)
234
+ with open(output, "w", encoding="utf-8") as f:
235
+ json.dump(tm.model_dump(), f)
236
+
237
+ @classmethod
238
+ def load_model(cls, file_path: Union[str, Path]) -> "RAGInferenceEngine":
239
+ """
240
+ Load a trained model and related data from a file.
241
+
242
+ :param file_path: Path to the saved model
243
+ :return: SklearnInferenceEngine instance with loaded model
244
+ """
245
+ with open(file_path, "r", encoding="utf-8") as f:
246
+ model_data = json.load(f)
247
+ tm = TrainedModel(**model_data)
248
+ from linkml_store.api import Client
249
+
250
+ client = Client()
251
+ db = client.attach_database("duckdb", alias="training")
252
+ db.store({"data": tm.rag_collection_rows})
253
+ collection = db.get_collection("data")
254
+ ix = collection.attach_indexer("llm", auto_index=False)
255
+ assert ix.name
256
+ ix_coll_name = collection.get_index_collection_name(ix)
257
+ assert ix_coll_name
258
+ ix_coll = db.get_collection(ix_coll_name, create_if_not_exists=True)
259
+ ix_coll.insert(tm.index_rows)
260
+ ie = cls(config=tm.config)
261
+ ie.rag_collection = collection
262
+ return ie
@@ -153,7 +153,7 @@ class SklearnInferenceEngine(InferenceEngine):
153
153
  y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
154
154
  self.transformed_targets = y_encoder.classes_
155
155
 
156
- logger.info(f"Fitting model with features: {X.columns}")
156
+ # print(f"Fitting model with features: {X.columns}")
157
157
  clf = DecisionTreeClassifier(random_state=42)
158
158
  clf.fit(X, y)
159
159
  self.classifier = clf
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import List, Optional, Tuple
2
+ from typing import List, Optional, Tuple, Any
3
3
 
4
4
  from pydantic import BaseModel, ConfigDict, Field
5
5
 
@@ -35,6 +35,8 @@ class InferenceConfig(BaseModel, extra="forbid"):
35
35
  feature_attributes: Optional[List[str]] = None
36
36
  train_test_split: Optional[Tuple[float, float]] = None
37
37
  llm_config: Optional[LLMConfig] = None
38
+ random_seed: Optional[int] = None
39
+ validate_results: Optional[bool] = None
38
40
 
39
41
  @classmethod
40
42
  def from_file(cls, file_path: str, format: Optional[Format] = None) -> "InferenceConfig":
@@ -57,6 +59,7 @@ class Inference(BaseModel, extra="forbid"):
57
59
  """
58
60
  Result of an inference derivation.
59
61
  """
60
-
62
+ query: Optional[OBJECT] = Field(default=None, description="The query object.")
61
63
  predicted_object: OBJECT = Field(..., description="The predicted object.")
62
64
  confidence: Optional[float] = Field(default=None, description="The confidence of the prediction.", le=1.0, ge=0.0)
65
+ explanation: Optional[Any] = Field(default=None, description="Explanation of the prediction.")
@@ -29,6 +29,7 @@ class ModelSerialization(str, Enum):
29
29
  PNG = "png"
30
30
  LINKML_EXPRESSION = "linkml_expression"
31
31
  RULE_BASED = "rulebased"
32
+ RAG_INDEX = "rag_index"
32
33
 
33
34
  @classmethod
34
35
  def from_filepath(cls, file_path: str) -> Optional["ModelSerialization"]:
@@ -58,7 +59,7 @@ class ModelSerialization(str, Enum):
58
59
 
59
60
 
60
61
  class CollectionSlice(BaseModel):
61
- model_config = ConfigDict(arbitrary_types_allowed=True)
62
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
62
63
 
63
64
  name: Optional[str] = None
64
65
  base_collection: Optional[Collection] = None
@@ -69,17 +70,26 @@ class CollectionSlice(BaseModel):
69
70
 
70
71
  @property
71
72
  def collection(self) -> Collection:
73
+ if not self._collection and not self.indices:
74
+ return self.base_collection
72
75
  if not self._collection:
73
76
  rows = self.base_collection.find({}, limit=-1).rows
74
- # subset based on indices
75
77
  subset = [rows[i] for i in self.indices]
76
78
  db = self.base_collection.parent
77
- subset_name = f"{self.base_collection.alias}__rag_{self.name}"
79
+ subset_name = self.slice_alias
78
80
  subset_collection = db.get_collection(subset_name, create_if_not_exists=True)
81
+ # ensure the collection has the same schema type as the base collection;
82
+ # this ensures that column/attribute types are preserved
83
+ subset_collection.metadata.type = self.base_collection.target_class_name
84
+ subset_collection.delete_where({})
79
85
  subset_collection.insert(subset)
80
86
  self._collection = subset_collection
81
87
  return self._collection
82
88
 
89
+ @property
90
+ def slice_alias(self) -> str:
91
+ return f"{self.base_collection.alias}__rag_{self.name}"
92
+
83
93
  def as_dataframe(self, flattened=False) -> pd.DataFrame:
84
94
  """
85
95
  Return the slice of the collection as a dataframe.
@@ -113,31 +123,28 @@ class InferenceEngine(ABC):
113
123
 
114
124
  :param collection:
115
125
  :param split:
126
+ :param randomize:
116
127
  :return:
117
128
  """
129
+ local_random = random.Random(self.config.random_seed) if self.config.random_seed else random.Random()
118
130
  split = split or self.config.train_test_split
119
131
  if not split:
120
132
  split = (0.7, 0.3)
133
+ if split[0] == 1.0:
134
+ self.training_data = CollectionSlice(name="train", base_collection=collection, indices=None)
135
+ self.testing_data = None
136
+ return
121
137
  logger.info(f"Loading and splitting data from collection {collection.alias}")
122
138
  size = collection.size()
123
139
  indices = range(size)
124
140
  if randomize:
125
- train_indices = random.sample(indices, int(size * split[0]))
141
+ train_indices = local_random.sample(indices, int(size * split[0]))
126
142
  test_indices = set(indices) - set(train_indices)
127
143
  else:
128
144
  train_indices = indices[: int(size * split[0])]
129
145
  test_indices = indices[int(size * split[0]) :]
130
146
  self.training_data = CollectionSlice(name="train", base_collection=collection, indices=train_indices)
131
147
  self.testing_data = CollectionSlice(name="test", base_collection=collection, indices=test_indices)
132
- # all_data = collection.find({}, limit=size).rows
133
- # all_data_df = nested_objects_to_dataframe(all_data)
134
- # all_data_df = collection.find({}, limit=size).rows_dataframe
135
- # randomize/shuffle order of rows in dataframe
136
- # all_data_df = all_data_df.sample(frac=1).reset_index(drop=True)
137
- # self.training_data = CollectionSlice(dataframe=all_data_df[: int(size * split[0])])
138
- # self.testing_data = CollectionSlice(dataframe=all_data_df[int(size * split[0]) : size])
139
- # self.training_data = CollectionSlice(base_collection=collection, slice=(0, int(size * split[0])))
140
- # self.testing_data = CollectionSlice(base_collection=collection, slice=(int(size * split[0]), size))
141
148
 
142
149
  def initialize_model(self, **kwargs):
143
150
  """
@@ -20,6 +20,7 @@ MODEL_TOKEN_MAPPING = {
20
20
  "gpt-3.5-turbo-instruct": 4096,
21
21
  "text-ada-001": 2049,
22
22
  "ada": 2049,
23
+ "ada-002": 8192,
23
24
  "text-babbage-001": 2040,
24
25
  "babbage": 2049,
25
26
  "text-curie-001": 2049,
@@ -0,0 +1,165 @@
1
+ import logging
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ LOL = List[List[float]]
10
+
11
+ def pairwise_cosine_similarity(vector1: np.array, vector2: np.array) -> float:
12
+ """
13
+ Calculate the cosine similarity between two vectors.
14
+
15
+ >>> v100 = np.array([1, 0, 0])
16
+ >>> v010 = np.array([0, 1, 0])
17
+ >>> v001 = np.array([0, 0, 1])
18
+ >>> v011 = np.array([0, 1, 1])
19
+ >>> pairwise_cosine_similarity(v100, v010)
20
+ 0.0
21
+ >>> pairwise_cosine_similarity(v100, v001)
22
+ 0.0
23
+ >>> pairwise_cosine_similarity(v010, v001)
24
+ 0.0
25
+ >>> pairwise_cosine_similarity(v100, v100)
26
+ 1.0
27
+ >>> f"{pairwise_cosine_similarity(v010, v011):0.3f}"
28
+ '0.707'
29
+
30
+ :param vector1:
31
+ :param vector2:
32
+ :return:
33
+ """
34
+ dot_product = np.dot(vector1, vector2)
35
+ norm1 = np.linalg.norm(vector1)
36
+ norm2 = np.linalg.norm(vector2)
37
+ return dot_product / (norm1 * norm2)
38
+
39
+
40
+ def compute_cosine_similarity_matrix(list1: LOL, list2: LOL) -> np.ndarray:
41
+ """
42
+ Compute cosine similarity between two lists of vectors.
43
+
44
+ Result is a two column vector sim[ROW][COL] where ROW is from list1 and COL is from list2.
45
+
46
+ :param list1:
47
+ :param list2:
48
+ :return:
49
+ """
50
+ # Convert lists to numpy arrays
51
+ matrix1 = np.array(list1)
52
+ matrix2 = np.array(list2)
53
+
54
+ # Normalize the vectors in both matrices
55
+ matrix1_norm = matrix1 / np.linalg.norm(matrix1, axis=1)[:, np.newaxis]
56
+ matrix2_norm = matrix2 / np.linalg.norm(matrix2, axis=1)[:, np.newaxis]
57
+
58
+ # Compute dot products (resulting in cosine similarity values)
59
+ cosine_similarity_matrix = np.dot(matrix1_norm, matrix2_norm.T)
60
+
61
+ return cosine_similarity_matrix
62
+
63
+
64
+ def top_matches(cosine_similarity_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
65
+ """
66
+ Find the top match for each row in the cosine similarity matrix.
67
+
68
+ :param cosine_similarity_matrix:
69
+ :return:
70
+ """
71
+ # Find the index of the maximum value in each row
72
+ top_match_indices = np.argmax(cosine_similarity_matrix, axis=1)
73
+
74
+ # Find the maximum similarity value in each row
75
+ top_match_values = np.amax(cosine_similarity_matrix, axis=1)
76
+
77
+ return top_match_indices, top_match_values
78
+
79
+
80
+ def top_n_matches(
81
+ cosine_similarity_matrix: np.ndarray, n: int = 10
82
+ ) -> Tuple[np.ndarray, np.ndarray]:
83
+ # Find the indices that would sort each row in descending order
84
+ sorted_indices = np.argsort(-cosine_similarity_matrix, axis=1)
85
+
86
+ # Take the first n indices from the sorted indices to get the top n matches
87
+ top_n_indices = sorted_indices[:, :n]
88
+
89
+ # Take the first n values from the sorted values to get the top n match values
90
+ top_n_values = -np.sort(-cosine_similarity_matrix, axis=1)[:, :n]
91
+
92
+ return top_n_indices, top_n_values
93
+
94
+
95
+ def mmr_diversified_search(
96
+ query_vector: np.ndarray, document_vectors: List[np.ndarray], relevance_factor=0.5, top_n=None
97
+ ) -> List[int]:
98
+ """
99
+ Perform diversified search using Maximal Marginal Relevance (MMR).
100
+
101
+ :param query_vector: The vector representing the query.
102
+ :param document_vectors: The vectors representing the documents.
103
+ :param relevance_factor: The balance parameter between relevance and diversity.
104
+ :param top_n: The number of results to return. If None, return all.
105
+ :return: A list of indices representing the diversified order of documents.
106
+ """
107
+ if top_n is None:
108
+ # If no specific number of results is specified, return all
109
+ top_n = len(document_vectors)
110
+
111
+ if top_n == 0:
112
+ return []
113
+
114
+ # Calculate cosine similarities between query and all documents
115
+ norms_query = np.linalg.norm(query_vector)
116
+ norms_docs = np.linalg.norm(document_vectors, axis=1)
117
+ similarities = np.dot(document_vectors, query_vector) / (norms_docs * norms_query)
118
+
119
+ # Initialize set of selected indices and results list
120
+ selected_indices = set()
121
+ result_indices = []
122
+
123
+ # Diversified search loop
124
+ for _ in range(top_n):
125
+ max_mmr = float("-inf")
126
+ best_index = None
127
+
128
+ # Loop over all documents
129
+ for idx, _doc_vector in enumerate(document_vectors):
130
+ if idx not in selected_indices:
131
+ relevance = relevance_factor * similarities[idx]
132
+ diversity = 0
133
+
134
+ # Penalize based on similarity to already selected documents
135
+ if selected_indices:
136
+ max_sim_to_selected = max(
137
+ [
138
+ np.dot(document_vectors[idx], document_vectors[s])
139
+ / (
140
+ np.linalg.norm(document_vectors[idx])
141
+ * np.linalg.norm(document_vectors[s])
142
+ )
143
+ for s in selected_indices
144
+ ]
145
+ )
146
+ diversity = (1 - relevance_factor) * max_sim_to_selected
147
+
148
+ mmr_score = relevance - diversity
149
+
150
+ # Update best MMR score and index
151
+ if mmr_score > max_mmr:
152
+ max_mmr = mmr_score
153
+ best_index = idx
154
+
155
+ # Add the best document to the result and mark it as selected
156
+ if best_index is None:
157
+ logger.warning(f"No best index found over {len(document_vectors)} documents.")
158
+ continue
159
+ result_indices.append(best_index)
160
+ selected_indices.add(best_index)
161
+
162
+ return result_indices
163
+
164
+
165
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: linkml-store
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: linkml-store
5
5
  License: MIT
6
6
  Author: Author 1
@@ -18,6 +18,7 @@ Provides-Extra: chromadb
18
18
  Provides-Extra: fastapi
19
19
  Provides-Extra: frictionless
20
20
  Provides-Extra: h5py
21
+ Provides-Extra: ibis
21
22
  Provides-Extra: llm
22
23
  Provides-Extra: map
23
24
  Provides-Extra: mongodb
@@ -34,7 +35,9 @@ Requires-Dist: duckdb (>=0.10.1)
34
35
  Requires-Dist: duckdb-engine (>=0.11.2)
35
36
  Requires-Dist: fastapi ; extra == "fastapi"
36
37
  Requires-Dist: frictionless ; extra == "frictionless"
38
+ Requires-Dist: gcsfs ; extra == "ibis"
37
39
  Requires-Dist: h5py ; extra == "h5py"
40
+ Requires-Dist: ibis-framework[duckdb,examples] (>=9.3.0) ; extra == "ibis"
38
41
  Requires-Dist: jinja2 (>=3.1.4,<4.0.0)
39
42
  Requires-Dist: jsonlines (>=4.0.0,<5.0.0)
40
43
  Requires-Dist: linkml (>=1.8.0) ; extra == "validation"
@@ -43,6 +46,7 @@ Requires-Dist: linkml_map ; extra == "map"
43
46
  Requires-Dist: linkml_renderer ; extra == "renderer"
44
47
  Requires-Dist: llm ; extra == "llm"
45
48
  Requires-Dist: matplotlib ; extra == "analytics"
49
+ Requires-Dist: multipledispatch ; extra == "ibis"
46
50
  Requires-Dist: neo4j ; extra == "neo4j"
47
51
  Requires-Dist: networkx ; extra == "neo4j"
48
52
  Requires-Dist: pandas (>=2.2.1) ; extra == "analytics"
@@ -52,6 +56,7 @@ Requires-Dist: pyarrow ; extra == "pyarrow"
52
56
  Requires-Dist: pydantic (>=2.0.0,<3.0.0)
53
57
  Requires-Dist: pymongo ; extra == "mongodb"
54
58
  Requires-Dist: pystow (>=0.5.4,<0.6.0)
59
+ Requires-Dist: ruff (>=0.6.2) ; extra == "tests"
55
60
  Requires-Dist: scikit-learn ; extra == "scipy"
56
61
  Requires-Dist: scipy ; extra == "scipy"
57
62
  Requires-Dist: seaborn ; extra == "analytics"
@@ -1,16 +1,16 @@
1
1
  linkml_store/__init__.py,sha256=jlU6WOUAn8cKIhzbTULmBTWpW9gZdEt7q_RI6KZN1bY,118
2
2
  linkml_store/api/__init__.py,sha256=3CelcFEFz0y3MkQAzhQ9JxHIt1zFk6nYZxSmYTo8YZE,226
3
3
  linkml_store/api/client.py,sha256=3klBXenQVbLjNQF3WmYfjASt3zvKOfWaCNp5aJM81Ec,12034
4
- linkml_store/api/collection.py,sha256=7JndC6A9r3OVbR9aB6d_bdaYN53XU4FpppUterygOaE,37800
4
+ linkml_store/api/collection.py,sha256=YVmfqdZaWfLAw3yzho-GEknsAiV1h5Z3O6csB_8CTY0,39407
5
5
  linkml_store/api/config.py,sha256=71pxQ5jM-ETxJWU7CzmKjsH6IEJUMP5sml381u9TYVk,5654
6
- linkml_store/api/database.py,sha256=QVvUuLQPCxB4cvsS7rXqPSfoHkhcMzP9vUcsjkuEYds,29051
7
- linkml_store/api/queries.py,sha256=w0qnNeCH6pC9WTGoEQYd300MF6o0G3atz2YxN3WecAs,2028
6
+ linkml_store/api/database.py,sha256=nvae8jnOZsQIFCsl_lRBnKcvrpJg4A10ujIKGeMyUS8,29350
7
+ linkml_store/api/queries.py,sha256=tx9fgGY5fC_2ZbIvg4BqTK_MXJwA_DI4mxr8HdQ6Vos,2075
8
8
  linkml_store/api/stores/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  linkml_store/api/stores/chromadb/__init__.py,sha256=e9BkOPuPnVQKA5PRKDulag59yGNHDP3U2_DnPSrFAKM,132
10
10
  linkml_store/api/stores/chromadb/chromadb_collection.py,sha256=RQUZx5oeotkzNihg-dlSevkiTiKY1d9x0bS63HF80W4,4270
11
11
  linkml_store/api/stores/chromadb/chromadb_database.py,sha256=dZA3LQE8-ZMhJQOzsUFyxehnKpFF7adR182aggfkaFY,3205
12
12
  linkml_store/api/stores/duckdb/__init__.py,sha256=rbQSDgNg-fdvi6-pHGYkJTST4p1qXUZBf9sFSsO3KPk,387
13
- linkml_store/api/stores/duckdb/duckdb_collection.py,sha256=yXnJpEiGK4lMyNuJykuvlKOqaV9ntqv0m0NZMOw0auk,6911
13
+ linkml_store/api/stores/duckdb/duckdb_collection.py,sha256=Rkbm_uIVIRj5576lEolsyY_3Um1h8Lf3RHn8Fy3LIgU,7036
14
14
  linkml_store/api/stores/duckdb/duckdb_database.py,sha256=GH9bcOfHpNp6r-Eu1C3W0xuYcLsqGFDH1Sh4weifGaQ,9923
15
15
  linkml_store/api/stores/duckdb/mappings.py,sha256=tDce3W1Apwammhf4LS6cRJ0m4NiJ0eB7vOI_4U5ETY8,148
16
16
  linkml_store/api/stores/filesystem/__init__.py,sha256=KjvCjdttwqMHNeGyL-gr59zRz0--HFEWWUNNCJ5hITs,347
@@ -30,30 +30,30 @@ linkml_store/api/stores/solr/solr_collection.py,sha256=ZlxC3JbVaHfSA4HuTeJTsp6qe
30
30
  linkml_store/api/stores/solr/solr_database.py,sha256=TFjqbY7jAkdrhAchbNg0E-mChSP7ogNwFExslbvX7Yo,2877
31
31
  linkml_store/api/stores/solr/solr_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  linkml_store/api/types.py,sha256=3aIQtDFMvsSmjuN5qrR2vNK5sHa6yzD_rEOPA6tHwvg,176
33
- linkml_store/cli.py,sha256=NIEU5dEkEKz3a2q4mpkdXxHX1mANd2z9oFIkNVz-wsw,27048
33
+ linkml_store/cli.py,sha256=wl8BhnPcSU6Lt-jsvN1o6086PpUAfu43n5GI6w9SGxw,29384
34
34
  linkml_store/constants.py,sha256=x4ZmDsfE9rZcL5WpA93uTKrRWzCD6GodYXviVzIvR38,112
35
35
  linkml_store/graphs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  linkml_store/graphs/graph_map.py,sha256=bYRxv8n1YPnFqE9d6JKNmRawb8EAhsPlHhBue0gvtZE,712
37
37
  linkml_store/index/__init__.py,sha256=6SQzDe-WZSSqbGNsbCDfyPTyz0s9ISDKw1dm9xgQuT4,1396
38
38
  linkml_store/index/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- linkml_store/index/implementations/llm_indexer.py,sha256=LI5f8SLF_rJY5W6wZPLaUqpyoq-VDW_KqlCBNDNm_po,4827
39
+ linkml_store/index/implementations/llm_indexer.py,sha256=y1xvfUm_rl4UEiWJbsUsEnTCma98XRB9C1XOnuaAv5o,5474
40
40
  linkml_store/index/implementations/simple_indexer.py,sha256=KnkFJtXTHnwjhD_D6ZK2rFhBID1dgCedcOVPEWAY2NU,1282
41
- linkml_store/index/indexer.py,sha256=K-TDPzdTyGFo6iG4XI_A_3IpwDbKeiTIbdr85NIL5r8,4918
41
+ linkml_store/index/indexer.py,sha256=e5dsjh2wjOTDRsfClKJAFTbcK1UC7BOGkUCOfDg9omI,7635
42
42
  linkml_store/inference/__init__.py,sha256=b8NAFNZjOYU_8gOvxdyCyoiHOOl5Ai2ckKs1tv7ZkkY,342
43
- linkml_store/inference/evaluation.py,sha256=qvsmGDBKTZBDKhpbPDe_AkcJ2LtQ8e-oUYCUGfI6IAE,5799
43
+ linkml_store/inference/evaluation.py,sha256=YDFYaEu2QLSfFq4oyARrnKfTiPLtNF8irhhspgVDfdY,6013
44
44
  linkml_store/inference/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- linkml_store/inference/implementations/rag_inference_engine.py,sha256=MH50-6i30Y5oKgIx47-yDjsPCojYC6-lujtHFBDqIxs,5833
45
+ linkml_store/inference/implementations/rag_inference_engine.py,sha256=mN7YQI-BeZglsAnZnNIuAj-Nxg1su5efNaohooEmNmM,10622
46
46
  linkml_store/inference/implementations/rule_based_inference_engine.py,sha256=0IEY_fsHJPJy6QKbYQU_qE87RRnPOXQxPuJKXCQG8jU,6250
47
- linkml_store/inference/implementations/sklearn_inference_engine.py,sha256=HRhwnlpDJOijxvhLmdTSOq1S9xjBVCrgRT1C8uS0XZQ,13196
48
- linkml_store/inference/inference_config.py,sha256=xgl3VmueErLIOnQQn4HdC2STJNY6yKoPasWmym4ltHQ,2014
49
- linkml_store/inference/inference_engine.py,sha256=D1JlkihyNbZp7PYe5lplUbTJgyP7jL4vnxcpBio-KUs,6987
47
+ linkml_store/inference/implementations/sklearn_inference_engine.py,sha256=Sdi7CoRK3qoLJu3prgLy1Ck_zQ1gHWRKFybHe7XQ4_g,13192
48
+ linkml_store/inference/inference_config.py,sha256=EFGdigxWsfTPREbgqyJVRShN0JktCEmFLLoECrLfXSg,2282
49
+ linkml_store/inference/inference_engine.py,sha256=l2UB6cA0rW7a9qyiv8JF5Nzj8nRHGX_yqMYbiDnY1Qc,7055
50
50
  linkml_store/inference/inference_engine_registry.py,sha256=6o66gvBYBwdeAKm62zqqvfaBlcopVP_cla3L6uXGsHA,3015
51
51
  linkml_store/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
52
52
  linkml_store/utils/change_utils.py,sha256=O2rvSvgTKB60reLLz9mX5OWykAA_m93bwnUh5ZWa0EY,471
53
53
  linkml_store/utils/file_utils.py,sha256=rQ7-XpmI6_Kx_dhEnI98muFRr0MmgI_kZ_9cgJBf_0I,1411
54
54
  linkml_store/utils/format_utils.py,sha256=airJ2_tFsr0dTIbSHT5y0TZbDrvBBV4_qThFPFY5k8U,10925
55
55
  linkml_store/utils/io.py,sha256=JHUrWDtlZC2jtN_PQZ4ypdGIyYlftZEN3JaCvEPs44w,884
56
- linkml_store/utils/llm_utils.py,sha256=Wb4h_E8vrJZDAYHhOdMCSMcz-xxVia4nfuFqiYitZ98,2864
56
+ linkml_store/utils/llm_utils.py,sha256=3jRFUtEywoKdomKb3aCH1GdI9hQJOQo8Udb3Jy4M-Xw,2885
57
57
  linkml_store/utils/mongodb_utils.py,sha256=Rl1YmMKs1IXwSsJIViSDChbi0Oer5cBnMmjka2TeQS8,4665
58
58
  linkml_store/utils/neo4j_utils.py,sha256=y3KPmDZ8mQmePgg0lUeKkeKqzEr2rV226xxEtHc5pRg,1266
59
59
  linkml_store/utils/object_utils.py,sha256=Vib-5Ip2DlRVKLZpU-008ZZI813-vfKVSCY0TksRenM,6293
@@ -64,6 +64,7 @@ linkml_store/utils/schema_utils.py,sha256=iJiZxo5NGr7v87h4DV6V9DrDOZHSswMRuf0N4V
64
64
  linkml_store/utils/sklearn_utils.py,sha256=itPpcrsbbyOazdjmivaaZ1lyZeytm0a0hJ2AS8ziUgg,7590
65
65
  linkml_store/utils/sql_utils.py,sha256=T41w_vsc3SauTJQkDMwid_nOtKW1YOKyUuaxEf470hk,5938
66
66
  linkml_store/utils/stats_utils.py,sha256=4KqBb1bqDgAmq-1fJLLu5B2paPgoZZc3A-gnyVam4bI,1799
67
+ linkml_store/utils/vector_utils.py,sha256=Q1RlpDzavJAM9-H2m2XNU5BNUcfZkpIWeEZii2hK0PQ,5449
67
68
  linkml_store/webapi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
68
69
  linkml_store/webapi/html/__init__.py,sha256=hwp5eeBJKH65Bvv1x9Z4vsT1tLSYtb9Dq4I9r1kL1q0,69
69
70
  linkml_store/webapi/html/base.html.j2,sha256=hoiV2uaSxxrQp7VuAZBOHueH7czyJMYcPBRN6dZFYhk,693
@@ -72,8 +73,8 @@ linkml_store/webapi/html/database_details.html.j2,sha256=qtXdavbZb0mohiObI9dvJtk
72
73
  linkml_store/webapi/html/databases.html.j2,sha256=a9BCWQYfPeFhdUd31CWhB0yWhTIFXQayO08JgjyqKoc,294
73
74
  linkml_store/webapi/html/generic.html.j2,sha256=KtLaO2HUEF2Opq-OwHKgRKetNWe8IWc6JuIkxRPsywk,1018
74
75
  linkml_store/webapi/main.py,sha256=B0Da575kKR7X88N9ykm99Dem8FyBAW9f-w3A_JwUzfw,29165
75
- linkml_store-0.2.0.dist-info/LICENSE,sha256=77mDOslUnalYnuq9xQYZKtIoNEzcH9mIjvWHOKjamnE,1086
76
- linkml_store-0.2.0.dist-info/METADATA,sha256=v_KjIlu-gTOHunF0ASPHRP_utQv-ry1piX3RpfPWX1k,6743
77
- linkml_store-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
78
- linkml_store-0.2.0.dist-info/entry_points.txt,sha256=gWxVsHqx-t-UKWFHFzawQTvs4is4vC1rCF5AeKyqWWk,101
79
- linkml_store-0.2.0.dist-info/RECORD,,
76
+ linkml_store-0.2.2.dist-info/LICENSE,sha256=77mDOslUnalYnuq9xQYZKtIoNEzcH9mIjvWHOKjamnE,1086
77
+ linkml_store-0.2.2.dist-info/METADATA,sha256=_zde_tfX6AAw1ZvM1LnYOmzkQbiz6f3rQhVyBKODdnE,6977
78
+ linkml_store-0.2.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
79
+ linkml_store-0.2.2.dist-info/entry_points.txt,sha256=gWxVsHqx-t-UKWFHFzawQTvs4is4vC1rCF5AeKyqWWk,101
80
+ linkml_store-0.2.2.dist-info/RECORD,,