langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
langroid/__init__.py
CHANGED
@@ -1,3 +1,98 @@
|
|
1
1
|
"""
|
2
2
|
Main langroid package
|
3
3
|
"""
|
4
|
+
|
5
|
+
from . import mytypes
|
6
|
+
from . import utils
|
7
|
+
|
8
|
+
from . import parsing
|
9
|
+
from . import prompts
|
10
|
+
from . import cachedb
|
11
|
+
|
12
|
+
from . import language_models
|
13
|
+
from . import embedding_models
|
14
|
+
|
15
|
+
from . import vector_store
|
16
|
+
from . import agent
|
17
|
+
|
18
|
+
from .agent.base import (
|
19
|
+
Agent,
|
20
|
+
AgentConfig,
|
21
|
+
)
|
22
|
+
|
23
|
+
from .agent.batch import (
|
24
|
+
run_batch_tasks,
|
25
|
+
llm_response_batch,
|
26
|
+
agent_response_batch,
|
27
|
+
)
|
28
|
+
|
29
|
+
from .agent.chat_document import (
|
30
|
+
ChatDocument,
|
31
|
+
ChatDocMetaData,
|
32
|
+
)
|
33
|
+
|
34
|
+
from .agent.tool_message import (
|
35
|
+
ToolMessage,
|
36
|
+
)
|
37
|
+
|
38
|
+
from .agent.chat_agent import (
|
39
|
+
ChatAgent,
|
40
|
+
ChatAgentConfig,
|
41
|
+
)
|
42
|
+
|
43
|
+
from .agent.task import Task
|
44
|
+
|
45
|
+
try:
|
46
|
+
from .agent.callbacks.chainlit import (
|
47
|
+
ChainlitAgentCallbacks,
|
48
|
+
ChainlitTaskCallbacks,
|
49
|
+
ChainlitCallbackConfig,
|
50
|
+
)
|
51
|
+
|
52
|
+
chainlit_available = True
|
53
|
+
ChainlitAgentCallbacks
|
54
|
+
ChainlitTaskCallbacks
|
55
|
+
ChainlitCallbackConfig
|
56
|
+
except ImportError:
|
57
|
+
chainlit_available = False
|
58
|
+
|
59
|
+
|
60
|
+
from .mytypes import (
|
61
|
+
DocMetaData,
|
62
|
+
Document,
|
63
|
+
Entity,
|
64
|
+
)
|
65
|
+
|
66
|
+
__all__ = [
|
67
|
+
"mytypes",
|
68
|
+
"utils",
|
69
|
+
"parsing",
|
70
|
+
"prompts",
|
71
|
+
"cachedb",
|
72
|
+
"language_models",
|
73
|
+
"embedding_models",
|
74
|
+
"vector_store",
|
75
|
+
"agent",
|
76
|
+
"Agent",
|
77
|
+
"AgentConfig",
|
78
|
+
"ChatAgent",
|
79
|
+
"ChatAgentConfig",
|
80
|
+
"ChatDocument",
|
81
|
+
"ChatDocMetaData",
|
82
|
+
"Task",
|
83
|
+
"DocMetaData",
|
84
|
+
"Document",
|
85
|
+
"Entity",
|
86
|
+
"ToolMessage",
|
87
|
+
"run_batch_tasks",
|
88
|
+
"llm_response_batch",
|
89
|
+
"agent_response_batch",
|
90
|
+
]
|
91
|
+
if chainlit_available:
|
92
|
+
__all__.extend(
|
93
|
+
[
|
94
|
+
"ChainlitAgentCallbacks",
|
95
|
+
"ChainlitTaskCallbacks",
|
96
|
+
"ChainlitCallbackConfig",
|
97
|
+
]
|
98
|
+
)
|
langroid/agent/__init__.py
CHANGED
@@ -0,0 +1,40 @@
|
|
1
|
+
from .base import Agent, AgentConfig
|
2
|
+
from .chat_document import (
|
3
|
+
ChatDocAttachment,
|
4
|
+
ChatDocMetaData,
|
5
|
+
ChatDocLoggerFields,
|
6
|
+
ChatDocument,
|
7
|
+
)
|
8
|
+
from .chat_agent import ChatAgentConfig, ChatAgent
|
9
|
+
from .tool_message import ToolMessage
|
10
|
+
from .task import Task
|
11
|
+
|
12
|
+
from . import base
|
13
|
+
from . import chat_document
|
14
|
+
from . import chat_agent
|
15
|
+
from . import task
|
16
|
+
from . import batch
|
17
|
+
from . import tool_message
|
18
|
+
from . import tools
|
19
|
+
from . import special
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"Agent",
|
23
|
+
"AgentConfig",
|
24
|
+
"ChatDocAttachment",
|
25
|
+
"ChatDocMetaData",
|
26
|
+
"ChatDocLoggerFields",
|
27
|
+
"ChatDocument",
|
28
|
+
"ChatAgent",
|
29
|
+
"ChatAgentConfig",
|
30
|
+
"ToolMessage",
|
31
|
+
"Task",
|
32
|
+
"base",
|
33
|
+
"chat_document",
|
34
|
+
"chat_agent",
|
35
|
+
"task",
|
36
|
+
"batch",
|
37
|
+
"tool_message",
|
38
|
+
"tools",
|
39
|
+
"special",
|
40
|
+
]
|
langroid/agent/base.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
|
+
import asyncio
|
1
2
|
import inspect
|
2
3
|
import json
|
3
4
|
import logging
|
4
5
|
from abc import ABC
|
5
6
|
from contextlib import ExitStack
|
7
|
+
from types import SimpleNamespace
|
6
8
|
from typing import (
|
7
9
|
Any,
|
8
10
|
Callable,
|
@@ -20,6 +22,7 @@ from typing import (
|
|
20
22
|
from pydantic import BaseSettings, ValidationError
|
21
23
|
from rich import print
|
22
24
|
from rich.console import Console
|
25
|
+
from rich.markup import escape
|
23
26
|
from rich.prompt import Prompt
|
24
27
|
|
25
28
|
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
@@ -32,15 +35,17 @@ from langroid.language_models.base import (
|
|
32
35
|
LLMTokenUsage,
|
33
36
|
StreamingIfAllowed,
|
34
37
|
)
|
35
|
-
from langroid.
|
36
|
-
from langroid.
|
38
|
+
from langroid.language_models.openai_gpt import OpenAIGPTConfig
|
39
|
+
from langroid.mytypes import Entity
|
40
|
+
from langroid.parsing.parse_json import extract_top_level_json
|
37
41
|
from langroid.parsing.parser import Parser, ParsingConfig
|
38
42
|
from langroid.prompts.prompts_config import PromptsConfig
|
39
43
|
from langroid.utils.configuration import settings
|
40
44
|
from langroid.utils.constants import NO_ANSWER
|
45
|
+
from langroid.utils.output import status
|
41
46
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
42
47
|
|
43
|
-
console = Console()
|
48
|
+
console = Console(quiet=settings.quiet)
|
44
49
|
|
45
50
|
logger = logging.getLogger(__name__)
|
46
51
|
|
@@ -53,10 +58,15 @@ class AgentConfig(BaseSettings):
|
|
53
58
|
|
54
59
|
name: str = "LLM-Agent"
|
55
60
|
debug: bool = False
|
56
|
-
vecdb: Optional[VectorStoreConfig] =
|
57
|
-
llm: Optional[LLMConfig] =
|
61
|
+
vecdb: Optional[VectorStoreConfig] = None
|
62
|
+
llm: Optional[LLMConfig] = OpenAIGPTConfig()
|
58
63
|
parsing: Optional[ParsingConfig] = ParsingConfig()
|
59
64
|
prompts: Optional[PromptsConfig] = PromptsConfig()
|
65
|
+
show_stats: bool = True # show token usage/cost stats?
|
66
|
+
|
67
|
+
|
68
|
+
def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
69
|
+
pass
|
60
70
|
|
61
71
|
|
62
72
|
class Agent(ABC):
|
@@ -70,8 +80,9 @@ class Agent(ABC):
|
|
70
80
|
information about any tool/function-calling messages that have been defined.
|
71
81
|
"""
|
72
82
|
|
73
|
-
def __init__(self, config: AgentConfig):
|
83
|
+
def __init__(self, config: AgentConfig = AgentConfig()):
|
74
84
|
self.config = config
|
85
|
+
self.lock = asyncio.Lock() # for async access to update self.llm.usage_cost
|
75
86
|
self.dialog: List[Tuple[str, str]] = [] # seq of LLM (prompt, response) tuples
|
76
87
|
self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
|
77
88
|
self.llm_tools_handled: Set[str] = set()
|
@@ -86,6 +97,18 @@ class Agent(ABC):
|
|
86
97
|
self.parser: Optional[Parser] = (
|
87
98
|
Parser(config.parsing) if config.parsing else None
|
88
99
|
)
|
100
|
+
self.callbacks = SimpleNamespace(
|
101
|
+
start_llm_stream=lambda: noop_fn,
|
102
|
+
cancel_llm_stream=noop_fn,
|
103
|
+
finish_llm_stream=noop_fn,
|
104
|
+
show_llm_response=noop_fn,
|
105
|
+
show_agent_response=noop_fn,
|
106
|
+
get_user_response=None,
|
107
|
+
get_last_step=noop_fn,
|
108
|
+
set_parent_agent=noop_fn,
|
109
|
+
show_error_message=noop_fn,
|
110
|
+
show_start_response=noop_fn,
|
111
|
+
)
|
89
112
|
|
90
113
|
def entity_responders(
|
91
114
|
self,
|
@@ -139,6 +162,9 @@ class Agent(ABC):
|
|
139
162
|
def get_dialog(self) -> List[Tuple[str, str]]:
|
140
163
|
return self.dialog
|
141
164
|
|
165
|
+
def clear_dialog(self) -> None:
|
166
|
+
self.dialog = []
|
167
|
+
|
142
168
|
def _get_tool_list(
|
143
169
|
self, message_class: Optional[Type[ToolMessage]] = None
|
144
170
|
) -> List[str]:
|
@@ -246,6 +272,10 @@ class Agent(ABC):
|
|
246
272
|
]
|
247
273
|
return "\n\n".join(sample_convo)
|
248
274
|
|
275
|
+
def agent_response_template(self) -> ChatDocument:
|
276
|
+
"""Template for agent_response."""
|
277
|
+
return self._response_template(Entity.AGENT)
|
278
|
+
|
249
279
|
async def agent_response_async(
|
250
280
|
self,
|
251
281
|
msg: Optional[str | ChatDocument] = None,
|
@@ -275,9 +305,19 @@ class Agent(ABC):
|
|
275
305
|
if results is None:
|
276
306
|
return None
|
277
307
|
if isinstance(results, ChatDocument):
|
308
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
309
|
+
results.metadata.tool_ids = (
|
310
|
+
[] if isinstance(msg, str) else msg.metadata.tool_ids
|
311
|
+
)
|
278
312
|
return results
|
279
|
-
|
280
|
-
|
313
|
+
if not settings.quiet:
|
314
|
+
console.print(f"[red]{self.indent}", end="")
|
315
|
+
print(f"[red]Agent: {results}")
|
316
|
+
maybe_json = len(extract_top_level_json(results)) > 0
|
317
|
+
self.callbacks.show_agent_response(
|
318
|
+
content=results,
|
319
|
+
language="json" if maybe_json else "text",
|
320
|
+
)
|
281
321
|
sender_name = self.config.name
|
282
322
|
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
283
323
|
# if result was from handling an LLM `function_call`,
|
@@ -290,9 +330,25 @@ class Agent(ABC):
|
|
290
330
|
source=Entity.AGENT,
|
291
331
|
sender=Entity.AGENT,
|
292
332
|
sender_name=sender_name,
|
333
|
+
# preserve trail of tool_ids for OpenAI Assistant fn-calls
|
334
|
+
tool_ids=[] if isinstance(msg, str) else msg.metadata.tool_ids,
|
335
|
+
),
|
336
|
+
)
|
337
|
+
|
338
|
+
def _response_template(self, e: Entity) -> ChatDocument:
|
339
|
+
"""Template for response from entity `e`."""
|
340
|
+
return ChatDocument(
|
341
|
+
content="",
|
342
|
+
tool_messages=[],
|
343
|
+
metadata=ChatDocMetaData(
|
344
|
+
source=e, sender=e, sender_name=self.config.name, tool_ids=[]
|
293
345
|
),
|
294
346
|
)
|
295
347
|
|
348
|
+
def user_response_template(self) -> ChatDocument:
|
349
|
+
"""Template for user_response."""
|
350
|
+
return self._response_template(Entity.USER)
|
351
|
+
|
296
352
|
async def user_response_async(
|
297
353
|
self,
|
298
354
|
msg: Optional[str | ChatDocument] = None,
|
@@ -320,12 +376,22 @@ class Agent(ABC):
|
|
320
376
|
elif not settings.interactive:
|
321
377
|
user_msg = ""
|
322
378
|
else:
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
379
|
+
if self.callbacks.get_user_response is not None:
|
380
|
+
# ask user with empty prompt: no need for prompt
|
381
|
+
# since user has seen the conversation so far.
|
382
|
+
# But non-empty prompt can be useful when Agent
|
383
|
+
# uses a tool that requires user input, or in other scenarios.
|
384
|
+
user_msg = self.callbacks.get_user_response(prompt="")
|
385
|
+
else:
|
386
|
+
user_msg = Prompt.ask(
|
387
|
+
f"[blue]{self.indent}Human "
|
388
|
+
"(respond or q, x to exit current level, "
|
389
|
+
f"or hit enter to continue)\n{self.indent}",
|
390
|
+
).strip()
|
391
|
+
|
392
|
+
tool_ids = []
|
393
|
+
if msg is not None and isinstance(msg, ChatDocument):
|
394
|
+
tool_ids = msg.metadata.tool_ids
|
329
395
|
# only return non-None result if user_msg not empty
|
330
396
|
if not user_msg:
|
331
397
|
return None
|
@@ -339,9 +405,11 @@ class Agent(ABC):
|
|
339
405
|
sender = Entity.USER
|
340
406
|
return ChatDocument(
|
341
407
|
content=user_msg,
|
342
|
-
metadata=
|
408
|
+
metadata=ChatDocMetaData(
|
343
409
|
source=source,
|
344
410
|
sender=sender,
|
411
|
+
# preserve trail of tool_ids for OpenAI Assistant fn-calls
|
412
|
+
tool_ids=tool_ids,
|
345
413
|
),
|
346
414
|
)
|
347
415
|
|
@@ -358,13 +426,6 @@ class Agent(ABC):
|
|
358
426
|
if self.llm is None:
|
359
427
|
return False
|
360
428
|
|
361
|
-
if isinstance(message, ChatDocument) and message.function_call is not None:
|
362
|
-
# LLM should not handle `function_call` messages,
|
363
|
-
# EVEN if message.function_call is not a legit function_call
|
364
|
-
# The OpenAI API raises error if there is a message in history
|
365
|
-
# from a non-Assistant role, with a `function_call` in it
|
366
|
-
return False
|
367
|
-
|
368
429
|
if message is not None and len(self.get_tool_messages(message)) > 0:
|
369
430
|
# if there is a valid "tool" message (either JSON or via `function_call`)
|
370
431
|
# then LLM cannot respond to it
|
@@ -372,6 +433,10 @@ class Agent(ABC):
|
|
372
433
|
|
373
434
|
return True
|
374
435
|
|
436
|
+
def llm_response_template(self) -> ChatDocument:
|
437
|
+
"""Template for llm_response."""
|
438
|
+
return self._response_template(Entity.LLM)
|
439
|
+
|
375
440
|
@no_type_check
|
376
441
|
async def llm_response_async(
|
377
442
|
self,
|
@@ -410,18 +475,24 @@ class Agent(ABC):
|
|
410
475
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
411
476
|
response = await self.llm.agenerate(prompt, output_len)
|
412
477
|
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
478
|
+
if not self.llm.get_stream() or response.cached and not settings.quiet:
|
479
|
+
# We would have already displayed the msg "live" ONLY if
|
480
|
+
# streaming was enabled, AND we did not find a cached response.
|
481
|
+
# If we are here, it means the response has not yet been displayed.
|
482
|
+
cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
|
483
|
+
print(cached + "[green]" + escape(response.message))
|
484
|
+
async with self.lock:
|
485
|
+
self.update_token_usage(
|
486
|
+
response,
|
487
|
+
prompt,
|
488
|
+
self.llm.get_stream(),
|
489
|
+
chat=False, # i.e. it's a completion model not chat model
|
490
|
+
print_response_stats=self.config.show_stats and not settings.quiet,
|
491
|
+
)
|
492
|
+
cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
|
493
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
494
|
+
cdoc.metadata.tool_ids = [] if isinstance(msg, str) else msg.metadata.tool_ids
|
495
|
+
return cdoc
|
425
496
|
|
426
497
|
@no_type_check
|
427
498
|
def llm_response(
|
@@ -447,7 +518,7 @@ class Agent(ABC):
|
|
447
518
|
with ExitStack() as stack: # for conditionally using rich spinner
|
448
519
|
if not self.llm.get_stream():
|
449
520
|
# show rich spinner only if not streaming!
|
450
|
-
cm =
|
521
|
+
cm = status("LLM responding to message...")
|
451
522
|
stack.enter_context(cm)
|
452
523
|
output_len = self.config.llm.max_output_tokens
|
453
524
|
if (
|
@@ -472,36 +543,61 @@ class Agent(ABC):
|
|
472
543
|
the completion context length of the LLM.
|
473
544
|
"""
|
474
545
|
)
|
475
|
-
if self.llm.get_stream():
|
546
|
+
if self.llm.get_stream() and not settings.quiet:
|
476
547
|
console.print(f"[green]{self.indent}", end="")
|
477
548
|
response = self.llm.generate(prompt, output_len)
|
478
549
|
|
479
|
-
|
480
|
-
if not self.llm.get_stream() or response.cached:
|
550
|
+
if not self.llm.get_stream() or response.cached and not settings.quiet:
|
481
551
|
# we would have already displayed the msg "live" ONLY if
|
482
552
|
# streaming was enabled, AND we did not find a cached response
|
553
|
+
# If we are here, it means the response has not yet been displayed.
|
554
|
+
cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
|
483
555
|
console.print(f"[green]{self.indent}", end="")
|
484
|
-
print("[green]" + response.message)
|
485
|
-
displayed = True
|
556
|
+
print(cached + "[green]" + escape(response.message))
|
486
557
|
self.update_token_usage(
|
487
558
|
response,
|
488
559
|
prompt,
|
489
560
|
self.llm.get_stream(),
|
490
|
-
|
561
|
+
chat=False, # i.e. it's a completion model not chat model
|
562
|
+
print_response_stats=self.config.show_stats and not settings.quiet,
|
491
563
|
)
|
492
|
-
|
564
|
+
cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
|
565
|
+
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
566
|
+
cdoc.metadata.tool_ids = [] if isinstance(msg, str) else msg.metadata.tool_ids
|
567
|
+
return cdoc
|
568
|
+
|
569
|
+
def has_tool_message_attempt(self, msg: str | ChatDocument | None) -> bool:
|
570
|
+
"""Check whether msg contains a Tool/fn-call attempt (by the LLM)"""
|
571
|
+
if msg is None:
|
572
|
+
return False
|
573
|
+
try:
|
574
|
+
tools = self.get_tool_messages(msg)
|
575
|
+
return len(tools) > 0
|
576
|
+
except ValidationError:
|
577
|
+
# there is a tool/fn-call attempt but had a validation error,
|
578
|
+
# so we still consider this a tool message "attempt"
|
579
|
+
return True
|
580
|
+
return False
|
493
581
|
|
494
582
|
def get_tool_messages(self, msg: str | ChatDocument) -> List[ToolMessage]:
|
495
583
|
if isinstance(msg, str):
|
496
584
|
return self.get_json_tool_messages(msg)
|
585
|
+
if len(msg.tool_messages) > 0:
|
586
|
+
# We've already found tool_messages
|
587
|
+
# (either via OpenAI Fn-call or Langroid-native ToolMessage)
|
588
|
+
return msg.tool_messages
|
497
589
|
assert isinstance(msg, ChatDocument)
|
498
590
|
# when `content` is non-empty, we assume there will be no `function_call`
|
499
591
|
if msg.content != "":
|
500
|
-
|
592
|
+
tools = self.get_json_tool_messages(msg.content)
|
593
|
+
msg.tool_messages = tools
|
594
|
+
return tools
|
501
595
|
|
502
596
|
# otherwise, we look for a `function_call`
|
503
597
|
fun_call_cls = self.get_function_call_class(msg)
|
504
|
-
|
598
|
+
tools = [fun_call_cls] if fun_call_cls is not None else []
|
599
|
+
msg.tool_messages = tools
|
600
|
+
return tools
|
505
601
|
|
506
602
|
def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
|
507
603
|
"""
|
@@ -525,7 +621,17 @@ class Agent(ABC):
|
|
525
621
|
tool_name = msg.function_call.name
|
526
622
|
tool_msg = msg.function_call.arguments or {}
|
527
623
|
if tool_name not in self.llm_tools_handled:
|
528
|
-
|
624
|
+
logger.warning(
|
625
|
+
f"""
|
626
|
+
The function_call '{tool_name}' is not handled
|
627
|
+
by the agent named '{self.config.name}'!
|
628
|
+
If you intended this agent to handle this function_call,
|
629
|
+
either the fn-call name is incorrectly generated by the LLM,
|
630
|
+
(in which case you may need to adjust your LLM instructions),
|
631
|
+
or you need to enable this agent to handle this fn-call.
|
632
|
+
"""
|
633
|
+
)
|
634
|
+
return None
|
529
635
|
tool_class = self.llm_tools_map[tool_name]
|
530
636
|
tool_msg.update(dict(request=tool_name))
|
531
637
|
tool = tool_class.parse_obj(tool_msg)
|
@@ -544,7 +650,7 @@ class Agent(ABC):
|
|
544
650
|
"""
|
545
651
|
tool_name = cast(ToolMessage, ve.model).default_value("request")
|
546
652
|
bad_field_errors = "\n".join(
|
547
|
-
[f"{e['loc']
|
653
|
+
[f"{e['loc']}: {e['msg']}" for e in ve.errors() if "loc" in e]
|
548
654
|
)
|
549
655
|
return f"""
|
550
656
|
There were one or more errors in your attempt to use the
|
@@ -588,7 +694,7 @@ class Agent(ABC):
|
|
588
694
|
|
589
695
|
results_list = [r for r in results if r is not None]
|
590
696
|
if len(results_list) == 0:
|
591
|
-
return self.handle_message_fallback(msg)
|
697
|
+
return None # self.handle_message_fallback(msg)
|
592
698
|
# there was a non-None result
|
593
699
|
chat_doc_results = [r for r in results_list if isinstance(r, ChatDocument)]
|
594
700
|
if len(chat_doc_results) > 1:
|
@@ -603,19 +709,13 @@ class Agent(ABC):
|
|
603
709
|
|
604
710
|
str_doc_results = [r for r in results_list if isinstance(r, str)]
|
605
711
|
final = "\n".join(str_doc_results)
|
606
|
-
if final == "":
|
607
|
-
logger.warning(
|
608
|
-
"""final result from a tool handler should not be empty str, since
|
609
|
-
it would be considered an invalid result and other responders
|
610
|
-
will be tried, and we may not necessarily want that"""
|
611
|
-
)
|
612
712
|
return final
|
613
713
|
|
614
714
|
def handle_message_fallback(
|
615
715
|
self, msg: str | ChatDocument
|
616
716
|
) -> str | ChatDocument | None:
|
617
717
|
"""
|
618
|
-
Fallback method to handle possible "tool" msg if
|
718
|
+
Fallback method to handle possible "tool" msg if no other method applies
|
619
719
|
or if an error is thrown.
|
620
720
|
This method can be overridden by subclasses.
|
621
721
|
|
@@ -630,7 +730,11 @@ class Agent(ABC):
|
|
630
730
|
def _get_one_tool_message(self, json_str: str) -> Optional[ToolMessage]:
|
631
731
|
json_data = json.loads(json_str)
|
632
732
|
request = json_data.get("request")
|
633
|
-
if
|
733
|
+
if (
|
734
|
+
request is None
|
735
|
+
or not (isinstance(request, str))
|
736
|
+
or request not in self.llm_tools_handled
|
737
|
+
):
|
634
738
|
return None
|
635
739
|
|
636
740
|
message_class = self.llm_tools_map.get(request)
|
@@ -661,8 +765,10 @@ class Agent(ABC):
|
|
661
765
|
try:
|
662
766
|
result = handler_method(tool)
|
663
767
|
except Exception as e:
|
664
|
-
#
|
665
|
-
|
768
|
+
# raise the error here since we are sure it's
|
769
|
+
# not a pydantic validation error,
|
770
|
+
# which we check in `handle_message`
|
771
|
+
raise e
|
666
772
|
return result # type: ignore
|
667
773
|
|
668
774
|
def num_tokens(self, prompt: str | List[LLMMessage]) -> int:
|
@@ -671,7 +777,13 @@ class Agent(ABC):
|
|
671
777
|
if isinstance(prompt, str):
|
672
778
|
return self.parser.num_tokens(prompt)
|
673
779
|
else:
|
674
|
-
return sum(
|
780
|
+
return sum(
|
781
|
+
[
|
782
|
+
self.parser.num_tokens(m.content)
|
783
|
+
+ self.parser.num_tokens(str(m.function_call or ""))
|
784
|
+
for m in prompt
|
785
|
+
]
|
786
|
+
)
|
675
787
|
|
676
788
|
def _get_response_stats(
|
677
789
|
self, chat_length: int, tot_cost: float, response: LLMResponse
|
@@ -696,11 +808,17 @@ class Agent(ABC):
|
|
696
808
|
assert isinstance(self.llm, LanguageModel)
|
697
809
|
context_length = self.llm.chat_context_length()
|
698
810
|
max_out = self.config.llm.max_output_tokens
|
811
|
+
|
812
|
+
llm_model = (
|
813
|
+
"no-LLM" if self.config.llm is None else self.llm.config.chat_model
|
814
|
+
)
|
815
|
+
|
699
816
|
return (
|
700
|
-
f"[bold]Stats:[/bold] [magenta]
|
817
|
+
f"[bold]Stats:[/bold] [magenta]N_MSG={chat_length}, "
|
701
818
|
f"TOKENS: in={in_tokens}, out={out_tokens}, "
|
702
819
|
f"max={max_out}, ctx={context_length}, "
|
703
|
-
f"COST: now=${llm_response_cost}, cumul=${cumul_cost}
|
820
|
+
f"COST: now=${llm_response_cost}, cumul=${cumul_cost} "
|
821
|
+
f"[bold]({llm_model})[/bold][/magenta]"
|
704
822
|
)
|
705
823
|
return ""
|
706
824
|
|
@@ -709,6 +827,7 @@ class Agent(ABC):
|
|
709
827
|
response: LLMResponse,
|
710
828
|
prompt: str | List[LLMMessage],
|
711
829
|
stream: bool,
|
830
|
+
chat: bool = True,
|
712
831
|
print_response_stats: bool = True,
|
713
832
|
) -> None:
|
714
833
|
"""
|
@@ -722,36 +841,48 @@ class Agent(ABC):
|
|
722
841
|
prompt (str | List[LLMMessage]): prompt or list of LLMMessage objects
|
723
842
|
stream (bool): whether to update the usage in the response object
|
724
843
|
if the response is not cached.
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
844
|
+
chat (bool): whether this is a chat model or a completion model
|
845
|
+
print_response_stats (bool): whether to print the response stats
|
846
|
+
"""
|
847
|
+
if response is None or self.llm is None:
|
848
|
+
return
|
849
|
+
|
850
|
+
# Note: If response was not streamed, then
|
851
|
+
# `response.usage` would already have been set by the API,
|
852
|
+
# so we only need to update in the stream case.
|
853
|
+
if stream:
|
854
|
+
# usage, cost = 0 when response is from cache
|
855
|
+
prompt_tokens = 0
|
856
|
+
completion_tokens = 0
|
857
|
+
cost = 0.0
|
858
|
+
if not response.cached:
|
859
|
+
prompt_tokens = self.num_tokens(prompt)
|
860
|
+
completion_tokens = self.num_tokens(response.message)
|
861
|
+
if response.function_call is not None:
|
862
|
+
completion_tokens += self.num_tokens(str(response.function_call))
|
863
|
+
cost = self.compute_token_cost(prompt_tokens, completion_tokens)
|
864
|
+
response.usage = LLMTokenUsage(
|
865
|
+
prompt_tokens=prompt_tokens,
|
866
|
+
completion_tokens=completion_tokens,
|
867
|
+
cost=cost,
|
868
|
+
)
|
744
869
|
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
870
|
+
# update total counters
|
871
|
+
if response.usage is not None:
|
872
|
+
self.total_llm_token_cost += response.usage.cost
|
873
|
+
self.total_llm_token_usage += response.usage.total_tokens
|
874
|
+
self.llm.update_usage_cost(
|
875
|
+
chat,
|
876
|
+
response.usage.prompt_tokens,
|
877
|
+
response.usage.completion_tokens,
|
878
|
+
response.usage.cost,
|
879
|
+
)
|
880
|
+
chat_length = 1 if isinstance(prompt, str) else len(prompt)
|
881
|
+
self.token_stats_str = self._get_response_stats(
|
882
|
+
chat_length, self.total_llm_token_cost, response
|
883
|
+
)
|
884
|
+
if print_response_stats:
|
885
|
+
print(self.indent + self.token_stats_str)
|
755
886
|
|
756
887
|
def compute_token_cost(self, prompt: int, completion: int) -> float:
|
757
888
|
price = cast(LanguageModel, self.llm).chat_cost()
|
@@ -773,8 +904,8 @@ class Agent(ABC):
|
|
773
904
|
Args:
|
774
905
|
agent (Agent): agent to ask
|
775
906
|
request (str): request to send
|
776
|
-
no_answer: expected response when agent does not know the answer
|
777
|
-
|
907
|
+
no_answer (str): expected response when agent does not know the answer
|
908
|
+
user_confirm (bool): whether to gate the request with a human confirmation
|
778
909
|
|
779
910
|
Returns:
|
780
911
|
str: response from agent
|