linkml-store 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (79) hide show
  1. {linkml_store-0.2.0 → linkml_store-0.2.1}/PKG-INFO +6 -1
  2. {linkml_store-0.2.0 → linkml_store-0.2.1}/pyproject.toml +9 -18
  3. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/collection.py +48 -5
  4. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/database.py +7 -1
  5. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/queries.py +3 -1
  6. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/duckdb/duckdb_collection.py +5 -2
  7. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/cli.py +21 -4
  8. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/index/implementations/llm_indexer.py +20 -2
  9. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/index/indexer.py +51 -1
  10. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/evaluation.py +9 -3
  11. linkml_store-0.2.1/src/linkml_store/inference/implementations/rag_inference_engine.py +232 -0
  12. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/implementations/sklearn_inference_engine.py +1 -1
  13. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/inference_config.py +1 -0
  14. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/inference_engine.py +20 -13
  15. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/llm_utils.py +1 -0
  16. linkml_store-0.2.0/src/linkml_store/inference/implementations/rag_inference_engine.py +0 -145
  17. {linkml_store-0.2.0 → linkml_store-0.2.1}/LICENSE +0 -0
  18. {linkml_store-0.2.0 → linkml_store-0.2.1}/README.md +0 -0
  19. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/__init__.py +0 -0
  20. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/__init__.py +0 -0
  21. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/client.py +0 -0
  22. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/config.py +0 -0
  23. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/__init__.py +0 -0
  24. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/chromadb/__init__.py +0 -0
  25. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/chromadb/chromadb_collection.py +0 -0
  26. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/chromadb/chromadb_database.py +0 -0
  27. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/duckdb/__init__.py +0 -0
  28. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/duckdb/duckdb_database.py +0 -0
  29. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/duckdb/mappings.py +0 -0
  30. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/filesystem/__init__.py +0 -0
  31. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/filesystem/filesystem_collection.py +0 -0
  32. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/filesystem/filesystem_database.py +0 -0
  33. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/hdf5/__init__.py +0 -0
  34. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/hdf5/hdf5_collection.py +0 -0
  35. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/hdf5/hdf5_database.py +0 -0
  36. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/mongodb/__init__.py +0 -0
  37. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/mongodb/mongodb_collection.py +0 -0
  38. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/mongodb/mongodb_database.py +0 -0
  39. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/neo4j/__init__.py +0 -0
  40. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/neo4j/neo4j_collection.py +0 -0
  41. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/neo4j/neo4j_database.py +0 -0
  42. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/solr/__init__.py +0 -0
  43. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/solr/solr_collection.py +0 -0
  44. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/solr/solr_database.py +0 -0
  45. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/stores/solr/solr_utils.py +0 -0
  46. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/api/types.py +0 -0
  47. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/constants.py +0 -0
  48. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/graphs/__init__.py +0 -0
  49. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/graphs/graph_map.py +0 -0
  50. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/index/__init__.py +0 -0
  51. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/index/implementations/__init__.py +0 -0
  52. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/index/implementations/simple_indexer.py +0 -0
  53. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/__init__.py +0 -0
  54. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/implementations/__init__.py +0 -0
  55. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/implementations/rule_based_inference_engine.py +0 -0
  56. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/inference/inference_engine_registry.py +0 -0
  57. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/__init__.py +0 -0
  58. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/change_utils.py +0 -0
  59. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/file_utils.py +0 -0
  60. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/format_utils.py +0 -0
  61. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/io.py +0 -0
  62. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/mongodb_utils.py +0 -0
  63. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/neo4j_utils.py +0 -0
  64. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/object_utils.py +0 -0
  65. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/pandas_utils.py +0 -0
  66. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/patch_utils.py +0 -0
  67. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/query_utils.py +0 -0
  68. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/schema_utils.py +0 -0
  69. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/sklearn_utils.py +0 -0
  70. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/sql_utils.py +0 -0
  71. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/utils/stats_utils.py +0 -0
  72. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/__init__.py +0 -0
  73. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/html/__init__.py +0 -0
  74. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/html/base.html.j2 +0 -0
  75. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/html/collection_details.html.j2 +0 -0
  76. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/html/database_details.html.j2 +0 -0
  77. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/html/databases.html.j2 +0 -0
  78. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/html/generic.html.j2 +0 -0
  79. {linkml_store-0.2.0 → linkml_store-0.2.1}/src/linkml_store/webapi/main.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: linkml-store
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: linkml-store
5
5
  License: MIT
6
6
  Author: Author 1
@@ -18,6 +18,7 @@ Provides-Extra: chromadb
18
18
  Provides-Extra: fastapi
19
19
  Provides-Extra: frictionless
20
20
  Provides-Extra: h5py
21
+ Provides-Extra: ibis
21
22
  Provides-Extra: llm
22
23
  Provides-Extra: map
23
24
  Provides-Extra: mongodb
@@ -34,7 +35,9 @@ Requires-Dist: duckdb (>=0.10.1)
34
35
  Requires-Dist: duckdb-engine (>=0.11.2)
35
36
  Requires-Dist: fastapi ; extra == "fastapi"
36
37
  Requires-Dist: frictionless ; extra == "frictionless"
38
+ Requires-Dist: gcsfs ; extra == "ibis"
37
39
  Requires-Dist: h5py ; extra == "h5py"
40
+ Requires-Dist: ibis-framework[duckdb,examples] (>=9.3.0) ; extra == "ibis"
38
41
  Requires-Dist: jinja2 (>=3.1.4,<4.0.0)
39
42
  Requires-Dist: jsonlines (>=4.0.0,<5.0.0)
40
43
  Requires-Dist: linkml (>=1.8.0) ; extra == "validation"
@@ -43,6 +46,7 @@ Requires-Dist: linkml_map ; extra == "map"
43
46
  Requires-Dist: linkml_renderer ; extra == "renderer"
44
47
  Requires-Dist: llm ; extra == "llm"
45
48
  Requires-Dist: matplotlib ; extra == "analytics"
49
+ Requires-Dist: multipledispatch ; extra == "ibis"
46
50
  Requires-Dist: neo4j ; extra == "neo4j"
47
51
  Requires-Dist: networkx ; extra == "neo4j"
48
52
  Requires-Dist: pandas (>=2.2.1) ; extra == "analytics"
@@ -52,6 +56,7 @@ Requires-Dist: pyarrow ; extra == "pyarrow"
52
56
  Requires-Dist: pydantic (>=2.0.0,<3.0.0)
53
57
  Requires-Dist: pymongo ; extra == "mongodb"
54
58
  Requires-Dist: pystow (>=0.5.4,<0.6.0)
59
+ Requires-Dist: ruff (>=0.6.2) ; extra == "tests"
55
60
  Requires-Dist: scikit-learn ; extra == "scipy"
56
61
  Requires-Dist: scipy ; extra == "scipy"
57
62
  Requires-Dist: seaborn ; extra == "analytics"
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "linkml-store"
3
- version = "0.2.0"
3
+ version = "0.2.1"
4
4
  description = "linkml-store"
5
5
  authors = ["Author 1 <author@org.org>"]
6
6
  license = "MIT"
@@ -20,6 +20,7 @@ seaborn = { version = "*", optional = true }
20
20
  plotly = { version = "*", optional = true }
21
21
  pystow = "^0.5.4"
22
22
  black = { version=">=24.0.0", optional = true }
23
+ ruff = { version=">=0.6.2", optional = true }
23
24
  llm = { version="*", optional = true }
24
25
  tiktoken = { version="*", optional = true }
25
26
  pymongo = { version="*", optional = true }
@@ -35,6 +36,9 @@ linkml = { version=">=1.8.0", optional = true }
35
36
  linkml_map = { version="*", optional = true }
36
37
  linkml_renderer = { version="*", optional = true }
37
38
  frictionless = { version="*", optional = true }
39
+ ibis-framework = { version=">=9.3.0", extras = ["duckdb", "examples"], optional = true }
40
+ gcsfs = { version="*", optional = true }
41
+ multipledispatch = { version="*" }
38
42
  pandas = ">=2.2.1"
39
43
  jinja2 = "^3.1.4"
40
44
  jsonlines = "^4.0.0"
@@ -69,7 +73,7 @@ numpy = [
69
73
  [tool.poetry.extras]
70
74
  analytics = ["pandas", "matplotlib", "seaborn", "plotly"]
71
75
  app = ["streamlit"]
72
- tests = ["black"]
76
+ tests = ["black", "ruff"]
73
77
  llm = ["llm", "tiktoken"]
74
78
  mongodb = ["pymongo"]
75
79
  neo4j = ["neo4j", "py2neo", "networkx"]
@@ -82,6 +86,7 @@ renderer = ["linkml_renderer"]
82
86
  fastapi = ["fastapi", "uvicorn"]
83
87
  frictionless = ["frictionless"]
84
88
  scipy = ["scipy", "scikit-learn"]
89
+ ibis = ["ibis-framework", "multipledispatch", "gcsfs"]
85
90
 
86
91
  [tool.poetry.scripts]
87
92
  linkml-store = "linkml_store.cli:cli"
@@ -119,27 +124,13 @@ extend-exclude = [
119
124
  ]
120
125
  force-exclude = true
121
126
  line-length = 120
122
- extend-ignore = ["E203"]
123
- select = [
127
+ lint.extend-ignore = ["E203"]
128
+ lint.select = [
124
129
  "E", # pycodestyle errors
125
130
  "F", # Pyflakes
126
131
  "I", # isort
127
132
  ]
128
- # Assume Python 3.8
129
- target-version = "py38"
130
133
 
131
- [tool.ruff.per-file-ignores]
132
- # These templates can have long lines
133
- "linkml/generators/sqlalchemy/sqlalchemy_declarative_template.py" = ["E501"]
134
- "linkml/generators/sqlalchemy/sqlalchemy_imperative_template.py" = ["E501"]
135
-
136
- # Notebooks can have unsorted imports
137
- "tests/test_notebooks/input/*" = ["E402"]
138
-
139
-
140
- [tool.ruff.mccabe]
141
- # Unlike Flake8, default to a complexity level of 10.
142
- max-complexity = 10
143
134
 
144
135
 
145
136
  [tool.codespell]
@@ -226,6 +226,18 @@ class Collection(Generic[DatabaseType]):
226
226
  self._initialized = True
227
227
  patches = [{"op": "add", "path": "/0", "value": obj} for obj in objs]
228
228
  self._broadcast(patches, **kwargs)
229
+ self._post_modification_hook(**kwargs)
230
+
231
+ def _post_delete_hook(self, **kwargs):
232
+ self._post_modification_hook(**kwargs)
233
+
234
+ def _post_modification_hook(self, **kwargs):
235
+ for indexer in self.indexers.values():
236
+ ix_collection_name = self.get_index_collection_name(indexer)
237
+ ix_collection = self.parent.get_collection(ix_collection_name)
238
+ # Currently updating the source triggers complete reindexing
239
+ # TODO: make this more efficient by only deleting modified
240
+ ix_collection.delete_where({})
229
241
 
230
242
  def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> Optional[int]:
231
243
  """
@@ -476,7 +488,7 @@ class Collection(Generic[DatabaseType]):
476
488
  Now let's index, using the simple trigram-based index
477
489
 
478
490
  >>> index = get_indexer("simple")
479
- >>> collection.attach_indexer(index)
491
+ >>> _ = collection.attach_indexer(index)
480
492
 
481
493
  Now let's find all objects:
482
494
 
@@ -514,7 +526,10 @@ class Collection(Generic[DatabaseType]):
514
526
  if ix_coll.size() == 0:
515
527
  logger.info(f"Index {index_name} is empty; indexing all objects")
516
528
  all_objs = self.find(limit=-1).rows
517
- self.index_objects(all_objs, index_name, replace=True, **kwargs)
529
+ if all_objs:
530
+ # print(f"Index {index_name} is empty; indexing all objects {len(all_objs)}")
531
+ self.index_objects(all_objs, index_name, replace=True, **kwargs)
532
+ assert ix_coll.size() > 0
518
533
  qr = ix_coll.find(where=where, limit=-1, **kwargs)
519
534
  index_col = ix.index_field
520
535
  # TODO: optimize this for large indexes
@@ -648,7 +663,31 @@ class Collection(Generic[DatabaseType]):
648
663
  """
649
664
  return self.find({}, limit=1).num_rows
650
665
 
651
- def attach_indexer(self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs):
666
+ def rows_iter(self) -> Iterable[OBJECT]:
667
+ """
668
+ Return an iterator over the objects in the collection.
669
+
670
+ :return:
671
+ """
672
+ yield from self.find({}, limit=-1).rows
673
+
674
+ def rows(self) -> List[OBJECT]:
675
+ """
676
+ Return a list of objects in the collection.
677
+
678
+ :return:
679
+ """
680
+ return list(self.rows_iter())
681
+
682
+ def ranked_rows(self) -> List[Tuple[float, OBJECT]]:
683
+ """
684
+ Return a list of objects in the collection, with scores.
685
+ """
686
+ return [(n, obj) for n, obj in enumerate(self.rows_iter())]
687
+
688
+ def attach_indexer(
689
+ self, index: Union[Indexer, str], name: Optional[str] = None, auto_index=True, **kwargs
690
+ ) -> Indexer:
652
691
  """
653
692
  Attach an index to the collection.
654
693
 
@@ -669,8 +708,8 @@ class Collection(Generic[DatabaseType]):
669
708
  >>> full_index.name = "full"
670
709
  >>> name_index = get_indexer("simple", text_template="{name}")
671
710
  >>> name_index.name = "name"
672
- >>> collection.attach_indexer(full_index)
673
- >>> collection.attach_indexer(name_index)
711
+ >>> _ = collection.attach_indexer(full_index)
712
+ >>> _ = collection.attach_indexer(name_index)
674
713
 
675
714
  Now let's find objects using the full index, using the string "France".
676
715
  We expect the country France to be the top hit, but the score will
@@ -713,6 +752,10 @@ class Collection(Generic[DatabaseType]):
713
752
  all_objs = self.find(limit=-1).rows
714
753
  logger.info(f"Auto-indexing {len(all_objs)} objects")
715
754
  self.index_objects(all_objs, index_name, replace=True, **kwargs)
755
+ return index
756
+
757
+ def get_index_collection_name(self, indexer: Indexer) -> str:
758
+ return self._index_collection_name(indexer.name)
716
759
 
717
760
  def _index_collection_name(self, index_name: str) -> str:
718
761
  """
@@ -268,7 +268,7 @@ class Database(ABC, Generic[CollectionType]):
268
268
  metadata: Optional[CollectionConfig] = None,
269
269
  recreate_if_exists=False,
270
270
  **kwargs,
271
- ) -> CollectionType:
271
+ ) -> Collection:
272
272
  """
273
273
  Create a new collection in the current database.
274
274
 
@@ -760,6 +760,12 @@ class Database(ABC, Generic[CollectionType]):
760
760
  """
761
761
  Export a database to a file or location.
762
762
 
763
+ >>> from linkml_store.api.client import Client
764
+ >>> client = Client()
765
+ >>> db = client.attach_database("duckdb", alias="test")
766
+ >>> db.import_database("tests/input/iris.csv", Format.CSV, collection_name="iris")
767
+ >>> db.export_database("/tmp/iris.yaml", Format.YAML)
768
+
763
769
  :param location: location of the file
764
770
  :param target_format: target format
765
771
  :param kwargs: additional arguments
@@ -40,7 +40,9 @@ class FacetCountResult(BaseModel):
40
40
 
41
41
  class QueryResult(BaseModel):
42
42
  """
43
- A query result
43
+ A query result.
44
+
45
+ TODO: make this a subclass of Collection
44
46
  """
45
47
 
46
48
  query: Optional[Query] = None
@@ -50,8 +50,9 @@ class DuckDBCollection(Collection):
50
50
  if not isinstance(objs, list):
51
51
  objs = [objs]
52
52
  cd = self.class_definition()
53
- if not cd:
53
+ if not cd or not cd.attributes:
54
54
  cd = self.induce_class_definition_from_objects(objs)
55
+ assert cd.attributes
55
56
  table = self._sqla_table(cd)
56
57
  engine = self.parent.engine
57
58
  with engine.connect() as conn:
@@ -61,7 +62,8 @@ class DuckDBCollection(Collection):
61
62
  stmt = stmt.compile(engine)
62
63
  conn.execute(stmt)
63
64
  conn.commit()
64
- return
65
+ self._post_delete_hook()
66
+ return None
65
67
 
66
68
  def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> Optional[int]:
67
69
  logger.info(f"Deleting from {self.target_class_name} where: {where}")
@@ -87,6 +89,7 @@ class DuckDBCollection(Collection):
87
89
  if deleted_rows_count == 0 and not missing_ok:
88
90
  raise ValueError(f"No rows found for {where}")
89
91
  conn.commit()
92
+ self._post_delete_hook()
90
93
  return deleted_rows_count if deleted_rows_count > -1 else None
91
94
 
92
95
  def query_facets(
@@ -76,6 +76,8 @@ class ContextSettings(BaseModel):
76
76
  if name is None:
77
77
  # if len(self.database.list_collections()) > 1:
78
78
  # raise ValueError("Collection must be specified if there are multiple collections.")
79
+ if not self.database:
80
+ return None
79
81
  if not self.database.list_collections():
80
82
  return None
81
83
  name = list(self.database.list_collections())[0]
@@ -218,7 +220,10 @@ def insert(ctx, files, object, format):
218
220
  @click.option("--object", "-i", multiple=True, help="Input object as YAML")
219
221
  @click.pass_context
220
222
  def store(ctx, files, object, format):
221
- """Store objects from files (JSON, YAML, TSV) into the specified collection."""
223
+ """Store objects from files (JSON, YAML, TSV) into the database.
224
+
225
+ Note: this is similar to insert, but a collection does not need to be specified
226
+ """
222
227
  settings = ctx.obj["settings"]
223
228
  db = settings.database
224
229
  if not files and not object:
@@ -499,6 +504,7 @@ def describe(ctx, where, output_type, output, limit):
499
504
  "--predictor-type", "-t", default="sklearn", show_default=True, type=click.STRING, help="Type of predictor"
500
505
  )
501
506
  @click.option("--evaluation-count", "-n", type=click.INT, help="Number of examples to evaluate over")
507
+ @click.option("--evaluation-match-function", help="Name of function to use for matching objects in eval")
502
508
  @click.option("--query", "-q", type=click.STRING, help="query term")
503
509
  @click.pass_context
504
510
  def infer(
@@ -506,6 +512,7 @@ def infer(
506
512
  inference_config_file,
507
513
  query,
508
514
  evaluation_count,
515
+ evaluation_match_function,
509
516
  training_test_data_split,
510
517
  predictor_type,
511
518
  target_attribute,
@@ -549,7 +556,10 @@ def infer(
549
556
  else:
550
557
  query_obj = None
551
558
  collection = ctx.obj["settings"].collection
552
- atts = collection.class_definition().attributes.keys()
559
+ if collection:
560
+ atts = collection.class_definition().attributes.keys()
561
+ else:
562
+ atts = []
553
563
  if feature_attributes:
554
564
  features = feature_attributes.split(",")
555
565
  features = [f.strip() for f in features]
@@ -575,7 +585,8 @@ def infer(
575
585
  if training_test_data_split:
576
586
  config.train_test_split = training_test_data_split
577
587
  predictor = get_inference_engine(predictor_type, config=config)
578
- predictor.load_and_split_data(collection)
588
+ if collection:
589
+ predictor.load_and_split_data(collection)
579
590
  predictor.initialize_model()
580
591
  if export_model:
581
592
  logger.info(f"Exporting model to {export_model} in {model_format}")
@@ -584,8 +595,14 @@ def infer(
584
595
  if not export_model and not evaluation_count:
585
596
  raise ValueError("Query or evaluate must be specified if not exporting model")
586
597
  if evaluation_count:
598
+ if evaluation_match_function == "score_text_overlap":
599
+ match_function_fn = score_text_overlap
600
+ elif evaluation_match_function is not None:
601
+ raise ValueError(f"Unknown match function: {evaluation_match_function}")
602
+ else:
603
+ match_function_fn = None
587
604
  outcome = evaluate_predictor(
588
- predictor, target_attributes, evaluation_count=evaluation_count, match_function=score_text_overlap
605
+ predictor, target_attributes, evaluation_count=evaluation_count, match_function=match_function_fn
589
606
  )
590
607
  print(f"Outcome: {outcome} // accuracy: {outcome.accuracy}")
591
608
  if query_obj:
@@ -1,11 +1,13 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import TYPE_CHECKING, List
3
+ from typing import TYPE_CHECKING, List, Optional
4
4
 
5
5
  import numpy as np
6
+ from tiktoken import encoding_for_model
6
7
 
7
8
  from linkml_store.api.config import CollectionConfig
8
9
  from linkml_store.index.indexer import INDEX_ITEM, Indexer
10
+ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
9
11
 
10
12
  if TYPE_CHECKING:
11
13
  import llm
@@ -29,6 +31,7 @@ class LLMIndexer(Indexer):
29
31
  cached_embeddings_database: str = None
30
32
  cached_embeddings_collection: str = None
31
33
  cache_queries: bool = False
34
+ truncation_method: Optional[str] = None
32
35
 
33
36
  @property
34
37
  def embedding_model(self):
@@ -62,6 +65,21 @@ class LLMIndexer(Indexer):
62
65
  """
63
66
  logging.info(f"Converting {len(texts)} texts to vectors")
64
67
  model = self.embedding_model
68
+ token_limit = get_token_limit(model.model_id)
69
+ encoding = encoding_for_model("gpt-4o")
70
+
71
+ def truncate_text(text: str) -> str:
72
+ # split into tokens every 1000 chars:
73
+ parts = [text[i : i + 1000] for i in range(0, len(text), 1000)]
74
+ return render_formatted_text(
75
+ lambda x: "".join(x),
76
+ parts,
77
+ encoding,
78
+ token_limit,
79
+ )
80
+
81
+ texts = [truncate_text(text) for text in texts]
82
+
65
83
  if self.cached_embeddings_database and (cache is None or cache or self.cache_queries):
66
84
  model_id = model.model_id
67
85
  if not model_id:
@@ -88,7 +106,7 @@ class LLMIndexer(Indexer):
88
106
  embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
89
107
  else:
90
108
  embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
91
- texts = list(texts)
109
+
92
110
  embeddings = list([None] * len(texts))
93
111
  uncached_texts = []
94
112
  n = 0
@@ -36,6 +36,54 @@ def cosine_similarity(vector1, vector2) -> float:
36
36
  class Indexer(BaseModel):
37
37
  """
38
38
  An indexer operates on a collection in order to search for objects.
39
+
40
+ You should use a subcllass of this; this can be looked up dynqamically:
41
+
42
+ >>> from linkml_store.index import get_indexer
43
+ >>> indexer = get_indexer("simple")
44
+
45
+ You can customize how objects are indexed by passing in a text template.
46
+ For example, if your collection has objects with "name" and "profession" attributes,
47
+ you can index them as "{name} {profession}".
48
+
49
+ >>> indexer = get_indexer("simple", text_template="{name} :: {profession}")
50
+
51
+ By default, python fstrings are assumed.
52
+
53
+ We can test this works using the :ref:`object_to_text` method (normally
54
+ you would never need to call this directly, but it's useful for testing):
55
+
56
+ >>> obj = {"name": "John", "profession": "doctor"}
57
+ >>> indexer.object_to_text(obj)
58
+ 'John :: doctor'
59
+
60
+ You can also use Jinja2 templates; this gives more flexibility and logic,
61
+ e.g. conditional formatting:
62
+
63
+ >>> tmpl = "{{name}}{% if profession %} :: {{profession}}{% endif %}"
64
+ >>> indexer = get_indexer("simple", text_template=tmpl, text_template_syntax=TemplateSyntaxEnum.jinja2)
65
+ >>> indexer.object_to_text(obj)
66
+ 'John :: doctor'
67
+ >>> indexer.object_to_text({"name": "John"})
68
+ 'John'
69
+
70
+ You can also specify which attributes to index:
71
+
72
+ >>> indexer = get_indexer("simple", index_attributes=["name"])
73
+ >>> indexer.object_to_text(obj)
74
+ 'John'
75
+
76
+ The purpose of an indexer is to translate a collection of objects into a collection of objects
77
+ such as vectors for purposes such as search. Unless you are implementing your own indexer, you
78
+ generally don't need to use the methods that return vectors, but we can examine their behavior
79
+ to get a sense of how they work.
80
+
81
+ >>> vectors = indexer.objects_to_vectors([{"name": "Aardvark"}, {"name": "Aardwolf"}, {"name": "Zesty"}])
82
+ >>> assert cosine_similarity(vectors[0], vectors[1]) > cosine_similarity(vectors[0], vectors[2])
83
+
84
+ Note you should consult the documentation for the specific indexer you are using for more details on
85
+ how text is converted to vectors.
86
+
39
87
  """
40
88
 
41
89
  name: Optional[str] = None
@@ -122,7 +170,9 @@ class Indexer(BaseModel):
122
170
  self, query: str, vectors: List[Tuple[str, INDEX_ITEM]], limit: Optional[int] = None
123
171
  ) -> List[Tuple[float, Any]]:
124
172
  """
125
- Search the index for a query string
173
+ Use the indexer to search against a database of vectors.
174
+
175
+ Note: this is a low-level method, typically you would use the :ref:`search` method on a :ref:`Collection`.
126
176
 
127
177
  :param query: The query string to search for
128
178
  :param vectors: A list of indexed items, where each item is a tuple of (id, vector)
@@ -20,6 +20,8 @@ def score_match(target: Optional[Any], candidate: Optional[Any], match_function:
20
20
  1.0
21
21
  >>> score_match("a", "b")
22
22
  0.0
23
+ >>> score_match("abcd", "abcde")
24
+ 0.0
23
25
  >>> score_match("a", None)
24
26
  0.0
25
27
  >>> score_match(None, "a")
@@ -52,7 +54,7 @@ def score_match(target: Optional[Any], candidate: Optional[Any], match_function:
52
54
 
53
55
  :param target:
54
56
  :param candidate:
55
- :param match_function:
57
+ :param match_function: defaults to struct
56
58
  :return:
57
59
  """
58
60
  if target == candidate:
@@ -99,7 +101,8 @@ def evaluate_predictor(
99
101
  :param predictor:
100
102
  :param target_attributes:
101
103
  :param feature_attributes:
102
- :param evaluation_count:
104
+ :param evaluation_count: max iterations
105
+ :param match_function: function to use for matching
103
106
  :return:
104
107
  """
105
108
  n = 0
@@ -113,8 +116,8 @@ def evaluate_predictor(
113
116
  else:
114
117
  test_obj = row
115
118
  result = predictor.derive(test_obj)
116
- logger.info(f"Predicted: {result.predicted_object} Expected: {expected_obj}")
117
119
  tp += score_match(result.predicted_object, expected_obj, match_function)
120
+ logger.info(f"TP={tp} MF={match_function} Predicted: {result.predicted_object} Expected: {expected_obj}")
118
121
  n += 1
119
122
  if evaluation_count is not None and n >= evaluation_count:
120
123
  break
@@ -125,6 +128,9 @@ def score_text_overlap(str1: Any, str2: Any) -> float:
125
128
  """
126
129
  Compute the overlap score between two strings.
127
130
 
131
+ >>> score_text_overlap("abc", "bcde")
132
+ 0.5
133
+
128
134
  :param str1:
129
135
  :param str2:
130
136
  :return:
@@ -0,0 +1,232 @@
1
+ import json
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import ClassVar, List, Optional, TextIO, Union
6
+
7
+ import yaml
8
+ from llm import get_key
9
+ from pydantic import BaseModel
10
+
11
+ from linkml_store.api.collection import OBJECT, Collection
12
+ from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
13
+ from linkml_store.inference.inference_engine import InferenceEngine, ModelSerialization
14
+ from linkml_store.utils.object_utils import select_nested
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ SYSTEM_PROMPT = """
19
+ You are a {llm_config.role}, your task is to inference the YAML
20
+ object output given the YAML object input. I will provide you
21
+ with a collection of examples that will provide guidance both
22
+ on the desired structure of the response, as well as the kind
23
+ of content.
24
+
25
+ You should return ONLY valid YAML in your response.
26
+ """
27
+
28
+
29
+ class TrainedModel(BaseModel, extra="forbid"):
30
+ rag_collection_rows: List[OBJECT]
31
+ index_rows: List[OBJECT]
32
+ config: Optional[InferenceConfig] = None
33
+
34
+
35
+ @dataclass
36
+ class RAGInferenceEngine(InferenceEngine):
37
+ """
38
+ AI Retrieval Augmented Generation (RAG) based predictor.
39
+
40
+
41
+ >>> from linkml_store.api.client import Client
42
+ >>> from linkml_store.utils.format_utils import Format
43
+ >>> from linkml_store.inference.inference_config import LLMConfig
44
+ >>> client = Client()
45
+ >>> db = client.attach_database("duckdb", alias="test")
46
+ >>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
47
+ >>> db.list_collection_names()
48
+ ['countries']
49
+ >>> collection = db.get_collection("countries")
50
+ >>> features = ["name"]
51
+ >>> targets = ["code", "capital", "continent", "languages"]
52
+ >>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
53
+ >>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
54
+ >>> ie = RAGInferenceEngine(config=config)
55
+ >>> ie.load_and_split_data(collection)
56
+ >>> ie.initialize_model()
57
+ >>> prediction = ie.derive({"name": "Uruguay"})
58
+ >>> prediction.predicted_object
59
+ {'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
60
+
61
+ The "model" can be saved for later use:
62
+
63
+ >>> ie.export_model("tests/output/countries.rag_model.json")
64
+
65
+ Note in this case the model is not the underlying LLM, but the "RAG Model" which is the vectorized
66
+ representation of training set objects.
67
+
68
+ """
69
+
70
+ _model: "llm.Model" = None # noqa: F821
71
+
72
+ rag_collection: Collection = None
73
+
74
+ PERSIST_COLS: ClassVar[List[str]] = [
75
+ "config",
76
+ ]
77
+
78
+ def __post_init__(self):
79
+ if not self.config:
80
+ self.config = InferenceConfig()
81
+ if not self.config.llm_config:
82
+ self.config.llm_config = LLMConfig()
83
+
84
+ @property
85
+ def model(self) -> "llm.Model": # noqa: F821
86
+ import llm
87
+
88
+ if self._model is None:
89
+ self._model = llm.get_model(self.config.llm_config.model_name)
90
+ if self._model.needs_key:
91
+ key = get_key(None, key_alias=self._model.needs_key)
92
+ self._model.key = key
93
+
94
+ return self._model
95
+
96
+ def initialize_model(self, **kwargs):
97
+ logger.info(f"Initializing model {self.model}")
98
+ if self.training_data:
99
+ rag_collection = self.training_data.collection
100
+ rag_collection.attach_indexer("llm", auto_index=False)
101
+ self.rag_collection = rag_collection
102
+
103
+ def object_to_text(self, object: OBJECT) -> str:
104
+ return yaml.dump(object)
105
+
106
+ def derive(self, object: OBJECT) -> Optional[Inference]:
107
+ import llm
108
+ from tiktoken import encoding_for_model
109
+
110
+ from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
111
+
112
+ model: llm.Model = self.model
113
+ model_name = self.config.llm_config.model_name
114
+ feature_attributes = self.config.feature_attributes
115
+ target_attributes = self.config.target_attributes
116
+ num_examples = self.config.llm_config.number_of_few_shot_examples or 5
117
+ query_text = self.object_to_text(object)
118
+ if not self.rag_collection:
119
+ # TODO: zero-shot mode
120
+ examples = []
121
+ else:
122
+ if not self.rag_collection.indexers:
123
+ raise ValueError("RAG collection must have an indexer attached")
124
+ rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
125
+ examples = rs.rows
126
+ if not examples:
127
+ raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
128
+ prompt_clauses = []
129
+ query_obj = select_nested(object, feature_attributes)
130
+ query_text = self.object_to_text(query_obj)
131
+ for example in examples:
132
+ input_obj = select_nested(example, feature_attributes)
133
+ input_obj_text = self.object_to_text(input_obj)
134
+ if input_obj_text == query_text:
135
+ raise ValueError(
136
+ f"Query object {query_text} is the same as example object {input_obj_text}\n"
137
+ "This indicates possible test data leakage\n."
138
+ "TODO: allow an option that allows user to treat this as a basic lookup\n"
139
+ )
140
+ output_obj = select_nested(example, target_attributes)
141
+ prompt_clause = (
142
+ "---\nExample:\n" f"## INPUT:\n{input_obj_text}\n" f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
143
+ )
144
+ prompt_clauses.append(prompt_clause)
145
+
146
+ prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
147
+ system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
148
+
149
+ def make_text(texts):
150
+ return "\n".join(prompt_clauses) + prompt_end
151
+
152
+ try:
153
+ encoding = encoding_for_model(model_name)
154
+ except KeyError:
155
+ encoding = encoding_for_model("gpt-4")
156
+ token_limit = get_token_limit(model_name)
157
+ prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
158
+ logger.info(f"Prompt: {prompt}")
159
+ response = model.prompt(prompt, system_prompt)
160
+ yaml_str = response.text()
161
+ logger.info(f"Response: {yaml_str}")
162
+ return Inference(predicted_object=self._parse_yaml_payload(yaml_str))
163
+
164
+ def _parse_yaml_payload(self, yaml_str: str, strict=False) -> Optional[OBJECT]:
165
+ if "```" in yaml_str:
166
+ yaml_str = yaml_str.split("```")[1].strip()
167
+ if yaml_str.startswith("yaml"):
168
+ yaml_str = yaml_str[4:].strip()
169
+ try:
170
+ return yaml.safe_load(yaml_str)
171
+ except Exception as e:
172
+ if strict:
173
+ raise e
174
+ logger.error(f"Error parsing YAML: {yaml_str}\n{e}")
175
+ return None
176
+
177
+ def export_model(
178
+ self, output: Optional[Union[str, Path, TextIO]], model_serialization: ModelSerialization = None, **kwargs
179
+ ):
180
+ self.save_model(output)
181
+
182
+ def save_model(self, output: Union[str, Path]) -> None:
183
+ """
184
+ Save the trained model and related data to a file.
185
+
186
+ :param output: Path to save the model
187
+ """
188
+
189
+ # trigger index
190
+ _qr = self.rag_collection.search("*", limit=1)
191
+ assert len(_qr.ranked_rows) > 0
192
+
193
+ rows = self.rag_collection.find(limit=-1).rows
194
+
195
+ indexers = self.rag_collection.indexers
196
+ assert len(indexers) == 1
197
+ ix = self.rag_collection.indexers["llm"]
198
+ ix_coll = self.rag_collection.parent.get_collection(self.rag_collection.get_index_collection_name(ix))
199
+
200
+ ix_rows = ix_coll.find(limit=-1).rows
201
+ assert len(ix_rows) > 0
202
+ tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows, config=self.config)
203
+ # tm = TrainedModel(rag_collection_rows=rows, index_rows=ix_rows)
204
+ with open(output, "w", encoding="utf-8") as f:
205
+ json.dump(tm.model_dump(), f)
206
+
207
+ @classmethod
208
+ def load_model(cls, file_path: Union[str, Path]) -> "RAGInferenceEngine":
209
+ """
210
+ Load a trained model and related data from a file.
211
+
212
+ :param file_path: Path to the saved model
213
+ :return: SklearnInferenceEngine instance with loaded model
214
+ """
215
+ with open(file_path, "r", encoding="utf-8") as f:
216
+ model_data = json.load(f)
217
+ tm = TrainedModel(**model_data)
218
+ from linkml_store.api import Client
219
+
220
+ client = Client()
221
+ db = client.attach_database("duckdb", alias="training")
222
+ db.store({"data": tm.rag_collection_rows})
223
+ collection = db.get_collection("data")
224
+ ix = collection.attach_indexer("llm", auto_index=False)
225
+ assert ix.name
226
+ ix_coll_name = collection.get_index_collection_name(ix)
227
+ assert ix_coll_name
228
+ ix_coll = db.get_collection(ix_coll_name, create_if_not_exists=True)
229
+ ix_coll.insert(tm.index_rows)
230
+ ie = cls(config=tm.config)
231
+ ie.rag_collection = collection
232
+ return ie
@@ -153,7 +153,7 @@ class SklearnInferenceEngine(InferenceEngine):
153
153
  y = y_encoder.fit_transform(y.values.ravel()) # Convert to 1D numpy array
154
154
  self.transformed_targets = y_encoder.classes_
155
155
 
156
- logger.info(f"Fitting model with features: {X.columns}")
156
+ # print(f"Fitting model with features: {X.columns}")
157
157
  clf = DecisionTreeClassifier(random_state=42)
158
158
  clf.fit(X, y)
159
159
  self.classifier = clf
@@ -35,6 +35,7 @@ class InferenceConfig(BaseModel, extra="forbid"):
35
35
  feature_attributes: Optional[List[str]] = None
36
36
  train_test_split: Optional[Tuple[float, float]] = None
37
37
  llm_config: Optional[LLMConfig] = None
38
+ random_seed: Optional[int] = None
38
39
 
39
40
  @classmethod
40
41
  def from_file(cls, file_path: str, format: Optional[Format] = None) -> "InferenceConfig":
@@ -29,6 +29,7 @@ class ModelSerialization(str, Enum):
29
29
  PNG = "png"
30
30
  LINKML_EXPRESSION = "linkml_expression"
31
31
  RULE_BASED = "rulebased"
32
+ RAG_INDEX = "rag_index"
32
33
 
33
34
  @classmethod
34
35
  def from_filepath(cls, file_path: str) -> Optional["ModelSerialization"]:
@@ -58,7 +59,7 @@ class ModelSerialization(str, Enum):
58
59
 
59
60
 
60
61
  class CollectionSlice(BaseModel):
61
- model_config = ConfigDict(arbitrary_types_allowed=True)
62
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
62
63
 
63
64
  name: Optional[str] = None
64
65
  base_collection: Optional[Collection] = None
@@ -69,17 +70,26 @@ class CollectionSlice(BaseModel):
69
70
 
70
71
  @property
71
72
  def collection(self) -> Collection:
73
+ if not self._collection and not self.indices:
74
+ return self.base_collection
72
75
  if not self._collection:
73
76
  rows = self.base_collection.find({}, limit=-1).rows
74
- # subset based on indices
75
77
  subset = [rows[i] for i in self.indices]
76
78
  db = self.base_collection.parent
77
- subset_name = f"{self.base_collection.alias}__rag_{self.name}"
79
+ subset_name = self.slice_alias
78
80
  subset_collection = db.get_collection(subset_name, create_if_not_exists=True)
81
+ # ensure the collection has the same schema type as the base collection;
82
+ # this ensures that column/attribute types are preserved
83
+ subset_collection.metadata.type = self.base_collection.target_class_name
84
+ subset_collection.delete_where({})
79
85
  subset_collection.insert(subset)
80
86
  self._collection = subset_collection
81
87
  return self._collection
82
88
 
89
+ @property
90
+ def slice_alias(self) -> str:
91
+ return f"{self.base_collection.alias}__rag_{self.name}"
92
+
83
93
  def as_dataframe(self, flattened=False) -> pd.DataFrame:
84
94
  """
85
95
  Return the slice of the collection as a dataframe.
@@ -113,31 +123,28 @@ class InferenceEngine(ABC):
113
123
 
114
124
  :param collection:
115
125
  :param split:
126
+ :param randomize:
116
127
  :return:
117
128
  """
129
+ local_random = random.Random(self.config.random_seed) if self.config.random_seed else random.Random()
118
130
  split = split or self.config.train_test_split
119
131
  if not split:
120
132
  split = (0.7, 0.3)
133
+ if split[0] == 1.0:
134
+ self.training_data = CollectionSlice(name="train", base_collection=collection, indices=None)
135
+ self.testing_data = None
136
+ return
121
137
  logger.info(f"Loading and splitting data from collection {collection.alias}")
122
138
  size = collection.size()
123
139
  indices = range(size)
124
140
  if randomize:
125
- train_indices = random.sample(indices, int(size * split[0]))
141
+ train_indices = local_random.sample(indices, int(size * split[0]))
126
142
  test_indices = set(indices) - set(train_indices)
127
143
  else:
128
144
  train_indices = indices[: int(size * split[0])]
129
145
  test_indices = indices[int(size * split[0]) :]
130
146
  self.training_data = CollectionSlice(name="train", base_collection=collection, indices=train_indices)
131
147
  self.testing_data = CollectionSlice(name="test", base_collection=collection, indices=test_indices)
132
- # all_data = collection.find({}, limit=size).rows
133
- # all_data_df = nested_objects_to_dataframe(all_data)
134
- # all_data_df = collection.find({}, limit=size).rows_dataframe
135
- # randomize/shuffle order of rows in dataframe
136
- # all_data_df = all_data_df.sample(frac=1).reset_index(drop=True)
137
- # self.training_data = CollectionSlice(dataframe=all_data_df[: int(size * split[0])])
138
- # self.testing_data = CollectionSlice(dataframe=all_data_df[int(size * split[0]) : size])
139
- # self.training_data = CollectionSlice(base_collection=collection, slice=(0, int(size * split[0])))
140
- # self.testing_data = CollectionSlice(base_collection=collection, slice=(int(size * split[0]), size))
141
148
 
142
149
  def initialize_model(self, **kwargs):
143
150
  """
@@ -20,6 +20,7 @@ MODEL_TOKEN_MAPPING = {
20
20
  "gpt-3.5-turbo-instruct": 4096,
21
21
  "text-ada-001": 2049,
22
22
  "ada": 2049,
23
+ "ada-002": 8192,
23
24
  "text-babbage-001": 2040,
24
25
  "babbage": 2049,
25
26
  "text-curie-001": 2049,
@@ -1,145 +0,0 @@
1
- import logging
2
- from dataclasses import dataclass
3
- from typing import Any, Optional
4
-
5
- import yaml
6
- from llm import get_key
7
-
8
- from linkml_store.api.collection import OBJECT, Collection
9
- from linkml_store.inference.inference_config import Inference, InferenceConfig, LLMConfig
10
- from linkml_store.inference.inference_engine import InferenceEngine
11
- from linkml_store.utils.object_utils import select_nested
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- SYSTEM_PROMPT = """
16
- You are a {llm_config.role}, your task is to inference the YAML
17
- object output given the YAML object input. I will provide you
18
- with a collection of examples that will provide guidance both
19
- on the desired structure of the response, as well as the kind
20
- of content.
21
-
22
- You should return ONLY valid YAML in your response.
23
- """
24
-
25
-
26
- # def select_object(obj: OBJECT, key_paths: List[str]) -> OBJECT:
27
- # return {k: obj.get(k, None) for k in keys}
28
- # return {k: object_path_get(obj, k, None) for k in key_paths}
29
-
30
-
31
- @dataclass
32
- class RAGInferenceEngine(InferenceEngine):
33
- """
34
- AI Retrieval Augmented Generation (RAG) based predictor.
35
-
36
-
37
- >>> from linkml_store.api.client import Client
38
- >>> from linkml_store.utils.format_utils import Format
39
- >>> from linkml_store.inference.inference_config import LLMConfig
40
- >>> client = Client()
41
- >>> db = client.attach_database("duckdb", alias="test")
42
- >>> db.import_database("tests/input/countries/countries.jsonl", Format.JSONL, collection_name="countries")
43
- >>> db.list_collection_names()
44
- ['countries']
45
- >>> collection = db.get_collection("countries")
46
- >>> features = ["name"]
47
- >>> targets = ["code", "capital", "continent", "languages"]
48
- >>> llm_config = LLMConfig(model_name="gpt-4o-mini",)
49
- >>> config = InferenceConfig(target_attributes=targets, feature_attributes=features, llm_config=llm_config)
50
- >>> ie = RAGInferenceEngine(config=config)
51
- >>> ie.load_and_split_data(collection)
52
- >>> ie.initialize_model()
53
- >>> prediction = ie.derive({"name": "Uruguay"})
54
- >>> prediction.predicted_object
55
- {'capital': 'Montevideo', 'code': 'UY', 'continent': 'South America', 'languages': ['Spanish']}
56
-
57
- """
58
-
59
- classifier: Any = None
60
- encoders: dict = None
61
- _model: "llm.Model" = None # noqa: F821
62
-
63
- rag_collection: Collection = None
64
-
65
- def __post_init__(self):
66
- if not self.config:
67
- self.config = InferenceConfig()
68
- if not self.config.llm_config:
69
- self.config.llm_config = LLMConfig()
70
-
71
- @property
72
- def model(self) -> "llm.Model": # noqa: F821
73
- import llm
74
-
75
- if self._model is None:
76
- self._model = llm.get_model(self.config.llm_config.model_name)
77
- if self._model.needs_key:
78
- key = get_key(None, key_alias=self._model.needs_key)
79
- self._model.key = key
80
-
81
- return self._model
82
-
83
- def initialize_model(self, **kwargs):
84
- rag_collection = self.training_data.collection
85
- rag_collection.attach_indexer("llm", auto_index=False)
86
- self.rag_collection = rag_collection
87
-
88
- def object_to_text(self, object: OBJECT) -> str:
89
- return yaml.dump(object)
90
-
91
- def derive(self, object: OBJECT) -> Optional[Inference]:
92
- import llm
93
- from tiktoken import encoding_for_model
94
-
95
- from linkml_store.utils.llm_utils import get_token_limit, render_formatted_text
96
-
97
- model: llm.Model = self.model
98
- model_name = self.config.llm_config.model_name
99
- feature_attributes = self.config.feature_attributes
100
- target_attributes = self.config.target_attributes
101
- num_examples = self.config.llm_config.number_of_few_shot_examples or 5
102
- query_text = self.object_to_text(object)
103
- if not self.rag_collection.indexers:
104
- raise ValueError("RAG collection must have an indexer attached")
105
- rs = self.rag_collection.search(query_text, limit=num_examples, index_name="llm")
106
- examples = rs.rows
107
- if not examples:
108
- raise ValueError(f"No examples found for {query_text}; size = {self.rag_collection.size()}")
109
- prompt_clauses = []
110
- for example in examples:
111
- # input_obj = {k: example.get(k, None) for k in feature_attributes}
112
- input_obj = select_nested(example, feature_attributes)
113
- # output_obj = {k: example.get(k, None) for k in target_attributes}
114
- output_obj = select_nested(example, target_attributes)
115
- prompt_clause = (
116
- "---\nExample:\n"
117
- f"## INPUT:\n{self.object_to_text(input_obj)}\n"
118
- f"## OUTPUT:\n{self.object_to_text(output_obj)}\n"
119
- )
120
- prompt_clauses.append(prompt_clause)
121
- # query_obj = {k: object.get(k, None) for k in feature_attributes}
122
- query_obj = select_nested(object, feature_attributes)
123
- query_text = self.object_to_text(query_obj)
124
- prompt_end = "---\nQuery:\n" f"## INPUT:\n{query_text}\n" "## OUTPUT:\n"
125
- system_prompt = SYSTEM_PROMPT.format(llm_config=self.config.llm_config)
126
-
127
- def make_text(texts):
128
- return "\n".join(prompt_clauses) + prompt_end
129
-
130
- try:
131
- encoding = encoding_for_model(model_name)
132
- except KeyError:
133
- encoding = encoding_for_model("gpt-4")
134
- token_limit = get_token_limit(model_name)
135
- prompt = render_formatted_text(make_text, prompt_clauses, encoding, token_limit)
136
- logger.info(f"Prompt: {prompt}")
137
- response = model.prompt(prompt, system_prompt)
138
- yaml_str = response.text()
139
- logger.info(f"Response: {yaml_str}")
140
- try:
141
- predicted_object = yaml.safe_load(yaml_str)
142
- return Inference(predicted_object=predicted_object)
143
- except yaml.parser.ParserError as e:
144
- logger.error(f"Error parsing response: {yaml_str}\n{e}")
145
- return None
File without changes
File without changes