langroid 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,12 +35,10 @@ class ToolMessage(ABC, BaseModel):
35
35
  request (str): name of agent method to map to.
36
36
  purpose (str): purpose of agent method, expressed in general terms.
37
37
  (This is used when auto-generating the tool instruction to the LLM)
38
- result (str): example of result of agent method.
39
38
  """
40
39
 
41
40
  request: str
42
41
  purpose: str
43
- result: str = ""
44
42
 
45
43
  class Config:
46
44
  arbitrary_types_allowed = False
@@ -48,7 +46,7 @@ class ToolMessage(ABC, BaseModel):
48
46
  validate_assignment = True
49
47
  # do not include these fields in the generated schema
50
48
  # since we don't require the LLM to specify them
51
- schema_extra = {"exclude": {"purpose", "result"}}
49
+ schema_extra = {"exclude": {"purpose"}}
52
50
 
53
51
  @classmethod
54
52
  def instructions(cls) -> str:
@@ -110,13 +108,13 @@ class ToolMessage(ABC, BaseModel):
110
108
  return "\n\n".join(examples_jsons)
111
109
 
112
110
  def to_json(self) -> str:
113
- return self.json(indent=4, exclude={"result", "purpose"})
111
+ return self.json(indent=4, exclude={"purpose"})
114
112
 
115
113
  def json_example(self) -> str:
116
- return self.json(indent=4, exclude={"result", "purpose"})
114
+ return self.json(indent=4, exclude={"purpose"})
117
115
 
118
116
  def dict_example(self) -> Dict[str, Any]:
119
- return self.dict(exclude={"result", "purpose"})
117
+ return self.dict(exclude={"purpose"})
120
118
 
121
119
  @classmethod
122
120
  def default_value(cls, f: str) -> Any:
@@ -220,9 +218,7 @@ class ToolMessage(ABC, BaseModel):
220
218
  if "description" not in parameters["properties"][name]:
221
219
  parameters["properties"][name]["description"] = description
222
220
 
223
- excludes = (
224
- ["result", "purpose"] if request else ["request", "result", "purpose"]
225
- )
221
+ excludes = ["purpose"] if request else ["request", "purpose"]
226
222
  # exclude 'excludes' from parameters["properties"]:
227
223
  parameters["properties"] = {
228
224
  field: details
@@ -263,5 +259,5 @@ class ToolMessage(ABC, BaseModel):
263
259
  Returns:
264
260
  Dict[str, Any]: simplified schema
265
261
  """
266
- schema = generate_simple_schema(cls, exclude=["result", "purpose"])
262
+ schema = generate_simple_schema(cls, exclude=["purpose"])
267
263
  return schema
@@ -133,12 +133,15 @@ class AzureGPT(OpenAIGPT):
133
133
  """
134
134
  Handles the setting of the GPT-4 model in the configuration.
135
135
  This function checks the `model_version` in the configuration.
136
- If the version is not set, it raises a ValueError indicating that the model
137
- version needs to be specified in the ``.env`` file.
138
- It sets `OpenAIChatModel.GPT4_TURBO` if the version is
139
- '1106-Preview', otherwise, it defaults to setting `OpenAIChatModel.GPT4`.
136
+ If the version is not set, it raises a ValueError indicating
137
+ that the model version needs to be specified in the ``.env``
138
+ file. It sets `OpenAIChatMode.GPT4o` if the version is
139
+ '2024-05-13', `OpenAIChatModel.GPT4_TURBO` if the version is
140
+ '1106-Preview', otherwise, it defaults to setting
141
+ `OpenAIChatModel.GPT4`.
140
142
  """
141
143
  VERSION_1106_PREVIEW = "1106-Preview"
144
+ VERSION_GPT4o = "2024-05-13"
142
145
 
143
146
  if self.config.model_version == "":
144
147
  raise ValueError(
@@ -146,7 +149,9 @@ class AzureGPT(OpenAIGPT):
146
149
  "Please set it to the chat model version used in your deployment."
147
150
  )
148
151
 
149
- if self.config.model_version == VERSION_1106_PREVIEW:
152
+ if self.config.model_version == VERSION_GPT4o:
153
+ self.config.chat_model = OpenAIChatModel.GPT4o
154
+ elif self.config.model_version == VERSION_1106_PREVIEW:
150
155
  self.config.chat_model = OpenAIChatModel.GPT4_TURBO
151
156
  else:
152
157
  self.config.chat_model = OpenAIChatModel.GPT4
@@ -234,6 +234,7 @@ class LLMResponse(BaseModel):
234
234
  # in this case we ignore message, since all information is in function_call
235
235
  msg = ""
236
236
  args = self.function_call.arguments
237
+ recipient = ""
237
238
  if isinstance(args, dict):
238
239
  recipient = args.get("recipient", "")
239
240
  return recipient, msg
@@ -1,4 +1,3 @@
1
- import ast
2
1
  import hashlib
3
2
  import json
4
3
  import logging
@@ -49,6 +48,7 @@ from langroid.language_models.utils import (
49
48
  async_retry_with_exponential_backoff,
50
49
  retry_with_exponential_backoff,
51
50
  )
51
+ from langroid.parsing.parse_json import parse_imperfect_json
52
52
  from langroid.pydantic_v1 import BaseModel
53
53
  from langroid.utils.configuration import settings
54
54
  from langroid.utils.constants import Colors
@@ -797,11 +797,24 @@ class OpenAIGPT(LanguageModel):
797
797
  args = {}
798
798
  if has_function and function_args != "":
799
799
  try:
800
- args = ast.literal_eval(function_args.strip())
801
- except (SyntaxError, ValueError):
800
+ stripped_fn_args = function_args.strip()
801
+ dict_or_list = parse_imperfect_json(stripped_fn_args)
802
+ if not isinstance(dict_or_list, dict):
803
+ raise ValueError(
804
+ f"""
805
+ Invalid function args: {stripped_fn_args}
806
+ parsed as {dict_or_list},
807
+ which is not a valid dict.
808
+ """
809
+ )
810
+ args = dict_or_list
811
+ except (SyntaxError, ValueError) as e:
802
812
  logging.warning(
803
- f"Parsing OpenAI function args failed: {function_args};"
804
- " treating args as normal message"
813
+ f"""
814
+ Parsing OpenAI function args failed: {function_args};
815
+ treating args as normal message. Error detail:
816
+ {e}
817
+ """
805
818
  )
806
819
  has_function = False
807
820
  completion = completion + function_args
@@ -1,5 +1,6 @@
1
+ import ast
1
2
  import json
2
- from typing import Any, Iterator, List
3
+ from typing import Any, Dict, Iterator, List, Union
3
4
 
4
5
  import yaml
5
6
  from pyparsing import nestedExpr, originalTextFor
@@ -73,6 +74,31 @@ def add_quotes(s: str) -> str:
73
74
  return s
74
75
 
75
76
 
77
+ def parse_imperfect_json(json_string: str) -> Union[Dict[str, Any], List[Any]]:
78
+ if not json_string.strip():
79
+ raise ValueError("Empty string is not valid JSON")
80
+
81
+ # First, try parsing with ast.literal_eval
82
+ try:
83
+ result = ast.literal_eval(json_string)
84
+ if isinstance(result, (dict, list)):
85
+ return result
86
+ except (ValueError, SyntaxError):
87
+ pass
88
+
89
+ # If ast.literal_eval fails or returns non-dict/list, try json.loads
90
+ try:
91
+ str = add_quotes(json_string)
92
+ result = json.loads(str)
93
+ if isinstance(result, (dict, list)):
94
+ return result
95
+ except json.JSONDecodeError:
96
+ pass
97
+
98
+ # If all methods fail, raise ValueError
99
+ raise ValueError(f"Unable to parse as JSON: {json_string}")
100
+
101
+
76
102
  def repair_newlines(s: str) -> str:
77
103
  """
78
104
  Attempt to load as json, and if it fails, try with newlines replaced by space.
@@ -9,8 +9,6 @@ from typing import (
9
9
  Tuple,
10
10
  Type,
11
11
  TypeVar,
12
- get_args,
13
- get_origin,
14
12
  no_type_check,
15
13
  )
16
14
 
@@ -313,54 +311,6 @@ def pydantic_obj_from_flat_dict(
313
311
  return model(**nested_data)
314
312
 
315
313
 
316
- def clean_schema(model: Type[BaseModel], excludes: List[str] = []) -> Dict[str, Any]:
317
- """
318
- Generate a simple schema for a given Pydantic model,
319
- including inherited fields, with an option to exclude certain fields.
320
- Handles cases where fields are Lists or other generic types and includes
321
- field descriptions if available.
322
-
323
- Args:
324
- model (Type[BaseModel]): The Pydantic model class.
325
- excludes (List[str]): A list of field names to exclude.
326
-
327
- Returns:
328
- Dict[str, Any]: A dictionary representing the simple schema.
329
- """
330
- schema = {}
331
-
332
- for field_name, field_info in model.__fields__.items():
333
- if field_name in excludes:
334
- continue
335
-
336
- field_type = field_info.outer_type_
337
- description = field_info.field_info.description or ""
338
-
339
- # Handle generic types like List[...]
340
- if get_origin(field_type):
341
- inner_types = get_args(field_type)
342
- inner_type_names = [
343
- t.__name__ if hasattr(t, "__name__") else str(t) for t in inner_types
344
- ]
345
- field_type_str = (
346
- f"{get_origin(field_type).__name__}" f'[{", ".join(inner_type_names)}]'
347
- )
348
- schema[field_name] = {"type": field_type_str, "description": description}
349
- elif issubclass(field_type, BaseModel):
350
- # Directly use the nested model's schema,
351
- # integrating it into the current level
352
- nested_schema = clean_schema(field_type, excludes)
353
- schema[field_name] = {**nested_schema, "description": description}
354
- else:
355
- # For basic types, use 'type'
356
- schema[field_name] = {
357
- "type": field_type.__name__,
358
- "description": description,
359
- }
360
-
361
- return schema
362
-
363
-
364
314
  @contextmanager
365
315
  def temp_update(
366
316
  pydantic_object: BaseModel, updates: Dict[str, Any]
@@ -1,14 +1,14 @@
1
1
  import copy
2
2
  import logging
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, List, Optional, Sequence, Tuple
4
+ from typing import Dict, List, Optional, Sequence, Tuple, Type
5
5
 
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
 
9
9
  from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
10
10
  from langroid.embedding_models.models import OpenAIEmbeddingsConfig
11
- from langroid.mytypes import Document
11
+ from langroid.mytypes import DocMetaData, Document
12
12
  from langroid.pydantic_v1 import BaseSettings
13
13
  from langroid.utils.algorithms.graph import components, topological_sort
14
14
  from langroid.utils.configuration import settings
@@ -32,6 +32,9 @@ class VectorStoreConfig(BaseSettings):
32
32
  timeout: int = 60
33
33
  host: str = "127.0.0.1"
34
34
  port: int = 6333
35
+ # used when parsing search results back as Document objects
36
+ document_class: Type[Document] = Document
37
+ metadata_class: Type[DocMetaData] = DocMetaData
35
38
  # compose_file: str = "langroid/vector_store/docker-compose-qdrant.yml"
36
39
 
37
40
 
@@ -113,8 +116,7 @@ class VectorStore(ABC):
113
116
  """
114
117
 
115
118
  self.config.collection_name = collection_name
116
- if collection_name not in self.list_collections() or replace:
117
- self.create_collection(collection_name, replace=replace)
119
+ self.config.replace_collection = replace
118
120
 
119
121
  @abstractmethod
120
122
  def create_collection(self, collection_name: str, replace: bool = False) -> None:
@@ -8,7 +8,7 @@ from langroid.embedding_models.base import (
8
8
  )
9
9
  from langroid.embedding_models.models import OpenAIEmbeddingsConfig
10
10
  from langroid.exceptions import LangroidImportError
11
- from langroid.mytypes import DocMetaData, Document
11
+ from langroid.mytypes import Document
12
12
  from langroid.utils.configuration import settings
13
13
  from langroid.utils.output.printing import print_long_text
14
14
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
@@ -200,7 +200,9 @@ class ChromaDB(VectorStore):
200
200
  else:
201
201
  m["window_ids"] = m["window_ids"].split(",")
202
202
  docs = [
203
- Document(content=d, metadata=DocMetaData(**m))
203
+ self.config.document_class(
204
+ content=d, metadata=self.config.metadata_class(**m)
205
+ )
204
206
  for d, m in zip(contents, metadatas)
205
207
  ]
206
208
  return docs
@@ -32,13 +32,7 @@ from langroid.utils.configuration import settings
32
32
  from langroid.utils.pydantic_utils import (
33
33
  dataframe_to_document_model,
34
34
  dataframe_to_documents,
35
- extend_document_class,
36
- extra_metadata,
37
- flatten_pydantic_instance,
38
- flatten_pydantic_model,
39
- nested_dict_from_flat,
40
35
  )
41
- from langroid.utils.system import pydantic_major_version
42
36
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
43
37
 
44
38
  try:
@@ -58,10 +52,6 @@ class LanceDBConfig(VectorStoreConfig):
58
52
  storage_path: str = ".lancedb/data"
59
53
  embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
60
54
  distance: str = "cosine"
61
- # document_class is used to store in lancedb with right schema,
62
- # and also to retrieve the right type of Documents when searching.
63
- document_class: Type[Document] = Document
64
- flatten: bool = False # flatten Document class into LanceSchema ?
65
55
 
66
56
 
67
57
  class LanceDB(VectorStore):
@@ -78,7 +68,6 @@ class LanceDB(VectorStore):
78
68
  self.port = config.port
79
69
  self.is_from_dataframe = False # were docs ingested from a dataframe?
80
70
  self.df_metadata_columns: List[str] = [] # metadata columns from dataframe
81
- self._setup_schemas(config.document_class)
82
71
 
83
72
  load_dotenv()
84
73
  if self.config.cloud:
@@ -104,40 +93,6 @@ class LanceDB(VectorStore):
104
93
  uri=new_storage_path,
105
94
  )
106
95
 
107
- # Note: Only create collection if a non-null collection name is provided.
108
- # This is useful to delay creation of vecdb until we have a suitable
109
- # collection name (e.g. we could get it from the url or folder path).
110
- if config.collection_name is not None:
111
- self.create_collection(
112
- config.collection_name, replace=config.replace_collection
113
- )
114
-
115
- def _setup_schemas(self, doc_cls: Type[Document] | None) -> None:
116
- try:
117
- doc_cls = doc_cls or self.config.document_class
118
- self.unflattened_schema = self._create_lance_schema(doc_cls)
119
- self.schema = (
120
- self._create_flat_lance_schema(doc_cls)
121
- if self.config.flatten
122
- else self.unflattened_schema
123
- )
124
- except (AttributeError, TypeError) as e:
125
- pydantic_version = pydantic_major_version()
126
- if pydantic_version > 1:
127
- raise ValueError(
128
- f"""
129
- {e}
130
- ====
131
- You are using Pydantic v{pydantic_version},
132
- which is not yet compatible with Langroid's LanceDB integration.
133
- To use Lancedb with Langroid, please install the
134
- latest pydantic 1.x instead of pydantic v2, e.g.
135
- pip install "pydantic<2.0.0"
136
- """
137
- )
138
- else:
139
- raise e
140
-
141
96
  def clear_empty_collections(self) -> int:
142
97
  coll_names = self.list_collections()
143
98
  n_deletes = 0
@@ -234,91 +189,8 @@ class LanceDB(VectorStore):
234
189
  ) # type: ignore
235
190
  return NewModel # type: ignore
236
191
 
237
- def _create_flat_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
238
- """
239
- Flat version of the lance_schema, as nested Pydantic schemas are not yet
240
- supported by LanceDB.
241
- """
242
- if not has_lancedb:
243
- raise LangroidImportError("lancedb", "lancedb")
244
- lance_model = self._create_lance_schema(doc_cls)
245
- FlatModel = flatten_pydantic_model(lance_model, base_model=LanceModel)
246
- return FlatModel
247
-
248
192
  def create_collection(self, collection_name: str, replace: bool = False) -> None:
249
- """
250
- Create a collection with the given name, optionally replacing an existing
251
- collection if `replace` is True.
252
- Args:
253
- collection_name (str): Name of the collection to create.
254
- replace (bool): Whether to replace an existing collection
255
- with the same name. Defaults to False.
256
- """
257
- self.config.collection_name = collection_name
258
- collections = self.list_collections()
259
- if collection_name in collections:
260
- coll = self.client.open_table(collection_name)
261
- if coll.head().shape[0] > 0:
262
- logger.warning(f"Non-empty Collection {collection_name} already exists")
263
- if not replace:
264
- logger.warning("Not replacing collection")
265
- return
266
- else:
267
- logger.warning("Recreating fresh collection")
268
- try:
269
- self.client.create_table(
270
- collection_name, schema=self.schema, mode="overwrite"
271
- )
272
- except (AttributeError, TypeError) as e:
273
- pydantic_version = pydantic_major_version()
274
- if pydantic_version > 1:
275
- raise ValueError(
276
- f"""
277
- {e}
278
- ====
279
- You are using Pydantic v{pydantic_version},
280
- which is not yet compatible with Langroid's LanceDB integration.
281
- To use Lancedb with Langroid, please install the
282
- latest pydantic 1.x instead of pydantic v2, e.g.
283
- pip install "pydantic<2.0.0"
284
- """
285
- )
286
- else:
287
- raise e
288
-
289
- if settings.debug:
290
- level = logger.getEffectiveLevel()
291
- logger.setLevel(logging.INFO)
292
- logger.setLevel(level)
293
-
294
- def _maybe_set_doc_class_schema(self, doc: Document) -> None:
295
- """
296
- Set the config.document_class and self.schema based on doc if needed
297
- Args:
298
- doc: an instance of Document, to be added to a collection
299
- """
300
- extra_metadata_fields = extra_metadata(doc, self.config.document_class)
301
- if len(extra_metadata_fields) > 0:
302
- logger.warning(
303
- f"""
304
- Added documents contain extra metadata fields:
305
- {extra_metadata_fields}
306
- which were not present in the original config.document_class.
307
- Trying to change document_class and corresponding schemas.
308
- Overriding LanceDBConfig.document_class with an auto-generated
309
- Pydantic class that includes these extra fields.
310
- If this fails, or you see odd results, it is recommended that you
311
- define a subclass of Document, with metadata of class derived from
312
- DocMetaData, with extra fields defined via
313
- `Field(..., description="...")` declarations,
314
- and set this document class as the value of the
315
- LanceDBConfig.document_class attribute.
316
- """
317
- )
318
-
319
- doc_cls = extend_document_class(doc)
320
- self.config.document_class = doc_cls
321
- self._setup_schemas(doc_cls)
193
+ self.config.replace_collection = replace
322
194
 
323
195
  def add_documents(self, documents: Sequence[Document]) -> None:
324
196
  super().maybe_add_ids(documents)
@@ -329,39 +201,52 @@ class LanceDB(VectorStore):
329
201
  coll_name = self.config.collection_name
330
202
  if coll_name is None:
331
203
  raise ValueError("No collection name set, cannot ingest docs")
332
- self._maybe_set_doc_class_schema(documents[0])
204
+ # self._maybe_set_doc_class_schema(documents[0])
205
+ table_exists = False
333
206
  if (
334
- coll_name not in colls
335
- or self.client.open_table(coll_name).head(1).shape[0] == 0
207
+ coll_name in colls
208
+ and self.client.open_table(coll_name).head(1).shape[0] > 0
336
209
  ):
337
- # collection either doesn't exist or is empty, so replace it,
338
- self.create_collection(coll_name, replace=True)
210
+ # collection exists and is not empty:
211
+ # if replace_collection is True, we'll overwrite the existing collection,
212
+ # else we'll append to it.
213
+ if self.config.replace_collection:
214
+ self.client.drop_table(coll_name)
215
+ else:
216
+ table_exists = True
339
217
 
340
218
  ids = [str(d.id()) for d in documents]
341
219
  # don't insert all at once, batch in chunks of b,
342
220
  # else we get an API error
343
221
  b = self.config.batch_size
344
222
 
345
- def make_batches() -> Generator[List[BaseModel], None, None]:
223
+ def make_batches() -> Generator[List[Dict[str, Any]], None, None]:
346
224
  for i in range(0, len(ids), b):
347
225
  batch = [
348
- self.unflattened_schema(
226
+ dict(
349
227
  id=ids[i + j],
350
228
  vector=embedding_vecs[i + j],
351
229
  **doc.dict(),
352
230
  )
353
231
  for j, doc in enumerate(documents[i : i + b])
354
232
  ]
355
- if self.config.flatten:
356
- batch = [
357
- flatten_pydantic_instance(instance) # type: ignore
358
- for instance in batch
359
- ]
360
233
  yield batch
361
234
 
362
- tbl = self.client.open_table(self.config.collection_name)
363
235
  try:
364
- tbl.add(make_batches())
236
+ if table_exists:
237
+ tbl = self.client.open_table(coll_name)
238
+ tbl.add(make_batches())
239
+ else:
240
+ batch_gen = make_batches()
241
+ batch = next(batch_gen)
242
+ # use first batch to create table...
243
+ tbl = self.client.create_table(
244
+ coll_name,
245
+ data=batch,
246
+ mode="create",
247
+ )
248
+ # ... and add the rest
249
+ tbl.add(batch_gen)
365
250
  except Exception as e:
366
251
  logger.error(
367
252
  f"""
@@ -427,7 +312,6 @@ class LanceDB(VectorStore):
427
312
  exclude=["vector"],
428
313
  )
429
314
  self.config.document_class = doc_cls # type: ignore
430
- self._setup_schemas(doc_cls) # type: ignore
431
315
  else:
432
316
  # collection exists and is not empty, so append to it
433
317
  tbl = self.client.open_table(self.config.collection_name)
@@ -452,35 +336,19 @@ class LanceDB(VectorStore):
452
336
  return self._records_to_docs(records)
453
337
 
454
338
  def _records_to_docs(self, records: List[Dict[str, Any]]) -> List[Document]:
455
- if self.config.flatten:
456
- docs = [
457
- self.unflattened_schema(**nested_dict_from_flat(rec)) for rec in records
458
- ]
459
- else:
460
- try:
461
- docs = [self.schema(**rec) for rec in records]
462
- except ValidationError as e:
463
- raise ValueError(
464
- f"""
465
- Error validating LanceDB result: {e}
466
- HINT: This could happen when you're re-using an
467
- existing LanceDB store with a different schema.
468
- Try deleting your local lancedb storage at `{self.config.storage_path}`
469
- re-ingesting your documents and/or replacing the collections.
470
- """
471
- )
472
-
473
- doc_cls = self.config.document_class
474
- doc_cls_field_names = doc_cls.__fields__.keys()
475
- return [
476
- doc_cls(
477
- **{
478
- field_name: getattr(doc, field_name)
479
- for field_name in doc_cls_field_names
480
- }
339
+ try:
340
+ docs = [self.config.document_class(**rec) for rec in records]
341
+ except ValidationError as e:
342
+ raise ValueError(
343
+ f"""
344
+ Error validating LanceDB result: {e}
345
+ HINT: This could happen when you're re-using an
346
+ existing LanceDB store with a different schema.
347
+ Try deleting your local lancedb storage at `{self.config.storage_path}`
348
+ re-ingesting your documents and/or replacing the collections.
349
+ """
481
350
  )
482
- for doc in docs
483
- ]
351
+ return docs
484
352
 
485
353
  def get_all_documents(self, where: str = "") -> List[Document]:
486
354
  if self.config.collection_name is None:
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Momento Vector Index.
3
3
  https://docs.momentohq.com/vector-index/develop/api-reference
4
+ DEPRECATED: API is unstable.
4
5
  """
5
6
 
6
7
  from __future__ import annotations
@@ -63,6 +63,7 @@ def is_valid_uuid(uuid_to_test: str) -> bool:
63
63
 
64
64
  class QdrantDBConfig(VectorStoreConfig):
65
65
  cloud: bool = True
66
+ docker: bool = False
66
67
  collection_name: str | None = "temp"
67
68
  storage_path: str = ".qdrant/data"
68
69
  embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
@@ -102,7 +103,19 @@ class QdrantDB(VectorStore):
102
103
  load_dotenv()
103
104
  key = os.getenv("QDRANT_API_KEY")
104
105
  url = os.getenv("QDRANT_API_URL")
105
- if config.cloud and None in [key, url]:
106
+ if config.docker:
107
+ if url is None:
108
+ logger.warning(
109
+ f"""The QDRANT_API_URL env variable must be set to use
110
+ QdrantDB in local docker mode. Please set this
111
+ value in your .env file.
112
+ Switching to local storage at {config.storage_path}
113
+ """
114
+ )
115
+ config.cloud = False
116
+ else:
117
+ config.cloud = True
118
+ elif config.cloud and None in [key, url]:
106
119
  logger.warning(
107
120
  f"""QDRANT_API_KEY, QDRANT_API_URL env variable must be set to use
108
121
  QdrantDB in cloud mode. Please set these values
@@ -111,6 +124,7 @@ class QdrantDB(VectorStore):
111
124
  """
112
125
  )
113
126
  config.cloud = False
127
+
114
128
  if config.cloud:
115
129
  self.client = QdrantClient(
116
130
  url=url,
@@ -366,7 +380,11 @@ class QdrantDB(VectorStore):
366
380
  with_payload=True,
367
381
  with_vectors=False,
368
382
  )
369
- docs += [Document(**record.payload) for record in results] # type: ignore
383
+ docs += [
384
+ self.config.document_class(**record.payload) # type: ignore
385
+ for record in results
386
+ ]
387
+ # ignore
370
388
  if next_page_offset is None:
371
389
  break
372
390
  offset = next_page_offset # type: ignore
@@ -385,7 +403,7 @@ class QdrantDB(VectorStore):
385
403
  # Note the records may NOT be in the order of the ids,
386
404
  # so we re-order them here.
387
405
  id2payload = {record.id: record.payload for record in records}
388
- ordered_payloads = [id2payload[id] for id in _ids]
406
+ ordered_payloads = [id2payload[id] for id in _ids if id in id2payload]
389
407
  docs = [Document(**payload) for payload in ordered_payloads] # type: ignore
390
408
  return docs
391
409
 
@@ -437,7 +455,7 @@ class QdrantDB(VectorStore):
437
455
  ] # 2D list -> 1D list
438
456
  scores = [match.score for match in search_result if match is not None]
439
457
  docs = [
440
- Document(**(match.payload)) # type: ignore
458
+ self.config.document_class(**(match.payload)) # type: ignore
441
459
  for match in search_result
442
460
  if match is not None
443
461
  ]