linkml-store 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

@@ -226,6 +226,18 @@ class Collection(Generic[DatabaseType]):
226
226
  self._initialized = True
227
227
  patches = [{"op": "add", "path": "/0", "value": obj} for obj in objs]
228
228
  self._broadcast(patches, **kwargs)
229
+ self._post_modification_hook(**kwargs)
230
+
231
+ def _post_delete_hook(self, **kwargs):
232
+ self._post_modification_hook(**kwargs)
233
+
234
+ def _post_modification_hook(self, **kwargs):
235
+ for indexer in self.indexers.values():
236
+ ix_collection_name = self.get_index_collection_name(indexer)
237
+ ix_collection = self.parent.get_collection(ix_collection_name)
238
+ # Currently updating the source triggers complete reindexing
239
+ # TODO: make this more efficient by only deleting modified
240
+ ix_collection.delete_where({})
229
241
 
230
242
  def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> Optional[int]:
231
243
  """
@@ -476,7 +488,7 @@ class Collection(Generic[DatabaseType]):
476
488
  Now let's index, using the simple trigram-based index
477
489
 
478
490
  >>> index = get_indexer("simple")
479
- >>> collection.attach_indexer(index)
491
+ >>> _ = collection.attach_indexer(index)
480
492
 
481
493
  Now let's find all objects:
482
494
 
@@ -514,7 +526,10 @@ class Collection(Generic[DatabaseType]):
514
526
  if ix_coll.size() == 0:
515
527
  logger.info(f"Index {index_name} is empty; indexing all objects")
516
528
  all_objs = self.find(limit=-1).rows
517
- self.index_objects(all_objs, index_name, replace=True, **kwargs)
529
+ if all_objs:
530
+ # print(f"Index {index_name} is empty; indexing all objects {len(all_objs)}")
531
+ self.index_objects(all_objs, index_name, replace=True, **kwargs)
532
+ assert ix_coll.size() > 0
518
533
  qr = ix_coll.find(where=where, limit=-1, **kwargs)
519
534
  index_col = ix.index_field
520
535
  # TODO: optimize this for large indexes
@@ -648,7 +663,31 @@ class Collection(Generic[DatabaseType]):
648
663
  """
649
664
  return self.find({}, limit=1).num_rows
650
665
 
651
- def attach_indexer(self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs):
666
+ def rows_iter(self) -> Iterable[OBJECT]:
667
+ """
668
+ Return an iterator over the objects in the collection.
669
+
670
+ :return:
671
+ """
672
+ yield from self.find({}, limit=-1).rows
673
+
674
+ def rows(self) -> List[OBJECT]:
675
+ """
676
+ Return a list of objects in the collection.
677
+
678
+ :return:
679
+ """
680
+ return list(self.rows_iter())
681
+
682
+ def ranked_rows(self) -> List[Tuple[float, OBJECT]]:
683
+ """
684
+ Return a list of objects in the collection, with scores.
685
+ """
686
+ return [(n, obj) for n, obj in enumerate(self.rows_iter())]
687
+
688
+ def attach_indexer(
689
+ self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs
690
+ ) -> Indexer:
652
691
  """
653
692
  Attach an index to the collection.
654
693
 
@@ -669,8 +708,8 @@ class Collection(Generic[DatabaseType]):
669
708
  >>> full_index.name = "full"
670
709
  >>> name_index = get_indexer("simple", text_template="{name}")
671
710
  >>> name_index.name = "name"
672
- >>> collection.attach_indexer(full_index)
673
- >>> collection.attach_indexer(name_index)
711
+ >>> _ = collection.attach_indexer(full_index)
712
+ >>> _ = collection.attach_indexer(name_index)
674
713
 
675
714
  Now let's find objects using the full index, using the string "France".
676
715
  We expect the country France to be the top hit, but the score will
@@ -713,6 +752,10 @@ class Collection(Generic[DatabaseType]):
713
752
  all_objs = self.find(limit=-1).rows
714
753
  logger.info(f"Auto-indexing {len(all_objs)} objects")
715
754
  self.index_objects(all_objs, index_name, replace=True, **kwargs)
755
+ return index
756
+
757
+ def get_index_collection_name(self, indexer: Indexer) -> str:
758
+ return self._index_collection_name(indexer.name)
716
759
 
717
760
  def _index_collection_name(self, index_name: str) -> str:
718
761
  """
@@ -268,7 +268,7 @@ class Database(ABC, Generic[CollectionType]):
268
268
  metadata: Optional[CollectionConfig] = None,
269
269
  recreate_if_exists=False,
270
270
  **kwargs,
271
- ) -> CollectionType:
271
+ ) -> Collection:
272
272
  """
273
273
  Create a new collection in the current database.
274
274
 
@@ -760,6 +760,12 @@ class Database(ABC, Generic[CollectionType]):
760
760
  """
761
761
  Export a database to a file or location.
762
762
 
763
+ >>> from linkml_store.api.client import Client
764
+ >>> client = Client()
765
+ >>> db = client.attach_database("duckdb", alias="test")
766
+ >>> db.import_database("tests/input/iris.csv", Format.CSV, collection_name="iris")
767
+ >>> db.export_database("/tmp/iris.yaml", Format.YAML)
768
+
763
769
  :param location: location of the file
764
770
  :param target_format: target format
765
771
  :param kwargs: additional arguments
@@ -40,7 +40,9 @@ class FacetCountResult(BaseModel):
40
40
 
41
41
  class QueryResult(BaseModel):
42
42
  """
43
- A query result
43
+ A query result.
44
+
45
+ TODO: make this a subclass of Collection
44
46
  """
45
47
 
46
48
  query: Optional[Query] = None
@@ -50,8 +50,9 @@ class DuckDBCollection(Collection):
50
50
  if not isinstance(objs, list):
51
51
  objs = [objs]
52
52
  cd = self.class_definition()
53
- if not cd:
53
+ if not cd or not cd.attributes:
54
54
  cd = self.induce_class_definition_from_objects(objs)
55
+ assert cd.attributes
55
56
  table = self._sqla_table(cd)
56
57
  engine = self.parent.engine
57
58
  with engine.connect() as conn:
@@ -61,7 +62,8 @@ class DuckDBCollection(Collection):
61
62
  stmt = stmt.compile(engine)
62
63
  conn.execute(stmt)
63
64
  conn.commit()
64
- return
65
+ self._post_delete_hook()
66
+ return None
65
67
 
66
68
  def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> Optional[int]:
67
69
  logger.info(f"Deleting from {self.target_class_name} where: {where}")
@@ -87,6 +89,7 @@ class DuckDBCollection(Collection):
87
89
  if deleted_rows_count == 0 and not missing_ok:
88
90
  raise ValueError(f"No rows found for {where}")
89
91
  conn.commit()
92
+ self._post_delete_hook()
90
93
  return deleted_rows_count if deleted_rows_count > -1 else None
91
94
 
92
95
  def query_facets(
linkml_store/cli.py CHANGED
@@ -76,6 +76,8 @@ class ContextSettings(BaseModel):
76
76
  if name is None:
77
77
  # if len(self.database.list_collections()) > 1:
78
78
  # raise ValueError("Collection must be specified if there are multiple collections.")
79
+ if not self.database:
80
+ return None
79
81
  if not self.database.list_collections():
80
82
  return None
81
83
  name = list(self.database.list_collections())[0]
@@ -218,7 +220,10 @@ def insert(ctx, files, object, format):
218
220
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
219
221
  @click.pass_context
220
222
  def store(ctx, files, object, format):
221
- """Store objects from files (JSON, YAML, TSV) into the specified collection."""
223
+ """Store objects from files (JSON, YAML, TSV) into the database.
224
+
225
+ Note: this is similar to insert, but a collection does not need to be specified
226
+ """
222
227
  settings = ctx.obj["settings"]
223
228
  db = settings.database
224
229
  if not files and not object:
@@ -499,6 +504,7 @@ def describe(ctx, where, output_type, output, limit):
499
504
  "--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
500
505
  )
501
506
  @click.option("--evaluation-count", "-n", type=click.INT, help="Number of examples to evaluate over")
507
+ @click.option("--evaluation-match-function", help="Name of function to use for matching objects in eval")
502
508
  @click.option("--query", "-q", type=click.STRING, help="query term")
503
509
  @click.pass_context
504
510
  def infer(
@@ -506,6 +512,7 @@ def infer(
506
512
  inference_config_file,
507
513
  query,
508
514
  evaluation_count,
515
+ evaluation_match_function,
509
516
  training_test_data_split,
510
517
  predictor_type,
511
518
  target_attribute,
@@ -549,7 +556,10 @@ def infer(
549
556
  else:
550
557
  query_obj = None
551
558
  collection = ctx.obj["settings"].collection
552
- atts = collection.class_definition().attributes.keys()
559
+ if collection:
560
+ atts = collection.class_definition().attributes.keys()
561
+ else:
562
+ atts = []
553
563
  if feature_attributes:
554
564
  features = feature_attributes.split(",")
555
565
  features = [f.strip() for f in features]
@@ -575,7 +585,8 @@ def infer(
575
585
  if training_test_data_split:
576
586
  config.train_test_split = training_test_data_split
577
587
  predictor = get_inference_engine(predictor_type, config=config)
578
- predictor.load_and_split_data(collection)
588
+ if collection:
589
+ predictor.load_and_split_data(collection)
579
590
  predictor.initialize_model()
580
591
  if export_model:
581
592
  logger.info(f"Exporting model to {export_model} in {model_format}")
@@ -584,8 +595,14 @@ def infer(
584
595
  if not export_model and not evaluation_count:
585
596
  raise ValueError("Query or evaluate must be specified if not exporting model")
586
597
  if evaluation_count:
598
+ if evaluation_match_function == "score_text_overlap":
599
+ match_function_fn = score_text_overlap
600
+ elif evaluation_match_function is not None:
601
+ raise ValueError(f"Unknown match function: {evaluation_match_function}")
602
+ else:
603
+ match_function_fn = None
587
604
  outcome = evaluate_predictor(
588
- predictor, target_attributes, evaluation_count=evaluation_count, match_function=score_text_overlap
605
+ predictor, target_attributes, evaluation_count=evaluation_count, match_function=match_function_fn
589
606
  )
590
607
  print(f"Outcome: {outcome} // accuracy: {outcome.accuracy}")
591
608
  if query_obj:
@@ -1,11 +1,13 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import TYPE_CHECKING, List
3
+ from typing import TYPE_CHECKING, List, Optional
4
4
 
5
5
  import numpy as np
6
+ from tiktoken import encoding_for_model
6
7
 
7
8
  from linkml_store.api.config import CollectionConfig
8
9
  from linkml_store.index.indexer import INDEX_ITEM, Indexer
10
+ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
9
11
 
10
12
  if TYPE_CHECKING:
11
13
  import llm
@@ -29,6 +31,7 @@ class LLMIndexer(Indexer):
29
31
  cached_embeddings_database: str = None
30
32
  cached_embeddings_collection: str = None
31
33
  cache_queries: bool = False
34
+ truncation_method: Optional[str] = None
32
35
 
33
36
  @property
34
37
  def embedding_model(self):
@@ -62,6 +65,21 @@ class LLMIndexer(Indexer):
62
65
  """
63
66
  logging.info(f"Converting {len(texts)} texts to vectors")
64
67
  model = self.embedding_model
68
+ token_limit = get_token_limit(model.model_id)
69
+ encoding = encoding_for_model("gpt-4o")
70
+
71
+ def truncate_text(text: str) -> str:
72
+ # split into tokens every 1000 chars:
73
+ parts = [text[i : i + 1000] for i in range(0, len(text), 1000)]
74
+ return render_formatted_text(
75
+ lambda x: "".join(x),
76
+ parts,
77
+ encoding,
78
+ token_limit,
79
+ )
80
+
81
+ texts = [truncate_text(text) for text in texts]
82
+
65
83
  if self.cached_embeddings_database and (cache is None or cache or self.cache_queries):
66
84
  model_id = model.model_id
67
85
  if not model_id:
@@ -88,7 +106,7 @@ class LLMIndexer(Indexer):
88
106
  embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
89
107
  else:
90
108
  embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
91
- texts = list(texts)
109
+
92
110
  embeddings = list([None] * len(texts))
93
111
  uncached_texts = []
94
112
  n = 0
@@ -36,6 +36,54 @@ def cosine_similarity(vector1, vector2) -> float:
36
36
  class Indexer(BaseModel):
37
37
  """
38
38
  An indexer operates on a collection in order to search for objects.
39
+
40
+ You should use a subcllass of this; this can be looked up dynqamically:
41
+
42
+ >>> from linkml_store.index import get_indexer
43
+ >>> indexer = get_indexer("simple")
44
+
45
+ You can customize how objects are indexed by passing in a text template.
46
+ For example, if your collection has objects with "name" and "profession" attributes,
47
+ you can index them as "{name} {profession}".
48
+
49
+ >>> indexer = get_indexer("simple", text_template="{name} :: {profession}")
50
+
51
+ By default, python fstrings are assumed.
52
+
53
+ We can test this works using the :ref:`object_to_text` method (normally
54
+ you would never need to call this directly, but it's useful for testing):
55
+
56
+ >>> obj = {"name": "John", "profession": "doctor"}
57
+ >>> indexer.object_to_text(obj)
58
+ 'John :: doctor'
59
+
60
+ You can also use Jinja2 templates; this gives more flexibility and logic,
61
+ e.g. conditional formatting:
62
+
63
+ >>> tmpl = "{{name}}{% if profession %} :: {{profession}}{% endif %}"
64
+ >>> indexer = get_indexer("simple", text_template=tmpl, text_template_syntax=TemplateSyntaxEnum.jinja2)
65
+ >>> indexer.object_to_text(obj)
66
+ 'John :: doctor'
67
+ >>> indexer.object_to_text({"name": "John"})
68
+ 'John'
69
+
70
+ You can also specify which attributes to index:
71
+
72
+ >>> indexer = get_indexer("simple", index_attributes=["name"])
73
+ >>> indexer.object_to_text(obj)
74
+ 'John'
75
+
76
+ The purpose of an indexer is to translate a collection of objects into a collection of objects
77
+ such as vectors for purposes such as search. Unless you are implementing your own indexer, you
78
+ generally don't need to use the methods that return vectors, but we can examine their behavior
79
+ to get a sense of how they work.
80
+
81
+ >>> vectors = indexer.objects_to_vectors([{"name": "Aardvark"}, {"name": "Aardwolf"}, {"name": "Zesty"}])
82
+ >>> assert cosine_similarity(vectors[0], vectors[1]) > cosine_similarity(vectors[0], vectors[2])
83
+
84
+ Note you should consult the documentation for the specific indexer you are using for more details on
85
+ how text is converted to vectors.
86
+
39
87
  """
40
88
 
41
89
  name: Optional[str] = None
@@ -122,7 +170,9 @@ class Indexer(BaseModel):
122
170
  self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None
123
171
  ) -> List[Tuple[float, Any]]:
124
172
  """
125
- Search the index for a query string
173
+ Use the indexer to search against a database of vectors.
174
+
175
+ Note: this is a low-level method, typically you would use the :ref:`search` method on a :ref:`Collection`.
126
176
 
127
177
  :param query: The query string to search for
128
178
  :param vectors: A list of indexed items, where each item is a tuple of (id, vector)
@@ -20,6 +20,8 @@ def score_match(target: Optional[Any], candidate: Optional[Any], match_function:
20
20
  1.0
21
21
  >>> score_match("a", "b")
22
22
  0.0
23
+ >>> score_match("abcd", "abcde")
24
+ 0.0
23
25
  >>> score_match("a", None)
24
26
  0.0
25
27
  >>> score_match(None, "a")
@@ -52,7 +54,7 @@ def score_match(target: Optional[Any], candidate: Optional[Any], match_function:
52
54
 
53
55
  :param target:
54
56
  :param candidate:
55
- :param match_function:
57
+ :param match_function: defaults to struct
56
58
  :return:
57
59
  """
58
60
  if target == candidate:
@@ -99,7 +101,8 @@ def evaluate_predictor(
99
101
  :param predictor:
100
102
  :param target_attributes:
101
103
  :param feature_attributes:
102
- :param evaluation_count:
104
+ :param evaluation_count: max iterations
105
+ :param match_function: function to use for matching
103
106
  :return:
104
107
  """
105
108
  n = 0
@@ -113,8 +116,8 @@ def evaluate_predictor(
113
116
  else:
114
117
  test_obj = row
115
118
  result = predictor.derive(test_obj)
116
- logger.info(f"Predicted: {result.predicted_object} Expected: {expected_obj}")
117
119
  tp += score_match(result.predicted_object, expected_obj, match_function)
120
+ logger.info(f"TP={tp} MF={match_function} Predicted: {result.predicted_object} Expected: {expected_obj}")
118
121
  n += 1
119
122
  if evaluation_count is not None and n >= evaluation_count:
120
123
  break
@@ -125,6 +128,9 @@ def score_text_overlap(str1: Any, str2: Any) -> float:
125
128
  """
126
129
  Compute the overlap score between two strings.
127
130
 
131
+ >>> score_text_overlap("abc", "bcde")
132
+ 0.5
133
+
128
134
  :param str1:
129
135
  :param str2:
130
136
  :return:
@@ -1,13 +1,16 @@
1
+ import json
1
2
  import logging
2
3
  from dataclasses import dataclass
3
- from typing import Any, Optional
4
+ from pathlib import Path
5
+ from typing import ClassVar, List, Optional, TextIO, Union
4
6
 
5
7
  import yaml
6
8
  from llm import get_key
9
+ from pydantic import BaseModel
7
10
 
8
11
  from linkml_store.api.collection import OBJECT, Collection
9
12
  from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
10
- from linkml_store.inference.inference_engine import InferenceEngine
13
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
11
14
  from linkml_store.utils.object_utils import select_nested
12
15
 
13
16
  logger = logging.getLogger(__name__)
@@ -23,9 +26,10 @@ You should return ONLY valid YAML in your response.
23
26
  """
24
27
 
25
28
 
26
- # def select_object(obj: OBJECT, key_paths: List[str]) -> OBJECT:
27
- # return {k: obj.get(k, None) for k in keys}
28
- # return {k: object_path_get(obj, k, None) for k in key_paths}
29
+ class TrainedModel(BaseModel, extra="forbid"):
30
+ rag_collection_rows: List[OBJECT]
31
+ index_rows: List[OBJECT]
32
+ config: Optional[InferenceConfig] = None
29
33
 
30
34
 
31
35
  @dataclass
@@ -54,14 +58,23 @@ class RAGInferenceEngine(InferenceEngine):
54
58
  >>> prediction.predicted_object
55
59
  {'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
56
60
 
61
+ The "model" can be saved for later use:
62
+
63
+ >>> ie.export_model("tests/output/countries.rag_model.json")
64
+
65
+ Note in this case the model is not the underlying LLM, but the "RAG Model" which is the vectorized
66
+ representation of training set objects.
67
+
57
68
  """
58
69
 
59
- classifier: Any = None
60
- encoders: dict = None
61
70
  _model: "llm.Model" = None # noqa: F821
62
71
 
63
72
  rag_collection: Collection = None
64
73
 
74
+ PERSIST_COLS: ClassVar[List[str]] = [
75
+ "config",
76
+ ]
77
+
65
78
  def __post_init__(self):
66
79
  if not self.config:
67
80
  self.config = InferenceConfig()
@@ -81,9 +94,11 @@ class RAGInferenceEngine(InferenceEngine):
81
94
  return self._model
82
95
 
83
96
  def initialize_model(self, **kwargs):
84
- rag_collection = self.training_data.collection
85
- rag_collection.attach_indexer("llm", auto_index=False)
86
- self.rag_collection = rag_collection
97
+ logger.info(f"Initializing model {self.model}")
98
+ if self.training_data:
99
+ rag_collection = self.training_data.collection
100
+ rag_collection.attach_indexer("llm", auto_index=False)
101
+ self.rag_collection = rag_collection
87
102
 
88
103
  def object_to_text(self, object: OBJECT) -> str:
89
104
  return yaml.dump(object)
@@ -100,27 +115,34 @@ class RAGInferenceEngine(InferenceEngine):
100
115
  target_attributes = self.config.target_attributes
101
116
  num_examples = self.config.llm_config.number_of_few_shot_examples or 5
102
117
  query_text = self.object_to_text(object)
103
- if not self.rag_collection.indexers:
104
- raise ValueError("RAG collection must have an indexer attached")
105
- rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
106
- examples = rs.rows
107
- if not examples:
108
- raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
118
+ if not self.rag_collection:
119
+ # TODO: zero-shot mode
120
+ examples = []
121
+ else:
122
+ if not self.rag_collection.indexers:
123
+ raise ValueError("RAG collection must have an indexer attached")
124
+ rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
125
+ examples = rs.rows
126
+ if not examples:
127
+ raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
109
128
  prompt_clauses = []
129
+ query_obj = select_nested(object, feature_attributes)
130
+ query_text = self.object_to_text(query_obj)
110
131
  for example in examples:
111
- # input_obj = {k: example.get(k, None) for k in feature_attributes}
112
132
  input_obj = select_nested(example, feature_attributes)
113
- # output_obj = {k: example.get(k, None) for k in target_attributes}
133
+ input_obj_text = self.object_to_text(input_obj)
134
+ if input_obj_text == query_text:
135
+ raise ValueError(
136
+ f"Query object {query_text} is the same as example object {input_obj_text}\n"
137
+ "This indicates possible test data leakage\n."
138
+ "TODO: allow an option that allows user to treat this as a basic lookup\n"
139
+ )
114
140
  output_obj = select_nested(example, target_attributes)
115
141
  prompt_clause = (
116
- "---\nExample:\n"
117
- f"## INPUT:\n{self.object_to_text(input_obj)}\n"
118
- f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
142
+ "---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
119
143
  )
120
144
  prompt_clauses.append(prompt_clause)
121
- # query_obj = {k: object.get(k, None) for k in feature_attributes}
122
- query_obj = select_nested(object, feature_attributes)
123
- query_text = self.object_to_text(query_obj)
145
+
124
146
  prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
125
147
  system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
126
148
 
@@ -137,9 +159,74 @@ class RAGInferenceEngine(InferenceEngine):
137
159
  response = model.prompt(prompt, system_prompt)
138
160
  yaml_str = response.text()
139
161
  logger.info(f"Response: {yaml_str}")
162
+ return Inference(predicted_object=self._parse_yaml_payload(yaml_str))
163
+
164
+ def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
165
+ if "```" in yaml_str:
166
+ yaml_str = yaml_str.split("```")[1].strip()
167
+ if yaml_str.startswith("yaml"):
168
+ yaml_str = yaml_str[4:].strip()
140
169
  try:
141
- predicted_object = yaml.safe_load(yaml_str)
142
- return Inference(predicted_object=predicted_object)
143
- except yaml.parser.ParserError as e:
144
- logger.error(f"Error parsing response: {yaml_str}\n{e}")
170
+ return yaml.safe_load(yaml_str)
171
+ except Exception as e:
172
+ if strict:
173
+ raise e
174
+ logger.error(f"Error parsing YAML: {yaml_str}\n{e}")
145
175
  return None
176
+
177
+ def export_model(
178
+ self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
179
+ ):
180
+ self.save_model(output)
181
+
182
+ def save_model(self, output: Union[str, Path]) -> None:
183
+ """
184
+ Save the trained model and related data to a file.
185
+
186
+ :param output: Path to save the model
187
+ """
188
+
189
+ # trigger index
190
+ _qr = self.rag_collection.search("*", limit=1)
191
+ assert len(_qr.ranked_rows) > 0
192
+
193
+ rows = self.rag_collection.find(limit=-1).rows
194
+
195
+ indexers = self.rag_collection.indexers
196
+ assert len(indexers) == 1
197
+ ix = self.rag_collection.indexers["llm"]
198
+ ix_coll = self.rag_collection.parent.get_collection(self.rag_collection.get_index_collection_name(ix))
199
+
200
+ ix_rows = ix_coll.find(limit=-1).rows
201
+ assert len(ix_rows) > 0
202
+ tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows, config=self.config)
203
+ # tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows)
204
+ with open(output, "w", encoding="utf-8") as f:
205
+ json.dump(tm.model_dump(), f)
206
+
207
+ @classmethod
208
+ def load_model(cls, file_path: Union[str, Path]) -> "RAGInferenceEngine":
209
+ """
210
+ Load a trained model and related data from a file.
211
+
212
+ :param file_path: Path to the saved model
213
+ :return: SklearnInferenceEngine instance with loaded model
214
+ """
215
+ with open(file_path, "r", encoding="utf-8") as f:
216
+ model_data = json.load(f)
217
+ tm = TrainedModel(**model_data)
218
+ from linkml_store.api import Client
219
+
220
+ client = Client()
221
+ db = client.attach_database("duckdb", alias="training")
222
+ db.store({"data": tm.rag_collection_rows})
223
+ collection = db.get_collection("data")
224
+ ix = collection.attach_indexer("llm", auto_index=False)
225
+ assert ix.name
226
+ ix_coll_name = collection.get_index_collection_name(ix)
227
+ assert ix_coll_name
228
+ ix_coll = db.get_collection(ix_coll_name, create_if_not_exists=True)
229
+ ix_coll.insert(tm.index_rows)
230
+ ie = cls(config=tm.config)
231
+ ie.rag_collection = collection
232
+ return ie
@@ -153,7 +153,7 @@ class SklearnInferenceEngine(InferenceEngine):
153
153
  y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
154
154
  self.transformed_targets = y_encoder.classes_
155
155
 
156
- logger.info(f"Fitting model with features: {X.columns}")
156
+ # print(f"Fitting model with features: {X.columns}")
157
157
  clf = DecisionTreeClassifier(random_state=42)
158
158
  clf.fit(X, y)
159
159
  self.classifier = clf
@@ -35,6 +35,7 @@ class InferenceConfig(BaseModel, extra="forbid"):
35
35
  feature_attributes: Optional[List[str]] = None
36
36
  train_test_split: Optional[Tuple[float, float]] = None
37
37
  llm_config: Optional[LLMConfig] = None
38
+ random_seed: Optional[int] = None
38
39
 
39
40
  @classmethod
40
41
  def from_file(cls, file_path: str, format: Optional[Format] = None) -> "InferenceConfig":
@@ -29,6 +29,7 @@ class ModelSerialization(str, Enum):
29
29
  PNG = "png"
30
30
  LINKML_EXPRESSION = "linkml_expression"
31
31
  RULE_BASED = "rulebased"
32
+ RAG_INDEX = "rag_index"
32
33
 
33
34
  @classmethod
34
35
  def from_filepath(cls, file_path: str) -> Optional["ModelSerialization"]:
@@ -58,7 +59,7 @@ class ModelSerialization(str, Enum):
58
59
 
59
60
 
60
61
  class CollectionSlice(BaseModel):
61
- model_config = ConfigDict(arbitrary_types_allowed=True)
62
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
62
63
 
63
64
  name: Optional[str] = None
64
65
  base_collection: Optional[Collection] = None
@@ -69,17 +70,26 @@ class CollectionSlice(BaseModel):
69
70
 
70
71
  @property
71
72
  def collection(self) -> Collection:
73
+ if not self._collection and not self.indices:
74
+ return self.base_collection
72
75
  if not self._collection:
73
76
  rows = self.base_collection.find({}, limit=-1).rows
74
- # subset based on indices
75
77
  subset = [rows[i] for i in self.indices]
76
78
  db = self.base_collection.parent
77
- subset_name = f"{self.base_collection.alias}__rag_{self.name}"
79
+ subset_name = self.slice_alias
78
80
  subset_collection = db.get_collection(subset_name, create_if_not_exists=True)
81
+ # ensure the collection has the same schema type as the base collection;
82
+ # this ensures that column/attribute types are preserved
83
+ subset_collection.metadata.type = self.base_collection.target_class_name
84
+ subset_collection.delete_where({})
79
85
  subset_collection.insert(subset)
80
86
  self._collection = subset_collection
81
87
  return self._collection
82
88
 
89
+ @property
90
+ def slice_alias(self) -> str:
91
+ return f"{self.base_collection.alias}__rag_{self.name}"
92
+
83
93
  def as_dataframe(self, flattened=False) -> pd.DataFrame:
84
94
  """
85
95
  Return the slice of the collection as a dataframe.
@@ -113,31 +123,28 @@ class InferenceEngine(ABC):
113
123
 
114
124
  :param collection:
115
125
  :param split:
126
+ :param randomize:
116
127
  :return:
117
128
  """
129
+ local_random = random.Random(self.config.random_seed) if self.config.random_seed else random.Random()
118
130
  split = split or self.config.train_test_split
119
131
  if not split:
120
132
  split = (0.7, 0.3)
133
+ if split[0] == 1.0:
134
+ self.training_data = CollectionSlice(name="train", base_collection=collection, indices=None)
135
+ self.testing_data = None
136
+ return
121
137
  logger.info(f"Loading and splitting data from collection {collection.alias}")
122
138
  size = collection.size()
123
139
  indices = range(size)
124
140
  if randomize:
125
- train_indices = random.sample(indices, int(size * split[0]))
141
+ train_indices = local_random.sample(indices, int(size * split[0]))
126
142
  test_indices = set(indices) - set(train_indices)
127
143
  else:
128
144
  train_indices = indices[: int(size * split[0])]
129
145
  test_indices = indices[int(size * split[0]) :]
130
146
  self.training_data = CollectionSlice(name="train", base_collection=collection, indices=train_indices)
131
147
  self.testing_data = CollectionSlice(name="test", base_collection=collection, indices=test_indices)
132
- # all_data = collection.find({}, limit=size).rows
133
- # all_data_df = nested_objects_to_dataframe(all_data)
134
- # all_data_df = collection.find({}, limit=size).rows_dataframe
135
- # randomize/shuffle order of rows in dataframe
136
- # all_data_df = all_data_df.sample(frac=1).reset_index(drop=True)
137
- # self.training_data = CollectionSlice(dataframe=all_data_df[: int(size * split[0])])
138
- # self.testing_data = CollectionSlice(dataframe=all_data_df[int(size * split[0]) : size])
139
- # self.training_data = CollectionSlice(base_collection=collection, slice=(0, int(size * split[0])))
140
- # self.testing_data = CollectionSlice(base_collection=collection, slice=(int(size * split[0]), size))
141
148
 
142
149
  def initialize_model(self, **kwargs):
143
150
  """
@@ -20,6 +20,7 @@ MODEL_TOKEN_MAPPING = {
20
20
  "gpt-3.5-turbo-instruct": 4096,
21
21
  "text-ada-001": 2049,
22
22
  "ada": 2049,
23
+ "ada-002": 8192,
23
24
  "text-babbage-001": 2040,
24
25
  "babbage": 2049,
25
26
  "text-curie-001": 2049,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: linkml-store
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: linkml-store
5
5
  License: MIT
6
6
  Author: Author 1
@@ -18,6 +18,7 @@ Provides-Extra: chromadb
18
18
  Provides-Extra: fastapi
19
19
  Provides-Extra: frictionless
20
20
  Provides-Extra: h5py
21
+ Provides-Extra: ibis
21
22
  Provides-Extra: llm
22
23
  Provides-Extra: map
23
24
  Provides-Extra: mongodb
@@ -34,7 +35,9 @@ Requires-Dist: duckdb (>=0.10.1)
34
35
  Requires-Dist: duckdb-engine (>=0.11.2)
35
36
  Requires-Dist: fastapi ; extra == "fastapi"
36
37
  Requires-Dist: frictionless ; extra == "frictionless"
38
+ Requires-Dist: gcsfs ; extra == "ibis"
37
39
  Requires-Dist: h5py ; extra == "h5py"
40
+ Requires-Dist: ibis-framework[duckdb,examples] (>=9.3.0) ; extra == "ibis"
38
41
  Requires-Dist: jinja2 (>=3.1.4,<4.0.0)
39
42
  Requires-Dist: jsonlines (>=4.0.0,<5.0.0)
40
43
  Requires-Dist: linkml (>=1.8.0) ; extra == "validation"
@@ -43,6 +46,7 @@ Requires-Dist: linkml_map ; extra == "map"
43
46
  Requires-Dist: linkml_renderer ; extra == "renderer"
44
47
  Requires-Dist: llm ; extra == "llm"
45
48
  Requires-Dist: matplotlib ; extra == "analytics"
49
+ Requires-Dist: multipledispatch ; extra == "ibis"
46
50
  Requires-Dist: neo4j ; extra == "neo4j"
47
51
  Requires-Dist: networkx ; extra == "neo4j"
48
52
  Requires-Dist: pandas (>=2.2.1) ; extra == "analytics"
@@ -52,6 +56,7 @@ Requires-Dist: pyarrow ; extra == "pyarrow"
52
56
  Requires-Dist: pydantic (>=2.0.0,<3.0.0)
53
57
  Requires-Dist: pymongo ; extra == "mongodb"
54
58
  Requires-Dist: pystow (>=0.5.4,<0.6.0)
59
+ Requires-Dist: ruff (>=0.6.2) ; extra == "tests"
55
60
  Requires-Dist: scikit-learn ; extra == "scipy"
56
61
  Requires-Dist: scipy ; extra == "scipy"
57
62
  Requires-Dist: seaborn ; extra == "analytics"
@@ -1,16 +1,16 @@
1
1
  linkml_store/__init__.py,sha256=jlU6WOUAn8cKIhzbTULmBTWpW9gZdEt7q_RI6KZN1bY,118
2
2
  linkml_store/api/__init__.py,sha256=3CelcFEFz0y3MkQAzhQ9JxHIt1zFk6nYZxSmYTo8YZE,226
3
3
  linkml_store/api/client.py,sha256=3klBXenQVbLjNQF3WmYfjASt3zvKOfWaCNp5aJM81Ec,12034
4
- linkml_store/api/collection.py,sha256=7JndC6A9r3OVbR9aB6d_bdaYN53XU4FpppUterygOaE,37800
4
+ linkml_store/api/collection.py,sha256=98qUYKVJOEzC9Sl9iBqxdBWnm_4Q8UT9r5UPRb4UoAU,39300
5
5
  linkml_store/api/config.py,sha256=71pxQ5jM-ETxJWU7CzmKjsH6IEJUMP5sml381u9TYVk,5654
6
- linkml_store/api/database.py,sha256=QVvUuLQPCxB4cvsS7rXqPSfoHkhcMzP9vUcsjkuEYds,29051
7
- linkml_store/api/queries.py,sha256=w0qnNeCH6pC9WTGoEQYd300MF6o0G3atz2YxN3WecAs,2028
6
+ linkml_store/api/database.py,sha256=nvae8jnOZsQIFCsl_lRBnKcvrpJg4A10ujIKGeMyUS8,29350
7
+ linkml_store/api/queries.py,sha256=tx9fgGY5fC_2ZbIvg4BqTK_MXJwA_DI4mxr8HdQ6Vos,2075
8
8
  linkml_store/api/stores/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  linkml_store/api/stores/chromadb/__init__.py,sha256=e9BkOPuPnVQKA5PRKDulag59yGNHDP3U2_DnPSrFAKM,132
10
10
  linkml_store/api/stores/chromadb/chromadb_collection.py,sha256=RQUZx5oeotkzNihg-dlSevkiTiKY1d9x0bS63HF80W4,4270
11
11
  linkml_store/api/stores/chromadb/chromadb_database.py,sha256=dZA3LQE8-ZMhJQOzsUFyxehnKpFF7adR182aggfkaFY,3205
12
12
  linkml_store/api/stores/duckdb/__init__.py,sha256=rbQSDgNg-fdvi6-pHGYkJTST4p1qXUZBf9sFSsO3KPk,387
13
- linkml_store/api/stores/duckdb/duckdb_collection.py,sha256=yXnJpEiGK4lMyNuJykuvlKOqaV9ntqv0m0NZMOw0auk,6911
13
+ linkml_store/api/stores/duckdb/duckdb_collection.py,sha256=Rkbm_uIVIRj5576lEolsyY_3Um1h8Lf3RHn8Fy3LIgU,7036
14
14
  linkml_store/api/stores/duckdb/duckdb_database.py,sha256=GH9bcOfHpNp6r-Eu1C3W0xuYcLsqGFDH1Sh4weifGaQ,9923
15
15
  linkml_store/api/stores/duckdb/mappings.py,sha256=tDce3W1Apwammhf4LS6cRJ0m4NiJ0eB7vOI_4U5ETY8,148
16
16
  linkml_store/api/stores/filesystem/__init__.py,sha256=KjvCjdttwqMHNeGyL-gr59zRz0--HFEWWUNNCJ5hITs,347
@@ -30,30 +30,30 @@ linkml_store/api/stores/solr/solr_collection.py,sha256=ZlxC3JbVaHfSA4HuTeJTsp6qe
30
30
  linkml_store/api/stores/solr/solr_database.py,sha256=TFjqbY7jAkdrhAchbNg0E-mChSP7ogNwFExslbvX7Yo,2877
31
31
  linkml_store/api/stores/solr/solr_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  linkml_store/api/types.py,sha256=3aIQtDFMvsSmjuN5qrR2vNK5sHa6yzD_rEOPA6tHwvg,176
33
- linkml_store/cli.py,sha256=NIEU5dEkEKz3a2q4mpkdXxHX1mANd2z9oFIkNVz-wsw,27048
33
+ linkml_store/cli.py,sha256=6pcVHM_hNH7EicleoCkjMVAbrGFZu8V_k2mv3aX0SH8,27703
34
34
  linkml_store/constants.py,sha256=x4ZmDsfE9rZcL5WpA93uTKrRWzCD6GodYXviVzIvR38,112
35
35
  linkml_store/graphs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  linkml_store/graphs/graph_map.py,sha256=bYRxv8n1YPnFqE9d6JKNmRawb8EAhsPlHhBue0gvtZE,712
37
37
  linkml_store/index/__init__.py,sha256=6SQzDe-WZSSqbGNsbCDfyPTyz0s9ISDKw1dm9xgQuT4,1396
38
38
  linkml_store/index/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
- linkml_store/index/implementations/llm_indexer.py,sha256=LI5f8SLF_rJY5W6wZPLaUqpyoq-VDW_KqlCBNDNm_po,4827
39
+ linkml_store/index/implementations/llm_indexer.py,sha256=y1xvfUm_rl4UEiWJbsUsEnTCma98XRB9C1XOnuaAv5o,5474
40
40
  linkml_store/index/implementations/simple_indexer.py,sha256=KnkFJtXTHnwjhD_D6ZK2rFhBID1dgCedcOVPEWAY2NU,1282
41
- linkml_store/index/indexer.py,sha256=K-TDPzdTyGFo6iG4XI_A_3IpwDbKeiTIbdr85NIL5r8,4918
41
+ linkml_store/index/indexer.py,sha256=3P-VtIEqitnMoaFjcXIIosoU7tJInop1Qq39QRbcT-8,7107
42
42
  linkml_store/inference/__init__.py,sha256=b8NAFNZjOYU_8gOvxdyCyoiHOOl5Ai2ckKs1tv7ZkkY,342
43
- linkml_store/inference/evaluation.py,sha256=qvsmGDBKTZBDKhpbPDe_AkcJ2LtQ8e-oUYCUGfI6IAE,5799
43
+ linkml_store/inference/evaluation.py,sha256=YDFYaEu2QLSfFq4oyARrnKfTiPLtNF8irhhspgVDfdY,6013
44
44
  linkml_store/inference/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- linkml_store/inference/implementations/rag_inference_engine.py,sha256=MH50-6i30Y5oKgIx47-yDjsPCojYC6-lujtHFBDqIxs,5833
45
+ linkml_store/inference/implementations/rag_inference_engine.py,sha256=zmcbxmVZTm8ViSp7WFs8KHRNbzWXdZQl7J7VvcIjDyU,9049
46
46
  linkml_store/inference/implementations/rule_based_inference_engine.py,sha256=0IEY_fsHJPJy6QKbYQU_qE87RRnPOXQxPuJKXCQG8jU,6250
47
- linkml_store/inference/implementations/sklearn_inference_engine.py,sha256=HRhwnlpDJOijxvhLmdTSOq1S9xjBVCrgRT1C8uS0XZQ,13196
48
- linkml_store/inference/inference_config.py,sha256=xgl3VmueErLIOnQQn4HdC2STJNY6yKoPasWmym4ltHQ,2014
49
- linkml_store/inference/inference_engine.py,sha256=D1JlkihyNbZp7PYe5lplUbTJgyP7jL4vnxcpBio-KUs,6987
47
+ linkml_store/inference/implementations/sklearn_inference_engine.py,sha256=Sdi7CoRK3qoLJu3prgLy1Ck_zQ1gHWRKFybHe7XQ4_g,13192
48
+ linkml_store/inference/inference_config.py,sha256=SbAlgQDRCWawWohe0IWX_Kvy-DFeaLYsN1HqrmLvc0k,2052
49
+ linkml_store/inference/inference_engine.py,sha256=l2UB6cA0rW7a9qyiv8JF5Nzj8nRHGX_yqMYbiDnY1Qc,7055
50
50
  linkml_store/inference/inference_engine_registry.py,sha256=6o66gvBYBwdeAKm62zqqvfaBlcopVP_cla3L6uXGsHA,3015
51
51
  linkml_store/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
52
52
  linkml_store/utils/change_utils.py,sha256=O2rvSvgTKB60reLLz9mX5OWykAA_m93bwnUh5ZWa0EY,471
53
53
  linkml_store/utils/file_utils.py,sha256=rQ7-XpmI6_Kx_dhEnI98muFRr0MmgI_kZ_9cgJBf_0I,1411
54
54
  linkml_store/utils/format_utils.py,sha256=airJ2_tFsr0dTIbSHT5y0TZbDrvBBV4_qThFPFY5k8U,10925
55
55
  linkml_store/utils/io.py,sha256=JHUrWDtlZC2jtN_PQZ4ypdGIyYlftZEN3JaCvEPs44w,884
56
- linkml_store/utils/llm_utils.py,sha256=Wb4h_E8vrJZDAYHhOdMCSMcz-xxVia4nfuFqiYitZ98,2864
56
+ linkml_store/utils/llm_utils.py,sha256=3jRFUtEywoKdomKb3aCH1GdI9hQJOQo8Udb3Jy4M-Xw,2885
57
57
  linkml_store/utils/mongodb_utils.py,sha256=Rl1YmMKs1IXwSsJIViSDChbi0Oer5cBnMmjka2TeQS8,4665
58
58
  linkml_store/utils/neo4j_utils.py,sha256=y3KPmDZ8mQmePgg0lUeKkeKqzEr2rV226xxEtHc5pRg,1266
59
59
  linkml_store/utils/object_utils.py,sha256=Vib-5Ip2DlRVKLZpU-008ZZI813-vfKVSCY0TksRenM,6293
@@ -72,8 +72,8 @@ linkml_store/webapi/html/database_details.html.j2,sha256=qtXdavbZb0mohiObI9dvJtk
72
72
  linkml_store/webapi/html/databases.html.j2,sha256=a9BCWQYfPeFhdUd31CWhB0yWhTIFXQayO08JgjyqKoc,294
73
73
  linkml_store/webapi/html/generic.html.j2,sha256=KtLaO2HUEF2Opq-OwHKgRKetNWe8IWc6JuIkxRPsywk,1018
74
74
  linkml_store/webapi/main.py,sha256=B0Da575kKR7X88N9ykm99Dem8FyBAW9f-w3A_JwUzfw,29165
75
- linkml_store-0.2.0.dist-info/LICENSE,sha256=77mDOslUnalYnuq9xQYZKtIoNEzcH9mIjvWHOKjamnE,1086
76
- linkml_store-0.2.0.dist-info/METADATA,sha256=v_KjIlu-gTOHunF0ASPHRP_utQv-ry1piX3RpfPWX1k,6743
77
- linkml_store-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
78
- linkml_store-0.2.0.dist-info/entry_points.txt,sha256=gWxVsHqx-t-UKWFHFzawQTvs4is4vC1rCF5AeKyqWWk,101
79
- linkml_store-0.2.0.dist-info/RECORD,,
75
+ linkml_store-0.2.1.dist-info/LICENSE,sha256=77mDOslUnalYnuq9xQYZKtIoNEzcH9mIjvWHOKjamnE,1086
76
+ linkml_store-0.2.1.dist-info/METADATA,sha256=ERSRCW1gMtcuLoWX-jKKfpgz10BBIUrGw-TvwXW3o-c,6977
77
+ linkml_store-0.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
78
+ linkml_store-0.2.1.dist-info/entry_points.txt,sha256=gWxVsHqx-t-UKWFHFzawQTvs4is4vC1rCF5AeKyqWWk,101
79
+ linkml_store-0.2.1.dist-info/RECORD,,