langroid 0.1.265__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +21 -9
- langroid/agent/chat_agent.py +69 -17
- langroid/agent/chat_document.py +59 -4
- langroid/agent/special/doc_chat_agent.py +8 -26
- langroid/agent/task.py +299 -103
- langroid/agent/tools/__init__.py +4 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/language_models/__init__.py +3 -0
- langroid/language_models/base.py +23 -4
- langroid/language_models/mock_lm.py +91 -0
- langroid/language_models/utils.py +2 -1
- langroid/mytypes.py +4 -35
- langroid/parsing/document_parser.py +5 -0
- langroid/parsing/parser.py +17 -2
- langroid/utils/__init__.py +2 -0
- langroid/utils/constants.py +2 -1
- langroid/utils/object_registry.py +66 -0
- langroid/utils/system.py +1 -2
- langroid/vector_store/base.py +3 -2
- {langroid-0.1.265.dist-info → langroid-0.2.2.dist-info}/METADATA +10 -6
- {langroid-0.1.265.dist-info → langroid-0.2.2.dist-info}/RECORD +24 -22
- pyproject.toml +2 -2
- langroid/language_models/openai_assistants.py +0 -3
- {langroid-0.1.265.dist-info → langroid-0.2.2.dist-info}/LICENSE +0 -0
- {langroid-0.1.265.dist-info → langroid-0.2.2.dist-info}/WHEEL +0 -0
langroid/agent/tools/__init__.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from . import google_search_tool
|
2
2
|
from . import recipient_tool
|
3
|
+
from . import rewind_tool
|
3
4
|
from .google_search_tool import GoogleSearchTool
|
4
5
|
from .recipient_tool import AddRecipientTool, RecipientTool
|
6
|
+
from .rewind_tool import RewindTool
|
5
7
|
|
6
8
|
__all__ = [
|
7
9
|
"GoogleSearchTool",
|
@@ -9,4 +11,6 @@ __all__ = [
|
|
9
11
|
"RecipientTool",
|
10
12
|
"google_search_tool",
|
11
13
|
"recipient_tool",
|
14
|
+
"rewind_tool",
|
15
|
+
"RewindTool",
|
12
16
|
]
|
@@ -0,0 +1,137 @@
|
|
1
|
+
"""
|
2
|
+
The `rewind_tool` is used to rewind to the `n`th previous Assistant message
|
3
|
+
and replace it with a new `content`. This is useful in several scenarios and
|
4
|
+
- saves token-cost + inference time,
|
5
|
+
- reduces distracting clutter in chat history, which helps improve response quality.
|
6
|
+
|
7
|
+
This is intended to mimic how a human user might use a chat interface, where they
|
8
|
+
go down a conversation path, and want to go back in history to "edit and re-submit"
|
9
|
+
a previous message, to get a better response.
|
10
|
+
|
11
|
+
See usage examples in `tests/main/test_rewind_tool.py`.
|
12
|
+
"""
|
13
|
+
|
14
|
+
from typing import List, Tuple
|
15
|
+
|
16
|
+
import langroid.language_models as lm
|
17
|
+
from langroid.agent.chat_agent import ChatAgent
|
18
|
+
from langroid.agent.chat_document import ChatDocument
|
19
|
+
from langroid.agent.tool_message import ToolMessage
|
20
|
+
|
21
|
+
|
22
|
+
def prune_messages(agent: ChatAgent, idx: int) -> ChatDocument | None:
|
23
|
+
"""
|
24
|
+
Clear the message history of agent, starting at index `idx`,
|
25
|
+
taking care to first clear all dependent messages (possibly from other agents'
|
26
|
+
message histories) that are linked to the message at `idx`, via the `child_id` field
|
27
|
+
of the `metadata` field of the ChatDocument linked from the message at `idx`.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
agent (ChatAgent): The agent whose message history is to be pruned.
|
31
|
+
idx (int): The index from which to start clearing the message history.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The parent ChatDocument of the ChatDocument linked from the message at `idx`,
|
35
|
+
if it exists, else None.
|
36
|
+
|
37
|
+
"""
|
38
|
+
assert idx >= 0, "Invalid index for message history!"
|
39
|
+
chat_doc_id = agent.message_history[idx].chat_document_id
|
40
|
+
chat_doc = ChatDocument.from_id(chat_doc_id)
|
41
|
+
assert chat_doc is not None, "ChatDocument not found in registry!"
|
42
|
+
|
43
|
+
parent = ChatDocument.from_id(chat_doc.metadata.parent_id) # may be None
|
44
|
+
# We're invaliding the msg at idx,
|
45
|
+
# so starting with chat_doc, go down the child links
|
46
|
+
# and clear history of each agent, to the msg_idx
|
47
|
+
curr_doc = chat_doc
|
48
|
+
while child_doc := curr_doc.metadata.child:
|
49
|
+
if child_doc.metadata.msg_idx >= 0:
|
50
|
+
child_agent = ChatAgent.from_id(child_doc.metadata.agent_id)
|
51
|
+
if child_agent is not None:
|
52
|
+
child_agent.clear_history(child_doc.metadata.msg_idx)
|
53
|
+
curr_doc = child_doc
|
54
|
+
|
55
|
+
# Clear out ObjectRegistry entries for this ChatDocuments
|
56
|
+
# and all descendants (in case they weren't already cleared above)
|
57
|
+
ChatDocument.delete_id(chat_doc.id())
|
58
|
+
|
59
|
+
# Finally, clear this agent's history back to idx,
|
60
|
+
# and replace the msg at idx with the new content
|
61
|
+
agent.clear_history(idx)
|
62
|
+
return parent
|
63
|
+
|
64
|
+
|
65
|
+
class RewindTool(ToolMessage):
|
66
|
+
"""
|
67
|
+
Used by LLM to rewind (i.e. backtrack) to the `n`th Assistant message
|
68
|
+
and replace with a new msg.
|
69
|
+
"""
|
70
|
+
|
71
|
+
request: str = "rewind_tool"
|
72
|
+
purpose: str = """
|
73
|
+
To rewind the conversation and replace the
|
74
|
+
<n>'th Assistant message with <content>
|
75
|
+
"""
|
76
|
+
n: int
|
77
|
+
content: str
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
|
81
|
+
return [
|
82
|
+
cls(n=1, content="What are the 3 major causes of heart disease?"),
|
83
|
+
(
|
84
|
+
"""
|
85
|
+
Based on the conversation so far, I realize I would get a better
|
86
|
+
response from Bob if rephrase my 2nd message to him to:
|
87
|
+
'Who wrote the book Grime and Banishment?'
|
88
|
+
""",
|
89
|
+
cls(n=2, content="who wrote the book 'Grime and Banishment'?"),
|
90
|
+
),
|
91
|
+
]
|
92
|
+
|
93
|
+
def response(self, agent: ChatAgent) -> str | ChatDocument:
|
94
|
+
"""
|
95
|
+
Define the tool-handler method for this tool here itself,
|
96
|
+
since it is a generic tool whose functionality should be the
|
97
|
+
same for any agent.
|
98
|
+
|
99
|
+
When LLM has correctly used this tool, rewind this agent's
|
100
|
+
`message_history` to the `n`th assistant msg, and replace it with `content`.
|
101
|
+
We need to mock it as if the LLM is sending this message.
|
102
|
+
|
103
|
+
Within a multi-agent scenario, this also means that any other messages dependent
|
104
|
+
on this message will need to be invalidated --
|
105
|
+
so go down the chain of child messages and clear each agent's history
|
106
|
+
back to the `msg_idx` corresponding to the child message.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
(ChatDocument): with content set to self.content.
|
110
|
+
"""
|
111
|
+
idx = agent.nth_message_idx_with_role(lm.Role.ASSISTANT, self.n)
|
112
|
+
if idx < 0:
|
113
|
+
# set up a corrective message from AGENT
|
114
|
+
msg = f"""
|
115
|
+
Could not rewind to {self.n}th Assistant message!
|
116
|
+
Please check the value of `n` and try again.
|
117
|
+
Or it may be too early to use the `rewind_tool`.
|
118
|
+
"""
|
119
|
+
return agent.create_agent_response(msg)
|
120
|
+
|
121
|
+
parent = prune_messages(agent, idx)
|
122
|
+
|
123
|
+
# create ChatDocument with new content, to be returned as result of this tool
|
124
|
+
result_doc = agent.create_llm_response(self.content)
|
125
|
+
result_doc.metadata.parent_id = "" if parent is None else parent.id()
|
126
|
+
result_doc.metadata.agent_id = agent.id
|
127
|
+
result_doc.metadata.msg_idx = idx
|
128
|
+
|
129
|
+
# replace the message at idx with this new message
|
130
|
+
agent.message_history.append(ChatDocument.to_LLMMessage(result_doc))
|
131
|
+
|
132
|
+
# set the replaced doc's parent's child to this result_doc
|
133
|
+
if parent is not None:
|
134
|
+
# first remove the this parent's child from registry
|
135
|
+
ChatDocument.delete_id(parent.metadata.child_id)
|
136
|
+
parent.metadata.child_id = result_doc.id()
|
137
|
+
return result_doc
|
@@ -20,6 +20,7 @@ from .openai_gpt import (
|
|
20
20
|
OpenAIGPTConfig,
|
21
21
|
OpenAIGPT,
|
22
22
|
)
|
23
|
+
from .mock_lm import MockLM, MockLMConfig
|
23
24
|
from .azure_openai import AzureConfig, AzureGPT
|
24
25
|
|
25
26
|
|
@@ -43,4 +44,6 @@ __all__ = [
|
|
43
44
|
"OpenAIGPT",
|
44
45
|
"AzureConfig",
|
45
46
|
"AzureGPT",
|
47
|
+
"MockLM",
|
48
|
+
"MockLMConfig",
|
46
49
|
]
|
langroid/language_models/base.py
CHANGED
@@ -4,7 +4,17 @@ import logging
|
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from datetime import datetime
|
6
6
|
from enum import Enum
|
7
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
Callable,
|
10
|
+
Dict,
|
11
|
+
List,
|
12
|
+
Optional,
|
13
|
+
Tuple,
|
14
|
+
Type,
|
15
|
+
Union,
|
16
|
+
cast,
|
17
|
+
)
|
8
18
|
|
9
19
|
from langroid.cachedb.base import CacheDBConfig
|
10
20
|
from langroid.parsing.agent_chats import parse_message
|
@@ -134,12 +144,15 @@ class LLMMessage(BaseModel):
|
|
134
144
|
content: str
|
135
145
|
function_call: Optional[LLMFunctionCall] = None
|
136
146
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
147
|
+
# link to corresponding chat document, for provenance/rewind purposes
|
148
|
+
chat_document_id: str = ""
|
137
149
|
|
138
150
|
def api_dict(self) -> Dict[str, Any]:
|
139
151
|
"""
|
140
|
-
Convert to dictionary for API request
|
141
|
-
|
142
|
-
|
152
|
+
Convert to dictionary for API request, keeping ONLY
|
153
|
+
the fields that are expected in an API call!
|
154
|
+
E.g., DROP the tool_id, since it is only for use in the Assistant API,
|
155
|
+
not the completion API.
|
143
156
|
Returns:
|
144
157
|
dict: dictionary representation of LLM message
|
145
158
|
"""
|
@@ -155,8 +168,10 @@ class LLMMessage(BaseModel):
|
|
155
168
|
dict_no_none["function_call"]["arguments"] = json.dumps(
|
156
169
|
dict_no_none["function_call"]["arguments"]
|
157
170
|
)
|
171
|
+
# IMPORTANT! drop fields that are not expected in API call
|
158
172
|
dict_no_none.pop("tool_id", None)
|
159
173
|
dict_no_none.pop("timestamp", None)
|
174
|
+
dict_no_none.pop("chat_document_id", None)
|
160
175
|
return dict_no_none
|
161
176
|
|
162
177
|
def __str__(self) -> str:
|
@@ -268,11 +283,15 @@ class LanguageModel(ABC):
|
|
268
283
|
"""
|
269
284
|
)
|
270
285
|
from langroid.language_models.azure_openai import AzureGPT
|
286
|
+
from langroid.language_models.mock_lm import MockLM, MockLMConfig
|
271
287
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
272
288
|
|
273
289
|
if config is None or config.type is None:
|
274
290
|
return None
|
275
291
|
|
292
|
+
if config.type == "mock":
|
293
|
+
return MockLM(cast(MockLMConfig, config))
|
294
|
+
|
276
295
|
openai: Union[Type[AzureGPT], Type[OpenAIGPT]]
|
277
296
|
|
278
297
|
if config.type == "azure":
|
@@ -0,0 +1,91 @@
|
|
1
|
+
"""Mock Language Model for testing"""
|
2
|
+
|
3
|
+
from typing import Callable, Dict, List, Optional, Union
|
4
|
+
|
5
|
+
import langroid.language_models as lm
|
6
|
+
from langroid.language_models import LLMResponse
|
7
|
+
from langroid.language_models.base import LanguageModel, LLMConfig
|
8
|
+
|
9
|
+
|
10
|
+
def none_fn(x: str) -> None | str:
|
11
|
+
return None
|
12
|
+
|
13
|
+
|
14
|
+
class MockLMConfig(LLMConfig):
|
15
|
+
"""
|
16
|
+
Mock Language Model Configuration.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
response_dict (Dict[str, str]): A "response rule-book", in the form of a
|
20
|
+
dictionary; if last msg in dialog is x,then respond with response_dict[x]
|
21
|
+
"""
|
22
|
+
|
23
|
+
response_dict: Dict[str, str] = {}
|
24
|
+
response_fn: Callable[[str], None | str] = none_fn
|
25
|
+
default_response: str = "Mock response"
|
26
|
+
|
27
|
+
type: str = "mock"
|
28
|
+
|
29
|
+
|
30
|
+
class MockLM(LanguageModel):
|
31
|
+
|
32
|
+
def __init__(self, config: MockLMConfig = MockLMConfig()):
|
33
|
+
super().__init__(config)
|
34
|
+
self.config: MockLMConfig = config
|
35
|
+
|
36
|
+
def _response(self, msg: str) -> LLMResponse:
|
37
|
+
# response is based on this fallback order:
|
38
|
+
# - response_dict
|
39
|
+
# - response_fn
|
40
|
+
# - default_response
|
41
|
+
return lm.LLMResponse(
|
42
|
+
message=self.config.response_dict.get(
|
43
|
+
msg,
|
44
|
+
self.config.response_fn(msg) or self.config.default_response,
|
45
|
+
),
|
46
|
+
cached=False,
|
47
|
+
)
|
48
|
+
|
49
|
+
def chat(
|
50
|
+
self,
|
51
|
+
messages: Union[str, List[lm.LLMMessage]],
|
52
|
+
max_tokens: int = 200,
|
53
|
+
functions: Optional[List[lm.LLMFunctionSpec]] = None,
|
54
|
+
function_call: str | Dict[str, str] = "auto",
|
55
|
+
) -> lm.LLMResponse:
|
56
|
+
"""
|
57
|
+
Mock chat function for testing
|
58
|
+
"""
|
59
|
+
last_msg = messages[-1].content if isinstance(messages, list) else messages
|
60
|
+
return self._response(last_msg)
|
61
|
+
|
62
|
+
async def achat(
|
63
|
+
self,
|
64
|
+
messages: Union[str, List[lm.LLMMessage]],
|
65
|
+
max_tokens: int = 200,
|
66
|
+
functions: Optional[List[lm.LLMFunctionSpec]] = None,
|
67
|
+
function_call: str | Dict[str, str] = "auto",
|
68
|
+
) -> lm.LLMResponse:
|
69
|
+
"""
|
70
|
+
Mock chat function for testing
|
71
|
+
"""
|
72
|
+
last_msg = messages[-1].content if isinstance(messages, list) else messages
|
73
|
+
return self._response(last_msg)
|
74
|
+
|
75
|
+
def generate(self, prompt: str, max_tokens: int = 200) -> lm.LLMResponse:
|
76
|
+
"""
|
77
|
+
Mock generate function for testing
|
78
|
+
"""
|
79
|
+
return self._response(prompt)
|
80
|
+
|
81
|
+
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
82
|
+
"""
|
83
|
+
Mock generate function for testing
|
84
|
+
"""
|
85
|
+
return self._response(prompt)
|
86
|
+
|
87
|
+
def get_stream(self) -> bool:
|
88
|
+
return False
|
89
|
+
|
90
|
+
def set_stream(self, stream: bool) -> bool:
|
91
|
+
return False
|
@@ -62,7 +62,7 @@ def retry_with_exponential_backoff(
|
|
62
62
|
if num_retries > max_retries:
|
63
63
|
raise Exception(
|
64
64
|
f"Maximum number of retries ({max_retries}) exceeded."
|
65
|
-
f" Last error: {e}."
|
65
|
+
f" Last error: {str(e)}."
|
66
66
|
)
|
67
67
|
|
68
68
|
# Increment the delay
|
@@ -128,6 +128,7 @@ def async_retry_with_exponential_backoff(
|
|
128
128
|
if num_retries > max_retries:
|
129
129
|
raise Exception(
|
130
130
|
f"Maximum number of retries ({max_retries}) exceeded."
|
131
|
+
f" Last error: {str(e)}."
|
131
132
|
)
|
132
133
|
|
133
134
|
# Increment the delay
|
langroid/mytypes.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
|
-
import hashlib
|
2
|
-
import uuid
|
3
1
|
from enum import Enum
|
4
2
|
from textwrap import dedent
|
5
3
|
from typing import Any, Callable, Dict, List, Union
|
4
|
+
from uuid import uuid4
|
6
5
|
|
7
|
-
from langroid.pydantic_v1 import BaseModel, Extra
|
6
|
+
from langroid.pydantic_v1 import BaseModel, Extra, Field
|
8
7
|
|
9
8
|
Number = Union[int, float]
|
10
9
|
Embedding = List[Number]
|
@@ -40,7 +39,7 @@ class DocMetaData(BaseModel):
|
|
40
39
|
|
41
40
|
source: str = "context"
|
42
41
|
is_chunk: bool = False # if it is a chunk, don't split
|
43
|
-
id: str =
|
42
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
44
43
|
window_ids: List[str] = [] # for RAG: ids of chunks around this one
|
45
44
|
|
46
45
|
def dict_bool_int(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
@@ -67,41 +66,11 @@ class Document(BaseModel):
|
|
67
66
|
content: str
|
68
67
|
metadata: DocMetaData
|
69
68
|
|
70
|
-
@staticmethod
|
71
|
-
def hash_id(doc: str) -> str:
|
72
|
-
# Encode the document as UTF-8
|
73
|
-
doc_utf8 = str(doc).encode("utf-8")
|
74
|
-
|
75
|
-
# Create a SHA256 hash object
|
76
|
-
sha256_hash = hashlib.sha256()
|
77
|
-
|
78
|
-
# Update the hash object with the bytes of the document
|
79
|
-
sha256_hash.update(doc_utf8)
|
80
|
-
|
81
|
-
# Get the hexadecimal representation of the hash
|
82
|
-
hash_hex = sha256_hash.hexdigest()
|
83
|
-
|
84
|
-
# Convert the first part of the hash to a UUID
|
85
|
-
hash_uuid = uuid.UUID(hash_hex[:32])
|
86
|
-
|
87
|
-
return str(hash_uuid)
|
88
|
-
|
89
|
-
def _unique_hash_id(self) -> str:
|
90
|
-
return self.hash_id(str(self))
|
91
|
-
|
92
69
|
def id(self) -> str:
|
93
|
-
|
94
|
-
hasattr(self.metadata, "id")
|
95
|
-
and self.metadata.id is not None
|
96
|
-
and self.metadata.id != ""
|
97
|
-
):
|
98
|
-
return self.metadata.id
|
99
|
-
else:
|
100
|
-
return self._unique_hash_id()
|
70
|
+
return self.metadata.id
|
101
71
|
|
102
72
|
def __str__(self) -> str:
|
103
73
|
# TODO: make metadata a pydantic model to enforce "source"
|
104
|
-
self.metadata.json()
|
105
74
|
return dedent(
|
106
75
|
f"""
|
107
76
|
CONTENT: {self.content}
|
@@ -8,6 +8,7 @@ from io import BytesIO
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Generator, List, Tuple
|
9
9
|
|
10
10
|
from langroid.exceptions import LangroidImportError
|
11
|
+
from langroid.utils.object_registry import ObjectRegistry
|
11
12
|
|
12
13
|
try:
|
13
14
|
import fitz
|
@@ -341,6 +342,8 @@ class DocumentParser(Parser):
|
|
341
342
|
split = [] # tokens in curr split
|
342
343
|
pages: List[str] = []
|
343
344
|
docs: List[Document] = []
|
345
|
+
# metadata.id to be shared by ALL chunks of this document
|
346
|
+
common_id = ObjectRegistry.new_id()
|
344
347
|
for i, page in self.iterate_pages():
|
345
348
|
page_text = self.extract_text_from_page(page)
|
346
349
|
split += self.tokenizer.encode(page_text)
|
@@ -358,6 +361,7 @@ class DocumentParser(Parser):
|
|
358
361
|
metadata=DocMetaData(
|
359
362
|
source=f"{self.source} pages {pg}",
|
360
363
|
is_chunk=True,
|
364
|
+
id=common_id,
|
361
365
|
),
|
362
366
|
)
|
363
367
|
)
|
@@ -372,6 +376,7 @@ class DocumentParser(Parser):
|
|
372
376
|
metadata=DocMetaData(
|
373
377
|
source=f"{self.source} pages {pg}",
|
374
378
|
is_chunk=True,
|
379
|
+
id=common_id,
|
375
380
|
),
|
376
381
|
)
|
377
382
|
)
|
langroid/parsing/parser.py
CHANGED
@@ -7,6 +7,7 @@ import tiktoken
|
|
7
7
|
from langroid.mytypes import Document
|
8
8
|
from langroid.parsing.para_sentence_split import create_chunks, remove_extra_whitespace
|
9
9
|
from langroid.pydantic_v1 import BaseSettings
|
10
|
+
from langroid.utils.object_registry import ObjectRegistry
|
10
11
|
|
11
12
|
logger = logging.getLogger(__name__)
|
12
13
|
logger.setLevel(logging.WARNING)
|
@@ -75,11 +76,13 @@ class Parser:
|
|
75
76
|
return
|
76
77
|
# The original metadata.id (if any) is ignored since it will be same for all
|
77
78
|
# chunks and is useless. We want a distinct id for each chunk.
|
79
|
+
# ASSUMPTION: all chunks c of a doc have same c.metadata.id !
|
78
80
|
orig_ids = [c.metadata.id for c in chunks]
|
79
|
-
ids = [
|
81
|
+
ids = [ObjectRegistry.new_id() for c in chunks]
|
80
82
|
id2chunk = {id: c for id, c in zip(ids, chunks)}
|
81
83
|
|
82
84
|
# group the ids by orig_id
|
85
|
+
# (each distinct orig_id refers to a different document)
|
83
86
|
orig_id_to_ids: Dict[str, List[str]] = {}
|
84
87
|
for orig_id, id in zip(orig_ids, ids):
|
85
88
|
if orig_id not in orig_id_to_ids:
|
@@ -108,6 +111,10 @@ class Parser:
|
|
108
111
|
if d.content.strip() == "":
|
109
112
|
continue
|
110
113
|
chunks = remove_extra_whitespace(d.content).split(self.config.separators[0])
|
114
|
+
# note we are ensuring we COPY the document metadata into each chunk,
|
115
|
+
# which ensures all chunks of a given doc have same metadata
|
116
|
+
# (and in particular same metadata.id, which is important later for
|
117
|
+
# add_window_ids)
|
111
118
|
chunk_docs = [
|
112
119
|
Document(
|
113
120
|
content=c, metadata=d.metadata.copy(update=dict(is_chunk=True))
|
@@ -156,6 +163,10 @@ class Parser:
|
|
156
163
|
if d.content.strip() == "":
|
157
164
|
continue
|
158
165
|
chunks = create_chunks(d.content, self.config.chunk_size, self.num_tokens)
|
166
|
+
# note we are ensuring we COPY the document metadata into each chunk,
|
167
|
+
# which ensures all chunks of a given doc have same metadata
|
168
|
+
# (and in particular same metadata.id, which is important later for
|
169
|
+
# add_window_ids)
|
159
170
|
chunk_docs = [
|
160
171
|
Document(
|
161
172
|
content=c, metadata=d.metadata.copy(update=dict(is_chunk=True))
|
@@ -171,6 +182,10 @@ class Parser:
|
|
171
182
|
final_docs = []
|
172
183
|
for d in docs:
|
173
184
|
chunks = self.chunk_tokens(d.content)
|
185
|
+
# note we are ensuring we COPY the document metadata into each chunk,
|
186
|
+
# which ensures all chunks of a given doc have same metadata
|
187
|
+
# (and in particular same metadata.id, which is important later for
|
188
|
+
# add_window_ids)
|
174
189
|
chunk_docs = [
|
175
190
|
Document(
|
176
191
|
content=c, metadata=d.metadata.copy(update=dict(is_chunk=True))
|
@@ -274,7 +289,7 @@ class Parser:
|
|
274
289
|
# we need this to distinguish docs later in add_window_ids
|
275
290
|
for d in docs:
|
276
291
|
if d.metadata.id in [None, ""]:
|
277
|
-
d.metadata.id =
|
292
|
+
d.metadata.id = ObjectRegistry.new_id()
|
278
293
|
# some docs are already splits, so don't split them further!
|
279
294
|
chunked_docs = [d for d in docs if d.metadata.is_chunk]
|
280
295
|
big_docs = [d for d in docs if not d.metadata.is_chunk]
|
langroid/utils/__init__.py
CHANGED
@@ -5,6 +5,7 @@ from . import logging
|
|
5
5
|
from . import pydantic_utils
|
6
6
|
from . import system
|
7
7
|
from . import output
|
8
|
+
from . import object_registry
|
8
9
|
|
9
10
|
__all__ = [
|
10
11
|
"configuration",
|
@@ -14,4 +15,5 @@ __all__ = [
|
|
14
15
|
"pydantic_utils",
|
15
16
|
"system",
|
16
17
|
"output",
|
18
|
+
"object_registry",
|
17
19
|
]
|
langroid/utils/constants.py
CHANGED
@@ -13,10 +13,11 @@ class Colors(BaseModel):
|
|
13
13
|
RESET: str = "\033[0m"
|
14
14
|
|
15
15
|
|
16
|
-
USER_QUIT_STRINGS = ["q", "x", "quit", "exit", "bye"]
|
17
16
|
NO_ANSWER = "DO-NOT-KNOW"
|
18
17
|
DONE = "DONE"
|
18
|
+
USER_QUIT_STRINGS = ["q", "x", "quit", "exit", "bye", DONE]
|
19
19
|
PASS = "__PASS__"
|
20
20
|
PASS_TO = PASS + ":"
|
21
21
|
SEND_TO = "SEND:"
|
22
22
|
TOOL = "TOOL"
|
23
|
+
AT = "@"
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import time
|
2
|
+
from typing import TYPE_CHECKING, Dict, Optional, TypeAlias, TypeVar
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from langroid.pydantic_v1 import BaseModel
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from langroid.agent.base import Agent
|
9
|
+
from langroid.agent.chat_agent import ChatAgent
|
10
|
+
from langroid.agent.chat_document import ChatDocument
|
11
|
+
|
12
|
+
# any derivative of BaseModel that has an id() method or an id attribute
|
13
|
+
ObjWithId: TypeAlias = ChatDocument | ChatAgent | Agent
|
14
|
+
else:
|
15
|
+
ObjWithId = BaseModel
|
16
|
+
|
17
|
+
# Define a type variable that can be any subclass of BaseModel
|
18
|
+
T = TypeVar("T", bound=BaseModel)
|
19
|
+
|
20
|
+
|
21
|
+
class ObjectRegistry:
|
22
|
+
"""A global registry to hold id -> object mappings."""
|
23
|
+
|
24
|
+
registry: Dict[str, ObjWithId] = {}
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def add(cls, obj: ObjWithId) -> str:
|
28
|
+
"""Adds an object to the registry, returning the object's ID."""
|
29
|
+
object_id = obj.id() if callable(obj.id) else obj.id
|
30
|
+
cls.registry[object_id] = obj
|
31
|
+
return object_id
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def get(cls, obj_id: str) -> Optional[ObjWithId]:
|
35
|
+
"""Retrieves an object by ID if it still exists."""
|
36
|
+
return cls.registry.get(obj_id)
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def register_object(cls, obj: ObjWithId) -> str:
|
40
|
+
"""Registers an object in the registry, returning the object's ID."""
|
41
|
+
return cls.add(obj)
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def remove(cls, obj_id: str) -> None:
|
45
|
+
"""Removes an object from the registry."""
|
46
|
+
if obj_id in cls.registry:
|
47
|
+
del cls.registry[obj_id]
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def cleanup(cls) -> None:
|
51
|
+
"""Cleans up the registry by removing entries where the object is None."""
|
52
|
+
to_remove = [key for key, value in cls.registry.items() if value is None]
|
53
|
+
for key in to_remove:
|
54
|
+
del cls.registry[key]
|
55
|
+
|
56
|
+
@staticmethod
|
57
|
+
def new_id() -> str:
|
58
|
+
"""Generates a new unique ID."""
|
59
|
+
return str(uuid4())
|
60
|
+
|
61
|
+
|
62
|
+
def scheduled_cleanup(interval: int = 600) -> None:
|
63
|
+
"""Periodically cleans up the global registry every 'interval' seconds."""
|
64
|
+
while True:
|
65
|
+
ObjectRegistry.cleanup()
|
66
|
+
time.sleep(interval)
|
langroid/utils/system.py
CHANGED
langroid/vector_store/base.py
CHANGED
@@ -12,6 +12,7 @@ from langroid.mytypes import Document
|
|
12
12
|
from langroid.pydantic_v1 import BaseSettings
|
13
13
|
from langroid.utils.algorithms.graph import components, topological_sort
|
14
14
|
from langroid.utils.configuration import settings
|
15
|
+
from langroid.utils.object_registry import ObjectRegistry
|
15
16
|
from langroid.utils.output.printing import print_long_text
|
16
17
|
from langroid.utils.pandas_utils import stringify
|
17
18
|
|
@@ -163,7 +164,7 @@ class VectorStore(ABC):
|
|
163
164
|
vecdbs don't like having blank ids."""
|
164
165
|
for d in documents:
|
165
166
|
if d.metadata.id in [None, ""]:
|
166
|
-
d.metadata.id =
|
167
|
+
d.metadata.id = ObjectRegistry.new_id()
|
167
168
|
|
168
169
|
@abstractmethod
|
169
170
|
def similar_texts_with_scores(
|
@@ -254,7 +255,7 @@ class VectorStore(ABC):
|
|
254
255
|
metadata=metadata,
|
255
256
|
)
|
256
257
|
# make a fresh id since content is in general different
|
257
|
-
document.metadata.id =
|
258
|
+
document.metadata.id = ObjectRegistry.new_id()
|
258
259
|
final_docs += [document]
|
259
260
|
final_scores += [max(id2max_score[id] for id in w)]
|
260
261
|
return list(zip(final_docs, final_scores))
|