langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,382 @@
|
|
1
|
+
import re
|
2
|
+
from collections.abc import Mapping
|
3
|
+
from typing import Any, Dict, List, Optional, get_args, get_origin
|
4
|
+
|
5
|
+
from lxml import etree
|
6
|
+
|
7
|
+
from langroid.agent.tool_message import ToolMessage
|
8
|
+
from langroid.pydantic_v1 import BaseModel
|
9
|
+
|
10
|
+
|
11
|
+
class XMLToolMessage(ToolMessage):
|
12
|
+
"""
|
13
|
+
Abstract class for tools formatted using XML instead of JSON.
|
14
|
+
|
15
|
+
When a subclass defines a field with the attribute `verbatim=True`,
|
16
|
+
instructions are sent to the LLM to ensure the field's content is:
|
17
|
+
- preserved as is, including whitespace, indents, quotes, newlines, etc
|
18
|
+
with no escaping, and
|
19
|
+
- enclosed in a CDATA section in the XML output.
|
20
|
+
This is useful for LLMs sending code as part of a tool;
|
21
|
+
results can be far superior compared to sending code in JSON-formatted tools,
|
22
|
+
where code needs to confirm to JSON's strict rules and escaping requirements.
|
23
|
+
(see test_xml_tool_message.py for an example).
|
24
|
+
|
25
|
+
"""
|
26
|
+
|
27
|
+
request: str
|
28
|
+
purpose: str
|
29
|
+
|
30
|
+
_allow_llm_use = True
|
31
|
+
|
32
|
+
class Config(ToolMessage.Config):
|
33
|
+
root_element = "tool"
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def extract_field_values(cls, formatted_string: str) -> Optional[Dict[str, Any]]:
|
37
|
+
"""
|
38
|
+
Extracts field values from an XML-formatted string.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
formatted_string (str): The XML-formatted string to parse.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
Optional[Dict[str, Any]]: A dictionary containing the extracted field
|
45
|
+
values, where keys are the XML element names and values are their
|
46
|
+
corresponding contents.
|
47
|
+
Returns None if parsing fails or the root element is not a dictionary.
|
48
|
+
|
49
|
+
Raises:
|
50
|
+
etree.XMLSyntaxError: If the input string is not valid XML.
|
51
|
+
"""
|
52
|
+
parser = etree.XMLParser(strip_cdata=False)
|
53
|
+
root = etree.fromstring(formatted_string.encode("utf-8"), parser=parser)
|
54
|
+
|
55
|
+
def parse_element(element: etree._Element) -> Any:
|
56
|
+
# Skip elements starting with underscore
|
57
|
+
if element.tag.startswith("_"):
|
58
|
+
return {}
|
59
|
+
|
60
|
+
field_info = cls.__fields__.get(element.tag)
|
61
|
+
is_verbatim = field_info and field_info.field_info.extra.get(
|
62
|
+
"verbatim", False
|
63
|
+
)
|
64
|
+
|
65
|
+
if is_verbatim:
|
66
|
+
# For code elements, preserve the content as is, including whitespace
|
67
|
+
content = element.text if element.text else ""
|
68
|
+
# Strip leading and trailing triple backticks if present,
|
69
|
+
# accounting for whitespace
|
70
|
+
return (
|
71
|
+
content.strip().removeprefix("```").removesuffix("```").strip()
|
72
|
+
if content.strip().startswith("```")
|
73
|
+
and content.strip().endswith("```")
|
74
|
+
else content
|
75
|
+
)
|
76
|
+
elif len(element) == 0:
|
77
|
+
# For non-code leaf elements, strip whitespace
|
78
|
+
return element.text.strip() if element.text else ""
|
79
|
+
else:
|
80
|
+
# For branch elements, handle potential lists or nested structures
|
81
|
+
children = [parse_element(child) for child in element]
|
82
|
+
if all(child.tag == element[0].tag for child in element):
|
83
|
+
# If all children have the same tag, treat as a list
|
84
|
+
return children
|
85
|
+
else:
|
86
|
+
# Otherwise, treat as a dictionary
|
87
|
+
result = {child.tag: parse_element(child) for child in element}
|
88
|
+
# Check if this corresponds to a nested Pydantic model
|
89
|
+
if field_info and issubclass(field_info.type_, BaseModel):
|
90
|
+
return field_info.type_(**result)
|
91
|
+
return result
|
92
|
+
|
93
|
+
result = parse_element(root)
|
94
|
+
if not isinstance(result, dict):
|
95
|
+
return None
|
96
|
+
# Filter out empty dictionaries from skipped underscore fields
|
97
|
+
return {k: v for k, v in result.items() if v != {}}
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def parse(cls, formatted_string: str) -> Optional["XMLToolMessage"]:
|
101
|
+
"""
|
102
|
+
Parses the XML-formatted string and returns an instance of the class.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
formatted_string (str): The XML-formatted string to parse.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
Optional["XMLToolMessage"]: An instance of the class if parsing succeeds,
|
109
|
+
None otherwise.
|
110
|
+
"""
|
111
|
+
try:
|
112
|
+
parsed_data = cls.extract_field_values(formatted_string)
|
113
|
+
if parsed_data is None:
|
114
|
+
return None
|
115
|
+
|
116
|
+
# Use Pydantic's parse_obj to create and validate the instance
|
117
|
+
return cls.parse_obj(parsed_data)
|
118
|
+
except Exception as e:
|
119
|
+
from langroid.exceptions import XMLException
|
120
|
+
|
121
|
+
raise XMLException(f"Error parsing XML: {str(e)}")
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
def find_verbatim_fields(
|
125
|
+
cls, prefix: str = "", parent_cls: Optional["BaseModel"] = None
|
126
|
+
) -> List[str]:
|
127
|
+
verbatim_fields = []
|
128
|
+
for field_name, field_info in (parent_cls or cls).__fields__.items():
|
129
|
+
full_name = f"{prefix}.{field_name}" if prefix else field_name
|
130
|
+
if (
|
131
|
+
field_info.field_info.extra.get("verbatim", False)
|
132
|
+
or field_name == "code"
|
133
|
+
):
|
134
|
+
verbatim_fields.append(full_name)
|
135
|
+
if issubclass(field_info.type_, BaseModel):
|
136
|
+
verbatim_fields.extend(
|
137
|
+
cls.find_verbatim_fields(full_name, field_info.type_)
|
138
|
+
)
|
139
|
+
return verbatim_fields
|
140
|
+
|
141
|
+
@classmethod
|
142
|
+
def format_instructions(cls, tool: bool = False) -> str:
|
143
|
+
fields = [
|
144
|
+
f
|
145
|
+
for f in cls.__fields__.keys()
|
146
|
+
if f not in cls.Config.schema_extra.get("exclude", set())
|
147
|
+
]
|
148
|
+
|
149
|
+
instructions = """
|
150
|
+
To use this tool, please provide the required information in an XML-like
|
151
|
+
format. Here's how to structure your input:\n\n
|
152
|
+
"""
|
153
|
+
|
154
|
+
preamble = "Placeholders:\n"
|
155
|
+
xml_format = f"Formatting example:\n\n<{cls.Config.root_element}>\n"
|
156
|
+
|
157
|
+
def format_field(
|
158
|
+
field_name: str,
|
159
|
+
field_type: type,
|
160
|
+
indent: str = "",
|
161
|
+
path: str = "",
|
162
|
+
) -> None:
|
163
|
+
nonlocal preamble, xml_format
|
164
|
+
current_path = f"{path}.{field_name}" if path else field_name
|
165
|
+
|
166
|
+
origin = get_origin(field_type)
|
167
|
+
args = get_args(field_type)
|
168
|
+
|
169
|
+
if (
|
170
|
+
origin is None
|
171
|
+
and isinstance(field_type, type)
|
172
|
+
and issubclass(field_type, BaseModel)
|
173
|
+
):
|
174
|
+
preamble += (
|
175
|
+
f"{field_name.upper()} = [nested structure for {field_name}]\n"
|
176
|
+
)
|
177
|
+
xml_format += f"{indent}<{field_name}>\n"
|
178
|
+
for sub_field, sub_field_info in field_type.__fields__.items():
|
179
|
+
format_field(
|
180
|
+
sub_field,
|
181
|
+
sub_field_info.outer_type_,
|
182
|
+
indent + " ",
|
183
|
+
current_path,
|
184
|
+
)
|
185
|
+
xml_format += f"{indent}</{field_name}>\n"
|
186
|
+
elif origin in (list, List) or (field_type is list):
|
187
|
+
item_type = args[0] if args else Any
|
188
|
+
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
|
189
|
+
preamble += (
|
190
|
+
f"{field_name.upper()} = "
|
191
|
+
f"[list of nested structures for {field_name}]\n"
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
preamble += (
|
195
|
+
f"{field_name.upper()} = "
|
196
|
+
f"[list of {getattr(item_type, '__name__', str(item_type))} "
|
197
|
+
f"for {field_name}]\n"
|
198
|
+
)
|
199
|
+
xml_format += f"{indent}<{field_name}>\n"
|
200
|
+
xml_format += (
|
201
|
+
f"{indent} <item>"
|
202
|
+
f"[{getattr(item_type, '__name__', str(item_type))} value]"
|
203
|
+
f"</item>\n"
|
204
|
+
)
|
205
|
+
xml_format += f"{indent} ...\n"
|
206
|
+
xml_format += f"{indent}</{field_name}>\n"
|
207
|
+
elif origin in (dict, Dict) or (
|
208
|
+
isinstance(field_type, type) and issubclass(field_type, Mapping)
|
209
|
+
):
|
210
|
+
key_type, value_type = args if len(args) == 2 else (Any, Any)
|
211
|
+
preamble += (
|
212
|
+
f"{field_name.upper()} = "
|
213
|
+
f"[dictionary with "
|
214
|
+
f"{getattr(key_type, '__name__', str(key_type))} keys and "
|
215
|
+
f"{getattr(value_type, '__name__', str(value_type))} values]\n"
|
216
|
+
)
|
217
|
+
xml_format += f"{indent}<{field_name}>\n"
|
218
|
+
xml_format += (
|
219
|
+
f"{indent} <{getattr(key_type, '__name__', str(key_type))}>"
|
220
|
+
f"[{getattr(value_type, '__name__', str(value_type))} value]"
|
221
|
+
f"</{getattr(key_type, '__name__', str(key_type))}>\n"
|
222
|
+
)
|
223
|
+
xml_format += f"{indent} ...\n"
|
224
|
+
xml_format += f"{indent}</{field_name}>\n"
|
225
|
+
else:
|
226
|
+
preamble += f"{field_name.upper()} = [value for {field_name}]\n"
|
227
|
+
if current_path in verbatim_fields:
|
228
|
+
xml_format += (
|
229
|
+
f"{indent}<{field_name}>"
|
230
|
+
f"<![CDATA[{{{field_name.upper()}}}]]></{field_name}>\n"
|
231
|
+
)
|
232
|
+
else:
|
233
|
+
xml_format += (
|
234
|
+
f"{indent}<{field_name}>"
|
235
|
+
f"{{{field_name.upper()}}}</{field_name}>\n"
|
236
|
+
)
|
237
|
+
|
238
|
+
verbatim_fields = cls.find_verbatim_fields()
|
239
|
+
|
240
|
+
for field in fields:
|
241
|
+
field_info = cls.__fields__[field]
|
242
|
+
field_type = (
|
243
|
+
field_info.outer_type_
|
244
|
+
) # Use outer_type_ to get the actual type including List, etc.
|
245
|
+
format_field(field, field_type)
|
246
|
+
|
247
|
+
xml_format += f"</{cls.Config.root_element}>"
|
248
|
+
|
249
|
+
verbatim_alert = ""
|
250
|
+
if len(verbatim_fields) > 0:
|
251
|
+
verbatim_alert = f"""
|
252
|
+
EXTREMELY IMPORTANT: For these fields:
|
253
|
+
{', '.join(verbatim_fields)},
|
254
|
+
the contents MUST be wrapped in a CDATA section, and the content
|
255
|
+
must be written verbatim WITHOUT any modifications or escaping,
|
256
|
+
such as spaces, tabs, indents, newlines, quotes, etc.
|
257
|
+
"""
|
258
|
+
|
259
|
+
examples_str = ""
|
260
|
+
if cls.examples():
|
261
|
+
examples_str = "EXAMPLES:\n" + cls.usage_examples()
|
262
|
+
|
263
|
+
return f"""
|
264
|
+
TOOL: {cls.default_value("request")}
|
265
|
+
PURPOSE: {cls.default_value("purpose")}
|
266
|
+
|
267
|
+
{instructions}
|
268
|
+
{preamble}
|
269
|
+
{xml_format}
|
270
|
+
|
271
|
+
Make sure to replace the placeholders with actual values
|
272
|
+
when using the tool.
|
273
|
+
{verbatim_alert}
|
274
|
+
{examples_str}
|
275
|
+
""".lstrip()
|
276
|
+
|
277
|
+
def format_example(self) -> str:
|
278
|
+
"""
|
279
|
+
Format the current instance as an XML example.
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
str: A string representation of the current instance in XML format.
|
283
|
+
|
284
|
+
Raises:
|
285
|
+
ValueError: If the result from etree.tostring is not a string.
|
286
|
+
"""
|
287
|
+
|
288
|
+
def create_element(
|
289
|
+
parent: etree._Element, name: str, value: Any, path: str = ""
|
290
|
+
) -> None:
|
291
|
+
if value is None:
|
292
|
+
return
|
293
|
+
|
294
|
+
elem = etree.SubElement(parent, name)
|
295
|
+
current_path = f"{path}.{name}" if path else name
|
296
|
+
|
297
|
+
if isinstance(value, list):
|
298
|
+
for item in value:
|
299
|
+
create_element(elem, "item", item, current_path)
|
300
|
+
elif isinstance(value, dict):
|
301
|
+
for k, v in value.items():
|
302
|
+
create_element(elem, k, v, current_path)
|
303
|
+
elif isinstance(value, BaseModel):
|
304
|
+
# Handle nested Pydantic models
|
305
|
+
for field_name, field_value in value.dict().items():
|
306
|
+
create_element(elem, field_name, field_value, current_path)
|
307
|
+
else:
|
308
|
+
if current_path in self.__class__.find_verbatim_fields():
|
309
|
+
elem.text = etree.CDATA(str(value))
|
310
|
+
else:
|
311
|
+
elem.text = str(value)
|
312
|
+
|
313
|
+
root = etree.Element(self.Config.root_element)
|
314
|
+
exclude_fields = self.Config.schema_extra.get("exclude", set())
|
315
|
+
for name, value in self.dict().items():
|
316
|
+
if name not in exclude_fields:
|
317
|
+
create_element(root, name, value)
|
318
|
+
|
319
|
+
result = etree.tostring(root, encoding="unicode", pretty_print=True)
|
320
|
+
if not isinstance(result, str):
|
321
|
+
raise ValueError("Unexpected non-string result from etree.tostring")
|
322
|
+
return result
|
323
|
+
|
324
|
+
@classmethod
|
325
|
+
def find_candidates(cls, text: str) -> List[str]:
|
326
|
+
"""
|
327
|
+
Finds XML-like tool message candidates in text, with relaxed opening tag rules.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
text: Input text to search for XML structures.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
List of XML strings. For fragments missing the root opening tag but having
|
334
|
+
valid XML structure and root closing tag, prepends the root opening tag.
|
335
|
+
|
336
|
+
Example:
|
337
|
+
With root_tag="tool", given:
|
338
|
+
"Hello <field1>data</field1> </tool>"
|
339
|
+
Returns: ["<tool><field1>data</field1></tool>"]
|
340
|
+
"""
|
341
|
+
|
342
|
+
root_tag = cls.Config.root_element
|
343
|
+
opening_tag = f"<{root_tag}>"
|
344
|
+
closing_tag = f"</{root_tag}>"
|
345
|
+
|
346
|
+
candidates = []
|
347
|
+
pos = 0
|
348
|
+
while True:
|
349
|
+
# Look for either proper opening tag or closing tag
|
350
|
+
start_normal = text.find(opening_tag, pos)
|
351
|
+
end = text.find(closing_tag, pos)
|
352
|
+
|
353
|
+
if start_normal == -1 and end == -1:
|
354
|
+
break
|
355
|
+
|
356
|
+
if start_normal != -1:
|
357
|
+
# Handle normal case (has opening tag)
|
358
|
+
end = text.find(closing_tag, start_normal)
|
359
|
+
if end != -1:
|
360
|
+
candidates.append(text[start_normal : end + len(closing_tag)])
|
361
|
+
pos = max(end + len(closing_tag), start_normal + 1)
|
362
|
+
continue
|
363
|
+
elif start_normal == text.rfind(opening_tag):
|
364
|
+
# last fragment - ok to miss closing tag
|
365
|
+
candidates.append(text[start_normal:] + closing_tag)
|
366
|
+
return candidates
|
367
|
+
else:
|
368
|
+
pos = start_normal + 1
|
369
|
+
continue
|
370
|
+
|
371
|
+
if end != -1:
|
372
|
+
# Look backwards for first XML tag
|
373
|
+
text_before = text[pos:end]
|
374
|
+
first_tag_match = re.search(r"<\w+>", text_before)
|
375
|
+
if first_tag_match:
|
376
|
+
start = pos + first_tag_match.start()
|
377
|
+
candidates.append(
|
378
|
+
opening_tag + text[start : end + len(closing_tag)]
|
379
|
+
)
|
380
|
+
pos = end + len(closing_tag)
|
381
|
+
|
382
|
+
return candidates
|
langroid/cachedb/base.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
from langroid.pydantic_v1 import BaseSettings
|
5
|
+
|
6
|
+
|
7
|
+
class CacheDBConfig(BaseSettings):
|
8
|
+
"""Configuration model for CacheDB."""
|
9
|
+
|
10
|
+
pass
|
11
|
+
|
12
|
+
|
13
|
+
class CacheDB(ABC):
|
14
|
+
"""Abstract base class for a cache database."""
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
def store(self, key: str, value: Any) -> None:
|
18
|
+
"""
|
19
|
+
Abstract method to store a value associated with a key.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
key (str): The key under which to store the value.
|
23
|
+
value (Any): The value to store.
|
24
|
+
"""
|
25
|
+
pass
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def retrieve(self, key: str) -> Dict[str, Any] | str | None:
|
29
|
+
"""
|
30
|
+
Abstract method to retrieve the value associated with a key.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
key (str): The key to retrieve the value for.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
dict: The value associated with the key.
|
37
|
+
"""
|
38
|
+
pass
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def delete_keys(self, keys: List[str]) -> None:
|
42
|
+
"""
|
43
|
+
Delete the keys from the cache.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
keys (List[str]): The keys to delete.
|
47
|
+
"""
|
48
|
+
pass
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def delete_keys_pattern(self, pattern: str) -> None:
|
52
|
+
"""
|
53
|
+
Delete all keys with the given pattern
|
54
|
+
|
55
|
+
Args:
|
56
|
+
prefix (str): The pattern to match.
|
57
|
+
"""
|
58
|
+
pass
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from datetime import timedelta
|
5
|
+
from typing import Any, Dict, List
|
6
|
+
|
7
|
+
from langroid.cachedb.base import CacheDBConfig
|
8
|
+
from langroid.exceptions import LangroidImportError
|
9
|
+
|
10
|
+
try:
|
11
|
+
import momento
|
12
|
+
from momento.responses import CacheGet
|
13
|
+
except ImportError:
|
14
|
+
raise LangroidImportError(package="momento", extra="momento")
|
15
|
+
|
16
|
+
from dotenv import load_dotenv
|
17
|
+
|
18
|
+
from langroid.cachedb.base import CacheDB
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class MomentoCacheConfig(CacheDBConfig):
|
24
|
+
"""Configuration model for Momento Cache."""
|
25
|
+
|
26
|
+
ttl: int = 60 * 60 * 24 * 7 # 1 week
|
27
|
+
cachename: str = "langroid_momento_cache"
|
28
|
+
|
29
|
+
|
30
|
+
class MomentoCache(CacheDB):
|
31
|
+
"""Momento implementation of the CacheDB."""
|
32
|
+
|
33
|
+
def __init__(self, config: MomentoCacheConfig):
|
34
|
+
"""
|
35
|
+
Initialize a MomentoCache with the given config.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
config (MomentoCacheConfig): The configuration to use.
|
39
|
+
"""
|
40
|
+
self.config = config
|
41
|
+
load_dotenv()
|
42
|
+
|
43
|
+
momento_token = os.getenv("MOMENTO_AUTH_TOKEN")
|
44
|
+
if momento_token is None:
|
45
|
+
raise ValueError("""MOMENTO_AUTH_TOKEN not set in .env file""")
|
46
|
+
else:
|
47
|
+
self.client = momento.CacheClient(
|
48
|
+
configuration=momento.Configurations.Laptop.v1(),
|
49
|
+
credential_provider=momento.CredentialProvider.from_environment_variable(
|
50
|
+
"MOMENTO_AUTH_TOKEN"
|
51
|
+
),
|
52
|
+
default_ttl=timedelta(seconds=self.config.ttl),
|
53
|
+
)
|
54
|
+
self.client.create_cache(self.config.cachename)
|
55
|
+
|
56
|
+
def clear(self) -> None:
|
57
|
+
"""Clear keys from current db."""
|
58
|
+
self.client.flush_cache(self.config.cachename)
|
59
|
+
|
60
|
+
def store(self, key: str, value: Any) -> None:
|
61
|
+
"""
|
62
|
+
Store a value associated with a key.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
key (str): The key under which to store the value.
|
66
|
+
value (Any): The value to store.
|
67
|
+
"""
|
68
|
+
self.client.set(self.config.cachename, key, json.dumps(value))
|
69
|
+
|
70
|
+
def retrieve(self, key: str) -> Dict[str, Any] | str | None:
|
71
|
+
"""
|
72
|
+
Retrieve the value associated with a key.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
key (str): The key to retrieve the value for.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
dict: The value associated with the key.
|
79
|
+
"""
|
80
|
+
value = self.client.get(self.config.cachename, key)
|
81
|
+
if isinstance(value, CacheGet.Hit):
|
82
|
+
return json.loads(value.value_string) # type: ignore
|
83
|
+
else:
|
84
|
+
return None
|
85
|
+
|
86
|
+
def delete_keys(self, keys: List[str]) -> None:
|
87
|
+
"""
|
88
|
+
Delete the keys from the cache.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
keys (List[str]): The keys to delete.
|
92
|
+
"""
|
93
|
+
for key in keys:
|
94
|
+
self.client.delete(self.config.cachename, key)
|
95
|
+
|
96
|
+
def delete_keys_pattern(self, pattern: str) -> None:
|
97
|
+
"""
|
98
|
+
Delete the keys from the cache with the given pattern.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
prefix (str): The pattern to match.
|
102
|
+
"""
|
103
|
+
raise NotImplementedError(
|
104
|
+
"""
|
105
|
+
MomentoCache does not support delete_keys_pattern.
|
106
|
+
Please use RedisCache instead.
|
107
|
+
"""
|
108
|
+
)
|