langroid 0.1.263__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +15 -1
- langroid/agent/chat_agent.py +68 -16
- langroid/agent/chat_document.py +57 -3
- langroid/agent/special/doc_chat_agent.py +8 -26
- langroid/agent/task.py +163 -32
- langroid/agent/tools/__init__.py +4 -0
- langroid/agent/tools/rewind_tool.py +136 -0
- langroid/cachedb/redis_cachedb.py +8 -4
- langroid/language_models/__init__.py +3 -0
- langroid/language_models/base.py +23 -4
- langroid/language_models/mock_lm.py +96 -0
- langroid/language_models/utils.py +2 -1
- langroid/mytypes.py +4 -35
- langroid/parsing/document_parser.py +5 -0
- langroid/parsing/parser.py +17 -2
- langroid/utils/__init__.py +2 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/system.py +11 -2
- langroid/vector_store/base.py +3 -2
- langroid/vector_store/lancedb.py +32 -23
- {langroid-0.1.263.dist-info → langroid-0.2.0.dist-info}/METADATA +5 -8
- {langroid-0.1.263.dist-info → langroid-0.2.0.dist-info}/RECORD +25 -23
- pyproject.toml +3 -6
- langroid/language_models/openai_assistants.py +0 -3
- {langroid-0.1.263.dist-info → langroid-0.2.0.dist-info}/LICENSE +0 -0
- {langroid-0.1.263.dist-info → langroid-0.2.0.dist-info}/WHEEL +0 -0
langroid/agent/base.py
CHANGED
@@ -40,9 +40,10 @@ from langroid.mytypes import Entity
|
|
40
40
|
from langroid.parsing.parse_json import extract_top_level_json
|
41
41
|
from langroid.parsing.parser import Parser, ParsingConfig
|
42
42
|
from langroid.prompts.prompts_config import PromptsConfig
|
43
|
-
from langroid.pydantic_v1 import BaseSettings, ValidationError, validator
|
43
|
+
from langroid.pydantic_v1 import BaseSettings, Field, ValidationError, validator
|
44
44
|
from langroid.utils.configuration import settings
|
45
45
|
from langroid.utils.constants import NO_ANSWER
|
46
|
+
from langroid.utils.object_registry import ObjectRegistry
|
46
47
|
from langroid.utils.output import status
|
47
48
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
48
49
|
|
@@ -64,6 +65,7 @@ class AgentConfig(BaseSettings):
|
|
64
65
|
parsing: Optional[ParsingConfig] = ParsingConfig()
|
65
66
|
prompts: Optional[PromptsConfig] = PromptsConfig()
|
66
67
|
show_stats: bool = True # show token usage/cost stats?
|
68
|
+
add_to_registry: bool = True # register agent in ObjectRegistry?
|
67
69
|
|
68
70
|
@validator("name")
|
69
71
|
def check_name_alphanum(cls, v: str) -> str:
|
@@ -90,6 +92,8 @@ class Agent(ABC):
|
|
90
92
|
information about any tool/function-calling messages that have been defined.
|
91
93
|
"""
|
92
94
|
|
95
|
+
id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
|
96
|
+
|
93
97
|
def __init__(self, config: AgentConfig = AgentConfig()):
|
94
98
|
self.config = config
|
95
99
|
self.lock = asyncio.Lock() # for async access to update self.llm.usage_cost
|
@@ -114,6 +118,8 @@ class Agent(ABC):
|
|
114
118
|
self.parser: Optional[Parser] = (
|
115
119
|
Parser(config.parsing) if config.parsing else None
|
116
120
|
)
|
121
|
+
if config.add_to_registry:
|
122
|
+
ObjectRegistry.register_object(self)
|
117
123
|
|
118
124
|
self.callbacks = SimpleNamespace(
|
119
125
|
start_llm_stream=lambda: noop_fn,
|
@@ -128,6 +134,14 @@ class Agent(ABC):
|
|
128
134
|
show_start_response=noop_fn,
|
129
135
|
)
|
130
136
|
|
137
|
+
@staticmethod
|
138
|
+
def from_id(id: str) -> "Agent":
|
139
|
+
return cast(Agent, ObjectRegistry.get(id))
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def delete_id(id: str) -> None:
|
143
|
+
ObjectRegistry.remove(id)
|
144
|
+
|
131
145
|
def entity_responders(
|
132
146
|
self,
|
133
147
|
) -> List[
|
langroid/agent/chat_agent.py
CHANGED
@@ -21,6 +21,7 @@ from langroid.language_models.base import (
|
|
21
21
|
)
|
22
22
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
23
23
|
from langroid.utils.configuration import settings
|
24
|
+
from langroid.utils.object_registry import ObjectRegistry
|
24
25
|
from langroid.utils.output import status
|
25
26
|
|
26
27
|
console = Console()
|
@@ -137,11 +138,22 @@ class ChatAgent(Agent):
|
|
137
138
|
self.llm_functions_usable: Set[str] = set()
|
138
139
|
self.llm_function_force: Optional[Dict[str, str]] = None
|
139
140
|
|
141
|
+
@staticmethod
|
142
|
+
def from_id(id: str) -> "ChatAgent":
|
143
|
+
"""
|
144
|
+
Get an agent from its ID
|
145
|
+
Args:
|
146
|
+
agent_id (str): ID of the agent
|
147
|
+
Returns:
|
148
|
+
ChatAgent: The agent with the given ID
|
149
|
+
"""
|
150
|
+
return cast(ChatAgent, Agent.from_id(id))
|
151
|
+
|
140
152
|
def clone(self, i: int = 0) -> "ChatAgent":
|
141
153
|
"""Create i'th clone of this agent, ensuring tool use/handling is cloned.
|
142
154
|
Important: We assume all member variables are in the __init__ method here
|
143
155
|
and in the Agent class.
|
144
|
-
TODO: We are attempting to
|
156
|
+
TODO: We are attempting to clone an agent after its state has been
|
145
157
|
changed in possibly many ways. Below is an imperfect solution. Caution advised.
|
146
158
|
Revisit later.
|
147
159
|
"""
|
@@ -158,6 +170,9 @@ class ChatAgent(Agent):
|
|
158
170
|
new_agent.llm_function_force = self.llm_function_force
|
159
171
|
# Caution - we are copying the vector-db, maybe we don't always want this?
|
160
172
|
new_agent.vecdb = self.vecdb
|
173
|
+
new_agent.id = ObjectRegistry.new_id()
|
174
|
+
if self.config.add_to_registry:
|
175
|
+
ObjectRegistry.register_object(new_agent)
|
161
176
|
return new_agent
|
162
177
|
|
163
178
|
def _fn_call_available(self) -> bool:
|
@@ -202,6 +217,10 @@ class ChatAgent(Agent):
|
|
202
217
|
if start < 0:
|
203
218
|
n = len(self.message_history)
|
204
219
|
start = max(0, n + start)
|
220
|
+
dropped = self.message_history[start:]
|
221
|
+
for msg in dropped:
|
222
|
+
# clear out the chat document from the ObjectRegistry
|
223
|
+
ChatDocument.delete_id(msg.chat_document_id)
|
205
224
|
self.message_history = self.message_history[:start]
|
206
225
|
|
207
226
|
def update_history(self, message: str, response: str) -> None:
|
@@ -310,10 +329,24 @@ class ChatAgent(Agent):
|
|
310
329
|
|
311
330
|
def last_message_with_role(self, role: Role) -> LLMMessage | None:
|
312
331
|
"""from `message_history`, return the last message with role `role`"""
|
313
|
-
for
|
314
|
-
|
315
|
-
|
316
|
-
|
332
|
+
n_role_msgs = len([m for m in self.message_history if m.role == role])
|
333
|
+
if n_role_msgs == 0:
|
334
|
+
return None
|
335
|
+
idx = self.nth_message_idx_with_role(role, n_role_msgs)
|
336
|
+
return self.message_history[idx]
|
337
|
+
|
338
|
+
def nth_message_idx_with_role(self, role: Role, n: int) -> int:
|
339
|
+
"""Index of `n`th message in message_history, with specified role.
|
340
|
+
(n is assumed to be 1-based, i.e. 1 is the first message with that role).
|
341
|
+
Return -1 if not found. Index = 0 is the first message in the history.
|
342
|
+
"""
|
343
|
+
indices_with_role = [
|
344
|
+
i for i, m in enumerate(self.message_history) if m.role == role
|
345
|
+
]
|
346
|
+
|
347
|
+
if len(indices_with_role) < n:
|
348
|
+
return -1
|
349
|
+
return indices_with_role[n - 1]
|
317
350
|
|
318
351
|
def update_last_message(self, message: str, role: str = Role.USER) -> None:
|
319
352
|
"""
|
@@ -488,9 +521,9 @@ class ChatAgent(Agent):
|
|
488
521
|
return None
|
489
522
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
490
523
|
response = self.llm_response_messages(hist, output_len)
|
491
|
-
# TODO - when response contains function_call we should include
|
492
|
-
# that (and related fields) in the message_history
|
493
524
|
self.message_history.append(ChatDocument.to_LLMMessage(response))
|
525
|
+
response.metadata.msg_idx = len(self.message_history) - 1
|
526
|
+
response.metadata.agent_id = self.id
|
494
527
|
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
495
528
|
response.metadata.tool_ids = (
|
496
529
|
[]
|
@@ -511,9 +544,9 @@ class ChatAgent(Agent):
|
|
511
544
|
hist, output_len = self._prep_llm_messages(message)
|
512
545
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
513
546
|
response = await self.llm_response_messages_async(hist, output_len)
|
514
|
-
# TODO - when response contains function_call we should include
|
515
|
-
# that (and related fields) in the message_history
|
516
547
|
self.message_history.append(ChatDocument.to_LLMMessage(response))
|
548
|
+
response.metadata.msg_idx = len(self.message_history) - 1
|
549
|
+
response.metadata.agent_id = self.id
|
517
550
|
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
518
551
|
response.metadata.tool_ids = (
|
519
552
|
[]
|
@@ -522,6 +555,16 @@ class ChatAgent(Agent):
|
|
522
555
|
)
|
523
556
|
return response
|
524
557
|
|
558
|
+
def init_message_history(self) -> None:
|
559
|
+
"""
|
560
|
+
Initialize the message history with the system message and user message
|
561
|
+
"""
|
562
|
+
self.message_history = [self._create_system_and_tools_message()]
|
563
|
+
if self.user_message:
|
564
|
+
self.message_history.append(
|
565
|
+
LLMMessage(role=Role.USER, content=self.user_message)
|
566
|
+
)
|
567
|
+
|
525
568
|
def _prep_llm_messages(
|
526
569
|
self,
|
527
570
|
message: Optional[str | ChatDocument] = None,
|
@@ -555,11 +598,7 @@ class ChatAgent(Agent):
|
|
555
598
|
|
556
599
|
if len(self.message_history) == 0:
|
557
600
|
# initial messages have not yet been loaded, so load them
|
558
|
-
self.
|
559
|
-
if self.user_message:
|
560
|
-
self.message_history.append(
|
561
|
-
LLMMessage(role=Role.USER, content=self.user_message)
|
562
|
-
)
|
601
|
+
self.init_message_history()
|
563
602
|
|
564
603
|
# for debugging, show the initial message history
|
565
604
|
if settings.debug:
|
@@ -576,8 +615,14 @@ class ChatAgent(Agent):
|
|
576
615
|
self.message_history[0] = self._create_system_and_tools_message()
|
577
616
|
|
578
617
|
if message is not None:
|
579
|
-
|
580
|
-
|
618
|
+
if (
|
619
|
+
isinstance(message, str)
|
620
|
+
or message.id() != self.message_history[-1].chat_document_id
|
621
|
+
):
|
622
|
+
# either the message is a str, or it is a fresh ChatDocument
|
623
|
+
# different from the last message in the history
|
624
|
+
llm_msg = ChatDocument.to_LLMMessage(message)
|
625
|
+
self.message_history.append(llm_msg)
|
581
626
|
|
582
627
|
hist = self.message_history
|
583
628
|
output_len = self.config.llm.max_output_tokens
|
@@ -614,6 +659,7 @@ class ChatAgent(Agent):
|
|
614
659
|
)
|
615
660
|
# drop the second message, i.e. first msg after the sys msg
|
616
661
|
# (typically user msg).
|
662
|
+
ChatDocument.delete_id(hist[1].chat_document_id)
|
617
663
|
hist = hist[:1] + hist[2:]
|
618
664
|
|
619
665
|
if len(hist) < len(self.message_history):
|
@@ -650,6 +696,12 @@ class ChatAgent(Agent):
|
|
650
696
|
and the response may be truncated.
|
651
697
|
"""
|
652
698
|
)
|
699
|
+
if isinstance(message, ChatDocument):
|
700
|
+
# record the position of the corresponding LLMMessage in
|
701
|
+
# the message_history
|
702
|
+
message.metadata.msg_idx = len(hist) - 1
|
703
|
+
message.metadata.agent_id = self.id
|
704
|
+
|
653
705
|
return hist, output_len
|
654
706
|
|
655
707
|
def _function_args(
|
langroid/agent/chat_document.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import copy
|
1
4
|
import json
|
2
5
|
from enum import Enum
|
3
|
-
from typing import List, Optional, Union
|
6
|
+
from typing import Any, List, Optional, Union, cast
|
4
7
|
|
5
8
|
from langroid.agent.tool_message import ToolMessage
|
6
9
|
from langroid.language_models.base import (
|
@@ -14,6 +17,7 @@ from langroid.mytypes import DocMetaData, Document, Entity
|
|
14
17
|
from langroid.parsing.agent_chats import parse_message
|
15
18
|
from langroid.parsing.parse_json import extract_top_level_json, top_level_json_field
|
16
19
|
from langroid.pydantic_v1 import BaseModel, Extra
|
20
|
+
from langroid.utils.object_registry import ObjectRegistry
|
17
21
|
from langroid.utils.output.printing import shorten_text
|
18
22
|
|
19
23
|
|
@@ -41,8 +45,11 @@ class StatusCode(str, Enum):
|
|
41
45
|
|
42
46
|
|
43
47
|
class ChatDocMetaData(DocMetaData):
|
44
|
-
|
45
|
-
|
48
|
+
parent_id: str = "" # msg (ChatDocument) to which this is a response
|
49
|
+
child_id: str = "" # ChatDocument that has response to this message
|
50
|
+
agent_id: str = "" # ChatAgent that generated this message
|
51
|
+
msg_idx: int = -1 # index of this message in the agent `message_history`
|
52
|
+
sender: Entity # sender of the message
|
46
53
|
tool_ids: List[str] = [] # stack of tool_ids; used by OpenAIAssistant
|
47
54
|
block: None | Entity = None
|
48
55
|
sender_name: str = ""
|
@@ -53,6 +60,14 @@ class ChatDocMetaData(DocMetaData):
|
|
53
60
|
has_citation: bool = False
|
54
61
|
status: Optional[StatusCode] = None
|
55
62
|
|
63
|
+
@property
|
64
|
+
def parent(self) -> Optional["ChatDocument"]:
|
65
|
+
return ChatDocument.from_id(self.parent_id)
|
66
|
+
|
67
|
+
@property
|
68
|
+
def child(self) -> Optional["ChatDocument"]:
|
69
|
+
return ChatDocument.from_id(self.child_id)
|
70
|
+
|
56
71
|
|
57
72
|
class ChatDocLoggerFields(BaseModel):
|
58
73
|
sender_entity: Entity = Entity.USER
|
@@ -75,6 +90,41 @@ class ChatDocument(Document):
|
|
75
90
|
metadata: ChatDocMetaData
|
76
91
|
attachment: None | ChatDocAttachment = None
|
77
92
|
|
93
|
+
def __init__(self, **data: Any):
|
94
|
+
super().__init__(**data)
|
95
|
+
ObjectRegistry.register_object(self)
|
96
|
+
|
97
|
+
@property
|
98
|
+
def parent(self) -> Optional["ChatDocument"]:
|
99
|
+
return ChatDocument.from_id(self.metadata.parent_id)
|
100
|
+
|
101
|
+
@property
|
102
|
+
def child(self) -> Optional["ChatDocument"]:
|
103
|
+
return ChatDocument.from_id(self.metadata.child_id)
|
104
|
+
|
105
|
+
@staticmethod
|
106
|
+
def deepcopy(doc: ChatDocument) -> ChatDocument:
|
107
|
+
new_doc = copy.deepcopy(doc)
|
108
|
+
new_doc.metadata.id = ObjectRegistry.new_id()
|
109
|
+
ObjectRegistry.register_object(new_doc)
|
110
|
+
return new_doc
|
111
|
+
|
112
|
+
@staticmethod
|
113
|
+
def from_id(id: str) -> Optional["ChatDocument"]:
|
114
|
+
return cast(ChatDocument, ObjectRegistry.get(id))
|
115
|
+
|
116
|
+
@staticmethod
|
117
|
+
def delete_id(id: str) -> None:
|
118
|
+
"""Remove ChatDocument with given id from ObjectRegistry,
|
119
|
+
and all its descendants.
|
120
|
+
"""
|
121
|
+
chat_doc = ChatDocument.from_id(id)
|
122
|
+
# first delete all descendants
|
123
|
+
while chat_doc is not None:
|
124
|
+
next_chat_doc = chat_doc.child
|
125
|
+
ObjectRegistry.remove(chat_doc.id())
|
126
|
+
chat_doc = next_chat_doc
|
127
|
+
|
78
128
|
def __str__(self) -> str:
|
79
129
|
fields = self.log_fields()
|
80
130
|
tool_str = ""
|
@@ -224,6 +274,7 @@ class ChatDocument(Document):
|
|
224
274
|
sender_role = Role.USER
|
225
275
|
fun_call = None
|
226
276
|
tool_id = ""
|
277
|
+
chat_document_id: str = ""
|
227
278
|
if isinstance(message, ChatDocument):
|
228
279
|
content = message.content
|
229
280
|
fun_call = message.function_call
|
@@ -240,6 +291,7 @@ class ChatDocument(Document):
|
|
240
291
|
sender_name = message.metadata.sender_name
|
241
292
|
tool_ids = message.metadata.tool_ids
|
242
293
|
tool_id = tool_ids[-1] if len(tool_ids) > 0 else ""
|
294
|
+
chat_document_id = message.id()
|
243
295
|
if message.metadata.sender == Entity.SYSTEM:
|
244
296
|
sender_role = Role.SYSTEM
|
245
297
|
if (
|
@@ -260,7 +312,9 @@ class ChatDocument(Document):
|
|
260
312
|
content=content,
|
261
313
|
function_call=fun_call,
|
262
314
|
name=sender_name,
|
315
|
+
chat_document_id=chat_document_id,
|
263
316
|
)
|
264
317
|
|
265
318
|
|
319
|
+
LLMMessage.update_forward_refs()
|
266
320
|
ChatDocMetaData.update_forward_refs()
|
@@ -35,7 +35,6 @@ from langroid.embedding_models.models import (
|
|
35
35
|
OpenAIEmbeddingsConfig,
|
36
36
|
SentenceTransformerEmbeddingsConfig,
|
37
37
|
)
|
38
|
-
from langroid.exceptions import LangroidImportError
|
39
38
|
from langroid.language_models.base import StreamingIfAllowed
|
40
39
|
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
41
40
|
from langroid.mytypes import DocMetaData, Document, Entity
|
@@ -54,6 +53,7 @@ from langroid.parsing.utils import batched
|
|
54
53
|
from langroid.prompts.prompts_config import PromptsConfig
|
55
54
|
from langroid.prompts.templates import SUMMARY_ANSWER_PROMPT_GPT4
|
56
55
|
from langroid.utils.constants import NO_ANSWER
|
56
|
+
from langroid.utils.object_registry import ObjectRegistry
|
57
57
|
from langroid.utils.output import show_if_debug, status
|
58
58
|
from langroid.utils.output.citations import (
|
59
59
|
extract_markdown_references,
|
@@ -101,29 +101,6 @@ oai_embed_config = OpenAIEmbeddingsConfig(
|
|
101
101
|
dims=1536,
|
102
102
|
)
|
103
103
|
|
104
|
-
vecdb_config: VectorStoreConfig = QdrantDBConfig(
|
105
|
-
collection_name="doc-chat-qdrantdb",
|
106
|
-
replace_collection=True,
|
107
|
-
storage_path=".qdrantdb/data/",
|
108
|
-
embedding=hf_embed_config if has_sentence_transformers else oai_embed_config,
|
109
|
-
)
|
110
|
-
|
111
|
-
try:
|
112
|
-
import lancedb
|
113
|
-
|
114
|
-
lancedb # appease mypy
|
115
|
-
from langroid.vector_store.lancedb import LanceDBConfig
|
116
|
-
|
117
|
-
vecdb_config = LanceDBConfig(
|
118
|
-
collection_name="doc-chat-lancedb",
|
119
|
-
replace_collection=True,
|
120
|
-
storage_path=".lancedb/data/",
|
121
|
-
embedding=(hf_embed_config if has_sentence_transformers else oai_embed_config),
|
122
|
-
)
|
123
|
-
|
124
|
-
except (ImportError, LangroidImportError):
|
125
|
-
pass
|
126
|
-
|
127
104
|
|
128
105
|
class DocChatAgentConfig(ChatAgentConfig):
|
129
106
|
system_message: str = DEFAULT_DOC_CHAT_SYSTEM_MESSAGE
|
@@ -201,7 +178,12 @@ class DocChatAgentConfig(ChatAgentConfig):
|
|
201
178
|
)
|
202
179
|
|
203
180
|
# Allow vecdb to be None in case we want to explicitly set it later
|
204
|
-
vecdb: Optional[VectorStoreConfig] =
|
181
|
+
vecdb: Optional[VectorStoreConfig] = QdrantDBConfig(
|
182
|
+
collection_name="doc-chat-qdrantdb",
|
183
|
+
replace_collection=True,
|
184
|
+
storage_path=".qdrantdb/data/",
|
185
|
+
embedding=hf_embed_config if has_sentence_transformers else oai_embed_config,
|
186
|
+
)
|
205
187
|
|
206
188
|
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
207
189
|
type="openai",
|
@@ -414,7 +396,7 @@ class DocChatAgent(ChatAgent):
|
|
414
396
|
raise ValueError("Parser not set")
|
415
397
|
for d in docs:
|
416
398
|
if d.metadata.id in [None, ""]:
|
417
|
-
d.metadata.id =
|
399
|
+
d.metadata.id = ObjectRegistry.new_id()
|
418
400
|
if split:
|
419
401
|
docs = self.parser.split(docs)
|
420
402
|
else:
|