langroid 0.1.251__py3-none-any.whl → 0.1.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +16 -15
- langroid/agent/__init__.py +1 -0
- langroid/agent/base.py +11 -1
- langroid/agent/callbacks/chainlit.py +5 -12
- langroid/agent/special/__init__.py +3 -2
- langroid/agent/special/doc_chat_agent.py +36 -56
- langroid/agent/special/neo4j/csv_kg_chat.py +2 -2
- langroid/agent/special/sql/__init__.py +1 -2
- langroid/agent/special/sql/sql_chat_agent.py +10 -4
- langroid/agent/special/sql/utils/__init__.py +4 -5
- langroid/agent/special/sql/utils/description_extractors.py +7 -2
- langroid/agent/special/sql/utils/populate_metadata.py +6 -1
- langroid/agent/special/table_chat_agent.py +2 -2
- langroid/agent/task.py +25 -8
- langroid/agent/tool_message.py +14 -3
- langroid/agent/tools/__init__.py +2 -3
- langroid/agent/tools/duckduckgo_search_tool.py +2 -2
- langroid/agent/tools/google_search_tool.py +2 -2
- langroid/agent/tools/metaphor_search_tool.py +2 -2
- langroid/agent/tools/retrieval_tool.py +2 -2
- langroid/agent/tools/run_python_code.py +2 -2
- langroid/agent/tools/segment_extract_tool.py +2 -2
- langroid/cachedb/base.py +10 -2
- langroid/cachedb/momento_cachedb.py +10 -4
- langroid/cachedb/redis_cachedb.py +2 -3
- langroid/embedding_models/__init__.py +1 -0
- langroid/exceptions.py +57 -0
- langroid/language_models/__init__.py +1 -0
- langroid/language_models/base.py +2 -3
- langroid/language_models/openai_gpt.py +15 -14
- langroid/language_models/prompt_formatter/__init__.py +4 -3
- langroid/parsing/document_parser.py +20 -4
- langroid/parsing/parser.pyi +56 -0
- langroid/utils/logging.py +7 -3
- langroid/utils/output/__init__.py +1 -2
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +7 -2
- langroid/vector_store/__init__.py +33 -17
- langroid/vector_store/chromadb.py +2 -8
- langroid/vector_store/lancedb.py +36 -5
- langroid/vector_store/meilisearch.py +21 -11
- langroid/vector_store/momento.py +31 -14
- {langroid-0.1.251.dist-info → langroid-0.1.253.dist-info}/METADATA +31 -29
- {langroid-0.1.251.dist-info → langroid-0.1.253.dist-info}/RECORD +46 -44
- {langroid-0.1.251.dist-info → langroid-0.1.253.dist-info}/LICENSE +0 -0
- {langroid-0.1.251.dist-info → langroid-0.1.253.dist-info}/WHEEL +0 -0
@@ -4,17 +4,23 @@ import os
|
|
4
4
|
from datetime import timedelta
|
5
5
|
from typing import Any, Dict, List
|
6
6
|
|
7
|
-
import
|
7
|
+
from langroid.cachedb.base import CacheDBConfig
|
8
|
+
from langroid.exceptions import LangroidImportError
|
9
|
+
|
10
|
+
try:
|
11
|
+
import momento
|
12
|
+
from momento.responses import CacheGet
|
13
|
+
except ImportError:
|
14
|
+
raise LangroidImportError(package="momento", extra="momento")
|
15
|
+
|
8
16
|
from dotenv import load_dotenv
|
9
|
-
from momento.responses import CacheGet
|
10
|
-
from pydantic import BaseModel
|
11
17
|
|
12
18
|
from langroid.cachedb.base import CacheDB
|
13
19
|
|
14
20
|
logger = logging.getLogger(__name__)
|
15
21
|
|
16
22
|
|
17
|
-
class MomentoCacheConfig(
|
23
|
+
class MomentoCacheConfig(CacheDBConfig):
|
18
24
|
"""Configuration model for Momento Cache."""
|
19
25
|
|
20
26
|
ttl: int = 60 * 60 * 24 * 7 # 1 week
|
@@ -7,15 +7,14 @@ from typing import Any, Dict, List, TypeVar
|
|
7
7
|
import fakeredis
|
8
8
|
import redis
|
9
9
|
from dotenv import load_dotenv
|
10
|
-
from pydantic import BaseModel
|
11
10
|
|
12
|
-
from langroid.cachedb.base import CacheDB
|
11
|
+
from langroid.cachedb.base import CacheDB, CacheDBConfig
|
13
12
|
|
14
13
|
T = TypeVar("T", bound="RedisCache")
|
15
14
|
logger = logging.getLogger(__name__)
|
16
15
|
|
17
16
|
|
18
|
-
class RedisCacheConfig(
|
17
|
+
class RedisCacheConfig(CacheDBConfig):
|
19
18
|
"""Configuration model for RedisCache."""
|
20
19
|
|
21
20
|
fake: bool = False
|
langroid/exceptions.py
CHANGED
@@ -1,3 +1,60 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
|
1
4
|
class InfiniteLoopException(Exception):
|
2
5
|
def __init__(self, message: str = "Infinite loop detected", *args: object) -> None:
|
3
6
|
super().__init__(message, *args)
|
7
|
+
|
8
|
+
|
9
|
+
class LangroidImportError(ImportError):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
package: Optional[str] = None,
|
13
|
+
extra: Optional[str] = None,
|
14
|
+
error: str = "",
|
15
|
+
*args: object,
|
16
|
+
) -> None:
|
17
|
+
"""
|
18
|
+
Generate helpful warning when attempting to import package or module.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
package (str): The name of the package to import.
|
22
|
+
extra (str): The name of the extras package required for this import.
|
23
|
+
error (str): The error message to display. Depending on context, we
|
24
|
+
can set this by capturing the ImportError message.
|
25
|
+
|
26
|
+
"""
|
27
|
+
if error == "" and package is not None:
|
28
|
+
error = f"{package} is not installed by default with Langroid.\n"
|
29
|
+
|
30
|
+
if extra:
|
31
|
+
install_help = f"""
|
32
|
+
If you want to use it, please install langroid
|
33
|
+
with the `{extra}` extra, for example:
|
34
|
+
|
35
|
+
If you are using pip:
|
36
|
+
pip install "langroid[{extra}]"
|
37
|
+
|
38
|
+
For multiple extras, you can separate them with commas:
|
39
|
+
pip install "langroid[{extra},another-extra]"
|
40
|
+
|
41
|
+
If you are using Poetry:
|
42
|
+
poetry add langroid --extras "{extra}"
|
43
|
+
|
44
|
+
For multiple extras with Poetry, list them with spaces:
|
45
|
+
poetry add langroid --extras "{extra} another-extra"
|
46
|
+
|
47
|
+
If you are working within the langroid dev env (which uses Poetry),
|
48
|
+
you can do:
|
49
|
+
poetry install -E "{extra}"
|
50
|
+
or if you want to include multiple extras:
|
51
|
+
poetry install -E "{extra} another-extra"
|
52
|
+
"""
|
53
|
+
else:
|
54
|
+
install_help = """
|
55
|
+
If you want to use it, please install it in the same
|
56
|
+
virtual environment as langroid.
|
57
|
+
"""
|
58
|
+
msg = error + install_help
|
59
|
+
|
60
|
+
super().__init__(msg, *args)
|
langroid/language_models/base.py
CHANGED
@@ -10,8 +10,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
10
10
|
import aiohttp
|
11
11
|
from pydantic import BaseModel, BaseSettings, Field
|
12
12
|
|
13
|
-
from langroid.cachedb.
|
14
|
-
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
13
|
+
from langroid.cachedb.base import CacheDBConfig
|
15
14
|
from langroid.mytypes import Document
|
16
15
|
from langroid.parsing.agent_chats import parse_message
|
17
16
|
from langroid.parsing.parse_json import top_level_json_field
|
@@ -49,7 +48,7 @@ class LLMConfig(BaseSettings):
|
|
49
48
|
# use chat model for completion? For OpenAI models, this MUST be set to True!
|
50
49
|
use_chat_for_completion: bool = True
|
51
50
|
stream: bool = True # stream output from API?
|
52
|
-
cache_config: None |
|
51
|
+
cache_config: None | CacheDBConfig = None
|
53
52
|
|
54
53
|
# Dict of model -> (input/prompt cost, output/completion cost)
|
55
54
|
chat_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
@@ -28,8 +28,9 @@ from pydantic import BaseModel
|
|
28
28
|
from rich import print
|
29
29
|
from rich.markup import escape
|
30
30
|
|
31
|
-
from langroid.cachedb.
|
31
|
+
from langroid.cachedb.base import CacheDB
|
32
32
|
from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
|
33
|
+
from langroid.exceptions import LangroidImportError
|
33
34
|
from langroid.language_models.base import (
|
34
35
|
LanguageModel,
|
35
36
|
LLMConfig,
|
@@ -280,14 +281,7 @@ class OpenAIGPTConfig(LLMConfig):
|
|
280
281
|
try:
|
281
282
|
import litellm
|
282
283
|
except ImportError:
|
283
|
-
raise
|
284
|
-
"""
|
285
|
-
litellm not installed. Please install it via:
|
286
|
-
pip install litellm.
|
287
|
-
Or when installing langroid, install it with the `litellm` extra:
|
288
|
-
pip install langroid[litellm]
|
289
|
-
"""
|
290
|
-
)
|
284
|
+
raise LangroidImportError("litellm", "litellm")
|
291
285
|
litellm.telemetry = False
|
292
286
|
litellm.drop_params = True # drop un-supported params without crashing
|
293
287
|
self.seed = None # some local mdls don't support seed
|
@@ -482,17 +476,24 @@ class OpenAIGPT(LanguageModel):
|
|
482
476
|
timeout=Timeout(self.config.timeout),
|
483
477
|
)
|
484
478
|
|
485
|
-
self.cache:
|
479
|
+
self.cache: CacheDB
|
486
480
|
if settings.cache_type == "momento":
|
487
|
-
|
488
|
-
|
481
|
+
from langroid.cachedb.momento_cachedb import (
|
482
|
+
MomentoCache,
|
483
|
+
MomentoCacheConfig,
|
484
|
+
)
|
485
|
+
|
486
|
+
if config.cache_config is None or not isinstance(
|
487
|
+
config.cache_config,
|
488
|
+
MomentoCacheConfig,
|
489
489
|
):
|
490
490
|
# switch to fresh momento config if needed
|
491
491
|
config.cache_config = MomentoCacheConfig()
|
492
492
|
self.cache = MomentoCache(config.cache_config)
|
493
493
|
elif "redis" in settings.cache_type:
|
494
|
-
if config.cache_config is None or isinstance(
|
495
|
-
config.cache_config,
|
494
|
+
if config.cache_config is None or not isinstance(
|
495
|
+
config.cache_config,
|
496
|
+
RedisCacheConfig,
|
496
497
|
):
|
497
498
|
# switch to fresh redis config if needed
|
498
499
|
config.cache_config = RedisCacheConfig(
|
@@ -1,9 +1,10 @@
|
|
1
|
+
from . import base
|
2
|
+
from . import llama2_formatter
|
1
3
|
from .base import PromptFormatter
|
2
4
|
from .llama2_formatter import Llama2Formatter
|
3
|
-
from ..config import PromptFormatterConfig
|
5
|
+
from ..config import PromptFormatterConfig
|
6
|
+
from ..config import Llama2FormatterConfig
|
4
7
|
|
5
|
-
from . import base
|
6
|
-
from . import llama2_formatter
|
7
8
|
|
8
9
|
__all__ = [
|
9
10
|
"PromptFormatter",
|
@@ -5,9 +5,19 @@ from enum import Enum
|
|
5
5
|
from io import BytesIO
|
6
6
|
from typing import Any, Generator, List, Tuple
|
7
7
|
|
8
|
-
import
|
8
|
+
from langroid.exceptions import LangroidImportError
|
9
|
+
|
10
|
+
try:
|
11
|
+
import fitz
|
12
|
+
except ImportError:
|
13
|
+
raise LangroidImportError("PyMuPDF", "pdf-parsers")
|
14
|
+
|
15
|
+
try:
|
16
|
+
import pypdf
|
17
|
+
except ImportError:
|
18
|
+
raise LangroidImportError("pypdf", "pdf-parsers")
|
19
|
+
|
9
20
|
import pdfplumber
|
10
|
-
import pypdf
|
11
21
|
import requests
|
12
22
|
from bs4 import BeautifulSoup
|
13
23
|
from PIL import Image
|
@@ -456,7 +466,10 @@ class ImagePdfParser(DocumentParser):
|
|
456
466
|
def iterate_pages(
|
457
467
|
self,
|
458
468
|
) -> Generator[Tuple[int, "Image"], None, None]: # type: ignore
|
459
|
-
|
469
|
+
try:
|
470
|
+
from pdf2image import convert_from_bytes
|
471
|
+
except ImportError:
|
472
|
+
raise LangroidImportError("pdf2image", "pdf-parsers")
|
460
473
|
|
461
474
|
images = convert_from_bytes(self.doc_bytes.getvalue())
|
462
475
|
for i, image in enumerate(images):
|
@@ -472,7 +485,10 @@ class ImagePdfParser(DocumentParser):
|
|
472
485
|
Returns:
|
473
486
|
str: Extracted text from the image.
|
474
487
|
"""
|
475
|
-
|
488
|
+
try:
|
489
|
+
import pytesseract
|
490
|
+
except ImportError:
|
491
|
+
raise LangroidImportError("pytesseract", "pdf-parsers")
|
476
492
|
|
477
493
|
text = pytesseract.image_to_string(page)
|
478
494
|
return self.fix_text(text)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
from _typeshed import Incomplete
|
5
|
+
from pydantic import BaseSettings
|
6
|
+
|
7
|
+
from langroid.mytypes import Document as Document
|
8
|
+
from langroid.parsing.para_sentence_split import (
|
9
|
+
create_chunks as create_chunks,
|
10
|
+
)
|
11
|
+
from langroid.parsing.para_sentence_split import (
|
12
|
+
remove_extra_whitespace as remove_extra_whitespace,
|
13
|
+
)
|
14
|
+
|
15
|
+
logger: Incomplete
|
16
|
+
|
17
|
+
class Splitter(str, Enum):
|
18
|
+
TOKENS: str
|
19
|
+
PARA_SENTENCE: str
|
20
|
+
SIMPLE: str
|
21
|
+
|
22
|
+
class PdfParsingConfig(BaseSettings):
|
23
|
+
library: Literal["fitz", "pdfplumber", "pypdf", "unstructured", "pdf2image"]
|
24
|
+
|
25
|
+
class DocxParsingConfig(BaseSettings):
|
26
|
+
library: Literal["python-docx", "unstructured"]
|
27
|
+
|
28
|
+
class DocParsingConfig(BaseSettings):
|
29
|
+
library: Literal["unstructured"]
|
30
|
+
|
31
|
+
class ParsingConfig(BaseSettings):
|
32
|
+
splitter: str
|
33
|
+
chunk_size: int
|
34
|
+
overlap: int
|
35
|
+
max_chunks: int
|
36
|
+
min_chunk_chars: int
|
37
|
+
discard_chunk_chars: int
|
38
|
+
n_similar_docs: int
|
39
|
+
n_neighbor_ids: int
|
40
|
+
separators: list[str]
|
41
|
+
token_encoding_model: str
|
42
|
+
pdf: PdfParsingConfig
|
43
|
+
docx: DocxParsingConfig
|
44
|
+
doc: DocParsingConfig
|
45
|
+
|
46
|
+
class Parser:
|
47
|
+
config: Incomplete
|
48
|
+
tokenizer: Incomplete
|
49
|
+
def __init__(self, config: ParsingConfig) -> None: ...
|
50
|
+
def num_tokens(self, text: str) -> int: ...
|
51
|
+
def add_window_ids(self, chunks: list[Document]) -> None: ...
|
52
|
+
def split_simple(self, docs: list[Document]) -> list[Document]: ...
|
53
|
+
def split_para_sentence(self, docs: list[Document]) -> list[Document]: ...
|
54
|
+
def split_chunk_tokens(self, docs: list[Document]) -> list[Document]: ...
|
55
|
+
def chunk_tokens(self, text: str) -> list[str]: ...
|
56
|
+
def split(self, docs: list[Document]) -> list[Document]: ...
|
langroid/utils/logging.py
CHANGED
@@ -31,7 +31,11 @@ def setup_colored_logging() -> None:
|
|
31
31
|
# logger.setLevel(logging.DEBUG)
|
32
32
|
|
33
33
|
|
34
|
-
def setup_logger(
|
34
|
+
def setup_logger(
|
35
|
+
name: str,
|
36
|
+
level: int = logging.INFO,
|
37
|
+
terminal: bool = False,
|
38
|
+
) -> logging.Logger:
|
35
39
|
"""
|
36
40
|
Set up a logger of module `name` at a desired level.
|
37
41
|
Args:
|
@@ -42,7 +46,7 @@ def setup_logger(name: str, level: int = logging.INFO) -> logging.Logger:
|
|
42
46
|
"""
|
43
47
|
logger = logging.getLogger(name)
|
44
48
|
logger.setLevel(level)
|
45
|
-
if not logger.hasHandlers():
|
49
|
+
if not logger.hasHandlers() and terminal:
|
46
50
|
handler = logging.StreamHandler()
|
47
51
|
formatter = logging.Formatter(
|
48
52
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
@@ -73,7 +77,7 @@ def setup_file_logger(
|
|
73
77
|
) -> logging.Logger:
|
74
78
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
75
79
|
file_mode = "a" if append else "w"
|
76
|
-
logger = setup_logger(name)
|
80
|
+
logger = setup_logger(name, terminal=False)
|
77
81
|
handler = logging.FileHandler(filename, mode=file_mode)
|
78
82
|
handler.setLevel(logging.INFO)
|
79
83
|
if log_format:
|
@@ -1,5 +1,4 @@
|
|
1
1
|
from . import printing
|
2
|
-
|
3
2
|
from .printing import (
|
4
3
|
shorten_text,
|
5
4
|
print_long_text,
|
@@ -7,9 +6,9 @@ from .printing import (
|
|
7
6
|
SuppressLoggerWarnings,
|
8
7
|
PrintColored,
|
9
8
|
)
|
10
|
-
|
11
9
|
from .status import status
|
12
10
|
|
11
|
+
|
13
12
|
__all__ = [
|
14
13
|
"printing",
|
15
14
|
"shorten_text",
|
@@ -0,0 +1,41 @@
|
|
1
|
+
def extract_markdown_references(md_string: str) -> list[int]:
|
2
|
+
"""
|
3
|
+
Extracts markdown references (e.g., [^1], [^2]) from a string and returns
|
4
|
+
them as a sorted list of integers.
|
5
|
+
|
6
|
+
Args:
|
7
|
+
md_string (str): The markdown string containing references.
|
8
|
+
|
9
|
+
Returns:
|
10
|
+
list[int]: A sorted list of unique integers from the markdown references.
|
11
|
+
"""
|
12
|
+
import re
|
13
|
+
|
14
|
+
# Regex to find all occurrences of [^<number>]
|
15
|
+
matches = re.findall(r"\[\^(\d+)\]", md_string)
|
16
|
+
# Convert matches to integers, remove duplicates with set, and sort
|
17
|
+
return sorted(set(int(match) for match in matches))
|
18
|
+
|
19
|
+
|
20
|
+
def format_footnote_text(content: str, width: int = 80) -> str:
|
21
|
+
"""
|
22
|
+
Formats the content part of a footnote (i.e. not the first line that
|
23
|
+
appears right after the reference [^4])
|
24
|
+
It wraps the text so that no line is longer than the specified width and indents
|
25
|
+
lines as necessary for markdown footnotes.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
content (str): The text of the footnote to be formatted.
|
29
|
+
width (int): Maximum width of the text lines.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
str: Properly formatted markdown footnote text.
|
33
|
+
"""
|
34
|
+
import textwrap
|
35
|
+
|
36
|
+
# Wrap the text to the specified width
|
37
|
+
wrapped_lines = textwrap.wrap(content, width)
|
38
|
+
if len(wrapped_lines) == 0:
|
39
|
+
return ""
|
40
|
+
indent = " " # Indentation for markdown footnotes
|
41
|
+
return indent + ("\n" + indent).join(wrapped_lines)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import sys
|
3
3
|
from contextlib import contextmanager
|
4
|
-
from typing import Any, Iterator, Optional
|
4
|
+
from typing import Any, Iterator, Optional, Type
|
5
5
|
|
6
6
|
from rich import print as rprint
|
7
7
|
from rich.text import Text
|
@@ -89,6 +89,11 @@ class SuppressLoggerWarnings:
|
|
89
89
|
# Set the logging level to 'ERROR' to suppress warnings
|
90
90
|
self.logger.setLevel(logging.ERROR)
|
91
91
|
|
92
|
-
def __exit__(
|
92
|
+
def __exit__(
|
93
|
+
self,
|
94
|
+
exc_type: Optional[Type[BaseException]],
|
95
|
+
exc_value: Optional[BaseException],
|
96
|
+
traceback: Any,
|
97
|
+
) -> None:
|
93
98
|
# Reset the logging level to its original value
|
94
99
|
self.logger.setLevel(self.original_level)
|
@@ -1,25 +1,9 @@
|
|
1
1
|
from . import base
|
2
2
|
|
3
3
|
from . import qdrantdb
|
4
|
-
from . import meilisearch
|
5
|
-
from . import lancedb
|
6
4
|
|
7
5
|
from .base import VectorStoreConfig, VectorStore
|
8
6
|
from .qdrantdb import QdrantDBConfig, QdrantDB
|
9
|
-
from .meilisearch import MeiliSearch, MeiliSearchConfig
|
10
|
-
from .lancedb import LanceDB, LanceDBConfig
|
11
|
-
|
12
|
-
has_chromadb = False
|
13
|
-
try:
|
14
|
-
from . import chromadb
|
15
|
-
from .chromadb import ChromaDBConfig, ChromaDB
|
16
|
-
|
17
|
-
chromadb # silence linters
|
18
|
-
ChromaDB
|
19
|
-
ChromaDBConfig
|
20
|
-
has_chromadb = True
|
21
|
-
except ImportError:
|
22
|
-
pass
|
23
7
|
|
24
8
|
__all__ = [
|
25
9
|
"base",
|
@@ -36,5 +20,37 @@ __all__ = [
|
|
36
20
|
"LanceDBConfig",
|
37
21
|
]
|
38
22
|
|
39
|
-
|
23
|
+
|
24
|
+
try:
|
25
|
+
from . import meilisearch
|
26
|
+
from .meilisearch import MeiliSearch, MeiliSearchConfig
|
27
|
+
|
28
|
+
meilisearch
|
29
|
+
MeiliSearch
|
30
|
+
MeiliSearchConfig
|
31
|
+
__all__.extend(["meilisearch", "MeiliSearch", "MeiliSearchConfig"])
|
32
|
+
except ImportError:
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
try:
|
37
|
+
from . import lancedb
|
38
|
+
from .lancedb import LanceDB, LanceDBConfig
|
39
|
+
|
40
|
+
lancedb
|
41
|
+
LanceDB
|
42
|
+
LanceDBConfig
|
43
|
+
__all__.extend(["lancedb", "LanceDB", "LanceDBConfig"])
|
44
|
+
except ImportError:
|
45
|
+
pass
|
46
|
+
|
47
|
+
try:
|
48
|
+
from . import chromadb
|
49
|
+
from .chromadb import ChromaDBConfig, ChromaDB
|
50
|
+
|
51
|
+
chromadb # silence linters
|
52
|
+
ChromaDB
|
53
|
+
ChromaDBConfig
|
40
54
|
__all__.extend(["chromadb", "ChromaDBConfig", "ChromaDB"])
|
55
|
+
except ImportError:
|
56
|
+
pass
|
@@ -7,6 +7,7 @@ from langroid.embedding_models.base import (
|
|
7
7
|
EmbeddingModelsConfig,
|
8
8
|
)
|
9
9
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
10
|
+
from langroid.exceptions import LangroidImportError
|
10
11
|
from langroid.mytypes import DocMetaData, Document
|
11
12
|
from langroid.utils.configuration import settings
|
12
13
|
from langroid.utils.output.printing import print_long_text
|
@@ -29,14 +30,7 @@ class ChromaDB(VectorStore):
|
|
29
30
|
try:
|
30
31
|
import chromadb
|
31
32
|
except ImportError:
|
32
|
-
raise
|
33
|
-
"""
|
34
|
-
ChromaDB is not installed by default with Langroid.
|
35
|
-
If you want to use it, please install it with the `chromadb` extra, e.g.
|
36
|
-
pip install "langroid[chromadb]"
|
37
|
-
or an equivalent command.
|
38
|
-
"""
|
39
|
-
)
|
33
|
+
raise LangroidImportError("chromadb", "chromadb")
|
40
34
|
self.config = config
|
41
35
|
emb_model = EmbeddingModel.create(config.embedding)
|
42
36
|
self.embedding_fn = emb_model.embedding_fn()
|
langroid/vector_store/lancedb.py
CHANGED
@@ -1,18 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
|
-
from typing import
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Any,
|
7
|
+
Dict,
|
8
|
+
Generator,
|
9
|
+
List,
|
10
|
+
Optional,
|
11
|
+
Sequence,
|
12
|
+
Tuple,
|
13
|
+
Type,
|
14
|
+
)
|
3
15
|
|
4
|
-
import lancedb
|
5
16
|
import pandas as pd
|
6
17
|
from dotenv import load_dotenv
|
7
|
-
from lancedb.pydantic import LanceModel, Vector
|
8
|
-
from lancedb.query import LanceVectorQueryBuilder
|
9
18
|
from pydantic import BaseModel, ValidationError, create_model
|
10
19
|
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from lancedb.query import LanceVectorQueryBuilder
|
22
|
+
|
11
23
|
from langroid.embedding_models.base import (
|
12
24
|
EmbeddingModel,
|
13
25
|
EmbeddingModelsConfig,
|
14
26
|
)
|
15
27
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
28
|
+
from langroid.exceptions import LangroidImportError
|
16
29
|
from langroid.mytypes import Document, EmbeddingFunction
|
17
30
|
from langroid.utils.configuration import settings
|
18
31
|
from langroid.utils.pydantic_utils import (
|
@@ -26,6 +39,14 @@ from langroid.utils.pydantic_utils import (
|
|
26
39
|
)
|
27
40
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
28
41
|
|
42
|
+
try:
|
43
|
+
import lancedb
|
44
|
+
from lancedb.pydantic import LanceModel, Vector
|
45
|
+
|
46
|
+
has_lancedb = True
|
47
|
+
except ImportError:
|
48
|
+
has_lancedb = False
|
49
|
+
|
29
50
|
logger = logging.getLogger(__name__)
|
30
51
|
|
31
52
|
|
@@ -44,6 +65,9 @@ class LanceDBConfig(VectorStoreConfig):
|
|
44
65
|
class LanceDB(VectorStore):
|
45
66
|
def __init__(self, config: LanceDBConfig = LanceDBConfig()):
|
46
67
|
super().__init__(config)
|
68
|
+
if not has_lancedb:
|
69
|
+
raise LangroidImportError("lancedb", "lancedb")
|
70
|
+
|
47
71
|
self.config: LanceDBConfig = config
|
48
72
|
emb_model = EmbeddingModel.create(config.embedding)
|
49
73
|
self.embedding_fn: EmbeddingFunction = emb_model.embedding_fn()
|
@@ -170,6 +194,9 @@ class LanceDB(VectorStore):
|
|
170
194
|
if not issubclass(doc_cls, Document):
|
171
195
|
raise ValueError("DocClass must be a subclass of Document")
|
172
196
|
|
197
|
+
if not has_lancedb:
|
198
|
+
raise LangroidImportError("lancedb", "lancedb")
|
199
|
+
|
173
200
|
n = self.embedding_dim
|
174
201
|
|
175
202
|
# Prepare fields for the new model
|
@@ -193,6 +220,8 @@ class LanceDB(VectorStore):
|
|
193
220
|
Flat version of the lance_schema, as nested Pydantic schemas are not yet
|
194
221
|
supported by LanceDB.
|
195
222
|
"""
|
223
|
+
if not has_lancedb:
|
224
|
+
raise LangroidImportError("lancedb", "lancedb")
|
196
225
|
lance_model = self._create_lance_schema(doc_cls)
|
197
226
|
FlatModel = flatten_pydantic_model(lance_model, base_model=LanceModel)
|
198
227
|
return FlatModel
|
@@ -368,7 +397,9 @@ class LanceDB(VectorStore):
|
|
368
397
|
def delete_collection(self, collection_name: str) -> None:
|
369
398
|
self.client.drop_table(collection_name, ignore_missing=True)
|
370
399
|
|
371
|
-
def _lance_result_to_docs(
|
400
|
+
def _lance_result_to_docs(
|
401
|
+
self, result: "LanceVectorQueryBuilder"
|
402
|
+
) -> List[Document]:
|
372
403
|
if self.is_from_dataframe:
|
373
404
|
df = result.to_pandas()
|
374
405
|
return dataframe_to_documents(
|