linkml-store 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of linkml-store might be problematic. Click here for more details.
- linkml_store/api/client.py +30 -5
- linkml_store/api/collection.py +175 -21
- linkml_store/api/config.py +6 -2
- linkml_store/api/database.py +230 -18
- linkml_store/api/stores/chromadb/__init__.py +5 -1
- linkml_store/api/stores/duckdb/__init__.py +9 -0
- linkml_store/api/stores/duckdb/duckdb_collection.py +6 -4
- linkml_store/api/stores/duckdb/duckdb_database.py +19 -5
- linkml_store/api/stores/duckdb/mappings.py +1 -0
- linkml_store/api/stores/filesystem/__init__.py +16 -0
- linkml_store/api/stores/filesystem/filesystem_collection.py +142 -0
- linkml_store/api/stores/filesystem/filesystem_database.py +36 -0
- linkml_store/api/stores/hdf5/__init__.py +7 -0
- linkml_store/api/stores/mongodb/__init__.py +25 -0
- linkml_store/api/stores/mongodb/mongodb_collection.py +21 -6
- linkml_store/cli.py +64 -10
- linkml_store/index/__init__.py +6 -2
- linkml_store/index/implementations/llm_indexer.py +83 -5
- linkml_store/index/implementations/simple_indexer.py +2 -2
- linkml_store/index/indexer.py +32 -8
- linkml_store/utils/format_utils.py +52 -2
- {linkml_store-0.1.7.dist-info → linkml_store-0.1.8.dist-info}/METADATA +4 -1
- linkml_store-0.1.8.dist-info/RECORD +45 -0
- linkml_store-0.1.7.dist-info/RECORD +0 -42
- {linkml_store-0.1.7.dist-info → linkml_store-0.1.8.dist-info}/LICENSE +0 -0
- {linkml_store-0.1.7.dist-info → linkml_store-0.1.8.dist-info}/WHEEL +0 -0
- {linkml_store-0.1.7.dist-info → linkml_store-0.1.8.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import sqlalchemy as sqla
|
|
5
|
+
from linkml_runtime.linkml_model import ClassDefinition, SlotDefinition
|
|
6
|
+
from sqlalchemy import Column, Table, delete, insert, inspect, text
|
|
7
|
+
from sqlalchemy.sql.ddl import CreateTable
|
|
8
|
+
|
|
9
|
+
from linkml_store.api import Collection
|
|
10
|
+
from linkml_store.api.collection import DEFAULT_FACET_LIMIT, OBJECT
|
|
11
|
+
from linkml_store.api.queries import Query
|
|
12
|
+
from linkml_store.api.stores.duckdb.mappings import TMAP
|
|
13
|
+
from linkml_store.utils.sql_utils import facet_count_sql
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FileSystemCollection(Collection):
|
|
19
|
+
_table_created: bool = None
|
|
20
|
+
|
|
21
|
+
def insert(self, objs: Union[OBJECT, List[OBJECT]], **kwargs):
|
|
22
|
+
if not isinstance(objs, list):
|
|
23
|
+
objs = [objs]
|
|
24
|
+
if not objs:
|
|
25
|
+
return
|
|
26
|
+
cd = self.class_definition()
|
|
27
|
+
if not cd:
|
|
28
|
+
cd = self.induce_class_definition_from_objects(objs)
|
|
29
|
+
self._create_table(cd)
|
|
30
|
+
table = self._sqla_table(cd)
|
|
31
|
+
logger.info(f"Inserting into: {self.alias} // T={table.name}")
|
|
32
|
+
engine = self.parent.engine
|
|
33
|
+
col_names = [c.name for c in table.columns]
|
|
34
|
+
objs = [{k: obj.get(k, None) for k in col_names} for obj in objs]
|
|
35
|
+
with engine.connect() as conn:
|
|
36
|
+
with conn.begin():
|
|
37
|
+
conn.execute(insert(table), objs)
|
|
38
|
+
conn.commit()
|
|
39
|
+
|
|
40
|
+
def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> Optional[int]:
|
|
41
|
+
if not isinstance(objs, list):
|
|
42
|
+
objs = [objs]
|
|
43
|
+
cd = self.class_definition()
|
|
44
|
+
if not cd:
|
|
45
|
+
cd = self.induce_class_definition_from_objects(objs)
|
|
46
|
+
table = self._sqla_table(cd)
|
|
47
|
+
engine = self.parent.engine
|
|
48
|
+
with engine.connect() as conn:
|
|
49
|
+
for obj in objs:
|
|
50
|
+
conditions = [table.c[k] == v for k, v in obj.items() if k in cd.attributes]
|
|
51
|
+
stmt = delete(table).where(*conditions)
|
|
52
|
+
stmt = stmt.compile(engine)
|
|
53
|
+
conn.execute(stmt)
|
|
54
|
+
conn.commit()
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> Optional[int]:
|
|
58
|
+
logger.info(f"Deleting from {self.target_class_name} where: {where}")
|
|
59
|
+
if where is None:
|
|
60
|
+
where = {}
|
|
61
|
+
cd = self.class_definition()
|
|
62
|
+
if not cd:
|
|
63
|
+
logger.info(f"No class definition found for {self.target_class_name}, assuming not prepopulated")
|
|
64
|
+
return 0
|
|
65
|
+
table = self._sqla_table(cd)
|
|
66
|
+
engine = self.parent.engine
|
|
67
|
+
inspector = inspect(engine)
|
|
68
|
+
table_exists = table.name in inspector.get_table_names()
|
|
69
|
+
if not table_exists:
|
|
70
|
+
logger.info(f"Table {table.name} does not exist, assuming no data")
|
|
71
|
+
return 0
|
|
72
|
+
with engine.connect() as conn:
|
|
73
|
+
conditions = [table.c[k] == v for k, v in where.items()]
|
|
74
|
+
stmt = delete(table).where(*conditions)
|
|
75
|
+
stmt = stmt.compile(engine)
|
|
76
|
+
result = conn.execute(stmt)
|
|
77
|
+
deleted_rows_count = result.rowcount
|
|
78
|
+
if deleted_rows_count == 0 and not missing_ok:
|
|
79
|
+
raise ValueError(f"No rows found for {where}")
|
|
80
|
+
conn.commit()
|
|
81
|
+
return deleted_rows_count if deleted_rows_count > -1 else None
|
|
82
|
+
|
|
83
|
+
def query_facets(
|
|
84
|
+
self, where: Dict = None, facet_columns: List[str] = None, facet_limit=DEFAULT_FACET_LIMIT, **kwargs
|
|
85
|
+
) -> Dict[str, Dict[str, int]]:
|
|
86
|
+
results = {}
|
|
87
|
+
cd = self.class_definition()
|
|
88
|
+
with self.parent.engine.connect() as conn:
|
|
89
|
+
if not facet_columns:
|
|
90
|
+
facet_columns = list(self.class_definition().attributes.keys())
|
|
91
|
+
for col in facet_columns:
|
|
92
|
+
logger.debug(f"Faceting on {col}")
|
|
93
|
+
if isinstance(col, tuple):
|
|
94
|
+
sd = SlotDefinition(name="PLACEHOLDER")
|
|
95
|
+
else:
|
|
96
|
+
sd = cd.attributes[col]
|
|
97
|
+
facet_query = self._create_query(where_clause=where)
|
|
98
|
+
facet_query_str = facet_count_sql(facet_query, col, multivalued=sd.multivalued)
|
|
99
|
+
logger.debug(f"Facet query: {facet_query_str}")
|
|
100
|
+
rows = list(conn.execute(text(facet_query_str)))
|
|
101
|
+
results[col] = rows
|
|
102
|
+
return results
|
|
103
|
+
|
|
104
|
+
def _sqla_table(self, cd: ClassDefinition) -> Table:
|
|
105
|
+
schema_view = self.parent.schema_view
|
|
106
|
+
metadata_obj = sqla.MetaData()
|
|
107
|
+
cols = []
|
|
108
|
+
for att in schema_view.class_induced_slots(cd.name):
|
|
109
|
+
typ = TMAP.get(att.range, sqla.String)
|
|
110
|
+
if att.inlined:
|
|
111
|
+
typ = sqla.JSON
|
|
112
|
+
if att.multivalued:
|
|
113
|
+
typ = sqla.ARRAY(typ, dimensions=1)
|
|
114
|
+
if att.array:
|
|
115
|
+
typ = sqla.ARRAY(typ, dimensions=1)
|
|
116
|
+
col = Column(att.name, typ)
|
|
117
|
+
cols.append(col)
|
|
118
|
+
t = Table(self.alias, metadata_obj, *cols)
|
|
119
|
+
return t
|
|
120
|
+
|
|
121
|
+
def _create_table(self, cd: ClassDefinition):
|
|
122
|
+
if self._table_created or self.metadata.is_prepopulated:
|
|
123
|
+
logger.info(f"Already have table for: {cd.name}")
|
|
124
|
+
return
|
|
125
|
+
query = Query(
|
|
126
|
+
from_table="information_schema.tables", where_clause={"table_type": "BASE TABLE", "table_name": self.alias}
|
|
127
|
+
)
|
|
128
|
+
qr = self.parent.query(query)
|
|
129
|
+
if qr.num_rows > 0:
|
|
130
|
+
logger.info(f"Table already exists for {cd.name}")
|
|
131
|
+
self._table_created = True
|
|
132
|
+
self.metadata.is_prepopulated = True
|
|
133
|
+
return
|
|
134
|
+
logger.info(f"Creating table for {cd.name}")
|
|
135
|
+
t = self._sqla_table(cd)
|
|
136
|
+
ct = CreateTable(t)
|
|
137
|
+
ddl = str(ct.compile(self.parent.engine))
|
|
138
|
+
with self.parent.engine.connect() as conn:
|
|
139
|
+
conn.execute(text(ddl))
|
|
140
|
+
conn.commit()
|
|
141
|
+
self._table_created = True
|
|
142
|
+
self.metadata.is_prepopulated = True
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from linkml_store.api import Collection, Database
|
|
5
|
+
from linkml_store.api.config import CollectionConfig
|
|
6
|
+
from linkml_store.api.stores.duckdb import DuckDBDatabase
|
|
7
|
+
from linkml_store.api.stores.filesystem.filesystem_collection import FileSystemCollection
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FileSystemDatabase(Database):
|
|
13
|
+
collection_class = FileSystemCollection
|
|
14
|
+
wrapped_database: Database = None
|
|
15
|
+
|
|
16
|
+
def __init__(self, handle: Optional[str] = None, recreate_if_exists: bool = False, **kwargs):
|
|
17
|
+
self.wrapped_database = DuckDBDatabase("duckdb:///:memory:")
|
|
18
|
+
super().__init__(handle=handle, **kwargs)
|
|
19
|
+
|
|
20
|
+
def commit(self, **kwargs):
|
|
21
|
+
# TODO: sync
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def close(self, **kwargs):
|
|
25
|
+
self.wrapped_database.close()
|
|
26
|
+
|
|
27
|
+
def create_collection(
|
|
28
|
+
self,
|
|
29
|
+
name: str,
|
|
30
|
+
alias: Optional[str] = None,
|
|
31
|
+
metadata: Optional[CollectionConfig] = None,
|
|
32
|
+
recreate_if_exists=False,
|
|
33
|
+
**kwargs,
|
|
34
|
+
) -> Collection:
|
|
35
|
+
wd = self.wrapped_database
|
|
36
|
+
wd.create_collection()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Adapter for MongoDB document store.
|
|
3
|
+
|
|
4
|
+
Handles have the form: ``mongodb://<host>:<port>/<database>``
|
|
5
|
+
|
|
6
|
+
To use this, you must have the `pymongo` extra installed.
|
|
7
|
+
|
|
8
|
+
.. code-block:: bash
|
|
9
|
+
|
|
10
|
+
pip install linkml-store[mongodb]
|
|
11
|
+
|
|
12
|
+
or
|
|
13
|
+
|
|
14
|
+
.. code-block:: bash
|
|
15
|
+
|
|
16
|
+
pip install linkml-store[all]
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from linkml_store.api.stores.mongodb.mongodb_collection import MongoDBCollection
|
|
20
|
+
from linkml_store.api.stores.mongodb.mongodb_database import MongoDBDatabase
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"MongoDBCollection",
|
|
24
|
+
"MongoDBDatabase",
|
|
25
|
+
]
|
|
@@ -13,6 +13,14 @@ logger = logging.getLogger(__name__)
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class MongoDBCollection(Collection):
|
|
16
|
+
"""
|
|
17
|
+
Adapter for collections in a MongoDB database.
|
|
18
|
+
|
|
19
|
+
.. note::
|
|
20
|
+
|
|
21
|
+
You should not use or manipulate this class directly.
|
|
22
|
+
Instead, use the general :class:`linkml_store.api.Collection`
|
|
23
|
+
"""
|
|
16
24
|
|
|
17
25
|
@property
|
|
18
26
|
def mongo_collection(self) -> MongoCollection:
|
|
@@ -62,24 +70,31 @@ class MongoDBCollection(Collection):
|
|
|
62
70
|
if isinstance(col, tuple):
|
|
63
71
|
sd = SlotDefinition(name="PLACEHOLDER")
|
|
64
72
|
else:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
73
|
+
if col in cd.attributes:
|
|
74
|
+
sd = cd.attributes[col]
|
|
75
|
+
else:
|
|
76
|
+
logger.info(f"No schema metadata for {col}")
|
|
77
|
+
sd = SlotDefinition(name=col)
|
|
78
|
+
group = {"$group": {"_id": f"${col}", "count": {"$sum": 1}}}
|
|
79
|
+
if isinstance(col, tuple):
|
|
80
|
+
q = {k.replace(".", ""): f"${k}" for k in col}
|
|
81
|
+
group["$group"]["_id"] = q
|
|
82
|
+
if sd and sd.multivalued:
|
|
68
83
|
facet_pipeline = [
|
|
69
84
|
{"$match": where} if where else {"$match": {}},
|
|
70
85
|
{"$unwind": f"${col}"},
|
|
71
|
-
|
|
86
|
+
group,
|
|
72
87
|
{"$sort": {"count": -1}},
|
|
73
88
|
{"$limit": facet_limit},
|
|
74
89
|
]
|
|
75
90
|
else:
|
|
76
91
|
facet_pipeline = [
|
|
77
92
|
{"$match": where} if where else {"$match": {}},
|
|
78
|
-
|
|
93
|
+
group,
|
|
79
94
|
{"$sort": {"count": -1}},
|
|
80
95
|
{"$limit": facet_limit},
|
|
81
96
|
]
|
|
82
|
-
|
|
97
|
+
logger.info(f"Facet pipeline: {facet_pipeline}")
|
|
83
98
|
facet_results = list(self.mongo_collection.aggregate(facet_pipeline))
|
|
84
99
|
results[col] = [(result["_id"], result["count"]) for result in facet_results]
|
|
85
100
|
|
linkml_store/cli.py
CHANGED
|
@@ -11,12 +11,19 @@ from pydantic import BaseModel
|
|
|
11
11
|
from linkml_store import Client
|
|
12
12
|
from linkml_store.api import Collection, Database
|
|
13
13
|
from linkml_store.api.queries import Query
|
|
14
|
+
from linkml_store.index import get_indexer
|
|
14
15
|
from linkml_store.index.implementations.simple_indexer import SimpleIndexer
|
|
15
16
|
from linkml_store.index.indexer import Indexer
|
|
16
|
-
from linkml_store.utils.format_utils import Format, load_objects, render_output
|
|
17
|
+
from linkml_store.utils.format_utils import Format, guess_format, load_objects, render_output
|
|
17
18
|
from linkml_store.utils.object_utils import object_path_update
|
|
18
19
|
|
|
19
|
-
index_type_option = click.option(
|
|
20
|
+
index_type_option = click.option(
|
|
21
|
+
"--index-type",
|
|
22
|
+
"-t",
|
|
23
|
+
default="simple",
|
|
24
|
+
show_default=True,
|
|
25
|
+
help="Type of index to create. Values: simple, llm",
|
|
26
|
+
)
|
|
20
27
|
|
|
21
28
|
logger = logging.getLogger(__name__)
|
|
22
29
|
|
|
@@ -70,6 +77,9 @@ class ContextSettings(BaseModel):
|
|
|
70
77
|
format_choice = click.Choice([f.value for f in Format])
|
|
71
78
|
|
|
72
79
|
|
|
80
|
+
include_internal_option = click.option("--include-internal/--no-include-internal", default=False, show_default=True)
|
|
81
|
+
|
|
82
|
+
|
|
73
83
|
@click.group()
|
|
74
84
|
@click.option("--database", "-d", help="Database name")
|
|
75
85
|
@click.option("--collection", "-c", help="Collection name")
|
|
@@ -89,6 +99,15 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
|
|
|
89
99
|
if not stacktrace:
|
|
90
100
|
sys.tracebacklimit = 0
|
|
91
101
|
logger = logging.getLogger()
|
|
102
|
+
# Set handler for the root logger to output to the console
|
|
103
|
+
console_handler = logging.StreamHandler()
|
|
104
|
+
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
|
105
|
+
|
|
106
|
+
# Clear existing handlers to avoid duplicate messages if function runs multiple times
|
|
107
|
+
logger.handlers = []
|
|
108
|
+
|
|
109
|
+
# Add the newly created console handler to the logger
|
|
110
|
+
logger.addHandler(console_handler)
|
|
92
111
|
if verbose >= 2:
|
|
93
112
|
logger.setLevel(logging.DEBUG)
|
|
94
113
|
elif verbose == 1:
|
|
@@ -193,6 +212,35 @@ def store(ctx, files, object, format):
|
|
|
193
212
|
click.echo(f"Inserted {len(objects)} objects from {object_str} into collection '{db.name}'.")
|
|
194
213
|
|
|
195
214
|
|
|
215
|
+
@cli.command(name="import")
|
|
216
|
+
@click.argument("files", type=click.Path(exists=True), nargs=-1)
|
|
217
|
+
@click.option("--format", "-f", help="Input format")
|
|
218
|
+
@click.pass_context
|
|
219
|
+
def import_database(ctx, files, format):
|
|
220
|
+
"""Imports a database from a dump."""
|
|
221
|
+
settings = ctx.obj["settings"]
|
|
222
|
+
db = settings.database
|
|
223
|
+
if not files and not object:
|
|
224
|
+
files = ["-"]
|
|
225
|
+
for file_path in files:
|
|
226
|
+
db.import_database(file_path, source_format=format)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@cli.command()
|
|
230
|
+
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
231
|
+
@click.option("--output", "-o", required=True, type=click.Path(), help="Output file path")
|
|
232
|
+
@click.pass_context
|
|
233
|
+
def export(ctx, output_type, output):
|
|
234
|
+
"""Exports a database to a dump."""
|
|
235
|
+
settings = ctx.obj["settings"]
|
|
236
|
+
db = settings.database
|
|
237
|
+
if output_type is None:
|
|
238
|
+
output_type = guess_format(output)
|
|
239
|
+
if output_type is None:
|
|
240
|
+
raise ValueError(f"Output format must be specified can't be inferred from {output}.")
|
|
241
|
+
db.export_database(output, target_format=output_type)
|
|
242
|
+
|
|
243
|
+
|
|
196
244
|
@cli.command()
|
|
197
245
|
@click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
|
|
198
246
|
@click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
|
|
@@ -216,9 +264,10 @@ def query(ctx, where, limit, output_type, output):
|
|
|
216
264
|
|
|
217
265
|
@cli.command()
|
|
218
266
|
@click.pass_context
|
|
219
|
-
|
|
267
|
+
@include_internal_option
|
|
268
|
+
def list_collections(ctx, **kwargs):
|
|
220
269
|
db = ctx.obj["settings"].database
|
|
221
|
-
for collection in db.list_collections():
|
|
270
|
+
for collection in db.list_collections(**kwargs):
|
|
222
271
|
click.echo(collection.name)
|
|
223
272
|
click.echo(render_output(collection.metadata))
|
|
224
273
|
|
|
@@ -254,7 +303,7 @@ def fq(ctx, where, limit, columns, output_type, output):
|
|
|
254
303
|
|
|
255
304
|
def _untuple(key):
|
|
256
305
|
if isinstance(key, tuple):
|
|
257
|
-
return "+".join(key)
|
|
306
|
+
return "+".join([str(x) for x in key])
|
|
258
307
|
return key
|
|
259
308
|
|
|
260
309
|
count_dict = {}
|
|
@@ -279,8 +328,10 @@ def _get_index(index_type=None, **kwargs) -> Indexer:
|
|
|
279
328
|
|
|
280
329
|
@cli.command()
|
|
281
330
|
@index_type_option
|
|
331
|
+
@click.option("--cached-embeddings-database", "-E", help="Path to the database where embeddings are cached")
|
|
332
|
+
@click.option("--text-template", "-T", help="Template for text embeddings")
|
|
282
333
|
@click.pass_context
|
|
283
|
-
def index(ctx, index_type):
|
|
334
|
+
def index(ctx, index_type, **kwargs):
|
|
284
335
|
"""
|
|
285
336
|
Create an index over a collection.
|
|
286
337
|
|
|
@@ -289,7 +340,7 @@ def index(ctx, index_type):
|
|
|
289
340
|
:return:
|
|
290
341
|
"""
|
|
291
342
|
collection = ctx.obj["settings"].collection
|
|
292
|
-
ix =
|
|
343
|
+
ix = get_indexer(index_type, **kwargs)
|
|
293
344
|
collection.attach_indexer(ix)
|
|
294
345
|
|
|
295
346
|
|
|
@@ -322,14 +373,17 @@ def schema(ctx, output_type, output):
|
|
|
322
373
|
@click.option("--limit", "-l", type=click.INT, help="Maximum number of search results")
|
|
323
374
|
@click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
|
|
324
375
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
376
|
+
@click.option(
|
|
377
|
+
"--auto-index/--no-auto-index", default=False, show_default=True, help="Automatically index the collection"
|
|
378
|
+
)
|
|
325
379
|
@index_type_option
|
|
326
380
|
@click.pass_context
|
|
327
|
-
def search(ctx, search_term, where, limit, index_type, output_type, output):
|
|
381
|
+
def search(ctx, search_term, where, limit, index_type, output_type, output, auto_index):
|
|
328
382
|
"""Search objects in the specified collection."""
|
|
329
383
|
collection = ctx.obj["settings"].collection
|
|
330
|
-
ix =
|
|
384
|
+
ix = get_indexer(index_type)
|
|
331
385
|
logger.info(f"Attaching index to collection {collection.name}: {ix.model_dump()}")
|
|
332
|
-
collection.attach_indexer(ix, auto_index=
|
|
386
|
+
collection.attach_indexer(ix, auto_index=auto_index)
|
|
333
387
|
result = collection.search(search_term, where=where, limit=limit)
|
|
334
388
|
output_data = render_output([{"score": row[0], **row[1]} for row in result.ranked_rows], output_type)
|
|
335
389
|
if output:
|
linkml_store/index/__init__.py
CHANGED
|
@@ -22,7 +22,7 @@ def get_indexer_class(name: str) -> Type[Indexer]:
|
|
|
22
22
|
return INDEXER_CLASSES[name]
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def get_indexer(name: str,
|
|
25
|
+
def get_indexer(name: str, **kwargs) -> Indexer:
|
|
26
26
|
"""
|
|
27
27
|
Get an indexer by name.
|
|
28
28
|
|
|
@@ -30,4 +30,8 @@ def get_indexer(name: str, *args, **kwargs) -> Indexer:
|
|
|
30
30
|
:param kwargs: additional arguments to pass to the indexer
|
|
31
31
|
:return: the indexer
|
|
32
32
|
"""
|
|
33
|
-
|
|
33
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
34
|
+
cls = get_indexer_class(name)
|
|
35
|
+
kwargs["name"] = name
|
|
36
|
+
indexer = cls(**kwargs)
|
|
37
|
+
return indexer
|
|
@@ -1,20 +1,34 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
1
3
|
from typing import TYPE_CHECKING, List
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
6
|
|
|
7
|
+
from linkml_store.api.config import CollectionConfig
|
|
5
8
|
from linkml_store.index.indexer import INDEX_ITEM, Indexer
|
|
6
9
|
|
|
7
10
|
if TYPE_CHECKING:
|
|
8
11
|
import llm
|
|
9
12
|
|
|
10
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
11
17
|
class LLMIndexer(Indexer):
|
|
12
18
|
"""
|
|
13
|
-
|
|
19
|
+
An indexer that wraps the llm library.
|
|
20
|
+
|
|
21
|
+
This indexer is used to convert text to vectors using the llm library.
|
|
22
|
+
|
|
23
|
+
>>> indexer = LLMIndexer(cached_embeddings_database="tests/input/llm_cache.db")
|
|
24
|
+
>>> vector = indexer.text_to_vector("hello")
|
|
14
25
|
"""
|
|
15
26
|
|
|
16
27
|
embedding_model_name: str = "ada-002"
|
|
17
28
|
_embedding_model: "llm.EmbeddingModel" = None
|
|
29
|
+
cached_embeddings_database: str = None
|
|
30
|
+
cached_embeddings_collection: str = None
|
|
31
|
+
cache_queries: bool = False
|
|
18
32
|
|
|
19
33
|
@property
|
|
20
34
|
def embedding_model(self):
|
|
@@ -24,21 +38,85 @@ class LLMIndexer(Indexer):
|
|
|
24
38
|
self._embedding_model = llm.get_embedding_model(self.embedding_model_name)
|
|
25
39
|
return self._embedding_model
|
|
26
40
|
|
|
27
|
-
def text_to_vector(self, text: str) -> INDEX_ITEM:
|
|
41
|
+
def text_to_vector(self, text: str, cache: bool = None, **kwargs) -> INDEX_ITEM:
|
|
28
42
|
"""
|
|
29
43
|
Convert a text to an indexable object
|
|
30
44
|
|
|
45
|
+
>>> indexer = LLMIndexer(cached_embeddings_database="tests/input/llm_cache.db")
|
|
46
|
+
>>> vector = indexer.text_to_vector("hello")
|
|
47
|
+
|
|
31
48
|
:param text:
|
|
32
49
|
:return:
|
|
33
50
|
"""
|
|
34
|
-
return self.texts_to_vectors([text])[0]
|
|
51
|
+
return self.texts_to_vectors([text], cache=cache, **kwargs)[0]
|
|
35
52
|
|
|
36
|
-
def texts_to_vectors(self, texts: List[str]) -> List[INDEX_ITEM]:
|
|
53
|
+
def texts_to_vectors(self, texts: List[str], cache: bool = None, **kwargs) -> List[INDEX_ITEM]:
|
|
37
54
|
"""
|
|
38
55
|
Use LLM to embed
|
|
39
56
|
|
|
57
|
+
>>> indexer = LLMIndexer(cached_embeddings_database="tests/input/llm_cache.db")
|
|
58
|
+
>>> vectors = indexer.texts_to_vectors(["hello", "goodbye"])
|
|
59
|
+
|
|
40
60
|
:param texts:
|
|
41
61
|
:return:
|
|
42
62
|
"""
|
|
43
|
-
|
|
63
|
+
logging.info(f"Converting {len(texts)} texts to vectors")
|
|
64
|
+
model = self.embedding_model
|
|
65
|
+
if self.cached_embeddings_database and (cache is None or cache or self.cache_queries):
|
|
66
|
+
model_id = model.model_id
|
|
67
|
+
if not model_id:
|
|
68
|
+
raise ValueError("Model ID is required to cache embeddings")
|
|
69
|
+
db_path = Path(self.cached_embeddings_database)
|
|
70
|
+
coll_name = self.cached_embeddings_collection
|
|
71
|
+
if not coll_name:
|
|
72
|
+
coll_name = "all_embeddings"
|
|
73
|
+
from linkml_store import Client
|
|
74
|
+
|
|
75
|
+
embeddings_client = Client()
|
|
76
|
+
config = CollectionConfig(
|
|
77
|
+
name=coll_name,
|
|
78
|
+
type="Embeddings",
|
|
79
|
+
attributes={
|
|
80
|
+
"text": {"range": "string"},
|
|
81
|
+
"model_id": {"range": "string"},
|
|
82
|
+
"embedding": {"range": "float", "array": {}},
|
|
83
|
+
},
|
|
84
|
+
)
|
|
85
|
+
embeddings_db = embeddings_client.get_database(f"duckdb:///{db_path}")
|
|
86
|
+
if coll_name in embeddings_db.list_collection_names():
|
|
87
|
+
# Load existing collection and use its model
|
|
88
|
+
embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
|
|
89
|
+
else:
|
|
90
|
+
embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
|
|
91
|
+
texts = list(texts)
|
|
92
|
+
embeddings = list([None] * len(texts))
|
|
93
|
+
uncached_texts = []
|
|
94
|
+
n = 0
|
|
95
|
+
for i in range(len(texts)):
|
|
96
|
+
# TODO: optimize this
|
|
97
|
+
text = texts[i]
|
|
98
|
+
logger.info(f"Looking for cached embedding for {text}")
|
|
99
|
+
r = embeddings_collection.find({"text": text, "model_id": model_id})
|
|
100
|
+
if r.num_rows:
|
|
101
|
+
embeddings[i] = r.rows[0]["embedding"]
|
|
102
|
+
n += 1
|
|
103
|
+
logger.info("Found")
|
|
104
|
+
else:
|
|
105
|
+
uncached_texts.append((text, i))
|
|
106
|
+
logger.info("NOT Found")
|
|
107
|
+
logger.info(f"Found {n} cached embeddings")
|
|
108
|
+
if uncached_texts:
|
|
109
|
+
logger.info(f"Embedding {len(uncached_texts)} uncached texts")
|
|
110
|
+
uncached_texts, uncached_indices = zip(*uncached_texts)
|
|
111
|
+
uncached_embeddings = list(model.embed_multi(uncached_texts))
|
|
112
|
+
# TODO: combine into a single insert with multiple rows
|
|
113
|
+
for i, index in enumerate(uncached_indices):
|
|
114
|
+
logger.debug(f"Indexing text at {i}")
|
|
115
|
+
embeddings[index] = uncached_embeddings[i]
|
|
116
|
+
embeddings_collection.insert(
|
|
117
|
+
{"text": uncached_texts[i], "embedding": embeddings[index], "model_id": model_id}
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
logger.info(f"Embedding {len(texts)} texts")
|
|
121
|
+
embeddings = model.embed_multi(texts)
|
|
44
122
|
return [np.array(v, dtype=float) for v in embeddings]
|
|
@@ -15,7 +15,7 @@ class SimpleIndexer(Indexer):
|
|
|
15
15
|
This uses a naive method to generate an index from text. It is not suitable for production use.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
def text_to_vector(self, text: str) -> INDEX_ITEM:
|
|
18
|
+
def text_to_vector(self, text: str, cache: bool = None, **kwargs) -> INDEX_ITEM:
|
|
19
19
|
"""
|
|
20
20
|
This is a naive method purely for testing
|
|
21
21
|
|
|
@@ -39,5 +39,5 @@ class SimpleIndexer(Indexer):
|
|
|
39
39
|
|
|
40
40
|
# Increment the count at the computed index
|
|
41
41
|
vector[index] += 1.0
|
|
42
|
-
logger.
|
|
42
|
+
logger.debug(f"Indexed text: {text} as {vector}")
|
|
43
43
|
return vector
|
linkml_store/index/indexer.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
1
3
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -5,6 +7,13 @@ from pydantic import BaseModel
|
|
|
5
7
|
|
|
6
8
|
INDEX_ITEM = np.ndarray
|
|
7
9
|
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TemplateSyntaxEnum(str, Enum):
|
|
14
|
+
jinja2 = "jinja2"
|
|
15
|
+
fstring = "fstring"
|
|
16
|
+
|
|
8
17
|
|
|
9
18
|
def cosine_similarity(vector1, vector2):
|
|
10
19
|
dot_product = np.dot(vector1, vector2)
|
|
@@ -21,8 +30,9 @@ class Indexer(BaseModel):
|
|
|
21
30
|
name: Optional[str] = None
|
|
22
31
|
index_function: Optional[Callable] = None
|
|
23
32
|
distance_function: Optional[Callable] = None
|
|
24
|
-
index_attributes: Optional[str] = None
|
|
33
|
+
index_attributes: Optional[List[str]] = None
|
|
25
34
|
text_template: Optional[str] = None
|
|
35
|
+
text_template_syntax: Optional[TemplateSyntaxEnum] = None
|
|
26
36
|
filter_nulls: Optional[bool] = True
|
|
27
37
|
vector_default_length: Optional[int] = 1000
|
|
28
38
|
index_field: Optional[str] = "__index__"
|
|
@@ -41,24 +51,25 @@ class Indexer(BaseModel):
|
|
|
41
51
|
Convert a list of objects to indexable objects
|
|
42
52
|
|
|
43
53
|
:param objs:
|
|
44
|
-
:return:
|
|
54
|
+
:return: list of vectors
|
|
45
55
|
"""
|
|
46
|
-
return [self.
|
|
56
|
+
return self.texts_to_vectors([self.object_to_text(obj) for obj in objs])
|
|
47
57
|
|
|
48
|
-
def texts_to_vectors(self, texts: List[str]) -> List[INDEX_ITEM]:
|
|
58
|
+
def texts_to_vectors(self, texts: List[str], cache: bool = None, **kwargs) -> List[INDEX_ITEM]:
|
|
49
59
|
"""
|
|
50
60
|
Convert a list of texts to indexable objects
|
|
51
61
|
|
|
52
62
|
:param texts:
|
|
53
63
|
:return:
|
|
54
64
|
"""
|
|
55
|
-
return [self.text_to_vector(text) for text in texts]
|
|
65
|
+
return [self.text_to_vector(text, cache=cache, **kwargs) for text in texts]
|
|
56
66
|
|
|
57
|
-
def text_to_vector(self, text: str) -> INDEX_ITEM:
|
|
67
|
+
def text_to_vector(self, text: str, cache: bool = None, **kwargs) -> INDEX_ITEM:
|
|
58
68
|
"""
|
|
59
69
|
Convert a text to an indexable object
|
|
60
70
|
|
|
61
71
|
:param text:
|
|
72
|
+
:param cache:
|
|
62
73
|
:return:
|
|
63
74
|
"""
|
|
64
75
|
raise NotImplementedError
|
|
@@ -71,11 +82,24 @@ class Indexer(BaseModel):
|
|
|
71
82
|
:return:
|
|
72
83
|
"""
|
|
73
84
|
if self.index_attributes:
|
|
85
|
+
if len(self.index_attributes) == 1 and not self.text_template:
|
|
86
|
+
return str(obj[self.index_attributes[0]])
|
|
74
87
|
obj = {k: v for k, v in obj.items() if k in self.index_attributes}
|
|
75
88
|
if self.filter_nulls:
|
|
76
89
|
obj = {k: v for k, v in obj.items() if v is not None}
|
|
77
90
|
if self.text_template:
|
|
78
|
-
|
|
91
|
+
syntax = self.text_template_syntax
|
|
92
|
+
if not syntax:
|
|
93
|
+
if "{%" in self.text_template or "{{" in self.text_template:
|
|
94
|
+
logger.info("Detected Jinja2 syntax in text template")
|
|
95
|
+
syntax = TemplateSyntaxEnum.jinja2
|
|
96
|
+
if syntax and syntax == TemplateSyntaxEnum.jinja2:
|
|
97
|
+
from jinja2 import Template
|
|
98
|
+
|
|
99
|
+
template = Template(self.text_template)
|
|
100
|
+
return template.render(**obj)
|
|
101
|
+
else:
|
|
102
|
+
return self.text_template.format(**obj)
|
|
79
103
|
return str(obj)
|
|
80
104
|
|
|
81
105
|
def search(
|
|
@@ -91,7 +115,7 @@ class Indexer(BaseModel):
|
|
|
91
115
|
"""
|
|
92
116
|
|
|
93
117
|
# Convert the query string to a vector
|
|
94
|
-
query_vector = self.text_to_vector(query)
|
|
118
|
+
query_vector = self.text_to_vector(query, cache=False)
|
|
95
119
|
|
|
96
120
|
distances = []
|
|
97
121
|
|