langroid 0.1.252__py3-none-any.whl → 0.1.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +16 -15
- langroid/agent/__init__.py +1 -0
- langroid/agent/callbacks/chainlit.py +5 -12
- langroid/agent/special/__init__.py +3 -2
- langroid/agent/special/doc_chat_agent.py +36 -56
- langroid/agent/special/neo4j/csv_kg_chat.py +2 -2
- langroid/agent/special/sql/__init__.py +1 -2
- langroid/agent/special/sql/sql_chat_agent.py +10 -4
- langroid/agent/special/sql/utils/__init__.py +4 -5
- langroid/agent/special/sql/utils/description_extractors.py +7 -2
- langroid/agent/special/sql/utils/populate_metadata.py +6 -1
- langroid/agent/special/table_chat_agent.py +2 -2
- langroid/agent/tool_message.py +14 -3
- langroid/agent/tools/__init__.py +2 -3
- langroid/agent/tools/duckduckgo_search_tool.py +2 -2
- langroid/agent/tools/google_search_tool.py +2 -2
- langroid/agent/tools/metaphor_search_tool.py +2 -2
- langroid/agent/tools/retrieval_tool.py +2 -2
- langroid/agent/tools/run_python_code.py +2 -2
- langroid/agent/tools/segment_extract_tool.py +2 -2
- langroid/cachedb/base.py +10 -2
- langroid/cachedb/momento_cachedb.py +10 -4
- langroid/cachedb/redis_cachedb.py +2 -3
- langroid/embedding_models/__init__.py +1 -0
- langroid/exceptions.py +57 -0
- langroid/language_models/__init__.py +1 -0
- langroid/language_models/base.py +2 -3
- langroid/language_models/openai_gpt.py +15 -14
- langroid/language_models/prompt_formatter/__init__.py +4 -3
- langroid/parsing/document_parser.py +20 -4
- langroid/parsing/parser.pyi +56 -0
- langroid/utils/logging.py +7 -3
- langroid/utils/output/__init__.py +1 -2
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +7 -2
- langroid/vector_store/__init__.py +33 -17
- langroid/vector_store/chromadb.py +2 -8
- langroid/vector_store/lancedb.py +36 -5
- langroid/vector_store/meilisearch.py +21 -11
- langroid/vector_store/momento.py +31 -14
- {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/METADATA +31 -29
- {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/RECORD +44 -42
- {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/LICENSE +0 -0
- {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/WHEEL +0 -0
langroid/exceptions.py
CHANGED
@@ -1,3 +1,60 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
|
1
4
|
class InfiniteLoopException(Exception):
|
2
5
|
def __init__(self, message: str = "Infinite loop detected", *args: object) -> None:
|
3
6
|
super().__init__(message, *args)
|
7
|
+
|
8
|
+
|
9
|
+
class LangroidImportError(ImportError):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
package: Optional[str] = None,
|
13
|
+
extra: Optional[str] = None,
|
14
|
+
error: str = "",
|
15
|
+
*args: object,
|
16
|
+
) -> None:
|
17
|
+
"""
|
18
|
+
Generate helpful warning when attempting to import package or module.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
package (str): The name of the package to import.
|
22
|
+
extra (str): The name of the extras package required for this import.
|
23
|
+
error (str): The error message to display. Depending on context, we
|
24
|
+
can set this by capturing the ImportError message.
|
25
|
+
|
26
|
+
"""
|
27
|
+
if error == "" and package is not None:
|
28
|
+
error = f"{package} is not installed by default with Langroid.\n"
|
29
|
+
|
30
|
+
if extra:
|
31
|
+
install_help = f"""
|
32
|
+
If you want to use it, please install langroid
|
33
|
+
with the `{extra}` extra, for example:
|
34
|
+
|
35
|
+
If you are using pip:
|
36
|
+
pip install "langroid[{extra}]"
|
37
|
+
|
38
|
+
For multiple extras, you can separate them with commas:
|
39
|
+
pip install "langroid[{extra},another-extra]"
|
40
|
+
|
41
|
+
If you are using Poetry:
|
42
|
+
poetry add langroid --extras "{extra}"
|
43
|
+
|
44
|
+
For multiple extras with Poetry, list them with spaces:
|
45
|
+
poetry add langroid --extras "{extra} another-extra"
|
46
|
+
|
47
|
+
If you are working within the langroid dev env (which uses Poetry),
|
48
|
+
you can do:
|
49
|
+
poetry install -E "{extra}"
|
50
|
+
or if you want to include multiple extras:
|
51
|
+
poetry install -E "{extra} another-extra"
|
52
|
+
"""
|
53
|
+
else:
|
54
|
+
install_help = """
|
55
|
+
If you want to use it, please install it in the same
|
56
|
+
virtual environment as langroid.
|
57
|
+
"""
|
58
|
+
msg = error + install_help
|
59
|
+
|
60
|
+
super().__init__(msg, *args)
|
langroid/language_models/base.py
CHANGED
@@ -10,8 +10,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
10
10
|
import aiohttp
|
11
11
|
from pydantic import BaseModel, BaseSettings, Field
|
12
12
|
|
13
|
-
from langroid.cachedb.
|
14
|
-
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
13
|
+
from langroid.cachedb.base import CacheDBConfig
|
15
14
|
from langroid.mytypes import Document
|
16
15
|
from langroid.parsing.agent_chats import parse_message
|
17
16
|
from langroid.parsing.parse_json import top_level_json_field
|
@@ -49,7 +48,7 @@ class LLMConfig(BaseSettings):
|
|
49
48
|
# use chat model for completion? For OpenAI models, this MUST be set to True!
|
50
49
|
use_chat_for_completion: bool = True
|
51
50
|
stream: bool = True # stream output from API?
|
52
|
-
cache_config: None |
|
51
|
+
cache_config: None | CacheDBConfig = None
|
53
52
|
|
54
53
|
# Dict of model -> (input/prompt cost, output/completion cost)
|
55
54
|
chat_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
@@ -28,8 +28,9 @@ from pydantic import BaseModel
|
|
28
28
|
from rich import print
|
29
29
|
from rich.markup import escape
|
30
30
|
|
31
|
-
from langroid.cachedb.
|
31
|
+
from langroid.cachedb.base import CacheDB
|
32
32
|
from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
|
33
|
+
from langroid.exceptions import LangroidImportError
|
33
34
|
from langroid.language_models.base import (
|
34
35
|
LanguageModel,
|
35
36
|
LLMConfig,
|
@@ -280,14 +281,7 @@ class OpenAIGPTConfig(LLMConfig):
|
|
280
281
|
try:
|
281
282
|
import litellm
|
282
283
|
except ImportError:
|
283
|
-
raise
|
284
|
-
"""
|
285
|
-
litellm not installed. Please install it via:
|
286
|
-
pip install litellm.
|
287
|
-
Or when installing langroid, install it with the `litellm` extra:
|
288
|
-
pip install langroid[litellm]
|
289
|
-
"""
|
290
|
-
)
|
284
|
+
raise LangroidImportError("litellm", "litellm")
|
291
285
|
litellm.telemetry = False
|
292
286
|
litellm.drop_params = True # drop un-supported params without crashing
|
293
287
|
self.seed = None # some local mdls don't support seed
|
@@ -482,17 +476,24 @@ class OpenAIGPT(LanguageModel):
|
|
482
476
|
timeout=Timeout(self.config.timeout),
|
483
477
|
)
|
484
478
|
|
485
|
-
self.cache:
|
479
|
+
self.cache: CacheDB
|
486
480
|
if settings.cache_type == "momento":
|
487
|
-
|
488
|
-
|
481
|
+
from langroid.cachedb.momento_cachedb import (
|
482
|
+
MomentoCache,
|
483
|
+
MomentoCacheConfig,
|
484
|
+
)
|
485
|
+
|
486
|
+
if config.cache_config is None or not isinstance(
|
487
|
+
config.cache_config,
|
488
|
+
MomentoCacheConfig,
|
489
489
|
):
|
490
490
|
# switch to fresh momento config if needed
|
491
491
|
config.cache_config = MomentoCacheConfig()
|
492
492
|
self.cache = MomentoCache(config.cache_config)
|
493
493
|
elif "redis" in settings.cache_type:
|
494
|
-
if config.cache_config is None or isinstance(
|
495
|
-
config.cache_config,
|
494
|
+
if config.cache_config is None or not isinstance(
|
495
|
+
config.cache_config,
|
496
|
+
RedisCacheConfig,
|
496
497
|
):
|
497
498
|
# switch to fresh redis config if needed
|
498
499
|
config.cache_config = RedisCacheConfig(
|
@@ -1,9 +1,10 @@
|
|
1
|
+
from . import base
|
2
|
+
from . import llama2_formatter
|
1
3
|
from .base import PromptFormatter
|
2
4
|
from .llama2_formatter import Llama2Formatter
|
3
|
-
from ..config import PromptFormatterConfig
|
5
|
+
from ..config import PromptFormatterConfig
|
6
|
+
from ..config import Llama2FormatterConfig
|
4
7
|
|
5
|
-
from . import base
|
6
|
-
from . import llama2_formatter
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"PromptFormatter",
|
@@ -5,9 +5,19 @@ from enum import Enum
|
|
5
5
|
from io import BytesIO
|
6
6
|
from typing import Any, Generator, List, Tuple
|
7
7
|
|
8
|
-
import
|
8
|
+
from langroid.exceptions import LangroidImportError
|
9
|
+
|
10
|
+
try:
|
11
|
+
import fitz
|
12
|
+
except ImportError:
|
13
|
+
raise LangroidImportError("PyMuPDF", "pdf-parsers")
|
14
|
+
|
15
|
+
try:
|
16
|
+
import pypdf
|
17
|
+
except ImportError:
|
18
|
+
raise LangroidImportError("pypdf", "pdf-parsers")
|
19
|
+
|
9
20
|
import pdfplumber
|
10
|
-
import pypdf
|
11
21
|
import requests
|
12
22
|
from bs4 import BeautifulSoup
|
13
23
|
from PIL import Image
|
@@ -456,7 +466,10 @@ class ImagePdfParser(DocumentParser):
|
|
456
466
|
def iterate_pages(
|
457
467
|
self,
|
458
468
|
) -> Generator[Tuple[int, "Image"], None, None]: # type: ignore
|
459
|
-
|
469
|
+
try:
|
470
|
+
from pdf2image import convert_from_bytes
|
471
|
+
except ImportError:
|
472
|
+
raise LangroidImportError("pdf2image", "pdf-parsers")
|
460
473
|
|
461
474
|
images = convert_from_bytes(self.doc_bytes.getvalue())
|
462
475
|
for i, image in enumerate(images):
|
@@ -472,7 +485,10 @@ class ImagePdfParser(DocumentParser):
|
|
472
485
|
Returns:
|
473
486
|
str: Extracted text from the image.
|
474
487
|
"""
|
475
|
-
|
488
|
+
try:
|
489
|
+
import pytesseract
|
490
|
+
except ImportError:
|
491
|
+
raise LangroidImportError("pytesseract", "pdf-parsers")
|
476
492
|
|
477
493
|
text = pytesseract.image_to_string(page)
|
478
494
|
return self.fix_text(text)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
from _typeshed import Incomplete
|
5
|
+
from pydantic import BaseSettings
|
6
|
+
|
7
|
+
from langroid.mytypes import Document as Document
|
8
|
+
from langroid.parsing.para_sentence_split import (
|
9
|
+
create_chunks as create_chunks,
|
10
|
+
)
|
11
|
+
from langroid.parsing.para_sentence_split import (
|
12
|
+
remove_extra_whitespace as remove_extra_whitespace,
|
13
|
+
)
|
14
|
+
|
15
|
+
logger: Incomplete
|
16
|
+
|
17
|
+
class Splitter(str, Enum):
|
18
|
+
TOKENS: str
|
19
|
+
PARA_SENTENCE: str
|
20
|
+
SIMPLE: str
|
21
|
+
|
22
|
+
class PdfParsingConfig(BaseSettings):
|
23
|
+
library: Literal["fitz", "pdfplumber", "pypdf", "unstructured", "pdf2image"]
|
24
|
+
|
25
|
+
class DocxParsingConfig(BaseSettings):
|
26
|
+
library: Literal["python-docx", "unstructured"]
|
27
|
+
|
28
|
+
class DocParsingConfig(BaseSettings):
|
29
|
+
library: Literal["unstructured"]
|
30
|
+
|
31
|
+
class ParsingConfig(BaseSettings):
|
32
|
+
splitter: str
|
33
|
+
chunk_size: int
|
34
|
+
overlap: int
|
35
|
+
max_chunks: int
|
36
|
+
min_chunk_chars: int
|
37
|
+
discard_chunk_chars: int
|
38
|
+
n_similar_docs: int
|
39
|
+
n_neighbor_ids: int
|
40
|
+
separators: list[str]
|
41
|
+
token_encoding_model: str
|
42
|
+
pdf: PdfParsingConfig
|
43
|
+
docx: DocxParsingConfig
|
44
|
+
doc: DocParsingConfig
|
45
|
+
|
46
|
+
class Parser:
|
47
|
+
config: Incomplete
|
48
|
+
tokenizer: Incomplete
|
49
|
+
def __init__(self, config: ParsingConfig) -> None: ...
|
50
|
+
def num_tokens(self, text: str) -> int: ...
|
51
|
+
def add_window_ids(self, chunks: list[Document]) -> None: ...
|
52
|
+
def split_simple(self, docs: list[Document]) -> list[Document]: ...
|
53
|
+
def split_para_sentence(self, docs: list[Document]) -> list[Document]: ...
|
54
|
+
def split_chunk_tokens(self, docs: list[Document]) -> list[Document]: ...
|
55
|
+
def chunk_tokens(self, text: str) -> list[str]: ...
|
56
|
+
def split(self, docs: list[Document]) -> list[Document]: ...
|
langroid/utils/logging.py
CHANGED
@@ -31,7 +31,11 @@ def setup_colored_logging() -> None:
|
|
31
31
|
# logger.setLevel(logging.DEBUG)
|
32
32
|
|
33
33
|
|
34
|
-
def setup_logger(
|
34
|
+
def setup_logger(
|
35
|
+
name: str,
|
36
|
+
level: int = logging.INFO,
|
37
|
+
terminal: bool = False,
|
38
|
+
) -> logging.Logger:
|
35
39
|
"""
|
36
40
|
Set up a logger of module `name` at a desired level.
|
37
41
|
Args:
|
@@ -42,7 +46,7 @@ def setup_logger(name: str, level: int = logging.INFO) -> logging.Logger:
|
|
42
46
|
"""
|
43
47
|
logger = logging.getLogger(name)
|
44
48
|
logger.setLevel(level)
|
45
|
-
if not logger.hasHandlers():
|
49
|
+
if not logger.hasHandlers() and terminal:
|
46
50
|
handler = logging.StreamHandler()
|
47
51
|
formatter = logging.Formatter(
|
48
52
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
@@ -73,7 +77,7 @@ def setup_file_logger(
|
|
73
77
|
) -> logging.Logger:
|
74
78
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
75
79
|
file_mode = "a" if append else "w"
|
76
|
-
logger = setup_logger(name)
|
80
|
+
logger = setup_logger(name, terminal=False)
|
77
81
|
handler = logging.FileHandler(filename, mode=file_mode)
|
78
82
|
handler.setLevel(logging.INFO)
|
79
83
|
if log_format:
|
@@ -1,5 +1,4 @@
|
|
1
1
|
from . import printing
|
2
|
-
|
3
2
|
from .printing import (
|
4
3
|
shorten_text,
|
5
4
|
print_long_text,
|
@@ -7,9 +6,9 @@ from .printing import (
|
|
7
6
|
SuppressLoggerWarnings,
|
8
7
|
PrintColored,
|
9
8
|
)
|
10
|
-
|
11
9
|
from .status import status
|
12
10
|
|
11
|
+
|
13
12
|
__all__ = [
|
14
13
|
"printing",
|
15
14
|
"shorten_text",
|
@@ -0,0 +1,41 @@
|
|
1
|
+
def extract_markdown_references(md_string: str) -> list[int]:
|
2
|
+
"""
|
3
|
+
Extracts markdown references (e.g., [^1], [^2]) from a string and returns
|
4
|
+
them as a sorted list of integers.
|
5
|
+
|
6
|
+
Args:
|
7
|
+
md_string (str): The markdown string containing references.
|
8
|
+
|
9
|
+
Returns:
|
10
|
+
list[int]: A sorted list of unique integers from the markdown references.
|
11
|
+
"""
|
12
|
+
import re
|
13
|
+
|
14
|
+
# Regex to find all occurrences of [^<number>]
|
15
|
+
matches = re.findall(r"\[\^(\d+)\]", md_string)
|
16
|
+
# Convert matches to integers, remove duplicates with set, and sort
|
17
|
+
return sorted(set(int(match) for match in matches))
|
18
|
+
|
19
|
+
|
20
|
+
def format_footnote_text(content: str, width: int = 80) -> str:
|
21
|
+
"""
|
22
|
+
Formats the content part of a footnote (i.e. not the first line that
|
23
|
+
appears right after the reference [^4])
|
24
|
+
It wraps the text so that no line is longer than the specified width and indents
|
25
|
+
lines as necessary for markdown footnotes.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
content (str): The text of the footnote to be formatted.
|
29
|
+
width (int): Maximum width of the text lines.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
str: Properly formatted markdown footnote text.
|
33
|
+
"""
|
34
|
+
import textwrap
|
35
|
+
|
36
|
+
# Wrap the text to the specified width
|
37
|
+
wrapped_lines = textwrap.wrap(content, width)
|
38
|
+
if len(wrapped_lines) == 0:
|
39
|
+
return ""
|
40
|
+
indent = " " # Indentation for markdown footnotes
|
41
|
+
return indent + ("\n" + indent).join(wrapped_lines)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import sys
|
3
3
|
from contextlib import contextmanager
|
4
|
-
from typing import Any, Iterator, Optional
|
4
|
+
from typing import Any, Iterator, Optional, Type
|
5
5
|
|
6
6
|
from rich import print as rprint
|
7
7
|
from rich.text import Text
|
@@ -89,6 +89,11 @@ class SuppressLoggerWarnings:
|
|
89
89
|
# Set the logging level to 'ERROR' to suppress warnings
|
90
90
|
self.logger.setLevel(logging.ERROR)
|
91
91
|
|
92
|
-
def __exit__(
|
92
|
+
def __exit__(
|
93
|
+
self,
|
94
|
+
exc_type: Optional[Type[BaseException]],
|
95
|
+
exc_value: Optional[BaseException],
|
96
|
+
traceback: Any,
|
97
|
+
) -> None:
|
93
98
|
# Reset the logging level to its original value
|
94
99
|
self.logger.setLevel(self.original_level)
|
@@ -1,25 +1,9 @@
|
|
1
1
|
from . import base
|
2
2
|
|
3
3
|
from . import qdrantdb
|
4
|
-
from . import meilisearch
|
5
|
-
from . import lancedb
|
6
4
|
|
7
5
|
from .base import VectorStoreConfig, VectorStore
|
8
6
|
from .qdrantdb import QdrantDBConfig, QdrantDB
|
9
|
-
from .meilisearch import MeiliSearch, MeiliSearchConfig
|
10
|
-
from .lancedb import LanceDB, LanceDBConfig
|
11
|
-
|
12
|
-
has_chromadb = False
|
13
|
-
try:
|
14
|
-
from . import chromadb
|
15
|
-
from .chromadb import ChromaDBConfig, ChromaDB
|
16
|
-
|
17
|
-
chromadb # silence linters
|
18
|
-
ChromaDB
|
19
|
-
ChromaDBConfig
|
20
|
-
has_chromadb = True
|
21
|
-
except ImportError:
|
22
|
-
pass
|
23
7
|
|
24
8
|
__all__ = [
|
25
9
|
"base",
|
@@ -36,5 +20,37 @@ __all__ = [
|
|
36
20
|
"LanceDBConfig",
|
37
21
|
]
|
38
22
|
|
39
|
-
|
23
|
+
|
24
|
+
try:
|
25
|
+
from . import meilisearch
|
26
|
+
from .meilisearch import MeiliSearch, MeiliSearchConfig
|
27
|
+
|
28
|
+
meilisearch
|
29
|
+
MeiliSearch
|
30
|
+
MeiliSearchConfig
|
31
|
+
__all__.extend(["meilisearch", "MeiliSearch", "MeiliSearchConfig"])
|
32
|
+
except ImportError:
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
try:
|
37
|
+
from . import lancedb
|
38
|
+
from .lancedb import LanceDB, LanceDBConfig
|
39
|
+
|
40
|
+
lancedb
|
41
|
+
LanceDB
|
42
|
+
LanceDBConfig
|
43
|
+
__all__.extend(["lancedb", "LanceDB", "LanceDBConfig"])
|
44
|
+
except ImportError:
|
45
|
+
pass
|
46
|
+
|
47
|
+
try:
|
48
|
+
from . import chromadb
|
49
|
+
from .chromadb import ChromaDBConfig, ChromaDB
|
50
|
+
|
51
|
+
chromadb # silence linters
|
52
|
+
ChromaDB
|
53
|
+
ChromaDBConfig
|
40
54
|
__all__.extend(["chromadb", "ChromaDBConfig", "ChromaDB"])
|
55
|
+
except ImportError:
|
56
|
+
pass
|
@@ -7,6 +7,7 @@ from langroid.embedding_models.base import (
|
|
7
7
|
EmbeddingModelsConfig,
|
8
8
|
)
|
9
9
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
10
|
+
from langroid.exceptions import LangroidImportError
|
10
11
|
from langroid.mytypes import DocMetaData, Document
|
11
12
|
from langroid.utils.configuration import settings
|
12
13
|
from langroid.utils.output.printing import print_long_text
|
@@ -29,14 +30,7 @@ class ChromaDB(VectorStore):
|
|
29
30
|
try:
|
30
31
|
import chromadb
|
31
32
|
except ImportError:
|
32
|
-
raise
|
33
|
-
"""
|
34
|
-
ChromaDB is not installed by default with Langroid.
|
35
|
-
If you want to use it, please install it with the `chromadb` extra, e.g.
|
36
|
-
pip install "langroid[chromadb]"
|
37
|
-
or an equivalent command.
|
38
|
-
"""
|
39
|
-
)
|
33
|
+
raise LangroidImportError("chromadb", "chromadb")
|
40
34
|
self.config = config
|
41
35
|
emb_model = EmbeddingModel.create(config.embedding)
|
42
36
|
self.embedding_fn = emb_model.embedding_fn()
|
langroid/vector_store/lancedb.py
CHANGED
@@ -1,18 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
|
-
from typing import
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Any,
|
7
|
+
Dict,
|
8
|
+
Generator,
|
9
|
+
List,
|
10
|
+
Optional,
|
11
|
+
Sequence,
|
12
|
+
Tuple,
|
13
|
+
Type,
|
14
|
+
)
|
3
15
|
|
4
|
-
import lancedb
|
5
16
|
import pandas as pd
|
6
17
|
from dotenv import load_dotenv
|
7
|
-
from lancedb.pydantic import LanceModel, Vector
|
8
|
-
from lancedb.query import LanceVectorQueryBuilder
|
9
18
|
from pydantic import BaseModel, ValidationError, create_model
|
10
19
|
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from lancedb.query import LanceVectorQueryBuilder
|
22
|
+
|
11
23
|
from langroid.embedding_models.base import (
|
12
24
|
EmbeddingModel,
|
13
25
|
EmbeddingModelsConfig,
|
14
26
|
)
|
15
27
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
28
|
+
from langroid.exceptions import LangroidImportError
|
16
29
|
from langroid.mytypes import Document, EmbeddingFunction
|
17
30
|
from langroid.utils.configuration import settings
|
18
31
|
from langroid.utils.pydantic_utils import (
|
@@ -26,6 +39,14 @@ from langroid.utils.pydantic_utils import (
|
|
26
39
|
)
|
27
40
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
28
41
|
|
42
|
+
try:
|
43
|
+
import lancedb
|
44
|
+
from lancedb.pydantic import LanceModel, Vector
|
45
|
+
|
46
|
+
has_lancedb = True
|
47
|
+
except ImportError:
|
48
|
+
has_lancedb = False
|
49
|
+
|
29
50
|
logger = logging.getLogger(__name__)
|
30
51
|
|
31
52
|
|
@@ -44,6 +65,9 @@ class LanceDBConfig(VectorStoreConfig):
|
|
44
65
|
class LanceDB(VectorStore):
|
45
66
|
def __init__(self, config: LanceDBConfig = LanceDBConfig()):
|
46
67
|
super().__init__(config)
|
68
|
+
if not has_lancedb:
|
69
|
+
raise LangroidImportError("lancedb", "lancedb")
|
70
|
+
|
47
71
|
self.config: LanceDBConfig = config
|
48
72
|
emb_model = EmbeddingModel.create(config.embedding)
|
49
73
|
self.embedding_fn: EmbeddingFunction = emb_model.embedding_fn()
|
@@ -170,6 +194,9 @@ class LanceDB(VectorStore):
|
|
170
194
|
if not issubclass(doc_cls, Document):
|
171
195
|
raise ValueError("DocClass must be a subclass of Document")
|
172
196
|
|
197
|
+
if not has_lancedb:
|
198
|
+
raise LangroidImportError("lancedb", "lancedb")
|
199
|
+
|
173
200
|
n = self.embedding_dim
|
174
201
|
|
175
202
|
# Prepare fields for the new model
|
@@ -193,6 +220,8 @@ class LanceDB(VectorStore):
|
|
193
220
|
Flat version of the lance_schema, as nested Pydantic schemas are not yet
|
194
221
|
supported by LanceDB.
|
195
222
|
"""
|
223
|
+
if not has_lancedb:
|
224
|
+
raise LangroidImportError("lancedb", "lancedb")
|
196
225
|
lance_model = self._create_lance_schema(doc_cls)
|
197
226
|
FlatModel = flatten_pydantic_model(lance_model, base_model=LanceModel)
|
198
227
|
return FlatModel
|
@@ -368,7 +397,9 @@ class LanceDB(VectorStore):
|
|
368
397
|
def delete_collection(self, collection_name: str) -> None:
|
369
398
|
self.client.drop_table(collection_name, ignore_missing=True)
|
370
399
|
|
371
|
-
def _lance_result_to_docs(
|
400
|
+
def _lance_result_to_docs(
|
401
|
+
self, result: "LanceVectorQueryBuilder"
|
402
|
+
) -> List[Document]:
|
372
403
|
if self.is_from_dataframe:
|
373
404
|
df = result.to_pandas()
|
374
405
|
return dataframe_to_documents(
|
@@ -7,16 +7,21 @@ Note that what we call "Collection" in Langroid is referred to as
|
|
7
7
|
but for uniformity we use the Langroid terminology here.
|
8
8
|
"""
|
9
9
|
|
10
|
+
from __future__ import annotations
|
11
|
+
|
10
12
|
import asyncio
|
11
13
|
import logging
|
12
14
|
import os
|
13
|
-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
15
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple
|
14
16
|
|
15
|
-
import meilisearch_python_sdk as meilisearch
|
16
17
|
from dotenv import load_dotenv
|
17
|
-
from meilisearch_python_sdk.index import AsyncIndex
|
18
|
-
from meilisearch_python_sdk.models.documents import DocumentsInfo
|
19
18
|
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from meilisearch_python_sdk.index import AsyncIndex
|
21
|
+
from meilisearch_python_sdk.models.documents import DocumentsInfo
|
22
|
+
|
23
|
+
|
24
|
+
from langroid.exceptions import LangroidImportError
|
20
25
|
from langroid.mytypes import DocMetaData, Document
|
21
26
|
from langroid.utils.configuration import settings
|
22
27
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
@@ -34,6 +39,11 @@ class MeiliSearchConfig(VectorStoreConfig):
|
|
34
39
|
class MeiliSearch(VectorStore):
|
35
40
|
def __init__(self, config: MeiliSearchConfig = MeiliSearchConfig()):
|
36
41
|
super().__init__(config)
|
42
|
+
try:
|
43
|
+
import meilisearch_python_sdk as meilisearch
|
44
|
+
except ImportError:
|
45
|
+
raise LangroidImportError("meilisearch", "meilisearch")
|
46
|
+
|
37
47
|
self.config: MeiliSearchConfig = config
|
38
48
|
self.host = config.host
|
39
49
|
self.port = config.port
|
@@ -98,12 +108,12 @@ class MeiliSearch(VectorStore):
|
|
98
108
|
async def _async_get_indexes(self) -> List[AsyncIndex]:
|
99
109
|
async with self.client() as client:
|
100
110
|
indexes = await client.get_indexes(limit=10_000)
|
101
|
-
return [] if indexes is None else indexes
|
111
|
+
return [] if indexes is None else indexes # type: ignore
|
102
112
|
|
103
|
-
async def _async_get_index(self, index_uid: str) -> AsyncIndex:
|
113
|
+
async def _async_get_index(self, index_uid: str) -> "AsyncIndex":
|
104
114
|
async with self.client() as client:
|
105
115
|
index = await client.get_index(index_uid)
|
106
|
-
return index
|
116
|
+
return index # type: ignore
|
107
117
|
|
108
118
|
def list_collections(self, empty: bool = False) -> List[str]:
|
109
119
|
"""
|
@@ -116,7 +126,7 @@ class MeiliSearch(VectorStore):
|
|
116
126
|
else:
|
117
127
|
return [ind.uid for ind in indexes]
|
118
128
|
|
119
|
-
async def _async_create_index(self, collection_name: str) -> AsyncIndex:
|
129
|
+
async def _async_create_index(self, collection_name: str) -> "AsyncIndex":
|
120
130
|
async with self.client() as client:
|
121
131
|
index = await client.create_index(
|
122
132
|
uid=collection_name,
|
@@ -128,7 +138,7 @@ class MeiliSearch(VectorStore):
|
|
128
138
|
"""Delete index if it exists. Returns True iff index was deleted"""
|
129
139
|
async with self.client() as client:
|
130
140
|
result = await client.delete_index_if_exists(uid=collection_name)
|
131
|
-
return result
|
141
|
+
return result # type: ignore
|
132
142
|
|
133
143
|
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
134
144
|
"""
|
@@ -198,7 +208,7 @@ class MeiliSearch(VectorStore):
|
|
198
208
|
except ValueError:
|
199
209
|
return id
|
200
210
|
|
201
|
-
async def _async_get_documents(self, where: str = "") -> DocumentsInfo:
|
211
|
+
async def _async_get_documents(self, where: str = "") -> "DocumentsInfo":
|
202
212
|
if self.config.collection_name is None:
|
203
213
|
raise ValueError("No collection name set, cannot retrieve docs")
|
204
214
|
filter = [] if where is None else where
|
@@ -258,7 +268,7 @@ class MeiliSearch(VectorStore):
|
|
258
268
|
show_ranking_score=True,
|
259
269
|
filter=filter,
|
260
270
|
)
|
261
|
-
return results.hits
|
271
|
+
return results.hits # type: ignore
|
262
272
|
|
263
273
|
def similar_texts_with_scores(
|
264
274
|
self,
|