langroid 0.1.252__py3-none-any.whl → 0.1.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. langroid/__init__.py +16 -15
  2. langroid/agent/__init__.py +1 -0
  3. langroid/agent/callbacks/chainlit.py +5 -12
  4. langroid/agent/special/__init__.py +3 -2
  5. langroid/agent/special/doc_chat_agent.py +36 -56
  6. langroid/agent/special/neo4j/csv_kg_chat.py +2 -2
  7. langroid/agent/special/sql/__init__.py +1 -2
  8. langroid/agent/special/sql/sql_chat_agent.py +10 -4
  9. langroid/agent/special/sql/utils/__init__.py +4 -5
  10. langroid/agent/special/sql/utils/description_extractors.py +7 -2
  11. langroid/agent/special/sql/utils/populate_metadata.py +6 -1
  12. langroid/agent/special/table_chat_agent.py +2 -2
  13. langroid/agent/tool_message.py +14 -3
  14. langroid/agent/tools/__init__.py +2 -3
  15. langroid/agent/tools/duckduckgo_search_tool.py +2 -2
  16. langroid/agent/tools/google_search_tool.py +2 -2
  17. langroid/agent/tools/metaphor_search_tool.py +2 -2
  18. langroid/agent/tools/retrieval_tool.py +2 -2
  19. langroid/agent/tools/run_python_code.py +2 -2
  20. langroid/agent/tools/segment_extract_tool.py +2 -2
  21. langroid/cachedb/base.py +10 -2
  22. langroid/cachedb/momento_cachedb.py +10 -4
  23. langroid/cachedb/redis_cachedb.py +2 -3
  24. langroid/embedding_models/__init__.py +1 -0
  25. langroid/exceptions.py +57 -0
  26. langroid/language_models/__init__.py +1 -0
  27. langroid/language_models/base.py +2 -3
  28. langroid/language_models/openai_gpt.py +15 -14
  29. langroid/language_models/prompt_formatter/__init__.py +4 -3
  30. langroid/parsing/document_parser.py +20 -4
  31. langroid/parsing/parser.pyi +56 -0
  32. langroid/utils/logging.py +7 -3
  33. langroid/utils/output/__init__.py +1 -2
  34. langroid/utils/output/citations.py +41 -0
  35. langroid/utils/output/printing.py +7 -2
  36. langroid/vector_store/__init__.py +33 -17
  37. langroid/vector_store/chromadb.py +2 -8
  38. langroid/vector_store/lancedb.py +36 -5
  39. langroid/vector_store/meilisearch.py +21 -11
  40. langroid/vector_store/momento.py +31 -14
  41. {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/METADATA +31 -29
  42. {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/RECORD +44 -42
  43. {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/LICENSE +0 -0
  44. {langroid-0.1.252.dist-info → langroid-0.1.253.dist-info}/WHEEL +0 -0
langroid/__init__.py CHANGED
@@ -43,20 +43,6 @@ from .agent.chat_agent import (
43
43
 
44
44
  from .agent.task import Task, TaskConfig
45
45
 
46
- try:
47
- from .agent.callbacks.chainlit import (
48
- ChainlitAgentCallbacks,
49
- ChainlitTaskCallbacks,
50
- ChainlitCallbackConfig,
51
- )
52
-
53
- chainlit_available = True
54
- ChainlitAgentCallbacks
55
- ChainlitTaskCallbacks
56
- ChainlitCallbackConfig
57
- except ImportError:
58
- chainlit_available = False
59
-
60
46
 
61
47
  from .mytypes import (
62
48
  DocMetaData,
@@ -65,6 +51,7 @@ from .mytypes import (
65
51
  )
66
52
 
67
53
  from .exceptions import InfiniteLoopException
54
+ from .exceptions import LangroidImportError
68
55
 
69
56
  __all__ = [
70
57
  "mytypes",
@@ -94,8 +81,20 @@ __all__ = [
94
81
  "llm_response_batch",
95
82
  "agent_response_batch",
96
83
  "InfiniteLoopException",
84
+ "LangroidImportError",
97
85
  ]
98
- if chainlit_available:
86
+
87
+
88
+ try:
89
+ from .agent.callbacks.chainlit import (
90
+ ChainlitAgentCallbacks,
91
+ ChainlitTaskCallbacks,
92
+ ChainlitCallbackConfig,
93
+ )
94
+
95
+ ChainlitAgentCallbacks
96
+ ChainlitTaskCallbacks
97
+ ChainlitCallbackConfig
99
98
  __all__.extend(
100
99
  [
101
100
  "ChainlitAgentCallbacks",
@@ -103,3 +102,5 @@ if chainlit_available:
103
102
  "ChainlitCallbackConfig",
104
103
  ]
105
104
  )
105
+ except ImportError:
106
+ pass
@@ -18,6 +18,7 @@ from . import tool_message
18
18
  from . import tools
19
19
  from . import special
20
20
 
21
+
21
22
  __all__ = [
22
23
  "Agent",
23
24
  "AgentConfig",
@@ -9,19 +9,12 @@ from typing import Any, Callable, Dict, List, Literal, Optional, no_type_check
9
9
 
10
10
  from pydantic import BaseSettings
11
11
 
12
+ from langroid.exceptions import LangroidImportError
13
+
12
14
  try:
13
15
  import chainlit as cl
14
16
  except ImportError:
15
- raise ImportError(
16
- """
17
- You are attempting to use `chainlit`, which is not installed
18
- by default with `langroid`.
19
- Please install langroid with the `chainlit` extra using:
20
- `pip install langroid[chainlit]` or
21
- `poetry install -E chainlit`
22
- depending on your scenario
23
- """
24
- )
17
+ raise LangroidImportError("chainlit", "chainlit")
25
18
 
26
19
  from chainlit import run_sync
27
20
  from chainlit.config import config
@@ -83,9 +76,9 @@ async def setup_llm() -> None:
83
76
 
84
77
 
85
78
  @no_type_check
86
- async def update_llm(settings: Dict[str, Any]) -> None:
79
+ async def update_llm(new_settings: Dict[str, Any]) -> None:
87
80
  """Update LLMConfig and LLM from settings, and save in session state."""
88
- cl.user_session.set("llm_settings", settings)
81
+ cl.user_session.set("llm_settings", new_settings)
89
82
  await inform_llm_settings()
90
83
  await setup_llm()
91
84
 
@@ -25,6 +25,7 @@ from . import lance_doc_chat_agent
25
25
  from . import lance_rag
26
26
  from . import table_chat_agent
27
27
 
28
+
28
29
  __all__ = [
29
30
  "RelevanceExtractorAgent",
30
31
  "RelevanceExtractorAgentConfig",
@@ -34,7 +35,6 @@ __all__ = [
34
35
  "RecordDoc",
35
36
  "RetrieverAgentConfig",
36
37
  "RetrieverAgent",
37
- "LanceDocChatAgent",
38
38
  "dataframe_summary",
39
39
  "TableChatAgent",
40
40
  "TableChatAgentConfig",
@@ -43,8 +43,9 @@ __all__ = [
43
43
  "relevance_extractor_agent",
44
44
  "doc_chat_agent",
45
45
  "retriever_agent",
46
+ "table_chat_agent",
47
+ "LanceDocChatAgent",
46
48
  "lance_tools",
47
49
  "lance_doc_chat_agent",
48
50
  "lance_rag",
49
- "table_chat_agent",
50
51
  ]
@@ -31,7 +31,10 @@ from langroid.agent.special.relevance_extractor_agent import (
31
31
  )
32
32
  from langroid.agent.task import Task
33
33
  from langroid.agent.tools.retrieval_tool import RetrievalTool
34
- from langroid.embedding_models.models import OpenAIEmbeddingsConfig
34
+ from langroid.embedding_models.models import (
35
+ OpenAIEmbeddingsConfig,
36
+ SentenceTransformerEmbeddingsConfig,
37
+ )
35
38
  from langroid.language_models.base import StreamingIfAllowed
36
39
  from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
37
40
  from langroid.mytypes import DocMetaData, Document, Entity
@@ -51,9 +54,13 @@ from langroid.prompts.prompts_config import PromptsConfig
51
54
  from langroid.prompts.templates import SUMMARY_ANSWER_PROMPT_GPT4
52
55
  from langroid.utils.constants import NO_ANSWER
53
56
  from langroid.utils.output import show_if_debug, status
57
+ from langroid.utils.output.citations import (
58
+ extract_markdown_references,
59
+ format_footnote_text,
60
+ )
54
61
  from langroid.utils.pydantic_utils import dataframe_to_documents, extract_fields
55
62
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
56
- from langroid.vector_store.lancedb import LanceDBConfig
63
+ from langroid.vector_store.qdrantdb import QdrantDBConfig
57
64
 
58
65
 
59
66
  @cache
@@ -82,47 +89,36 @@ except ImportError:
82
89
  pass
83
90
 
84
91
 
85
- def extract_markdown_references(md_string: str) -> list[int]:
86
- """
87
- Extracts markdown references (e.g., [^1], [^2]) from a string and returns
88
- them as a sorted list of integers.
89
-
90
- Args:
91
- md_string (str): The markdown string containing references.
92
-
93
- Returns:
94
- list[int]: A sorted list of unique integers from the markdown references.
95
- """
96
- import re
97
-
98
- # Regex to find all occurrences of [^<number>]
99
- matches = re.findall(r"\[\^(\d+)\]", md_string)
100
- # Convert matches to integers, remove duplicates with set, and sort
101
- return sorted(set(int(match) for match in matches))
92
+ hf_embed_config = SentenceTransformerEmbeddingsConfig(
93
+ model_type="sentence-transformer",
94
+ model_name="BAAI/bge-large-en-v1.5",
95
+ )
102
96
 
97
+ oai_embed_config = OpenAIEmbeddingsConfig(
98
+ model_type="openai",
99
+ model_name="text-embedding-ada-002",
100
+ dims=1536,
101
+ )
103
102
 
104
- def format_footnote_text(content: str, width: int = 80) -> str:
105
- """
106
- Formats the content part of a footnote (i.e. not the first line that
107
- appears right after the reference [^4])
108
- It wraps the text so that no line is longer than the specified width and indents
109
- lines as necessary for markdown footnotes.
103
+ vecdb_config: VectorStoreConfig = QdrantDBConfig(
104
+ collection_name="doc-chat-qdrantdb",
105
+ replace_collection=True,
106
+ storage_path=".qdrantdb/data/",
107
+ embedding=hf_embed_config if has_sentence_transformers else oai_embed_config,
108
+ )
110
109
 
111
- Args:
112
- content (str): The text of the footnote to be formatted.
113
- width (int): Maximum width of the text lines.
110
+ try:
111
+ from langroid.vector_store.lancedb import LanceDBConfig
114
112
 
115
- Returns:
116
- str: Properly formatted markdown footnote text.
117
- """
118
- import textwrap
113
+ vecdb_config = LanceDBConfig(
114
+ collection_name="doc-chat-lancedb",
115
+ replace_collection=True,
116
+ storage_path=".lancedb/data/",
117
+ embedding=(hf_embed_config if has_sentence_transformers else oai_embed_config),
118
+ )
119
119
 
120
- # Wrap the text to the specified width
121
- wrapped_lines = textwrap.wrap(content, width)
122
- if len(wrapped_lines) == 0:
123
- return ""
124
- indent = " " # Indentation for markdown footnotes
125
- return indent + ("\n" + indent).join(wrapped_lines)
120
+ except ImportError:
121
+ pass
126
122
 
127
123
 
128
124
  class DocChatAgentConfig(ChatAgentConfig):
@@ -199,26 +195,10 @@ class DocChatAgentConfig(ChatAgentConfig):
199
195
  library="pdfplumber",
200
196
  ),
201
197
  )
202
- from langroid.embedding_models.models import SentenceTransformerEmbeddingsConfig
203
-
204
- hf_embed_config = SentenceTransformerEmbeddingsConfig(
205
- model_type="sentence-transformer",
206
- model_name="BAAI/bge-large-en-v1.5",
207
- )
208
-
209
- oai_embed_config = OpenAIEmbeddingsConfig(
210
- model_type="openai",
211
- model_name="text-embedding-ada-002",
212
- dims=1536,
213
- )
214
198
 
215
199
  # Allow vecdb to be None in case we want to explicitly set it later
216
- vecdb: Optional[VectorStoreConfig] = LanceDBConfig(
217
- collection_name="doc-chat-lancedb",
218
- replace_collection=True,
219
- storage_path=".lancedb/data/",
220
- embedding=hf_embed_config if has_sentence_transformers else oai_embed_config,
221
- )
200
+ vecdb: Optional[VectorStoreConfig] = vecdb_config
201
+
222
202
  llm: OpenAIGPTConfig = OpenAIGPTConfig(
223
203
  type="openai",
224
204
  chat_model=OpenAIChatModel.GPT4,
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import List, Optional, Tuple
2
2
 
3
3
  import pandas as pd
4
4
  import typer
@@ -105,7 +105,7 @@ class PandasToKGTool(ToolMessage):
105
105
  args: list[str]
106
106
 
107
107
  @classmethod
108
- def examples(cls) -> List["ToolMessage"]:
108
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
109
109
  return [
110
110
  cls(
111
111
  cypherQuery="""MERGE (employee:Employee {name: $employeeName,
@@ -1,7 +1,6 @@
1
+ from . import sql_chat_agent, utils
1
2
  from .sql_chat_agent import SQLChatAgentConfig, SQLChatAgent
2
3
 
3
- from . import sql_chat_agent
4
- from . import utils
5
4
 
6
5
  __all__ = [
7
6
  "SQLChatAgentConfig",
@@ -12,10 +12,16 @@ from typing import Any, Dict, List, Optional, Sequence, Union
12
12
 
13
13
  from rich import print
14
14
  from rich.console import Console
15
- from sqlalchemy import MetaData, Row, create_engine, inspect, text
16
- from sqlalchemy.engine import Engine
17
- from sqlalchemy.exc import SQLAlchemyError
18
- from sqlalchemy.orm import Session, sessionmaker
15
+
16
+ from langroid.exceptions import LangroidImportError
17
+
18
+ try:
19
+ from sqlalchemy import MetaData, Row, create_engine, inspect, text
20
+ from sqlalchemy.engine import Engine
21
+ from sqlalchemy.exc import SQLAlchemyError
22
+ from sqlalchemy.orm import Session, sessionmaker
23
+ except ImportError as e:
24
+ raise LangroidImportError(extra="sql", error=str(e))
19
25
 
20
26
  from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
21
27
  from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
@@ -1,3 +1,7 @@
1
+ from . import tools
2
+ from . import description_extractors
3
+ from . import populate_metadata
4
+ from . import system_message
1
5
  from .tools import (
2
6
  RunQueryTool,
3
7
  GetTableNamesTool,
@@ -5,11 +9,6 @@ from .tools import (
5
9
  GetColumnDescriptionsTool,
6
10
  )
7
11
 
8
- from . import description_extractors
9
- from . import populate_metadata
10
- from . import system_message
11
- from . import tools
12
-
13
12
  __all__ = [
14
13
  "RunQueryTool",
15
14
  "GetTableNamesTool",
@@ -1,7 +1,12 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from sqlalchemy import inspect, text
4
- from sqlalchemy.engine import Engine
3
+ from langroid.exceptions import LangroidImportError
4
+
5
+ try:
6
+ from sqlalchemy import inspect, text
7
+ from sqlalchemy.engine import Engine
8
+ except ImportError as e:
9
+ raise LangroidImportError(extra="sql", error=str(e))
5
10
 
6
11
 
7
12
  def extract_postgresql_descriptions(
@@ -1,6 +1,11 @@
1
1
  from typing import Dict, List, Union
2
2
 
3
- from sqlalchemy import MetaData
3
+ from langroid.exceptions import LangroidImportError
4
+
5
+ try:
6
+ from sqlalchemy import MetaData
7
+ except ImportError as e:
8
+ raise LangroidImportError(extra="sql", error=str(e))
4
9
 
5
10
 
6
11
  def populate_metadata_with_schema_tools(
@@ -12,7 +12,7 @@ the expression and returns the result as a string.
12
12
  import io
13
13
  import logging
14
14
  import sys
15
- from typing import List, Optional, no_type_check
15
+ from typing import List, Optional, Tuple, no_type_check
16
16
 
17
17
  import numpy as np
18
18
  import pandas as pd
@@ -138,7 +138,7 @@ class PandasEvalTool(ToolMessage):
138
138
  expression: str
139
139
 
140
140
  @classmethod
141
- def examples(cls) -> List["ToolMessage"]:
141
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
142
142
  return [
143
143
  cls(expression="df.head()"),
144
144
  cls(expression="df[(df['gender'] == 'Male')]['income'].mean()"),
@@ -10,7 +10,7 @@ import json
10
10
  import textwrap
11
11
  from abc import ABC
12
12
  from random import choice
13
- from typing import Any, Dict, List, Type
13
+ from typing import Any, Dict, List, Tuple, Type
14
14
 
15
15
  from docstring_parser import parse
16
16
  from pydantic import BaseModel
@@ -65,9 +65,16 @@ class ToolMessage(ABC, BaseModel):
65
65
  return ToolMessageWithRecipient
66
66
 
67
67
  @classmethod
68
- def examples(cls) -> List["ToolMessage"]:
68
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
69
69
  """
70
70
  Examples to use in few-shot demos with JSON formatting instructions.
71
+ Each example can be either:
72
+ - just a ToolMessage instance, e.g. MyTool(param1=1, param2="hello"), or
73
+ - a tuple (description, ToolMessage instance), where the description is
74
+ a natural language "thought" that leads to the tool usage,
75
+ e.g. ("I want to find the square of 5", SquareTool(num=5))
76
+ In some scenarios, ncluding such a description can significantly
77
+ enhance reliability of tool use.
71
78
  Returns:
72
79
  """
73
80
  return []
@@ -83,7 +90,11 @@ class ToolMessage(ABC, BaseModel):
83
90
  if len(cls.examples()) == 0:
84
91
  return ""
85
92
  ex = choice(cls.examples())
86
- return ex.json_example()
93
+ if isinstance(ex, tuple):
94
+ # (description, example_instance)
95
+ return f"{ex[0]} => {ex[1].json_example()}"
96
+ else:
97
+ return ex.json_example()
87
98
 
88
99
  def to_json(self) -> str:
89
100
  return self.json(indent=4, exclude={"result", "purpose"})
@@ -1,8 +1,7 @@
1
- from .google_search_tool import GoogleSearchTool
2
- from .recipient_tool import AddRecipientTool, RecipientTool
3
-
4
1
  from . import google_search_tool
5
2
  from . import recipient_tool
3
+ from .google_search_tool import GoogleSearchTool
4
+ from .recipient_tool import AddRecipientTool, RecipientTool
6
5
 
7
6
  __all__ = [
8
7
  "GoogleSearchTool",
@@ -5,7 +5,7 @@ access to agent state), it can be enabled for any agent, without having to defin
5
5
  special method inside the agent: `agent.enable_message(DuckduckgoSearchTool)`
6
6
  """
7
7
 
8
- from typing import List
8
+ from typing import List, Tuple
9
9
 
10
10
  from langroid.agent.tool_message import ToolMessage
11
11
  from langroid.parsing.web_search import duckduckgo_search
@@ -41,7 +41,7 @@ class DuckduckgoSearchTool(ToolMessage):
41
41
  """
42
42
 
43
43
  @classmethod
44
- def examples(cls) -> List["ToolMessage"]:
44
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
45
45
  return [
46
46
  cls(
47
47
  query="When was the Llama2 Large Language Model (LLM) released?",
@@ -9,7 +9,7 @@ environment variables in your `.env` file, as explained in the
9
9
  [README](https://github.com/langroid/langroid#gear-installation-and-setup).
10
10
  """
11
11
 
12
- from typing import List
12
+ from typing import List, Tuple
13
13
 
14
14
  from langroid.agent.tool_message import ToolMessage
15
15
  from langroid.parsing.web_search import google_search
@@ -30,7 +30,7 @@ class GoogleSearchTool(ToolMessage):
30
30
  return "\n\n".join(str(result) for result in search_results)
31
31
 
32
32
  @classmethod
33
- def examples(cls) -> List["ToolMessage"]:
33
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
34
34
  return [
35
35
  cls(
36
36
  query="When was the Llama2 Large Language Model (LLM) released?",
@@ -21,7 +21,7 @@ For more information, please refer to the official docs:
21
21
  https://metaphor.systems/
22
22
  """
23
23
 
24
- from typing import List
24
+ from typing import List, Tuple
25
25
 
26
26
  from langroid.agent.tool_message import ToolMessage
27
27
  from langroid.parsing.web_search import metaphor_search
@@ -58,7 +58,7 @@ class MetaphorSearchTool(ToolMessage):
58
58
  """
59
59
 
60
60
  @classmethod
61
- def examples(cls) -> List["ToolMessage"]:
61
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
62
62
  return [
63
63
  cls(
64
64
  query="When was the Llama2 Large Language Model (LLM) released?",
@@ -1,4 +1,4 @@
1
- from typing import List
1
+ from typing import List, Tuple
2
2
 
3
3
  from langroid.agent.tool_message import ToolMessage
4
4
 
@@ -16,7 +16,7 @@ class RetrievalTool(ToolMessage):
16
16
  num_results: int
17
17
 
18
18
  @classmethod
19
- def examples(cls) -> List["ToolMessage"]:
19
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
20
20
  return [
21
21
  cls(
22
22
  query="What are the eligibility criteria for the scholarship?",
@@ -1,6 +1,6 @@
1
1
  import io
2
2
  import sys
3
- from typing import List
3
+ from typing import List, Tuple
4
4
 
5
5
  from langroid.agent.tool_message import ToolMessage
6
6
 
@@ -19,7 +19,7 @@ class RunPythonCodeTool(ToolMessage):
19
19
  code: str
20
20
 
21
21
  @classmethod
22
- def examples(cls) -> List["ToolMessage"]:
22
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
23
23
  return [
24
24
  cls(code="import numpy as np\nnp.square(9)"),
25
25
  ]
@@ -10,7 +10,7 @@ This will usually be much cheaper and faster than actually writing out the extra
10
10
  text. The handler of this tool/function will then extract the text and send it back.
11
11
  """
12
12
 
13
- from typing import List
13
+ from typing import List, Tuple
14
14
 
15
15
  from langroid.agent.tool_message import ToolMessage
16
16
 
@@ -25,7 +25,7 @@ class SegmentExtractTool(ToolMessage):
25
25
  segment_list: str
26
26
 
27
27
  @classmethod
28
- def examples(cls) -> List["ToolMessage"]:
28
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
29
29
  return [cls(segment_list="1,3,5-7")]
30
30
 
31
31
  @classmethod
langroid/cachedb/base.py CHANGED
@@ -1,18 +1,26 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Any, Dict, List
3
3
 
4
+ from pydantic import BaseSettings
5
+
6
+
7
+ class CacheDBConfig(BaseSettings):
8
+ """Configuration model for CacheDB."""
9
+
10
+ pass
11
+
4
12
 
5
13
  class CacheDB(ABC):
6
14
  """Abstract base class for a cache database."""
7
15
 
8
16
  @abstractmethod
9
- def store(self, key: str, value: Dict[str, Any]) -> None:
17
+ def store(self, key: str, value: Any) -> None:
10
18
  """
11
19
  Abstract method to store a value associated with a key.
12
20
 
13
21
  Args:
14
22
  key (str): The key under which to store the value.
15
- value (dict): The value to store.
23
+ value (Any): The value to store.
16
24
  """
17
25
  pass
18
26
 
@@ -4,17 +4,23 @@ import os
4
4
  from datetime import timedelta
5
5
  from typing import Any, Dict, List
6
6
 
7
- import momento
7
+ from langroid.cachedb.base import CacheDBConfig
8
+ from langroid.exceptions import LangroidImportError
9
+
10
+ try:
11
+ import momento
12
+ from momento.responses import CacheGet
13
+ except ImportError:
14
+ raise LangroidImportError(package="momento", extra="momento")
15
+
8
16
  from dotenv import load_dotenv
9
- from momento.responses import CacheGet
10
- from pydantic import BaseModel
11
17
 
12
18
  from langroid.cachedb.base import CacheDB
13
19
 
14
20
  logger = logging.getLogger(__name__)
15
21
 
16
22
 
17
- class MomentoCacheConfig(BaseModel):
23
+ class MomentoCacheConfig(CacheDBConfig):
18
24
  """Configuration model for Momento Cache."""
19
25
 
20
26
  ttl: int = 60 * 60 * 24 * 7 # 1 week
@@ -7,15 +7,14 @@ from typing import Any, Dict, List, TypeVar
7
7
  import fakeredis
8
8
  import redis
9
9
  from dotenv import load_dotenv
10
- from pydantic import BaseModel
11
10
 
12
- from langroid.cachedb.base import CacheDB
11
+ from langroid.cachedb.base import CacheDB, CacheDBConfig
13
12
 
14
13
  T = TypeVar("T", bound="RedisCache")
15
14
  logger = logging.getLogger(__name__)
16
15
 
17
16
 
18
- class RedisCacheConfig(BaseModel):
17
+ class RedisCacheConfig(CacheDBConfig):
19
18
  """Configuration model for RedisCache."""
20
19
 
21
20
  fake: bool = False
@@ -18,6 +18,7 @@ from .remote_embeds import (
18
18
  RemoteEmbeddings,
19
19
  )
20
20
 
21
+
21
22
  __all__ = [
22
23
  "base",
23
24
  "models",