langroid 0.56.10__tar.gz → 0.56.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langroid-0.56.10 → langroid-0.56.12}/PKG-INFO +1 -1
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/base.py +5 -3
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/chat_agent.py +12 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/base.py +25 -19
- langroid-0.56.12/langroid/language_models/client_cache.py +255 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/model_info.py +57 -3
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/openai_gpt.py +102 -41
- {langroid-0.56.10 → langroid-0.56.12}/pyproject.toml +1 -1
- {langroid-0.56.10 → langroid-0.56.12}/.gitignore +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/LICENSE +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/README.md +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/batch.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/callbacks/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/callbacks/chainlit.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/chat_document.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/done_sequence_parser.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/openai_assistant.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/arangodb/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/arangodb/arangodb_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/arangodb/system_messages.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/arangodb/tools.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/arangodb/utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/doc_chat_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/doc_chat_task.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/lance_doc_chat_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/lance_rag/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/lance_rag/critic_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/lance_rag/lance_rag_task.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/lance_rag/query_planner_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/lance_tools.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/neo4j/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/neo4j/csv_kg_chat.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/neo4j/system_messages.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/neo4j/tools.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/relevance_extractor_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/retriever_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/sql_chat_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/utils/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/utils/description_extractors.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/utils/populate_metadata.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/utils/system_message.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/sql/utils/tools.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/special/table_chat_agent.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/task.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tool_message.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/duckduckgo_search_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/exa_search_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/file_tools.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/google_search_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/mcp/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/mcp/decorators.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/mcp/fastmcp_client.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/metaphor_search_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/orchestration.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/recipient_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/retrieval_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/rewind_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/segment_extract_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/task_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/tools/tavily_search_tool.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/agent/xml_tool_message.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/cachedb/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/cachedb/base.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/cachedb/redis_cachedb.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/base.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/models.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/protoc/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/protoc/embeddings.proto +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/protoc/embeddings_pb2.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/embedding_models/remote_embeds.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/exceptions.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/azure_openai.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/config.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/mock_lm.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/prompt_formatter/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/prompt_formatter/base.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/prompt_formatter/hf_formatter.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/prompt_formatter/llama2_formatter.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/provider_params.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/language_models/utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/mcp/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/mcp/server/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/mytypes.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/agent_chats.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/code_parser.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/document_parser.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/file_attachment.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/md_parser.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/para_sentence_split.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/parse_json.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/parser.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/pdf_utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/repo_loader.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/routing.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/search.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/spider.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/table_loader.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/url_loader.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/urls.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/parsing/web_search.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/prompts/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/prompts/dialog.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/prompts/prompts_config.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/prompts/templates.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/py.typed +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/pydantic_v1/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/pydantic_v1/main.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/algorithms/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/algorithms/graph.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/configuration.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/constants.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/git_utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/globals.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/logging.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/object_registry.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/output/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/output/citations.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/output/printing.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/output/status.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/pandas_utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/pydantic_utils.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/system.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/utils/types.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/__init__.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/base.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/chromadb.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/lancedb.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/meilisearch.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/pineconedb.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/postgres.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/qdrantdb.py +0 -0
- {langroid-0.56.10 → langroid-0.56.12}/langroid/vector_store/weaviatedb.py +0 -0
@@ -2142,7 +2142,7 @@ class Agent(ABC):
|
|
2142
2142
|
completion_tokens = self.num_tokens(response.message)
|
2143
2143
|
if response.function_call is not None:
|
2144
2144
|
completion_tokens += self.num_tokens(str(response.function_call))
|
2145
|
-
cost = self.compute_token_cost(prompt_tokens, completion_tokens)
|
2145
|
+
cost = self.compute_token_cost(prompt_tokens, 0, completion_tokens)
|
2146
2146
|
response.usage = LLMTokenUsage(
|
2147
2147
|
prompt_tokens=prompt_tokens,
|
2148
2148
|
completion_tokens=completion_tokens,
|
@@ -2166,9 +2166,11 @@ class Agent(ABC):
|
|
2166
2166
|
if print_response_stats:
|
2167
2167
|
print(self.indent + self.token_stats_str)
|
2168
2168
|
|
2169
|
-
def compute_token_cost(self, prompt: int, completion: int) -> float:
|
2169
|
+
def compute_token_cost(self, prompt: int, cached: int, completion: int) -> float:
|
2170
2170
|
price = cast(LanguageModel, self.llm).chat_cost()
|
2171
|
-
return (
|
2171
|
+
return (
|
2172
|
+
price[0] * (prompt - cached) + price[1] * cached + price[2] * completion
|
2173
|
+
) / 1000
|
2172
2174
|
|
2173
2175
|
def ask_agent(
|
2174
2176
|
self,
|
@@ -2068,3 +2068,15 @@ class ChatAgent(Agent):
|
|
2068
2068
|
return str(self.message_history[i])
|
2069
2069
|
else:
|
2070
2070
|
return "\n".join([str(m) for m in self.message_history[i:]])
|
2071
|
+
|
2072
|
+
def __del__(self) -> None:
|
2073
|
+
"""
|
2074
|
+
Cleanup method called when the ChatAgent is garbage collected.
|
2075
|
+
Note: We don't close LLM clients here because they may be shared
|
2076
|
+
across multiple agents when client caching is enabled.
|
2077
|
+
The clients are managed centrally and cleaned up via atexit hooks.
|
2078
|
+
"""
|
2079
|
+
# Previously we closed clients here, but this caused issues when
|
2080
|
+
# multiple agents shared the same cached client instance.
|
2081
|
+
# Clients are now managed centrally in langroid.language_models.client_cache
|
2082
|
+
pass
|
@@ -91,10 +91,6 @@ class LLMConfig(BaseSettings):
|
|
91
91
|
# reasoning output from reasoning models
|
92
92
|
cache_config: None | CacheDBConfig = RedisCacheConfig()
|
93
93
|
thought_delimiters: Tuple[str, str] = ("<think>", "</think>")
|
94
|
-
|
95
|
-
# Dict of model -> (input/prompt cost, output/completion cost)
|
96
|
-
chat_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
97
|
-
completion_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
98
94
|
retry_params: RetryParams = RetryParams()
|
99
95
|
|
100
96
|
@property
|
@@ -131,7 +127,7 @@ class LLMFunctionCall(BaseModel):
|
|
131
127
|
if not isinstance(dict_or_list, dict):
|
132
128
|
raise ValueError(
|
133
129
|
f"""
|
134
|
-
Invalid function args: {fun_args_str}
|
130
|
+
Invalid function args: {fun_args_str}
|
135
131
|
parsed as {dict_or_list},
|
136
132
|
which is not a valid dict.
|
137
133
|
"""
|
@@ -224,12 +220,14 @@ class LLMTokenUsage(BaseModel):
|
|
224
220
|
"""
|
225
221
|
|
226
222
|
prompt_tokens: int = 0
|
223
|
+
cached_tokens: int = 0
|
227
224
|
completion_tokens: int = 0
|
228
225
|
cost: float = 0.0
|
229
226
|
calls: int = 0 # how many API calls - not used as of 2025-04-04
|
230
227
|
|
231
228
|
def reset(self) -> None:
|
232
229
|
self.prompt_tokens = 0
|
230
|
+
self.cached_tokens = 0
|
233
231
|
self.completion_tokens = 0
|
234
232
|
self.cost = 0.0
|
235
233
|
self.calls = 0
|
@@ -237,7 +235,8 @@ class LLMTokenUsage(BaseModel):
|
|
237
235
|
def __str__(self) -> str:
|
238
236
|
return (
|
239
237
|
f"Tokens = "
|
240
|
-
f"(prompt {self.prompt_tokens},
|
238
|
+
f"(prompt {self.prompt_tokens}, cached {self.cached_tokens}, "
|
239
|
+
f"completion {self.completion_tokens}), "
|
241
240
|
f"Cost={self.cost}, Calls={self.calls}"
|
242
241
|
)
|
243
242
|
|
@@ -462,9 +461,9 @@ class LanguageModel(ABC):
|
|
462
461
|
if type(config) is LLMConfig:
|
463
462
|
raise ValueError(
|
464
463
|
"""
|
465
|
-
Cannot create a Language Model object from LLMConfig.
|
466
|
-
Please specify a specific subclass of LLMConfig e.g.,
|
467
|
-
OpenAIGPTConfig. If you are creating a ChatAgent from
|
464
|
+
Cannot create a Language Model object from LLMConfig.
|
465
|
+
Please specify a specific subclass of LLMConfig e.g.,
|
466
|
+
OpenAIGPTConfig. If you are creating a ChatAgent from
|
468
467
|
a ChatAgentConfig, please specify the `llm` field of this config
|
469
468
|
as a specific subclass of LLMConfig, e.g., OpenAIGPTConfig.
|
470
469
|
"""
|
@@ -666,8 +665,15 @@ class LanguageModel(ABC):
|
|
666
665
|
def completion_context_length(self) -> int:
|
667
666
|
return self.config.completion_context_length or DEFAULT_CONTEXT_LENGTH
|
668
667
|
|
669
|
-
def chat_cost(self) -> Tuple[float, float]:
|
670
|
-
|
668
|
+
def chat_cost(self) -> Tuple[float, float, float]:
|
669
|
+
"""
|
670
|
+
Return the cost per 1000 tokens for chat completions.
|
671
|
+
|
672
|
+
Returns:
|
673
|
+
Tuple[float, float, float]: (input_cost, cached_cost, output_cost)
|
674
|
+
per 1000 tokens
|
675
|
+
"""
|
676
|
+
return (0.0, 0.0, 0.0)
|
671
677
|
|
672
678
|
def reset_usage_cost(self) -> None:
|
673
679
|
for mdl in [self.config.chat_model, self.config.completion_model]:
|
@@ -754,18 +760,18 @@ class LanguageModel(ABC):
|
|
754
760
|
|
755
761
|
prompt = f"""
|
756
762
|
You are an expert at understanding a CHAT HISTORY between an AI Assistant
|
757
|
-
and a User, and you are highly skilled in rephrasing the User's FOLLOW-UP
|
758
|
-
QUESTION/REQUEST as a STANDALONE QUESTION/REQUEST that can be understood
|
763
|
+
and a User, and you are highly skilled in rephrasing the User's FOLLOW-UP
|
764
|
+
QUESTION/REQUEST as a STANDALONE QUESTION/REQUEST that can be understood
|
759
765
|
WITHOUT the context of the chat history.
|
760
|
-
|
761
|
-
Below is the CHAT HISTORY. When the User asks you to rephrase a
|
762
|
-
FOLLOW-UP QUESTION/REQUEST, your ONLY task is to simply return the
|
763
|
-
question REPHRASED as a STANDALONE QUESTION/REQUEST, without any additional
|
766
|
+
|
767
|
+
Below is the CHAT HISTORY. When the User asks you to rephrase a
|
768
|
+
FOLLOW-UP QUESTION/REQUEST, your ONLY task is to simply return the
|
769
|
+
question REPHRASED as a STANDALONE QUESTION/REQUEST, without any additional
|
764
770
|
text or context.
|
765
|
-
|
771
|
+
|
766
772
|
<CHAT_HISTORY>
|
767
773
|
{history}
|
768
|
-
</CHAT_HISTORY>
|
774
|
+
</CHAT_HISTORY>
|
769
775
|
""".strip()
|
770
776
|
|
771
777
|
follow_up_question = f"""
|
@@ -0,0 +1,255 @@
|
|
1
|
+
"""
|
2
|
+
Client caching/singleton pattern for LLM clients to prevent connection pool exhaustion.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import atexit
|
6
|
+
import hashlib
|
7
|
+
import weakref
|
8
|
+
from typing import Any, Dict, Optional, Union, cast
|
9
|
+
|
10
|
+
from cerebras.cloud.sdk import AsyncCerebras, Cerebras
|
11
|
+
from groq import AsyncGroq, Groq
|
12
|
+
from httpx import Timeout
|
13
|
+
from openai import AsyncOpenAI, OpenAI
|
14
|
+
|
15
|
+
# Cache for client instances, keyed by hashed configuration parameters
|
16
|
+
_client_cache: Dict[str, Any] = {}
|
17
|
+
|
18
|
+
# Keep track of clients for cleanup
|
19
|
+
_all_clients: weakref.WeakSet[Any] = weakref.WeakSet()
|
20
|
+
|
21
|
+
|
22
|
+
def _get_cache_key(client_type: str, **kwargs: Any) -> str:
|
23
|
+
"""
|
24
|
+
Generate a cache key from client type and configuration parameters.
|
25
|
+
Uses the same approach as OpenAIGPT._cache_lookup for consistency.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
client_type: Type of client (e.g., "openai", "groq", "cerebras")
|
29
|
+
**kwargs: Configuration parameters (api_key, base_url, timeout, etc.)
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
SHA256 hash of the configuration as a hex string
|
33
|
+
"""
|
34
|
+
# Convert kwargs to sorted string representation
|
35
|
+
sorted_kwargs_str = str(sorted(kwargs.items()))
|
36
|
+
|
37
|
+
# Create raw key combining client type and sorted kwargs
|
38
|
+
raw_key = f"{client_type}:{sorted_kwargs_str}"
|
39
|
+
|
40
|
+
# Hash the key for consistent length and to handle complex objects
|
41
|
+
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
|
42
|
+
|
43
|
+
return hashed_key
|
44
|
+
|
45
|
+
|
46
|
+
def get_openai_client(
|
47
|
+
api_key: str,
|
48
|
+
base_url: Optional[str] = None,
|
49
|
+
organization: Optional[str] = None,
|
50
|
+
timeout: Union[float, Timeout] = 120.0,
|
51
|
+
default_headers: Optional[Dict[str, str]] = None,
|
52
|
+
) -> OpenAI:
|
53
|
+
"""
|
54
|
+
Get or create a singleton OpenAI client with the given configuration.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
api_key: OpenAI API key
|
58
|
+
base_url: Optional base URL for API
|
59
|
+
organization: Optional organization ID
|
60
|
+
timeout: Request timeout
|
61
|
+
default_headers: Optional default headers
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
OpenAI client instance
|
65
|
+
"""
|
66
|
+
if isinstance(timeout, (int, float)):
|
67
|
+
timeout = Timeout(timeout)
|
68
|
+
|
69
|
+
cache_key = _get_cache_key(
|
70
|
+
"openai",
|
71
|
+
api_key=api_key,
|
72
|
+
base_url=base_url,
|
73
|
+
organization=organization,
|
74
|
+
timeout=timeout,
|
75
|
+
default_headers=default_headers,
|
76
|
+
)
|
77
|
+
|
78
|
+
if cache_key in _client_cache:
|
79
|
+
return cast(OpenAI, _client_cache[cache_key])
|
80
|
+
|
81
|
+
client = OpenAI(
|
82
|
+
api_key=api_key,
|
83
|
+
base_url=base_url,
|
84
|
+
organization=organization,
|
85
|
+
timeout=timeout,
|
86
|
+
default_headers=default_headers,
|
87
|
+
)
|
88
|
+
|
89
|
+
_client_cache[cache_key] = client
|
90
|
+
_all_clients.add(client)
|
91
|
+
return client
|
92
|
+
|
93
|
+
|
94
|
+
def get_async_openai_client(
|
95
|
+
api_key: str,
|
96
|
+
base_url: Optional[str] = None,
|
97
|
+
organization: Optional[str] = None,
|
98
|
+
timeout: Union[float, Timeout] = 120.0,
|
99
|
+
default_headers: Optional[Dict[str, str]] = None,
|
100
|
+
) -> AsyncOpenAI:
|
101
|
+
"""
|
102
|
+
Get or create a singleton AsyncOpenAI client with the given configuration.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
api_key: OpenAI API key
|
106
|
+
base_url: Optional base URL for API
|
107
|
+
organization: Optional organization ID
|
108
|
+
timeout: Request timeout
|
109
|
+
default_headers: Optional default headers
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
AsyncOpenAI client instance
|
113
|
+
"""
|
114
|
+
if isinstance(timeout, (int, float)):
|
115
|
+
timeout = Timeout(timeout)
|
116
|
+
|
117
|
+
cache_key = _get_cache_key(
|
118
|
+
"async_openai",
|
119
|
+
api_key=api_key,
|
120
|
+
base_url=base_url,
|
121
|
+
organization=organization,
|
122
|
+
timeout=timeout,
|
123
|
+
default_headers=default_headers,
|
124
|
+
)
|
125
|
+
|
126
|
+
if cache_key in _client_cache:
|
127
|
+
return cast(AsyncOpenAI, _client_cache[cache_key])
|
128
|
+
|
129
|
+
client = AsyncOpenAI(
|
130
|
+
api_key=api_key,
|
131
|
+
base_url=base_url,
|
132
|
+
organization=organization,
|
133
|
+
timeout=timeout,
|
134
|
+
default_headers=default_headers,
|
135
|
+
)
|
136
|
+
|
137
|
+
_client_cache[cache_key] = client
|
138
|
+
_all_clients.add(client)
|
139
|
+
return client
|
140
|
+
|
141
|
+
|
142
|
+
def get_groq_client(api_key: str) -> Groq:
|
143
|
+
"""
|
144
|
+
Get or create a singleton Groq client with the given configuration.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
api_key: Groq API key
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
Groq client instance
|
151
|
+
"""
|
152
|
+
cache_key = _get_cache_key("groq", api_key=api_key)
|
153
|
+
|
154
|
+
if cache_key in _client_cache:
|
155
|
+
return cast(Groq, _client_cache[cache_key])
|
156
|
+
|
157
|
+
client = Groq(api_key=api_key)
|
158
|
+
_client_cache[cache_key] = client
|
159
|
+
_all_clients.add(client)
|
160
|
+
return client
|
161
|
+
|
162
|
+
|
163
|
+
def get_async_groq_client(api_key: str) -> AsyncGroq:
|
164
|
+
"""
|
165
|
+
Get or create a singleton AsyncGroq client with the given configuration.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
api_key: Groq API key
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
AsyncGroq client instance
|
172
|
+
"""
|
173
|
+
cache_key = _get_cache_key("async_groq", api_key=api_key)
|
174
|
+
|
175
|
+
if cache_key in _client_cache:
|
176
|
+
return cast(AsyncGroq, _client_cache[cache_key])
|
177
|
+
|
178
|
+
client = AsyncGroq(api_key=api_key)
|
179
|
+
_client_cache[cache_key] = client
|
180
|
+
_all_clients.add(client)
|
181
|
+
return client
|
182
|
+
|
183
|
+
|
184
|
+
def get_cerebras_client(api_key: str) -> Cerebras:
|
185
|
+
"""
|
186
|
+
Get or create a singleton Cerebras client with the given configuration.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
api_key: Cerebras API key
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
Cerebras client instance
|
193
|
+
"""
|
194
|
+
cache_key = _get_cache_key("cerebras", api_key=api_key)
|
195
|
+
|
196
|
+
if cache_key in _client_cache:
|
197
|
+
return cast(Cerebras, _client_cache[cache_key])
|
198
|
+
|
199
|
+
client = Cerebras(api_key=api_key)
|
200
|
+
_client_cache[cache_key] = client
|
201
|
+
_all_clients.add(client)
|
202
|
+
return client
|
203
|
+
|
204
|
+
|
205
|
+
def get_async_cerebras_client(api_key: str) -> AsyncCerebras:
|
206
|
+
"""
|
207
|
+
Get or create a singleton AsyncCerebras client with the given configuration.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
api_key: Cerebras API key
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
AsyncCerebras client instance
|
214
|
+
"""
|
215
|
+
cache_key = _get_cache_key("async_cerebras", api_key=api_key)
|
216
|
+
|
217
|
+
if cache_key in _client_cache:
|
218
|
+
return cast(AsyncCerebras, _client_cache[cache_key])
|
219
|
+
|
220
|
+
client = AsyncCerebras(api_key=api_key)
|
221
|
+
_client_cache[cache_key] = client
|
222
|
+
_all_clients.add(client)
|
223
|
+
return client
|
224
|
+
|
225
|
+
|
226
|
+
def _cleanup_clients() -> None:
|
227
|
+
"""
|
228
|
+
Cleanup function to close all cached clients on exit.
|
229
|
+
Called automatically via atexit.
|
230
|
+
"""
|
231
|
+
import inspect
|
232
|
+
|
233
|
+
for client in list(_all_clients):
|
234
|
+
if hasattr(client, "close") and callable(client.close):
|
235
|
+
try:
|
236
|
+
# Check if close is a coroutine function (async)
|
237
|
+
if inspect.iscoroutinefunction(client.close):
|
238
|
+
# For async clients, we can't await in atexit
|
239
|
+
# They will be cleaned up by the OS
|
240
|
+
pass
|
241
|
+
else:
|
242
|
+
# Sync clients can be closed directly
|
243
|
+
client.close()
|
244
|
+
except Exception:
|
245
|
+
pass # Ignore errors during cleanup
|
246
|
+
|
247
|
+
|
248
|
+
# Register cleanup function to run on exit
|
249
|
+
atexit.register(_cleanup_clients)
|
250
|
+
|
251
|
+
|
252
|
+
# For testing purposes
|
253
|
+
def _clear_cache() -> None:
|
254
|
+
"""Clear the client cache. Only for testing."""
|
255
|
+
_client_cache.clear()
|
@@ -69,7 +69,9 @@ class GeminiModel(ModelName):
|
|
69
69
|
GEMINI_1_5_FLASH = "gemini-1.5-flash"
|
70
70
|
GEMINI_1_5_FLASH_8B = "gemini-1.5-flash-8b"
|
71
71
|
GEMINI_1_5_PRO = "gemini-1.5-pro"
|
72
|
-
GEMINI_2_5_PRO = "gemini-2.5-pro
|
72
|
+
GEMINI_2_5_PRO = "gemini-2.5-pro"
|
73
|
+
GEMINI_2_5_FLASH = "gemini-2.5-flash"
|
74
|
+
GEMINI_2_5_FLASH_LITE_PREVIEW = "gemini-2.5-flash-lite-preview-06-17"
|
73
75
|
GEMINI_2_PRO = "gemini-2.0-pro-exp-02-05"
|
74
76
|
GEMINI_2_FLASH = "gemini-2.0-flash"
|
75
77
|
GEMINI_2_FLASH_LITE = "gemini-2.0-flash-lite-preview"
|
@@ -108,6 +110,7 @@ class ModelInfo(BaseModel):
|
|
108
110
|
max_cot_tokens: int = 0 # max chain of thought (thinking) tokens where applicable
|
109
111
|
max_output_tokens: int = 8192 # Maximum number of output tokens - model dependent
|
110
112
|
input_cost_per_million: float = 0.0 # Cost in USD per million input tokens
|
113
|
+
cached_cost_per_million: float = 0.0 # Cost in USD per million cached tokens
|
111
114
|
output_cost_per_million: float = 0.0 # Cost in USD per million output tokens
|
112
115
|
allows_streaming: bool = True # Whether model supports streaming output
|
113
116
|
allows_system_message: bool = True # Whether model supports system messages
|
@@ -173,6 +176,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
173
176
|
context_length=1_047_576,
|
174
177
|
max_output_tokens=32_768,
|
175
178
|
input_cost_per_million=0.10,
|
179
|
+
cached_cost_per_million=0.025,
|
176
180
|
output_cost_per_million=0.40,
|
177
181
|
description="GPT-4.1",
|
178
182
|
),
|
@@ -182,6 +186,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
182
186
|
context_length=1_047_576,
|
183
187
|
max_output_tokens=32_768,
|
184
188
|
input_cost_per_million=0.40,
|
189
|
+
cached_cost_per_million=0.10,
|
185
190
|
output_cost_per_million=1.60,
|
186
191
|
description="GPT-4.1 Mini",
|
187
192
|
),
|
@@ -191,6 +196,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
191
196
|
context_length=1_047_576,
|
192
197
|
max_output_tokens=32_768,
|
193
198
|
input_cost_per_million=2.00,
|
199
|
+
cached_cost_per_million=0.50,
|
194
200
|
output_cost_per_million=8.00,
|
195
201
|
description="GPT-4.1",
|
196
202
|
),
|
@@ -200,6 +206,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
200
206
|
context_length=128_000,
|
201
207
|
max_output_tokens=16_384,
|
202
208
|
input_cost_per_million=2.5,
|
209
|
+
cached_cost_per_million=1.25,
|
203
210
|
output_cost_per_million=10.0,
|
204
211
|
has_structured_output=True,
|
205
212
|
description="GPT-4o (128K context)",
|
@@ -210,6 +217,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
210
217
|
context_length=128_000,
|
211
218
|
max_output_tokens=16_384,
|
212
219
|
input_cost_per_million=0.15,
|
220
|
+
cached_cost_per_million=0.075,
|
213
221
|
output_cost_per_million=0.60,
|
214
222
|
has_structured_output=True,
|
215
223
|
description="GPT-4o Mini",
|
@@ -220,6 +228,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
220
228
|
context_length=200_000,
|
221
229
|
max_output_tokens=100_000,
|
222
230
|
input_cost_per_million=15.0,
|
231
|
+
cached_cost_per_million=7.50,
|
223
232
|
output_cost_per_million=60.0,
|
224
233
|
allows_streaming=True,
|
225
234
|
allows_system_message=False,
|
@@ -233,8 +242,9 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
233
242
|
provider=ModelProvider.OPENAI,
|
234
243
|
context_length=200_000,
|
235
244
|
max_output_tokens=100_000,
|
236
|
-
input_cost_per_million=
|
237
|
-
|
245
|
+
input_cost_per_million=2.0,
|
246
|
+
cached_cost_per_million=0.50,
|
247
|
+
output_cost_per_million=8.0,
|
238
248
|
allows_streaming=True,
|
239
249
|
allows_system_message=False,
|
240
250
|
unsupported_params=["temperature"],
|
@@ -248,6 +258,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
248
258
|
context_length=128_000,
|
249
259
|
max_output_tokens=65_536,
|
250
260
|
input_cost_per_million=1.1,
|
261
|
+
cached_cost_per_million=0.55,
|
251
262
|
output_cost_per_million=4.4,
|
252
263
|
allows_streaming=False,
|
253
264
|
allows_system_message=False,
|
@@ -262,6 +273,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
262
273
|
context_length=200_000,
|
263
274
|
max_output_tokens=100_000,
|
264
275
|
input_cost_per_million=1.1,
|
276
|
+
cached_cost_per_million=0.55,
|
265
277
|
output_cost_per_million=4.4,
|
266
278
|
allows_streaming=False,
|
267
279
|
allows_system_message=False,
|
@@ -276,6 +288,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
276
288
|
context_length=200_000,
|
277
289
|
max_output_tokens=100_000,
|
278
290
|
input_cost_per_million=1.10,
|
291
|
+
cached_cost_per_million=0.275,
|
279
292
|
output_cost_per_million=4.40,
|
280
293
|
allows_streaming=False,
|
281
294
|
allows_system_message=False,
|
@@ -291,6 +304,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
291
304
|
context_length=200_000,
|
292
305
|
max_output_tokens=8192,
|
293
306
|
input_cost_per_million=3.0,
|
307
|
+
cached_cost_per_million=0.30,
|
294
308
|
output_cost_per_million=15.0,
|
295
309
|
description="Claude 3.5 Sonnet",
|
296
310
|
),
|
@@ -300,6 +314,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
300
314
|
context_length=200_000,
|
301
315
|
max_output_tokens=4096,
|
302
316
|
input_cost_per_million=15.0,
|
317
|
+
cached_cost_per_million=1.50,
|
303
318
|
output_cost_per_million=75.0,
|
304
319
|
description="Claude 3 Opus",
|
305
320
|
),
|
@@ -309,6 +324,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
309
324
|
context_length=200_000,
|
310
325
|
max_output_tokens=4096,
|
311
326
|
input_cost_per_million=3.0,
|
327
|
+
cached_cost_per_million=0.30,
|
312
328
|
output_cost_per_million=15.0,
|
313
329
|
description="Claude 3 Sonnet",
|
314
330
|
),
|
@@ -318,6 +334,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
318
334
|
context_length=200_000,
|
319
335
|
max_output_tokens=4096,
|
320
336
|
input_cost_per_million=0.25,
|
337
|
+
cached_cost_per_million=0.03,
|
321
338
|
output_cost_per_million=1.25,
|
322
339
|
description="Claude 3 Haiku",
|
323
340
|
),
|
@@ -328,6 +345,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
328
345
|
context_length=64_000,
|
329
346
|
max_output_tokens=8_000,
|
330
347
|
input_cost_per_million=0.27,
|
348
|
+
cached_cost_per_million=0.07,
|
331
349
|
output_cost_per_million=1.10,
|
332
350
|
description="DeepSeek Chat",
|
333
351
|
),
|
@@ -337,6 +355,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
337
355
|
context_length=64_000,
|
338
356
|
max_output_tokens=8_000,
|
339
357
|
input_cost_per_million=0.55,
|
358
|
+
cached_cost_per_million=0.14,
|
340
359
|
output_cost_per_million=2.19,
|
341
360
|
description="DeepSeek-R1 Reasoning LM",
|
342
361
|
),
|
@@ -347,6 +366,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
347
366
|
context_length=1_056_768,
|
348
367
|
max_output_tokens=8192,
|
349
368
|
input_cost_per_million=0.10,
|
369
|
+
cached_cost_per_million=0.025,
|
350
370
|
output_cost_per_million=0.40,
|
351
371
|
rename_params={"max_tokens": "max_completion_tokens"},
|
352
372
|
description="Gemini 2.0 Flash",
|
@@ -401,6 +421,40 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
401
421
|
rename_params={"max_tokens": "max_completion_tokens"},
|
402
422
|
description="Gemini 2.0 Flash Thinking",
|
403
423
|
),
|
424
|
+
# Gemini 2.5 Models
|
425
|
+
GeminiModel.GEMINI_2_5_PRO.value: ModelInfo(
|
426
|
+
name=GeminiModel.GEMINI_2_5_PRO.value,
|
427
|
+
provider=ModelProvider.GOOGLE,
|
428
|
+
context_length=1_048_576,
|
429
|
+
max_output_tokens=65_536,
|
430
|
+
input_cost_per_million=1.25,
|
431
|
+
cached_cost_per_million=0.31,
|
432
|
+
output_cost_per_million=10.0,
|
433
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
434
|
+
description="Gemini 2.5 Pro",
|
435
|
+
),
|
436
|
+
GeminiModel.GEMINI_2_5_FLASH.value: ModelInfo(
|
437
|
+
name=GeminiModel.GEMINI_2_5_FLASH.value,
|
438
|
+
provider=ModelProvider.GOOGLE,
|
439
|
+
context_length=1_048_576,
|
440
|
+
max_output_tokens=65_536,
|
441
|
+
input_cost_per_million=0.30,
|
442
|
+
cached_cost_per_million=0.075,
|
443
|
+
output_cost_per_million=2.50,
|
444
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
445
|
+
description="Gemini 2.5 Flash",
|
446
|
+
),
|
447
|
+
GeminiModel.GEMINI_2_5_FLASH_LITE_PREVIEW.value: ModelInfo(
|
448
|
+
name=GeminiModel.GEMINI_2_5_FLASH_LITE_PREVIEW.value,
|
449
|
+
provider=ModelProvider.GOOGLE,
|
450
|
+
context_length=65_536,
|
451
|
+
max_output_tokens=65_536,
|
452
|
+
input_cost_per_million=0.10,
|
453
|
+
cached_cost_per_million=0.025,
|
454
|
+
output_cost_per_million=0.40,
|
455
|
+
rename_params={"max_tokens": "max_completion_tokens"},
|
456
|
+
description="Gemini 2.5 Flash Lite Preview",
|
457
|
+
),
|
404
458
|
}
|
405
459
|
|
406
460
|
|