linkml-store 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

@@ -0,0 +1,142 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import sqlalchemy as sqla
5
+ from linkml_runtime.linkml_model import ClassDefinition, SlotDefinition
6
+ from sqlalchemy import Column, Table, delete, insert, inspect, text
7
+ from sqlalchemy.sql.ddl import CreateTable
8
+
9
+ from linkml_store.api import Collection
10
+ from linkml_store.api.collection import DEFAULT_FACET_LIMIT, OBJECT
11
+ from linkml_store.api.queries import Query
12
+ from linkml_store.api.stores.duckdb.mappings import TMAP
13
+ from linkml_store.utils.sql_utils import facet_count_sql
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class FileSystemCollection(Collection):
19
+ _table_created: bool = None
20
+
21
+ def insert(self, objs: Union[OBJECT, List[OBJECT]], **kwargs):
22
+ if not isinstance(objs, list):
23
+ objs = [objs]
24
+ if not objs:
25
+ return
26
+ cd = self.class_definition()
27
+ if not cd:
28
+ cd = self.induce_class_definition_from_objects(objs)
29
+ self._create_table(cd)
30
+ table = self._sqla_table(cd)
31
+ logger.info(f"Inserting into: {self.alias} // T={table.name}")
32
+ engine = self.parent.engine
33
+ col_names = [c.name for c in table.columns]
34
+ objs = [{k: obj.get(k, None) for k in col_names} for obj in objs]
35
+ with engine.connect() as conn:
36
+ with conn.begin():
37
+ conn.execute(insert(table), objs)
38
+ conn.commit()
39
+
40
+ def delete(self, objs: Union[OBJECT, List[OBJECT]], **kwargs) -> Optional[int]:
41
+ if not isinstance(objs, list):
42
+ objs = [objs]
43
+ cd = self.class_definition()
44
+ if not cd:
45
+ cd = self.induce_class_definition_from_objects(objs)
46
+ table = self._sqla_table(cd)
47
+ engine = self.parent.engine
48
+ with engine.connect() as conn:
49
+ for obj in objs:
50
+ conditions = [table.c[k] == v for k, v in obj.items() if k in cd.attributes]
51
+ stmt = delete(table).where(*conditions)
52
+ stmt = stmt.compile(engine)
53
+ conn.execute(stmt)
54
+ conn.commit()
55
+ return
56
+
57
+ def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True, **kwargs) -> Optional[int]:
58
+ logger.info(f"Deleting from {self.target_class_name} where: {where}")
59
+ if where is None:
60
+ where = {}
61
+ cd = self.class_definition()
62
+ if not cd:
63
+ logger.info(f"No class definition found for {self.target_class_name}, assuming not prepopulated")
64
+ return 0
65
+ table = self._sqla_table(cd)
66
+ engine = self.parent.engine
67
+ inspector = inspect(engine)
68
+ table_exists = table.name in inspector.get_table_names()
69
+ if not table_exists:
70
+ logger.info(f"Table {table.name} does not exist, assuming no data")
71
+ return 0
72
+ with engine.connect() as conn:
73
+ conditions = [table.c[k] == v for k, v in where.items()]
74
+ stmt = delete(table).where(*conditions)
75
+ stmt = stmt.compile(engine)
76
+ result = conn.execute(stmt)
77
+ deleted_rows_count = result.rowcount
78
+ if deleted_rows_count == 0 and not missing_ok:
79
+ raise ValueError(f"No rows found for {where}")
80
+ conn.commit()
81
+ return deleted_rows_count if deleted_rows_count > -1 else None
82
+
83
+ def query_facets(
84
+ self, where: Dict = None, facet_columns: List[str] = None, facet_limit=DEFAULT_FACET_LIMIT, **kwargs
85
+ ) -> Dict[str, Dict[str, int]]:
86
+ results = {}
87
+ cd = self.class_definition()
88
+ with self.parent.engine.connect() as conn:
89
+ if not facet_columns:
90
+ facet_columns = list(self.class_definition().attributes.keys())
91
+ for col in facet_columns:
92
+ logger.debug(f"Faceting on {col}")
93
+ if isinstance(col, tuple):
94
+ sd = SlotDefinition(name="PLACEHOLDER")
95
+ else:
96
+ sd = cd.attributes[col]
97
+ facet_query = self._create_query(where_clause=where)
98
+ facet_query_str = facet_count_sql(facet_query, col, multivalued=sd.multivalued)
99
+ logger.debug(f"Facet query: {facet_query_str}")
100
+ rows = list(conn.execute(text(facet_query_str)))
101
+ results[col] = rows
102
+ return results
103
+
104
+ def _sqla_table(self, cd: ClassDefinition) -> Table:
105
+ schema_view = self.parent.schema_view
106
+ metadata_obj = sqla.MetaData()
107
+ cols = []
108
+ for att in schema_view.class_induced_slots(cd.name):
109
+ typ = TMAP.get(att.range, sqla.String)
110
+ if att.inlined:
111
+ typ = sqla.JSON
112
+ if att.multivalued:
113
+ typ = sqla.ARRAY(typ, dimensions=1)
114
+ if att.array:
115
+ typ = sqla.ARRAY(typ, dimensions=1)
116
+ col = Column(att.name, typ)
117
+ cols.append(col)
118
+ t = Table(self.alias, metadata_obj, *cols)
119
+ return t
120
+
121
+ def _create_table(self, cd: ClassDefinition):
122
+ if self._table_created or self.metadata.is_prepopulated:
123
+ logger.info(f"Already have table for: {cd.name}")
124
+ return
125
+ query = Query(
126
+ from_table="information_schema.tables", where_clause={"table_type": "BASE TABLE", "table_name": self.alias}
127
+ )
128
+ qr = self.parent.query(query)
129
+ if qr.num_rows > 0:
130
+ logger.info(f"Table already exists for {cd.name}")
131
+ self._table_created = True
132
+ self.metadata.is_prepopulated = True
133
+ return
134
+ logger.info(f"Creating table for {cd.name}")
135
+ t = self._sqla_table(cd)
136
+ ct = CreateTable(t)
137
+ ddl = str(ct.compile(self.parent.engine))
138
+ with self.parent.engine.connect() as conn:
139
+ conn.execute(text(ddl))
140
+ conn.commit()
141
+ self._table_created = True
142
+ self.metadata.is_prepopulated = True
@@ -0,0 +1,36 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from linkml_store.api import Collection, Database
5
+ from linkml_store.api.config import CollectionConfig
6
+ from linkml_store.api.stores.duckdb import DuckDBDatabase
7
+ from linkml_store.api.stores.filesystem.filesystem_collection import FileSystemCollection
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class FileSystemDatabase(Database):
13
+ collection_class = FileSystemCollection
14
+ wrapped_database: Database = None
15
+
16
+ def __init__(self, handle: Optional[str] = None, recreate_if_exists: bool = False, **kwargs):
17
+ self.wrapped_database = DuckDBDatabase("duckdb:///:memory:")
18
+ super().__init__(handle=handle, **kwargs)
19
+
20
+ def commit(self, **kwargs):
21
+ # TODO: sync
22
+ pass
23
+
24
+ def close(self, **kwargs):
25
+ self.wrapped_database.close()
26
+
27
+ def create_collection(
28
+ self,
29
+ name: str,
30
+ alias: Optional[str] = None,
31
+ metadata: Optional[CollectionConfig] = None,
32
+ recreate_if_exists=False,
33
+ **kwargs,
34
+ ) -> Collection:
35
+ wd = self.wrapped_database
36
+ wd.create_collection()
@@ -0,0 +1,7 @@
1
+ """
2
+ Adapter for HDF5 file storage.
3
+
4
+ .. warning::
5
+
6
+ Experimental support for HDF5 storage.
7
+ """
@@ -0,0 +1,25 @@
1
+ """
2
+ Adapter for MongoDB document store.
3
+
4
+ Handles have the form: ``mongodb://<host>:<port>/<database>``
5
+
6
+ To use this, you must have the `pymongo` extra installed.
7
+
8
+ .. code-block:: bash
9
+
10
+ pip install linkml-store[mongodb]
11
+
12
+ or
13
+
14
+ .. code-block:: bash
15
+
16
+ pip install linkml-store[all]
17
+ """
18
+
19
+ from linkml_store.api.stores.mongodb.mongodb_collection import MongoDBCollection
20
+ from linkml_store.api.stores.mongodb.mongodb_database import MongoDBDatabase
21
+
22
+ __all__ = [
23
+ "MongoDBCollection",
24
+ "MongoDBDatabase",
25
+ ]
@@ -13,6 +13,14 @@ logger = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  class MongoDBCollection(Collection):
16
+ """
17
+ Adapter for collections in a MongoDB database.
18
+
19
+ .. note::
20
+
21
+ You should not use or manipulate this class directly.
22
+ Instead, use the general :class:`linkml_store.api.Collection`
23
+ """
16
24
 
17
25
  @property
18
26
  def mongo_collection(self) -> MongoCollection:
@@ -62,24 +70,31 @@ class MongoDBCollection(Collection):
62
70
  if isinstance(col, tuple):
63
71
  sd = SlotDefinition(name="PLACEHOLDER")
64
72
  else:
65
- sd = cd.attributes[col]
66
-
67
- if sd.multivalued:
73
+ if col in cd.attributes:
74
+ sd = cd.attributes[col]
75
+ else:
76
+ logger.info(f"No schema metadata for {col}")
77
+ sd = SlotDefinition(name=col)
78
+ group = {"$group": {"_id": f"${col}", "count": {"$sum": 1}}}
79
+ if isinstance(col, tuple):
80
+ q = {k.replace(".", ""): f"${k}" for k in col}
81
+ group["$group"]["_id"] = q
82
+ if sd and sd.multivalued:
68
83
  facet_pipeline = [
69
84
  {"$match": where} if where else {"$match": {}},
70
85
  {"$unwind": f"${col}"},
71
- {"$group": {"_id": f"${col}", "count": {"$sum": 1}}},
86
+ group,
72
87
  {"$sort": {"count": -1}},
73
88
  {"$limit": facet_limit},
74
89
  ]
75
90
  else:
76
91
  facet_pipeline = [
77
92
  {"$match": where} if where else {"$match": {}},
78
- {"$group": {"_id": f"${col}", "count": {"$sum": 1}}},
93
+ group,
79
94
  {"$sort": {"count": -1}},
80
95
  {"$limit": facet_limit},
81
96
  ]
82
-
97
+ logger.info(f"Facet pipeline: {facet_pipeline}")
83
98
  facet_results = list(self.mongo_collection.aggregate(facet_pipeline))
84
99
  results[col] = [(result["_id"], result["count"]) for result in facet_results]
85
100
 
linkml_store/cli.py CHANGED
@@ -11,12 +11,19 @@ from pydantic import BaseModel
11
11
  from linkml_store import Client
12
12
  from linkml_store.api import Collection, Database
13
13
  from linkml_store.api.queries import Query
14
+ from linkml_store.index import get_indexer
14
15
  from linkml_store.index.implementations.simple_indexer import SimpleIndexer
15
16
  from linkml_store.index.indexer import Indexer
16
- from linkml_store.utils.format_utils import Format, load_objects, render_output
17
+ from linkml_store.utils.format_utils import Format, guess_format, load_objects, render_output
17
18
  from linkml_store.utils.object_utils import object_path_update
18
19
 
19
- index_type_option = click.option("--index-type", "-t")
20
+ index_type_option = click.option(
21
+ "--index-type",
22
+ "-t",
23
+ default="simple",
24
+ show_default=True,
25
+ help="Type of index to create. Values: simple, llm",
26
+ )
20
27
 
21
28
  logger = logging.getLogger(__name__)
22
29
 
@@ -70,6 +77,9 @@ class ContextSettings(BaseModel):
70
77
  format_choice = click.Choice([f.value for f in Format])
71
78
 
72
79
 
80
+ include_internal_option = click.option("--include-internal/--no-include-internal", default=False, show_default=True)
81
+
82
+
73
83
  @click.group()
74
84
  @click.option("--database", "-d", help="Database name")
75
85
  @click.option("--collection", "-c", help="Collection name")
@@ -89,6 +99,15 @@ def cli(ctx, verbose: int, quiet: bool, stacktrace: bool, database, collection,
89
99
  if not stacktrace:
90
100
  sys.tracebacklimit = 0
91
101
  logger = logging.getLogger()
102
+ # Set handler for the root logger to output to the console
103
+ console_handler = logging.StreamHandler()
104
+ console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
105
+
106
+ # Clear existing handlers to avoid duplicate messages if function runs multiple times
107
+ logger.handlers = []
108
+
109
+ # Add the newly created console handler to the logger
110
+ logger.addHandler(console_handler)
92
111
  if verbose >= 2:
93
112
  logger.setLevel(logging.DEBUG)
94
113
  elif verbose == 1:
@@ -193,6 +212,35 @@ def store(ctx, files, object, format):
193
212
  click.echo(f"Inserted {len(objects)} objects from {object_str} into collection '{db.name}'.")
194
213
 
195
214
 
215
+ @cli.command(name="import")
216
+ @click.argument("files", type=click.Path(exists=True), nargs=-1)
217
+ @click.option("--format", "-f", help="Input format")
218
+ @click.pass_context
219
+ def import_database(ctx, files, format):
220
+ """Imports a database from a dump."""
221
+ settings = ctx.obj["settings"]
222
+ db = settings.database
223
+ if not files and not object:
224
+ files = ["-"]
225
+ for file_path in files:
226
+ db.import_database(file_path, source_format=format)
227
+
228
+
229
+ @cli.command()
230
+ @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
231
+ @click.option("--output", "-o", required=True, type=click.Path(), help="Output file path")
232
+ @click.pass_context
233
+ def export(ctx, output_type, output):
234
+ """Exports a database to a dump."""
235
+ settings = ctx.obj["settings"]
236
+ db = settings.database
237
+ if output_type is None:
238
+ output_type = guess_format(output)
239
+ if output_type is None:
240
+ raise ValueError(f"Output format must be specified can't be inferred from {output}.")
241
+ db.export_database(output, target_format=output_type)
242
+
243
+
196
244
  @cli.command()
197
245
  @click.option("--where", "-w", type=click.STRING, help="WHERE clause for the query")
198
246
  @click.option("--limit", "-l", type=click.INT, help="Maximum number of results to return")
@@ -216,9 +264,10 @@ def query(ctx, where, limit, output_type, output):
216
264
 
217
265
  @cli.command()
218
266
  @click.pass_context
219
- def list_collections(ctx):
267
+ @include_internal_option
268
+ def list_collections(ctx, **kwargs):
220
269
  db = ctx.obj["settings"].database
221
- for collection in db.list_collections():
270
+ for collection in db.list_collections(**kwargs):
222
271
  click.echo(collection.name)
223
272
  click.echo(render_output(collection.metadata))
224
273
 
@@ -254,7 +303,7 @@ def fq(ctx, where, limit, columns, output_type, output):
254
303
 
255
304
  def _untuple(key):
256
305
  if isinstance(key, tuple):
257
- return "+".join(key)
306
+ return "+".join([str(x) for x in key])
258
307
  return key
259
308
 
260
309
  count_dict = {}
@@ -279,8 +328,10 @@ def _get_index(index_type=None, **kwargs) -> Indexer:
279
328
 
280
329
  @cli.command()
281
330
  @index_type_option
331
+ @click.option("--cached-embeddings-database", "-E", help="Path to the database where embeddings are cached")
332
+ @click.option("--text-template", "-T", help="Template for text embeddings")
282
333
  @click.pass_context
283
- def index(ctx, index_type):
334
+ def index(ctx, index_type, **kwargs):
284
335
  """
285
336
  Create an index over a collection.
286
337
 
@@ -289,7 +340,7 @@ def index(ctx, index_type):
289
340
  :return:
290
341
  """
291
342
  collection = ctx.obj["settings"].collection
292
- ix = _get_index(index_type)
343
+ ix = get_indexer(index_type, **kwargs)
293
344
  collection.attach_indexer(ix)
294
345
 
295
346
 
@@ -322,14 +373,17 @@ def schema(ctx, output_type, output):
322
373
  @click.option("--limit", "-l", type=click.INT, help="Maximum number of search results")
323
374
  @click.option("--output-type", "-O", type=format_choice, default="json", help="Output format")
324
375
  @click.option("--output", "-o", type=click.Path(), help="Output file path")
376
+ @click.option(
377
+ "--auto-index/--no-auto-index", default=False, show_default=True, help="Automatically index the collection"
378
+ )
325
379
  @index_type_option
326
380
  @click.pass_context
327
- def search(ctx, search_term, where, limit, index_type, output_type, output):
381
+ def search(ctx, search_term, where, limit, index_type, output_type, output, auto_index):
328
382
  """Search objects in the specified collection."""
329
383
  collection = ctx.obj["settings"].collection
330
- ix = _get_index(index_type)
384
+ ix = get_indexer(index_type)
331
385
  logger.info(f"Attaching index to collection {collection.name}: {ix.model_dump()}")
332
- collection.attach_indexer(ix, auto_index=False)
386
+ collection.attach_indexer(ix, auto_index=auto_index)
333
387
  result = collection.search(search_term, where=where, limit=limit)
334
388
  output_data = render_output([{"score": row[0], **row[1]} for row in result.ranked_rows], output_type)
335
389
  if output:
@@ -22,7 +22,7 @@ def get_indexer_class(name: str) -> Type[Indexer]:
22
22
  return INDEXER_CLASSES[name]
23
23
 
24
24
 
25
- def get_indexer(name: str, *args, **kwargs) -> Indexer:
25
+ def get_indexer(name: str, **kwargs) -> Indexer:
26
26
  """
27
27
  Get an indexer by name.
28
28
 
@@ -30,4 +30,8 @@ def get_indexer(name: str, *args, **kwargs) -> Indexer:
30
30
  :param kwargs: additional arguments to pass to the indexer
31
31
  :return: the indexer
32
32
  """
33
- return get_indexer_class(name)(*args, **kwargs)
33
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
34
+ cls = get_indexer_class(name)
35
+ kwargs["name"] = name
36
+ indexer = cls(**kwargs)
37
+ return indexer
@@ -1,20 +1,34 @@
1
+ import logging
2
+ from pathlib import Path
1
3
  from typing import TYPE_CHECKING, List
2
4
 
3
5
  import numpy as np
4
6
 
7
+ from linkml_store.api.config import CollectionConfig
5
8
  from linkml_store.index.indexer import INDEX_ITEM, Indexer
6
9
 
7
10
  if TYPE_CHECKING:
8
11
  import llm
9
12
 
10
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
+
11
17
  class LLMIndexer(Indexer):
12
18
  """
13
- A implementations index wraps the llm library
19
+ An indexer that wraps the llm library.
20
+
21
+ This indexer is used to convert text to vectors using the llm library.
22
+
23
+ >>> indexer = LLMIndexer(cached_embeddings_database="tests/input/llm_cache.db")
24
+ >>> vector = indexer.text_to_vector("hello")
14
25
  """
15
26
 
16
27
  embedding_model_name: str = "ada-002"
17
28
  _embedding_model: "llm.EmbeddingModel" = None
29
+ cached_embeddings_database: str = None
30
+ cached_embeddings_collection: str = None
31
+ cache_queries: bool = False
18
32
 
19
33
  @property
20
34
  def embedding_model(self):
@@ -24,21 +38,85 @@ class LLMIndexer(Indexer):
24
38
  self._embedding_model = llm.get_embedding_model(self.embedding_model_name)
25
39
  return self._embedding_model
26
40
 
27
- def text_to_vector(self, text: str) -> INDEX_ITEM:
41
+ def text_to_vector(self, text: str, cache: bool = None, **kwargs) -> INDEX_ITEM:
28
42
  """
29
43
  Convert a text to an indexable object
30
44
 
45
+ >>> indexer = LLMIndexer(cached_embeddings_database="tests/input/llm_cache.db")
46
+ >>> vector = indexer.text_to_vector("hello")
47
+
31
48
  :param text:
32
49
  :return:
33
50
  """
34
- return self.texts_to_vectors([text])[0]
51
+ return self.texts_to_vectors([text], cache=cache, **kwargs)[0]
35
52
 
36
- def texts_to_vectors(self, texts: List[str]) -> List[INDEX_ITEM]:
53
+ def texts_to_vectors(self, texts: List[str], cache: bool = None, **kwargs) -> List[INDEX_ITEM]:
37
54
  """
38
55
  Use LLM to embed
39
56
 
57
+ >>> indexer = LLMIndexer(cached_embeddings_database="tests/input/llm_cache.db")
58
+ >>> vectors = indexer.texts_to_vectors(["hello", "goodbye"])
59
+
40
60
  :param texts:
41
61
  :return:
42
62
  """
43
- embeddings = self.embedding_model.embed_multi(texts)
63
+ logging.info(f"Converting {len(texts)} texts to vectors")
64
+ model = self.embedding_model
65
+ if self.cached_embeddings_database and (cache is None or cache or self.cache_queries):
66
+ model_id = model.model_id
67
+ if not model_id:
68
+ raise ValueError("Model ID is required to cache embeddings")
69
+ db_path = Path(self.cached_embeddings_database)
70
+ coll_name = self.cached_embeddings_collection
71
+ if not coll_name:
72
+ coll_name = "all_embeddings"
73
+ from linkml_store import Client
74
+
75
+ embeddings_client = Client()
76
+ config = CollectionConfig(
77
+ name=coll_name,
78
+ type="Embeddings",
79
+ attributes={
80
+ "text": {"range": "string"},
81
+ "model_id": {"range": "string"},
82
+ "embedding": {"range": "float", "array": {}},
83
+ },
84
+ )
85
+ embeddings_db = embeddings_client.get_database(f"duckdb:///{db_path}")
86
+ if coll_name in embeddings_db.list_collection_names():
87
+ # Load existing collection and use its model
88
+ embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
89
+ else:
90
+ embeddings_collection = embeddings_db.create_collection(coll_name, metadata=config)
91
+ texts = list(texts)
92
+ embeddings = list([None] * len(texts))
93
+ uncached_texts = []
94
+ n = 0
95
+ for i in range(len(texts)):
96
+ # TODO: optimize this
97
+ text = texts[i]
98
+ logger.info(f"Looking for cached embedding for {text}")
99
+ r = embeddings_collection.find({"text": text, "model_id": model_id})
100
+ if r.num_rows:
101
+ embeddings[i] = r.rows[0]["embedding"]
102
+ n += 1
103
+ logger.info("Found")
104
+ else:
105
+ uncached_texts.append((text, i))
106
+ logger.info("NOT Found")
107
+ logger.info(f"Found {n} cached embeddings")
108
+ if uncached_texts:
109
+ logger.info(f"Embedding {len(uncached_texts)} uncached texts")
110
+ uncached_texts, uncached_indices = zip(*uncached_texts)
111
+ uncached_embeddings = list(model.embed_multi(uncached_texts))
112
+ # TODO: combine into a single insert with multiple rows
113
+ for i, index in enumerate(uncached_indices):
114
+ logger.debug(f"Indexing text at {i}")
115
+ embeddings[index] = uncached_embeddings[i]
116
+ embeddings_collection.insert(
117
+ {"text": uncached_texts[i], "embedding": embeddings[index], "model_id": model_id}
118
+ )
119
+ else:
120
+ logger.info(f"Embedding {len(texts)} texts")
121
+ embeddings = model.embed_multi(texts)
44
122
  return [np.array(v, dtype=float) for v in embeddings]
@@ -15,7 +15,7 @@ class SimpleIndexer(Indexer):
15
15
  This uses a naive method to generate an index from text. It is not suitable for production use.
16
16
  """
17
17
 
18
- def text_to_vector(self, text: str) -> INDEX_ITEM:
18
+ def text_to_vector(self, text: str, cache: bool = None, **kwargs) -> INDEX_ITEM:
19
19
  """
20
20
  This is a naive method purely for testing
21
21
 
@@ -39,5 +39,5 @@ class SimpleIndexer(Indexer):
39
39
 
40
40
  # Increment the count at the computed index
41
41
  vector[index] += 1.0
42
- logger.info(f"Indexed text: {text} as {vector}")
42
+ logger.debug(f"Indexed text: {text} as {vector}")
43
43
  return vector
@@ -1,3 +1,5 @@
1
+ import logging
2
+ from enum import Enum
1
3
  from typing import Any, Callable, Dict, List, Optional, Tuple
2
4
 
3
5
  import numpy as np
@@ -5,6 +7,13 @@ from pydantic import BaseModel
5
7
 
6
8
  INDEX_ITEM = np.ndarray
7
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class TemplateSyntaxEnum(str, Enum):
14
+ jinja2 = "jinja2"
15
+ fstring = "fstring"
16
+
8
17
 
9
18
  def cosine_similarity(vector1, vector2):
10
19
  dot_product = np.dot(vector1, vector2)
@@ -21,8 +30,9 @@ class Indexer(BaseModel):
21
30
  name: Optional[str] = None
22
31
  index_function: Optional[Callable] = None
23
32
  distance_function: Optional[Callable] = None
24
- index_attributes: Optional[str] = None
33
+ index_attributes: Optional[List[str]] = None
25
34
  text_template: Optional[str] = None
35
+ text_template_syntax: Optional[TemplateSyntaxEnum] = None
26
36
  filter_nulls: Optional[bool] = True
27
37
  vector_default_length: Optional[int] = 1000
28
38
  index_field: Optional[str] = "__index__"
@@ -41,24 +51,25 @@ class Indexer(BaseModel):
41
51
  Convert a list of objects to indexable objects
42
52
 
43
53
  :param objs:
44
- :return:
54
+ :return: list of vectors
45
55
  """
46
- return [self.object_to_vector(obj) for obj in objs]
56
+ return self.texts_to_vectors([self.object_to_text(obj) for obj in objs])
47
57
 
48
- def texts_to_vectors(self, texts: List[str]) -> List[INDEX_ITEM]:
58
+ def texts_to_vectors(self, texts: List[str], cache: bool = None, **kwargs) -> List[INDEX_ITEM]:
49
59
  """
50
60
  Convert a list of texts to indexable objects
51
61
 
52
62
  :param texts:
53
63
  :return:
54
64
  """
55
- return [self.text_to_vector(text) for text in texts]
65
+ return [self.text_to_vector(text, cache=cache, **kwargs) for text in texts]
56
66
 
57
- def text_to_vector(self, text: str) -> INDEX_ITEM:
67
+ def text_to_vector(self, text: str, cache: bool = None, **kwargs) -> INDEX_ITEM:
58
68
  """
59
69
  Convert a text to an indexable object
60
70
 
61
71
  :param text:
72
+ :param cache:
62
73
  :return:
63
74
  """
64
75
  raise NotImplementedError
@@ -71,11 +82,24 @@ class Indexer(BaseModel):
71
82
  :return:
72
83
  """
73
84
  if self.index_attributes:
85
+ if len(self.index_attributes) == 1 and not self.text_template:
86
+ return str(obj[self.index_attributes[0]])
74
87
  obj = {k: v for k, v in obj.items() if k in self.index_attributes}
75
88
  if self.filter_nulls:
76
89
  obj = {k: v for k, v in obj.items() if v is not None}
77
90
  if self.text_template:
78
- return self.text_template.format(**obj)
91
+ syntax = self.text_template_syntax
92
+ if not syntax:
93
+ if "{%" in self.text_template or "{{" in self.text_template:
94
+ logger.info("Detected Jinja2 syntax in text template")
95
+ syntax = TemplateSyntaxEnum.jinja2
96
+ if syntax and syntax == TemplateSyntaxEnum.jinja2:
97
+ from jinja2 import Template
98
+
99
+ template = Template(self.text_template)
100
+ return template.render(**obj)
101
+ else:
102
+ return self.text_template.format(**obj)
79
103
  return str(obj)
80
104
 
81
105
  def search(
@@ -91,7 +115,7 @@ class Indexer(BaseModel):
91
115
  """
92
116
 
93
117
  # Convert the query string to a vector
94
- query_vector = self.text_to_vector(query)
118
+ query_vector = self.text_to_vector(query, cache=False)
95
119
 
96
120
  distances = []
97
121