langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
langroid/__init__.py CHANGED
@@ -1,3 +1,98 @@
1
1
  """
2
2
  Main langroid package
3
3
  """
4
+
5
+ from . import mytypes
6
+ from . import utils
7
+
8
+ from . import parsing
9
+ from . import prompts
10
+ from . import cachedb
11
+
12
+ from . import language_models
13
+ from . import embedding_models
14
+
15
+ from . import vector_store
16
+ from . import agent
17
+
18
+ from .agent.base import (
19
+ Agent,
20
+ AgentConfig,
21
+ )
22
+
23
+ from .agent.batch import (
24
+ run_batch_tasks,
25
+ llm_response_batch,
26
+ agent_response_batch,
27
+ )
28
+
29
+ from .agent.chat_document import (
30
+ ChatDocument,
31
+ ChatDocMetaData,
32
+ )
33
+
34
+ from .agent.tool_message import (
35
+ ToolMessage,
36
+ )
37
+
38
+ from .agent.chat_agent import (
39
+ ChatAgent,
40
+ ChatAgentConfig,
41
+ )
42
+
43
+ from .agent.task import Task
44
+
45
+ try:
46
+ from .agent.callbacks.chainlit import (
47
+ ChainlitAgentCallbacks,
48
+ ChainlitTaskCallbacks,
49
+ ChainlitCallbackConfig,
50
+ )
51
+
52
+ chainlit_available = True
53
+ ChainlitAgentCallbacks
54
+ ChainlitTaskCallbacks
55
+ ChainlitCallbackConfig
56
+ except ImportError:
57
+ chainlit_available = False
58
+
59
+
60
+ from .mytypes import (
61
+ DocMetaData,
62
+ Document,
63
+ Entity,
64
+ )
65
+
66
+ __all__ = [
67
+ "mytypes",
68
+ "utils",
69
+ "parsing",
70
+ "prompts",
71
+ "cachedb",
72
+ "language_models",
73
+ "embedding_models",
74
+ "vector_store",
75
+ "agent",
76
+ "Agent",
77
+ "AgentConfig",
78
+ "ChatAgent",
79
+ "ChatAgentConfig",
80
+ "ChatDocument",
81
+ "ChatDocMetaData",
82
+ "Task",
83
+ "DocMetaData",
84
+ "Document",
85
+ "Entity",
86
+ "ToolMessage",
87
+ "run_batch_tasks",
88
+ "llm_response_batch",
89
+ "agent_response_batch",
90
+ ]
91
+ if chainlit_available:
92
+ __all__.extend(
93
+ [
94
+ "ChainlitAgentCallbacks",
95
+ "ChainlitTaskCallbacks",
96
+ "ChainlitCallbackConfig",
97
+ ]
98
+ )
@@ -0,0 +1,40 @@
1
+ from .base import Agent, AgentConfig
2
+ from .chat_document import (
3
+ ChatDocAttachment,
4
+ ChatDocMetaData,
5
+ ChatDocLoggerFields,
6
+ ChatDocument,
7
+ )
8
+ from .chat_agent import ChatAgentConfig, ChatAgent
9
+ from .tool_message import ToolMessage
10
+ from .task import Task
11
+
12
+ from . import base
13
+ from . import chat_document
14
+ from . import chat_agent
15
+ from . import task
16
+ from . import batch
17
+ from . import tool_message
18
+ from . import tools
19
+ from . import special
20
+
21
+ __all__ = [
22
+ "Agent",
23
+ "AgentConfig",
24
+ "ChatDocAttachment",
25
+ "ChatDocMetaData",
26
+ "ChatDocLoggerFields",
27
+ "ChatDocument",
28
+ "ChatAgent",
29
+ "ChatAgentConfig",
30
+ "ToolMessage",
31
+ "Task",
32
+ "base",
33
+ "chat_document",
34
+ "chat_agent",
35
+ "task",
36
+ "batch",
37
+ "tool_message",
38
+ "tools",
39
+ "special",
40
+ ]
langroid/agent/base.py CHANGED
@@ -1,8 +1,10 @@
1
+ import asyncio
1
2
  import inspect
2
3
  import json
3
4
  import logging
4
5
  from abc import ABC
5
6
  from contextlib import ExitStack
7
+ from types import SimpleNamespace
6
8
  from typing import (
7
9
  Any,
8
10
  Callable,
@@ -20,6 +22,7 @@ from typing import (
20
22
  from pydantic import BaseSettings, ValidationError
21
23
  from rich import print
22
24
  from rich.console import Console
25
+ from rich.markup import escape
23
26
  from rich.prompt import Prompt
24
27
 
25
28
  from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
@@ -32,15 +35,17 @@ from langroid.language_models.base import (
32
35
  LLMTokenUsage,
33
36
  StreamingIfAllowed,
34
37
  )
35
- from langroid.mytypes import DocMetaData, Entity
36
- from langroid.parsing.json import extract_top_level_json
38
+ from langroid.language_models.openai_gpt import OpenAIGPTConfig
39
+ from langroid.mytypes import Entity
40
+ from langroid.parsing.parse_json import extract_top_level_json
37
41
  from langroid.parsing.parser import Parser, ParsingConfig
38
42
  from langroid.prompts.prompts_config import PromptsConfig
39
43
  from langroid.utils.configuration import settings
40
44
  from langroid.utils.constants import NO_ANSWER
45
+ from langroid.utils.output import status
41
46
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
42
47
 
43
- console = Console()
48
+ console = Console(quiet=settings.quiet)
44
49
 
45
50
  logger = logging.getLogger(__name__)
46
51
 
@@ -53,10 +58,15 @@ class AgentConfig(BaseSettings):
53
58
 
54
59
  name: str = "LLM-Agent"
55
60
  debug: bool = False
56
- vecdb: Optional[VectorStoreConfig] = VectorStoreConfig()
57
- llm: Optional[LLMConfig] = LLMConfig()
61
+ vecdb: Optional[VectorStoreConfig] = None
62
+ llm: Optional[LLMConfig] = OpenAIGPTConfig()
58
63
  parsing: Optional[ParsingConfig] = ParsingConfig()
59
64
  prompts: Optional[PromptsConfig] = PromptsConfig()
65
+ show_stats: bool = True # show token usage/cost stats?
66
+
67
+
68
+ def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
69
+ pass
60
70
 
61
71
 
62
72
  class Agent(ABC):
@@ -70,8 +80,9 @@ class Agent(ABC):
70
80
  information about any tool/function-calling messages that have been defined.
71
81
  """
72
82
 
73
- def __init__(self, config: AgentConfig):
83
+ def __init__(self, config: AgentConfig = AgentConfig()):
74
84
  self.config = config
85
+ self.lock = asyncio.Lock() # for async access to update self.llm.usage_cost
75
86
  self.dialog: List[Tuple[str, str]] = [] # seq of LLM (prompt, response) tuples
76
87
  self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
77
88
  self.llm_tools_handled: Set[str] = set()
@@ -86,6 +97,18 @@ class Agent(ABC):
86
97
  self.parser: Optional[Parser] = (
87
98
  Parser(config.parsing) if config.parsing else None
88
99
  )
100
+ self.callbacks = SimpleNamespace(
101
+ start_llm_stream=lambda: noop_fn,
102
+ cancel_llm_stream=noop_fn,
103
+ finish_llm_stream=noop_fn,
104
+ show_llm_response=noop_fn,
105
+ show_agent_response=noop_fn,
106
+ get_user_response=None,
107
+ get_last_step=noop_fn,
108
+ set_parent_agent=noop_fn,
109
+ show_error_message=noop_fn,
110
+ show_start_response=noop_fn,
111
+ )
89
112
 
90
113
  def entity_responders(
91
114
  self,
@@ -139,6 +162,9 @@ class Agent(ABC):
139
162
  def get_dialog(self) -> List[Tuple[str, str]]:
140
163
  return self.dialog
141
164
 
165
+ def clear_dialog(self) -> None:
166
+ self.dialog = []
167
+
142
168
  def _get_tool_list(
143
169
  self, message_class: Optional[Type[ToolMessage]] = None
144
170
  ) -> List[str]:
@@ -246,6 +272,10 @@ class Agent(ABC):
246
272
  ]
247
273
  return "\n\n".join(sample_convo)
248
274
 
275
+ def agent_response_template(self) -> ChatDocument:
276
+ """Template for agent_response."""
277
+ return self._response_template(Entity.AGENT)
278
+
249
279
  async def agent_response_async(
250
280
  self,
251
281
  msg: Optional[str | ChatDocument] = None,
@@ -275,9 +305,19 @@ class Agent(ABC):
275
305
  if results is None:
276
306
  return None
277
307
  if isinstance(results, ChatDocument):
308
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
309
+ results.metadata.tool_ids = (
310
+ [] if isinstance(msg, str) else msg.metadata.tool_ids
311
+ )
278
312
  return results
279
- console.print(f"[red]{self.indent}", end="")
280
- print(f"[red]Agent: {results}")
313
+ if not settings.quiet:
314
+ console.print(f"[red]{self.indent}", end="")
315
+ print(f"[red]Agent: {results}")
316
+ maybe_json = len(extract_top_level_json(results)) > 0
317
+ self.callbacks.show_agent_response(
318
+ content=results,
319
+ language="json" if maybe_json else "text",
320
+ )
281
321
  sender_name = self.config.name
282
322
  if isinstance(msg, ChatDocument) and msg.function_call is not None:
283
323
  # if result was from handling an LLM `function_call`,
@@ -290,9 +330,25 @@ class Agent(ABC):
290
330
  source=Entity.AGENT,
291
331
  sender=Entity.AGENT,
292
332
  sender_name=sender_name,
333
+ # preserve trail of tool_ids for OpenAI Assistant fn-calls
334
+ tool_ids=[] if isinstance(msg, str) else msg.metadata.tool_ids,
335
+ ),
336
+ )
337
+
338
+ def _response_template(self, e: Entity) -> ChatDocument:
339
+ """Template for response from entity `e`."""
340
+ return ChatDocument(
341
+ content="",
342
+ tool_messages=[],
343
+ metadata=ChatDocMetaData(
344
+ source=e, sender=e, sender_name=self.config.name, tool_ids=[]
293
345
  ),
294
346
  )
295
347
 
348
+ def user_response_template(self) -> ChatDocument:
349
+ """Template for user_response."""
350
+ return self._response_template(Entity.USER)
351
+
296
352
  async def user_response_async(
297
353
  self,
298
354
  msg: Optional[str | ChatDocument] = None,
@@ -320,12 +376,22 @@ class Agent(ABC):
320
376
  elif not settings.interactive:
321
377
  user_msg = ""
322
378
  else:
323
- user_msg = Prompt.ask(
324
- f"[blue]{self.indent}Human "
325
- "(respond or q, x to exit current level, "
326
- f"or hit enter to continue)\n{self.indent}",
327
- ).strip()
328
-
379
+ if self.callbacks.get_user_response is not None:
380
+ # ask user with empty prompt: no need for prompt
381
+ # since user has seen the conversation so far.
382
+ # But non-empty prompt can be useful when Agent
383
+ # uses a tool that requires user input, or in other scenarios.
384
+ user_msg = self.callbacks.get_user_response(prompt="")
385
+ else:
386
+ user_msg = Prompt.ask(
387
+ f"[blue]{self.indent}Human "
388
+ "(respond or q, x to exit current level, "
389
+ f"or hit enter to continue)\n{self.indent}",
390
+ ).strip()
391
+
392
+ tool_ids = []
393
+ if msg is not None and isinstance(msg, ChatDocument):
394
+ tool_ids = msg.metadata.tool_ids
329
395
  # only return non-None result if user_msg not empty
330
396
  if not user_msg:
331
397
  return None
@@ -339,9 +405,11 @@ class Agent(ABC):
339
405
  sender = Entity.USER
340
406
  return ChatDocument(
341
407
  content=user_msg,
342
- metadata=DocMetaData(
408
+ metadata=ChatDocMetaData(
343
409
  source=source,
344
410
  sender=sender,
411
+ # preserve trail of tool_ids for OpenAI Assistant fn-calls
412
+ tool_ids=tool_ids,
345
413
  ),
346
414
  )
347
415
 
@@ -358,13 +426,6 @@ class Agent(ABC):
358
426
  if self.llm is None:
359
427
  return False
360
428
 
361
- if isinstance(message, ChatDocument) and message.function_call is not None:
362
- # LLM should not handle `function_call` messages,
363
- # EVEN if message.function_call is not a legit function_call
364
- # The OpenAI API raises error if there is a message in history
365
- # from a non-Assistant role, with a `function_call` in it
366
- return False
367
-
368
429
  if message is not None and len(self.get_tool_messages(message)) > 0:
369
430
  # if there is a valid "tool" message (either JSON or via `function_call`)
370
431
  # then LLM cannot respond to it
@@ -372,6 +433,10 @@ class Agent(ABC):
372
433
 
373
434
  return True
374
435
 
436
+ def llm_response_template(self) -> ChatDocument:
437
+ """Template for llm_response."""
438
+ return self._response_template(Entity.LLM)
439
+
375
440
  @no_type_check
376
441
  async def llm_response_async(
377
442
  self,
@@ -410,18 +475,24 @@ class Agent(ABC):
410
475
  with StreamingIfAllowed(self.llm, self.llm.get_stream()):
411
476
  response = await self.llm.agenerate(prompt, output_len)
412
477
 
413
- # we would have already displayed the msg "live" ONLY if
414
- # streaming was enabled, AND we did not find a cached response
415
- console.print(f"[green]{self.indent}", end="")
416
- print("[green]" + response.message)
417
- displayed = True
418
- self.update_token_usage(
419
- response,
420
- prompt,
421
- self.llm.get_stream(),
422
- print_response_stats=True,
423
- )
424
- return ChatDocument.from_LLMResponse(response, displayed)
478
+ if not self.llm.get_stream() or response.cached and not settings.quiet:
479
+ # We would have already displayed the msg "live" ONLY if
480
+ # streaming was enabled, AND we did not find a cached response.
481
+ # If we are here, it means the response has not yet been displayed.
482
+ cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
483
+ print(cached + "[green]" + escape(response.message))
484
+ async with self.lock:
485
+ self.update_token_usage(
486
+ response,
487
+ prompt,
488
+ self.llm.get_stream(),
489
+ chat=False, # i.e. it's a completion model not chat model
490
+ print_response_stats=self.config.show_stats and not settings.quiet,
491
+ )
492
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
493
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
494
+ cdoc.metadata.tool_ids = [] if isinstance(msg, str) else msg.metadata.tool_ids
495
+ return cdoc
425
496
 
426
497
  @no_type_check
427
498
  def llm_response(
@@ -447,7 +518,7 @@ class Agent(ABC):
447
518
  with ExitStack() as stack: # for conditionally using rich spinner
448
519
  if not self.llm.get_stream():
449
520
  # show rich spinner only if not streaming!
450
- cm = console.status("LLM responding to message...")
521
+ cm = status("LLM responding to message...")
451
522
  stack.enter_context(cm)
452
523
  output_len = self.config.llm.max_output_tokens
453
524
  if (
@@ -472,36 +543,61 @@ class Agent(ABC):
472
543
  the completion context length of the LLM.
473
544
  """
474
545
  )
475
- if self.llm.get_stream():
546
+ if self.llm.get_stream() and not settings.quiet:
476
547
  console.print(f"[green]{self.indent}", end="")
477
548
  response = self.llm.generate(prompt, output_len)
478
549
 
479
- displayed = False
480
- if not self.llm.get_stream() or response.cached:
550
+ if not self.llm.get_stream() or response.cached and not settings.quiet:
481
551
  # we would have already displayed the msg "live" ONLY if
482
552
  # streaming was enabled, AND we did not find a cached response
553
+ # If we are here, it means the response has not yet been displayed.
554
+ cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
483
555
  console.print(f"[green]{self.indent}", end="")
484
- print("[green]" + response.message)
485
- displayed = True
556
+ print(cached + "[green]" + escape(response.message))
486
557
  self.update_token_usage(
487
558
  response,
488
559
  prompt,
489
560
  self.llm.get_stream(),
490
- print_response_stats=True,
561
+ chat=False, # i.e. it's a completion model not chat model
562
+ print_response_stats=self.config.show_stats and not settings.quiet,
491
563
  )
492
- return ChatDocument.from_LLMResponse(response, displayed)
564
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
565
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
566
+ cdoc.metadata.tool_ids = [] if isinstance(msg, str) else msg.metadata.tool_ids
567
+ return cdoc
568
+
569
+ def has_tool_message_attempt(self, msg: str | ChatDocument | None) -> bool:
570
+ """Check whether msg contains a Tool/fn-call attempt (by the LLM)"""
571
+ if msg is None:
572
+ return False
573
+ try:
574
+ tools = self.get_tool_messages(msg)
575
+ return len(tools) > 0
576
+ except ValidationError:
577
+ # there is a tool/fn-call attempt but had a validation error,
578
+ # so we still consider this a tool message "attempt"
579
+ return True
580
+ return False
493
581
 
494
582
  def get_tool_messages(self, msg: str | ChatDocument) -> List[ToolMessage]:
495
583
  if isinstance(msg, str):
496
584
  return self.get_json_tool_messages(msg)
585
+ if len(msg.tool_messages) > 0:
586
+ # We've already found tool_messages
587
+ # (either via OpenAI Fn-call or Langroid-native ToolMessage)
588
+ return msg.tool_messages
497
589
  assert isinstance(msg, ChatDocument)
498
590
  # when `content` is non-empty, we assume there will be no `function_call`
499
591
  if msg.content != "":
500
- return self.get_json_tool_messages(msg.content)
592
+ tools = self.get_json_tool_messages(msg.content)
593
+ msg.tool_messages = tools
594
+ return tools
501
595
 
502
596
  # otherwise, we look for a `function_call`
503
597
  fun_call_cls = self.get_function_call_class(msg)
504
- return [fun_call_cls] if fun_call_cls is not None else []
598
+ tools = [fun_call_cls] if fun_call_cls is not None else []
599
+ msg.tool_messages = tools
600
+ return tools
505
601
 
506
602
  def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
507
603
  """
@@ -525,7 +621,17 @@ class Agent(ABC):
525
621
  tool_name = msg.function_call.name
526
622
  tool_msg = msg.function_call.arguments or {}
527
623
  if tool_name not in self.llm_tools_handled:
528
- raise ValueError(f"{tool_name} is not a valid function_call!")
624
+ logger.warning(
625
+ f"""
626
+ The function_call '{tool_name}' is not handled
627
+ by the agent named '{self.config.name}'!
628
+ If you intended this agent to handle this function_call,
629
+ either the fn-call name is incorrectly generated by the LLM,
630
+ (in which case you may need to adjust your LLM instructions),
631
+ or you need to enable this agent to handle this fn-call.
632
+ """
633
+ )
634
+ return None
529
635
  tool_class = self.llm_tools_map[tool_name]
530
636
  tool_msg.update(dict(request=tool_name))
531
637
  tool = tool_class.parse_obj(tool_msg)
@@ -544,7 +650,7 @@ class Agent(ABC):
544
650
  """
545
651
  tool_name = cast(ToolMessage, ve.model).default_value("request")
546
652
  bad_field_errors = "\n".join(
547
- [f"{e['loc'][0]}: {e['msg']}" for e in ve.errors() if "loc" in e]
653
+ [f"{e['loc']}: {e['msg']}" for e in ve.errors() if "loc" in e]
548
654
  )
549
655
  return f"""
550
656
  There were one or more errors in your attempt to use the
@@ -588,7 +694,7 @@ class Agent(ABC):
588
694
 
589
695
  results_list = [r for r in results if r is not None]
590
696
  if len(results_list) == 0:
591
- return self.handle_message_fallback(msg)
697
+ return None # self.handle_message_fallback(msg)
592
698
  # there was a non-None result
593
699
  chat_doc_results = [r for r in results_list if isinstance(r, ChatDocument)]
594
700
  if len(chat_doc_results) > 1:
@@ -603,19 +709,13 @@ class Agent(ABC):
603
709
 
604
710
  str_doc_results = [r for r in results_list if isinstance(r, str)]
605
711
  final = "\n".join(str_doc_results)
606
- if final == "":
607
- logger.warning(
608
- """final result from a tool handler should not be empty str, since
609
- it would be considered an invalid result and other responders
610
- will be tried, and we may not necessarily want that"""
611
- )
612
712
  return final
613
713
 
614
714
  def handle_message_fallback(
615
715
  self, msg: str | ChatDocument
616
716
  ) -> str | ChatDocument | None:
617
717
  """
618
- Fallback method to handle possible "tool" msg if not other method applies
718
+ Fallback method to handle possible "tool" msg if no other method applies
619
719
  or if an error is thrown.
620
720
  This method can be overridden by subclasses.
621
721
 
@@ -630,7 +730,11 @@ class Agent(ABC):
630
730
  def _get_one_tool_message(self, json_str: str) -> Optional[ToolMessage]:
631
731
  json_data = json.loads(json_str)
632
732
  request = json_data.get("request")
633
- if request is None or request not in self.llm_tools_handled:
733
+ if (
734
+ request is None
735
+ or not (isinstance(request, str))
736
+ or request not in self.llm_tools_handled
737
+ ):
634
738
  return None
635
739
 
636
740
  message_class = self.llm_tools_map.get(request)
@@ -661,8 +765,10 @@ class Agent(ABC):
661
765
  try:
662
766
  result = handler_method(tool)
663
767
  except Exception as e:
664
- # return the error message to the LLM so it can try to fix the error
665
- result = f"Error in tool/function-call {tool_name} usage: {type(e)}: {e}"
768
+ # raise the error here since we are sure it's
769
+ # not a pydantic validation error,
770
+ # which we check in `handle_message`
771
+ raise e
666
772
  return result # type: ignore
667
773
 
668
774
  def num_tokens(self, prompt: str | List[LLMMessage]) -> int:
@@ -671,7 +777,13 @@ class Agent(ABC):
671
777
  if isinstance(prompt, str):
672
778
  return self.parser.num_tokens(prompt)
673
779
  else:
674
- return sum([self.parser.num_tokens(m.content) for m in prompt])
780
+ return sum(
781
+ [
782
+ self.parser.num_tokens(m.content)
783
+ + self.parser.num_tokens(str(m.function_call or ""))
784
+ for m in prompt
785
+ ]
786
+ )
675
787
 
676
788
  def _get_response_stats(
677
789
  self, chat_length: int, tot_cost: float, response: LLMResponse
@@ -696,11 +808,17 @@ class Agent(ABC):
696
808
  assert isinstance(self.llm, LanguageModel)
697
809
  context_length = self.llm.chat_context_length()
698
810
  max_out = self.config.llm.max_output_tokens
811
+
812
+ llm_model = (
813
+ "no-LLM" if self.config.llm is None else self.llm.config.chat_model
814
+ )
815
+
699
816
  return (
700
- f"[bold]Stats:[/bold] [magenta] N_MSG={chat_length}, "
817
+ f"[bold]Stats:[/bold] [magenta]N_MSG={chat_length}, "
701
818
  f"TOKENS: in={in_tokens}, out={out_tokens}, "
702
819
  f"max={max_out}, ctx={context_length}, "
703
- f"COST: now=${llm_response_cost}, cumul=${cumul_cost}[/magenta]"
820
+ f"COST: now=${llm_response_cost}, cumul=${cumul_cost} "
821
+ f"[bold]({llm_model})[/bold][/magenta]"
704
822
  )
705
823
  return ""
706
824
 
@@ -709,6 +827,7 @@ class Agent(ABC):
709
827
  response: LLMResponse,
710
828
  prompt: str | List[LLMMessage],
711
829
  stream: bool,
830
+ chat: bool = True,
712
831
  print_response_stats: bool = True,
713
832
  ) -> None:
714
833
  """
@@ -722,36 +841,48 @@ class Agent(ABC):
722
841
  prompt (str | List[LLMMessage]): prompt or list of LLMMessage objects
723
842
  stream (bool): whether to update the usage in the response object
724
843
  if the response is not cached.
725
- """
726
- if response is not None:
727
- # Note: If response was not streamed, then
728
- # `response.usage` would already have been set by the API,
729
- # so we only need to update in the stream case.
730
- if stream:
731
- # usage, cost = 0 when response is from cache
732
- prompt_tokens = 0
733
- completion_tokens = 0
734
- cost = 0.0
735
- if not response.cached:
736
- prompt_tokens = self.num_tokens(prompt)
737
- completion_tokens = self.num_tokens(response.message)
738
- cost = self.compute_token_cost(prompt_tokens, completion_tokens)
739
- response.usage = LLMTokenUsage(
740
- prompt_tokens=prompt_tokens,
741
- completion_tokens=completion_tokens,
742
- cost=cost,
743
- )
844
+ chat (bool): whether this is a chat model or a completion model
845
+ print_response_stats (bool): whether to print the response stats
846
+ """
847
+ if response is None or self.llm is None:
848
+ return
849
+
850
+ # Note: If response was not streamed, then
851
+ # `response.usage` would already have been set by the API,
852
+ # so we only need to update in the stream case.
853
+ if stream:
854
+ # usage, cost = 0 when response is from cache
855
+ prompt_tokens = 0
856
+ completion_tokens = 0
857
+ cost = 0.0
858
+ if not response.cached:
859
+ prompt_tokens = self.num_tokens(prompt)
860
+ completion_tokens = self.num_tokens(response.message)
861
+ if response.function_call is not None:
862
+ completion_tokens += self.num_tokens(str(response.function_call))
863
+ cost = self.compute_token_cost(prompt_tokens, completion_tokens)
864
+ response.usage = LLMTokenUsage(
865
+ prompt_tokens=prompt_tokens,
866
+ completion_tokens=completion_tokens,
867
+ cost=cost,
868
+ )
744
869
 
745
- # update total counters
746
- if response.usage is not None:
747
- self.total_llm_token_cost += response.usage.cost
748
- self.total_llm_token_usage += response.usage.total_tokens
749
- chat_length = 1 if isinstance(prompt, str) else len(prompt)
750
- self.token_stats_str = self._get_response_stats(
751
- chat_length, self.total_llm_token_cost, response
752
- )
753
- if print_response_stats:
754
- print(self.indent + self.token_stats_str)
870
+ # update total counters
871
+ if response.usage is not None:
872
+ self.total_llm_token_cost += response.usage.cost
873
+ self.total_llm_token_usage += response.usage.total_tokens
874
+ self.llm.update_usage_cost(
875
+ chat,
876
+ response.usage.prompt_tokens,
877
+ response.usage.completion_tokens,
878
+ response.usage.cost,
879
+ )
880
+ chat_length = 1 if isinstance(prompt, str) else len(prompt)
881
+ self.token_stats_str = self._get_response_stats(
882
+ chat_length, self.total_llm_token_cost, response
883
+ )
884
+ if print_response_stats:
885
+ print(self.indent + self.token_stats_str)
755
886
 
756
887
  def compute_token_cost(self, prompt: int, completion: int) -> float:
757
888
  price = cast(LanguageModel, self.llm).chat_cost()
@@ -773,8 +904,8 @@ class Agent(ABC):
773
904
  Args:
774
905
  agent (Agent): agent to ask
775
906
  request (str): request to send
776
- no_answer: expected response when agent does not know the answer
777
- gate_human: whether to gate the request with a human confirmation
907
+ no_answer (str): expected response when agent does not know the answer
908
+ user_confirm (bool): whether to gate the request with a human confirmation
778
909
 
779
910
  Returns:
780
911
  str: response from agent