langroid 0.38.0__tar.gz → 0.39.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langroid-0.38.0 → langroid-0.39.1}/PKG-INFO +1 -1
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/base.py +8 -3
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/chat_agent.py +35 -7
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/doc_chat_agent.py +1 -1
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/__init__.py +4 -3
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/base.py +8 -1
- langroid-0.39.1/langroid/language_models/model_info.py +307 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/openai_gpt.py +45 -153
- {langroid-0.38.0 → langroid-0.39.1}/langroid/mytypes.py +9 -0
- {langroid-0.38.0 → langroid-0.39.1}/pyproject.toml +1 -1
- {langroid-0.38.0 → langroid-0.39.1}/.gitignore +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/LICENSE +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/README.md +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/batch.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/callbacks/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/callbacks/chainlit.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/chat_document.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/openai_assistant.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/arangodb/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/arangodb/arangodb_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/arangodb/system_messages.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/arangodb/tools.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/arangodb/utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/lance_doc_chat_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/lance_rag/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/lance_rag/critic_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/lance_rag/lance_rag_task.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/lance_rag/query_planner_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/lance_tools.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/neo4j/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/neo4j/csv_kg_chat.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/neo4j/system_messages.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/neo4j/tools.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/relevance_extractor_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/retriever_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/sql_chat_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/utils/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/utils/description_extractors.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/utils/populate_metadata.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/utils/system_message.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/sql/utils/tools.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/special/table_chat_agent.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/task.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tool_message.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/duckduckgo_search_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/file_tools.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/google_search_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/metaphor_search_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/orchestration.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/recipient_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/retrieval_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/rewind_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/tools/segment_extract_tool.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/agent/xml_tool_message.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/cachedb/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/cachedb/base.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/cachedb/momento_cachedb.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/cachedb/redis_cachedb.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/base.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/models.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/protoc/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/protoc/embeddings.proto +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/protoc/embeddings_pb2.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/embedding_models/remote_embeds.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/exceptions.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/azure_openai.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/config.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/mock_lm.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/prompt_formatter/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/prompt_formatter/base.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/prompt_formatter/hf_formatter.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/prompt_formatter/llama2_formatter.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/language_models/utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/agent_chats.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/code_parser.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/document_parser.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/para_sentence_split.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/parse_json.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/parser.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/pdf_utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/repo_loader.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/routing.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/search.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/spider.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/table_loader.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/url_loader.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/urls.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/parsing/web_search.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/prompts/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/prompts/dialog.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/prompts/prompts_config.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/prompts/templates.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/py.typed +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/pydantic_v1/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/pydantic_v1/main.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/algorithms/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/algorithms/graph.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/configuration.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/constants.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/git_utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/globals.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/logging.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/object_registry.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/output/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/output/citations.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/output/printing.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/output/status.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/pandas_utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/pydantic_utils.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/system.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/utils/types.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/__init__.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/base.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/chromadb.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/lancedb.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/meilisearch.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/momento.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/qdrantdb.py +0 -0
- {langroid-0.38.0 → langroid-0.39.1}/langroid/vector_store/weaviatedb.py +0 -0
@@ -333,6 +333,11 @@ class Agent(ABC):
|
|
333
333
|
if hasattr(message_class, "handle_message_fallback") and (
|
334
334
|
inspect.isfunction(message_class.handle_message_fallback)
|
335
335
|
):
|
336
|
+
# When a ToolMessage has a `handle_message_fallback` method,
|
337
|
+
# we inject it into the agent as a method, overriding the default
|
338
|
+
# `handle_message_fallback` method (which does nothing).
|
339
|
+
# It's possible multiple tool messages have a `handle_message_fallback`,
|
340
|
+
# in which case, the last one inserted will be used.
|
336
341
|
setattr(
|
337
342
|
self,
|
338
343
|
"handle_message_fallback",
|
@@ -912,7 +917,7 @@ class Agent(ABC):
|
|
912
917
|
else:
|
913
918
|
prompt = message
|
914
919
|
|
915
|
-
output_len = self.config.llm.
|
920
|
+
output_len = self.config.llm.model_max_output_tokens
|
916
921
|
if self.num_tokens(prompt) + output_len > self.llm.completion_context_length():
|
917
922
|
output_len = self.llm.completion_context_length() - self.num_tokens(prompt)
|
918
923
|
if output_len < self.config.llm.min_output_tokens:
|
@@ -981,7 +986,7 @@ class Agent(ABC):
|
|
981
986
|
# show rich spinner only if not streaming!
|
982
987
|
cm = status("LLM responding to message...")
|
983
988
|
stack.enter_context(cm)
|
984
|
-
output_len = self.config.llm.
|
989
|
+
output_len = self.config.llm.model_max_output_tokens
|
985
990
|
if (
|
986
991
|
self.num_tokens(prompt) + output_len
|
987
992
|
> self.llm.completion_context_length()
|
@@ -1866,7 +1871,7 @@ class Agent(ABC):
|
|
1866
1871
|
cumul_cost = format(tot_cost, ".4f")
|
1867
1872
|
assert isinstance(self.llm, LanguageModel)
|
1868
1873
|
context_length = self.llm.chat_context_length()
|
1869
|
-
max_out = self.config.llm.
|
1874
|
+
max_out = self.config.llm.model_max_output_tokens
|
1870
1875
|
|
1871
1876
|
llm_model = (
|
1872
1877
|
"no-LLM" if self.config.llm is None else self.llm.config.chat_model
|
@@ -5,7 +5,7 @@ import logging
|
|
5
5
|
import textwrap
|
6
6
|
from contextlib import ExitStack
|
7
7
|
from inspect import isclass
|
8
|
-
from typing import Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
|
8
|
+
from typing import Any, Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
|
9
9
|
|
10
10
|
import openai
|
11
11
|
from rich import print
|
@@ -31,6 +31,7 @@ from langroid.language_models.base import (
|
|
31
31
|
ToolChoiceTypes,
|
32
32
|
)
|
33
33
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
34
|
+
from langroid.mytypes import Entity, NonToolAction
|
34
35
|
from langroid.pydantic_v1 import BaseModel, ValidationError
|
35
36
|
from langroid.utils.configuration import settings
|
36
37
|
from langroid.utils.object_registry import ObjectRegistry
|
@@ -52,6 +53,7 @@ class ChatAgentConfig(AgentConfig):
|
|
52
53
|
user_message: user message to include in message sequence.
|
53
54
|
Used only if `task` is not specified in the constructor.
|
54
55
|
use_tools: whether to use our own ToolMessages mechanism
|
56
|
+
handle_llm_no_tool (NonToolAction|str): routing when LLM generates non-tool msg.
|
55
57
|
use_functions_api: whether to use functions/tools native to the LLM API
|
56
58
|
(e.g. OpenAI's `function_call` or `tool_call` mechanism)
|
57
59
|
use_tools_api: When `use_functions_api` is True, if this is also True,
|
@@ -84,6 +86,7 @@ class ChatAgentConfig(AgentConfig):
|
|
84
86
|
|
85
87
|
system_message: str = "You are a helpful assistant."
|
86
88
|
user_message: Optional[str] = None
|
89
|
+
handle_llm_no_tool: NonToolAction | None = None
|
87
90
|
use_tools: bool = False
|
88
91
|
use_functions_api: bool = True
|
89
92
|
use_tools_api: bool = False
|
@@ -579,6 +582,31 @@ class ChatAgent(Agent):
|
|
579
582
|
# remove leading and trailing newlines and other whitespace
|
580
583
|
return LLMMessage(role=Role.SYSTEM, content=content.strip())
|
581
584
|
|
585
|
+
def handle_message_fallback(self, msg: str | ChatDocument) -> Any:
|
586
|
+
"""
|
587
|
+
Fallback method for the "no-tools" scenario.
|
588
|
+
Users the self.config.non_tool_routing to determine the action to take.
|
589
|
+
|
590
|
+
This method can be overridden by subclasses, e.g.,
|
591
|
+
to create a "reminder" message when a tool is expected but the LLM "forgot"
|
592
|
+
to generate one.
|
593
|
+
|
594
|
+
Args:
|
595
|
+
msg (str | ChatDocument): The input msg to handle
|
596
|
+
Returns:
|
597
|
+
Any: The result of the handler method
|
598
|
+
"""
|
599
|
+
if self.config.handle_llm_no_tool is None:
|
600
|
+
return None
|
601
|
+
if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
|
602
|
+
from langroid.agent.tools.orchestration import AgentDoneTool, ForwardTool
|
603
|
+
|
604
|
+
match self.config.handle_llm_no_tool:
|
605
|
+
case NonToolAction.FORWARD_USER:
|
606
|
+
return ForwardTool(agent="User")
|
607
|
+
case NonToolAction.DONE:
|
608
|
+
return AgentDoneTool(content=msg.content, tools=msg.tool_messages)
|
609
|
+
|
582
610
|
def unhandled_tools(self) -> set[str]:
|
583
611
|
"""The set of tools that are known but not handled.
|
584
612
|
Useful in task flow: an agent can refuse to accept an incoming msg
|
@@ -1460,11 +1488,11 @@ class ChatAgent(Agent):
|
|
1460
1488
|
self.message_history.extend(llm_msgs)
|
1461
1489
|
|
1462
1490
|
hist = self.message_history
|
1463
|
-
output_len = self.config.llm.
|
1491
|
+
output_len = self.config.llm.model_max_output_tokens
|
1464
1492
|
if (
|
1465
1493
|
truncate
|
1466
1494
|
and self.chat_num_tokens(hist)
|
1467
|
-
> self.llm.chat_context_length() - self.config.llm.
|
1495
|
+
> self.llm.chat_context_length() - self.config.llm.model_max_output_tokens
|
1468
1496
|
):
|
1469
1497
|
# chat + output > max context length,
|
1470
1498
|
# so first try to shorten requested output len to fit.
|
@@ -1489,7 +1517,7 @@ class ChatAgent(Agent):
|
|
1489
1517
|
The message history is longer than the max chat context
|
1490
1518
|
length allowed, and we have run out of messages to drop.
|
1491
1519
|
HINT: In your `OpenAIGPTConfig` object, try increasing
|
1492
|
-
`chat_context_length` or decreasing `
|
1520
|
+
`chat_context_length` or decreasing `model_max_output_tokens`.
|
1493
1521
|
"""
|
1494
1522
|
)
|
1495
1523
|
# drop the second message, i.e. first msg after the sys msg
|
@@ -1638,12 +1666,12 @@ class ChatAgent(Agent):
|
|
1638
1666
|
Args:
|
1639
1667
|
messages: seq of messages (with role, content fields) sent to LLM
|
1640
1668
|
output_len: max number of tokens expected in response.
|
1641
|
-
If None, use the LLM's default
|
1669
|
+
If None, use the LLM's default model_max_output_tokens.
|
1642
1670
|
Returns:
|
1643
1671
|
Document (i.e. with fields "content", "metadata")
|
1644
1672
|
"""
|
1645
1673
|
assert self.config.llm is not None and self.llm is not None
|
1646
|
-
output_len = output_len or self.config.llm.
|
1674
|
+
output_len = output_len or self.config.llm.model_max_output_tokens
|
1647
1675
|
streamer = noop_fn
|
1648
1676
|
if self.llm.get_stream():
|
1649
1677
|
streamer = self.callbacks.start_llm_stream()
|
@@ -1713,7 +1741,7 @@ class ChatAgent(Agent):
|
|
1713
1741
|
Async version of `llm_response_messages`. See there for details.
|
1714
1742
|
"""
|
1715
1743
|
assert self.config.llm is not None and self.llm is not None
|
1716
|
-
output_len = output_len or self.config.llm.
|
1744
|
+
output_len = output_len or self.config.llm.model_max_output_tokens
|
1717
1745
|
functions, fun_call, tools, force_tool, output_format = self._function_args()
|
1718
1746
|
assert self.llm is not None
|
1719
1747
|
|
@@ -1565,7 +1565,7 @@ class DocChatAgent(ChatAgent):
|
|
1565
1565
|
tot_tokens = self.parser.num_tokens(full_text)
|
1566
1566
|
MAX_INPUT_TOKENS = (
|
1567
1567
|
self.llm.completion_context_length()
|
1568
|
-
- self.config.llm.
|
1568
|
+
- self.config.llm.model_max_output_tokens
|
1569
1569
|
- 100
|
1570
1570
|
)
|
1571
1571
|
if tot_tokens > MAX_INPUT_TOKENS:
|
@@ -15,14 +15,13 @@ from .base import (
|
|
15
15
|
LLMTokenUsage,
|
16
16
|
LLMResponse,
|
17
17
|
)
|
18
|
-
from .
|
18
|
+
from .model_info import (
|
19
19
|
OpenAIChatModel,
|
20
20
|
AnthropicModel,
|
21
21
|
GeminiModel,
|
22
22
|
OpenAICompletionModel,
|
23
|
-
OpenAIGPTConfig,
|
24
|
-
OpenAIGPT,
|
25
23
|
)
|
24
|
+
from .openai_gpt import OpenAIGPTConfig, OpenAIGPT, OpenAICallParams
|
26
25
|
from .mock_lm import MockLM, MockLMConfig
|
27
26
|
from .azure_openai import AzureConfig, AzureGPT
|
28
27
|
|
@@ -32,6 +31,7 @@ __all__ = [
|
|
32
31
|
"config",
|
33
32
|
"base",
|
34
33
|
"openai_gpt",
|
34
|
+
"model_info",
|
35
35
|
"azure_openai",
|
36
36
|
"prompt_formatter",
|
37
37
|
"StreamEventType",
|
@@ -48,6 +48,7 @@ __all__ = [
|
|
48
48
|
"OpenAICompletionModel",
|
49
49
|
"OpenAIGPTConfig",
|
50
50
|
"OpenAIGPT",
|
51
|
+
"OpenAICallParams",
|
51
52
|
"AzureConfig",
|
52
53
|
"AzureGPT",
|
53
54
|
"MockLM",
|
@@ -19,6 +19,7 @@ from typing import (
|
|
19
19
|
|
20
20
|
from langroid.cachedb.base import CacheDBConfig
|
21
21
|
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
22
|
+
from langroid.language_models.model_info import get_model_info
|
22
23
|
from langroid.parsing.agent_chats import parse_message
|
23
24
|
from langroid.parsing.parse_json import parse_imperfect_json, top_level_json_field
|
24
25
|
from langroid.prompts.dialog import collate_chat_history
|
@@ -60,6 +61,7 @@ class LLMConfig(BaseSettings):
|
|
60
61
|
streamer_async: Optional[Callable[..., Awaitable[None]]] = async_noop_fn
|
61
62
|
api_base: str | None = None
|
62
63
|
formatter: None | str = None
|
64
|
+
max_output_tokens: int | None = 8192 # specify None to use model_max_output_tokens
|
63
65
|
timeout: int = 20 # timeout for API requests
|
64
66
|
chat_model: str = ""
|
65
67
|
completion_model: str = ""
|
@@ -67,7 +69,6 @@ class LLMConfig(BaseSettings):
|
|
67
69
|
chat_context_length: int = 8000
|
68
70
|
async_stream_quiet: bool = True # suppress streaming output in async mode?
|
69
71
|
completion_context_length: int = 8000
|
70
|
-
max_output_tokens: int = 1024 # generate at most this many tokens
|
71
72
|
# if input length + max_output_tokens > context length of model,
|
72
73
|
# we will try shortening requested output
|
73
74
|
min_output_tokens: int = 64
|
@@ -84,6 +85,12 @@ class LLMConfig(BaseSettings):
|
|
84
85
|
chat_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
85
86
|
completion_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
86
87
|
|
88
|
+
@property
|
89
|
+
def model_max_output_tokens(self) -> int:
|
90
|
+
return (
|
91
|
+
self.max_output_tokens or get_model_info(self.chat_model).max_output_tokens
|
92
|
+
)
|
93
|
+
|
87
94
|
|
88
95
|
class LLMFunctionCall(BaseModel):
|
89
96
|
"""
|
@@ -0,0 +1,307 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Dict, List, Optional
|
3
|
+
|
4
|
+
from langroid.pydantic_v1 import BaseModel
|
5
|
+
|
6
|
+
|
7
|
+
class ModelProvider(str, Enum):
|
8
|
+
"""Enum for model providers"""
|
9
|
+
|
10
|
+
OPENAI = "openai"
|
11
|
+
ANTHROPIC = "anthropic"
|
12
|
+
DEEPSEEK = "deepseek"
|
13
|
+
GOOGLE = "google"
|
14
|
+
UNKNOWN = "unknown"
|
15
|
+
|
16
|
+
|
17
|
+
class ModelName(str, Enum):
|
18
|
+
"""Parent class for all model name enums"""
|
19
|
+
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class OpenAIChatModel(ModelName):
|
24
|
+
"""Enum for OpenAI Chat models"""
|
25
|
+
|
26
|
+
GPT3_5_TURBO = "gpt-3.5-turbo-1106"
|
27
|
+
GPT4 = "gpt-4"
|
28
|
+
GPT4_TURBO = "gpt-4-turbo"
|
29
|
+
GPT4o = "gpt-4o"
|
30
|
+
GPT4o_MINI = "gpt-4o-mini"
|
31
|
+
O1 = "o1"
|
32
|
+
O1_MINI = "o1-mini"
|
33
|
+
O3_MINI = "o3-mini"
|
34
|
+
|
35
|
+
|
36
|
+
class OpenAICompletionModel(str, Enum):
|
37
|
+
"""Enum for OpenAI Completion models"""
|
38
|
+
|
39
|
+
DAVINCI = "davinci-002"
|
40
|
+
BABBAGE = "babbage-002"
|
41
|
+
|
42
|
+
|
43
|
+
class AnthropicModel(ModelName):
|
44
|
+
"""Enum for Anthropic models"""
|
45
|
+
|
46
|
+
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
47
|
+
CLAUDE_3_OPUS = "claude-3-opus-latest"
|
48
|
+
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
49
|
+
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
50
|
+
|
51
|
+
|
52
|
+
class DeepSeekModel(ModelName):
|
53
|
+
"""Enum for DeepSeek models direct from DeepSeek API"""
|
54
|
+
|
55
|
+
DEEPSEEK = "deepseek/deepseek-chat"
|
56
|
+
DEEPSEEK_R1 = "deepseek/deepseek-reasoner"
|
57
|
+
|
58
|
+
|
59
|
+
class GeminiModel(ModelName):
|
60
|
+
"""Enum for Gemini models"""
|
61
|
+
|
62
|
+
GEMINI_1_5_FLASH = "gemini/gemini-1.5-flash"
|
63
|
+
GEMINI_1_5_FLASH_8B = "gemini/gemini-1.5-flash-8b"
|
64
|
+
GEMINI_1_5_PRO = "gemini/gemini-1.5-pro"
|
65
|
+
GEMINI_2_FLASH = "gemini/gemini-2.0-flash-exp"
|
66
|
+
GEMINI_2_FLASH_THINKING = "gemini/gemini-2.0-flash-thinking-exp"
|
67
|
+
|
68
|
+
|
69
|
+
class ModelInfo(BaseModel):
|
70
|
+
"""
|
71
|
+
Consolidated information about LLM, related to capacity, cost and API
|
72
|
+
idiosyncrasies. Reasonable defaults for all params in case there's no
|
73
|
+
specific info available.
|
74
|
+
"""
|
75
|
+
|
76
|
+
name: str = "unknown"
|
77
|
+
provider: ModelProvider = ModelProvider.UNKNOWN
|
78
|
+
context_length: int = 16_000
|
79
|
+
max_cot_tokens: int = 0 # max chain of thought (thinking) tokens where applicable
|
80
|
+
max_output_tokens: int = 8192 # Maximum number of output tokens - model dependent
|
81
|
+
input_cost_per_million: float = 0.0 # Cost in USD per million input tokens
|
82
|
+
output_cost_per_million: float = 0.0 # Cost in USD per million output tokens
|
83
|
+
allows_streaming: bool = True # Whether model supports streaming output
|
84
|
+
allows_system_message: bool = True # Whether model supports system messages
|
85
|
+
rename_params: Dict[str, str] = {} # Rename parameters for OpenAI API
|
86
|
+
unsupported_params: List[str] = []
|
87
|
+
has_structured_output: bool = False # Does model API support structured output?
|
88
|
+
has_tools: bool = True # Does model API support tools/function-calling?
|
89
|
+
needs_first_user_message: bool = False # Does API need first msg to be from user?
|
90
|
+
description: Optional[str] = None
|
91
|
+
|
92
|
+
|
93
|
+
# Model information registry
|
94
|
+
MODEL_INFO: Dict[str, ModelInfo] = {
|
95
|
+
# OpenAI Models
|
96
|
+
OpenAICompletionModel.DAVINCI.value: ModelInfo(
|
97
|
+
name=OpenAICompletionModel.DAVINCI.value,
|
98
|
+
provider=ModelProvider.OPENAI,
|
99
|
+
context_length=4096,
|
100
|
+
max_output_tokens=4096,
|
101
|
+
input_cost_per_million=2.0,
|
102
|
+
output_cost_per_million=2.0,
|
103
|
+
description="Davinci-002",
|
104
|
+
),
|
105
|
+
OpenAICompletionModel.BABBAGE.value: ModelInfo(
|
106
|
+
name=OpenAICompletionModel.BABBAGE.value,
|
107
|
+
provider=ModelProvider.OPENAI,
|
108
|
+
context_length=4096,
|
109
|
+
max_output_tokens=4096,
|
110
|
+
input_cost_per_million=0.40,
|
111
|
+
output_cost_per_million=0.40,
|
112
|
+
description="Babbage-002",
|
113
|
+
),
|
114
|
+
OpenAIChatModel.GPT3_5_TURBO.value: ModelInfo(
|
115
|
+
name=OpenAIChatModel.GPT3_5_TURBO.value,
|
116
|
+
provider=ModelProvider.OPENAI,
|
117
|
+
context_length=16_385,
|
118
|
+
max_output_tokens=4096,
|
119
|
+
input_cost_per_million=0.50,
|
120
|
+
output_cost_per_million=1.50,
|
121
|
+
description="GPT-3.5 Turbo",
|
122
|
+
),
|
123
|
+
OpenAIChatModel.GPT4.value: ModelInfo(
|
124
|
+
name=OpenAIChatModel.GPT4.value,
|
125
|
+
provider=ModelProvider.OPENAI,
|
126
|
+
context_length=8192,
|
127
|
+
max_output_tokens=8192,
|
128
|
+
input_cost_per_million=30.0,
|
129
|
+
output_cost_per_million=60.0,
|
130
|
+
description="GPT-4 (8K context)",
|
131
|
+
),
|
132
|
+
OpenAIChatModel.GPT4_TURBO.value: ModelInfo(
|
133
|
+
name=OpenAIChatModel.GPT4_TURBO.value,
|
134
|
+
provider=ModelProvider.OPENAI,
|
135
|
+
context_length=128_000,
|
136
|
+
max_output_tokens=4096,
|
137
|
+
input_cost_per_million=10.0,
|
138
|
+
output_cost_per_million=30.0,
|
139
|
+
description="GPT-4 Turbo",
|
140
|
+
),
|
141
|
+
OpenAIChatModel.GPT4o.value: ModelInfo(
|
142
|
+
name=OpenAIChatModel.GPT4o.value,
|
143
|
+
provider=ModelProvider.OPENAI,
|
144
|
+
context_length=128_000,
|
145
|
+
max_output_tokens=16_384,
|
146
|
+
input_cost_per_million=2.5,
|
147
|
+
output_cost_per_million=10.0,
|
148
|
+
has_structured_output=True,
|
149
|
+
description="GPT-4o (128K context)",
|
150
|
+
),
|
151
|
+
OpenAIChatModel.GPT4o_MINI.value: ModelInfo(
|
152
|
+
name=OpenAIChatModel.GPT4o_MINI.value,
|
153
|
+
provider=ModelProvider.OPENAI,
|
154
|
+
context_length=128_000,
|
155
|
+
max_output_tokens=16_384,
|
156
|
+
input_cost_per_million=0.15,
|
157
|
+
output_cost_per_million=0.60,
|
158
|
+
has_structured_output=True,
|
159
|
+
description="GPT-4o Mini",
|
160
|
+
),
|
161
|
+
OpenAIChatModel.O1.value: ModelInfo(
|
162
|
+
name=OpenAIChatModel.O1.value,
|
163
|
+
provider=ModelProvider.OPENAI,
|
164
|
+
context_length=200_000,
|
165
|
+
max_output_tokens=100_000,
|
166
|
+
input_cost_per_million=15.0,
|
167
|
+
output_cost_per_million=60.0,
|
168
|
+
allows_streaming=False,
|
169
|
+
allows_system_message=False,
|
170
|
+
unsupported_params=["temperature", "stream"],
|
171
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
172
|
+
has_tools=False,
|
173
|
+
description="O1 Reasoning LM",
|
174
|
+
),
|
175
|
+
OpenAIChatModel.O1_MINI.value: ModelInfo(
|
176
|
+
name=OpenAIChatModel.O1_MINI.value,
|
177
|
+
provider=ModelProvider.OPENAI,
|
178
|
+
context_length=128_000,
|
179
|
+
max_output_tokens=65_536,
|
180
|
+
input_cost_per_million=1.1,
|
181
|
+
output_cost_per_million=4.4,
|
182
|
+
allows_streaming=False,
|
183
|
+
allows_system_message=False,
|
184
|
+
unsupported_params=["temperature", "stream"],
|
185
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
186
|
+
has_tools=False,
|
187
|
+
description="O1 Mini Reasoning LM",
|
188
|
+
),
|
189
|
+
OpenAIChatModel.O3_MINI.value: ModelInfo(
|
190
|
+
name=OpenAIChatModel.O3_MINI.value,
|
191
|
+
provider=ModelProvider.OPENAI,
|
192
|
+
context_length=200_000,
|
193
|
+
max_output_tokens=100_000,
|
194
|
+
input_cost_per_million=1.1,
|
195
|
+
output_cost_per_million=4.4,
|
196
|
+
allows_streaming=False,
|
197
|
+
allows_system_message=False,
|
198
|
+
unsupported_params=["temperature", "stream"],
|
199
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
200
|
+
has_tools=False,
|
201
|
+
description="O3 Mini Reasoning LM",
|
202
|
+
),
|
203
|
+
# Anthropic Models
|
204
|
+
AnthropicModel.CLAUDE_3_5_SONNET.value: ModelInfo(
|
205
|
+
name=AnthropicModel.CLAUDE_3_5_SONNET.value,
|
206
|
+
provider=ModelProvider.ANTHROPIC,
|
207
|
+
context_length=200_000,
|
208
|
+
max_output_tokens=8192,
|
209
|
+
input_cost_per_million=3.0,
|
210
|
+
output_cost_per_million=15.0,
|
211
|
+
description="Claude 3.5 Sonnet",
|
212
|
+
),
|
213
|
+
AnthropicModel.CLAUDE_3_OPUS.value: ModelInfo(
|
214
|
+
name=AnthropicModel.CLAUDE_3_OPUS.value,
|
215
|
+
provider=ModelProvider.ANTHROPIC,
|
216
|
+
context_length=200_000,
|
217
|
+
max_output_tokens=4096,
|
218
|
+
input_cost_per_million=15.0,
|
219
|
+
output_cost_per_million=75.0,
|
220
|
+
description="Claude 3 Opus",
|
221
|
+
),
|
222
|
+
AnthropicModel.CLAUDE_3_SONNET.value: ModelInfo(
|
223
|
+
name=AnthropicModel.CLAUDE_3_SONNET.value,
|
224
|
+
provider=ModelProvider.ANTHROPIC,
|
225
|
+
context_length=200_000,
|
226
|
+
max_output_tokens=4096,
|
227
|
+
input_cost_per_million=3.0,
|
228
|
+
output_cost_per_million=15.0,
|
229
|
+
description="Claude 3 Sonnet",
|
230
|
+
),
|
231
|
+
AnthropicModel.CLAUDE_3_HAIKU.value: ModelInfo(
|
232
|
+
name=AnthropicModel.CLAUDE_3_HAIKU.value,
|
233
|
+
provider=ModelProvider.ANTHROPIC,
|
234
|
+
context_length=200_000,
|
235
|
+
max_output_tokens=4096,
|
236
|
+
input_cost_per_million=0.25,
|
237
|
+
output_cost_per_million=1.25,
|
238
|
+
description="Claude 3 Haiku",
|
239
|
+
),
|
240
|
+
# DeepSeek Models
|
241
|
+
DeepSeekModel.DEEPSEEK.value: ModelInfo(
|
242
|
+
name=DeepSeekModel.DEEPSEEK.value,
|
243
|
+
provider=ModelProvider.DEEPSEEK,
|
244
|
+
context_length=64_000,
|
245
|
+
max_output_tokens=8_000,
|
246
|
+
input_cost_per_million=0.27,
|
247
|
+
output_cost_per_million=1.10,
|
248
|
+
description="DeepSeek Chat",
|
249
|
+
),
|
250
|
+
DeepSeekModel.DEEPSEEK_R1.value: ModelInfo(
|
251
|
+
name=DeepSeekModel.DEEPSEEK_R1.value,
|
252
|
+
provider=ModelProvider.DEEPSEEK,
|
253
|
+
context_length=64_000,
|
254
|
+
max_output_tokens=8_000,
|
255
|
+
input_cost_per_million=0.55,
|
256
|
+
output_cost_per_million=2.19,
|
257
|
+
description="DeepSeek-R1 Reasoning LM",
|
258
|
+
),
|
259
|
+
# Gemini Models
|
260
|
+
GeminiModel.GEMINI_2_FLASH.value: ModelInfo(
|
261
|
+
name=GeminiModel.GEMINI_2_FLASH.value,
|
262
|
+
provider=ModelProvider.GOOGLE,
|
263
|
+
context_length=1_056_768,
|
264
|
+
max_output_tokens=8192,
|
265
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
266
|
+
description="Gemini 2.0 Flash",
|
267
|
+
),
|
268
|
+
GeminiModel.GEMINI_1_5_FLASH.value: ModelInfo(
|
269
|
+
name=GeminiModel.GEMINI_1_5_FLASH.value,
|
270
|
+
provider=ModelProvider.GOOGLE,
|
271
|
+
context_length=1_056_768,
|
272
|
+
max_output_tokens=8192,
|
273
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
274
|
+
description="Gemini 1.5 Flash",
|
275
|
+
),
|
276
|
+
GeminiModel.GEMINI_1_5_FLASH_8B.value: ModelInfo(
|
277
|
+
name=GeminiModel.GEMINI_1_5_FLASH_8B.value,
|
278
|
+
provider=ModelProvider.GOOGLE,
|
279
|
+
context_length=1_000_000,
|
280
|
+
max_output_tokens=8192,
|
281
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
282
|
+
description="Gemini 1.5 Flash 8B",
|
283
|
+
),
|
284
|
+
GeminiModel.GEMINI_1_5_PRO.value: ModelInfo(
|
285
|
+
name=GeminiModel.GEMINI_1_5_PRO.value,
|
286
|
+
provider=ModelProvider.GOOGLE,
|
287
|
+
context_length=2_000_000,
|
288
|
+
max_output_tokens=8192,
|
289
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
290
|
+
description="Gemini 1.5 Pro",
|
291
|
+
),
|
292
|
+
GeminiModel.GEMINI_2_FLASH_THINKING.value: ModelInfo(
|
293
|
+
name=GeminiModel.GEMINI_2_FLASH_THINKING.value,
|
294
|
+
provider=ModelProvider.GOOGLE,
|
295
|
+
context_length=1_000_000,
|
296
|
+
max_output_tokens=64_000,
|
297
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
298
|
+
description="Gemini 2.0 Flash Thinking",
|
299
|
+
),
|
300
|
+
}
|
301
|
+
|
302
|
+
|
303
|
+
def get_model_info(model: str | ModelName) -> ModelInfo:
|
304
|
+
"""Get model information by name or enum value"""
|
305
|
+
if isinstance(model, str):
|
306
|
+
return MODEL_INFO.get(model) or ModelInfo()
|
307
|
+
return MODEL_INFO.get(model.value) or ModelInfo()
|