langroid 0.53.16__py3-none-any.whl → 0.54.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/tools/mcp/fastmcp_client.py +16 -2
- langroid/embedding_models/models.py +1 -1
- langroid/language_models/openai_gpt.py +27 -20
- langroid/language_models/provider_params.py +171 -0
- {langroid-0.53.16.dist-info → langroid-0.54.0.dist-info}/METADATA +5 -1
- {langroid-0.53.16.dist-info → langroid-0.54.0.dist-info}/RECORD +8 -8
- langroid/language_models/mcp_client_lm.py +0 -128
- {langroid-0.53.16.dist-info → langroid-0.54.0.dist-info}/WHEEL +0 -0
- {langroid-0.53.16.dist-info → langroid-0.54.0.dist-info}/licenses/LICENSE +0 -0
@@ -277,8 +277,22 @@ class FastMCPClient:
|
|
277
277
|
result: CallToolResult,
|
278
278
|
) -> List[str] | str | None:
|
279
279
|
if result.isError:
|
280
|
-
|
281
|
-
|
280
|
+
# Log more detailed error information
|
281
|
+
error_content = None
|
282
|
+
if result.content and len(result.content) > 0:
|
283
|
+
try:
|
284
|
+
error_content = [
|
285
|
+
item.text if hasattr(item, "text") else str(item)
|
286
|
+
for item in result.content
|
287
|
+
]
|
288
|
+
except Exception as e:
|
289
|
+
error_content = [f"Could not extract error content: {str(e)}"]
|
290
|
+
|
291
|
+
self.logger.error(
|
292
|
+
f"Error calling MCP tool {tool_name}. Details: {error_content}"
|
293
|
+
)
|
294
|
+
return f"ERROR: Tool call failed - {error_content}"
|
295
|
+
|
282
296
|
has_nontext_results = any(
|
283
297
|
not isinstance(item, TextContent) for item in result.content
|
284
298
|
)
|
@@ -10,7 +10,7 @@ from openai import AzureOpenAI, OpenAI
|
|
10
10
|
|
11
11
|
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
12
12
|
from langroid.exceptions import LangroidImportError
|
13
|
-
from langroid.language_models.
|
13
|
+
from langroid.language_models.provider_params import LangDBParams
|
14
14
|
from langroid.mytypes import Embeddings
|
15
15
|
from langroid.parsing.utils import batched
|
16
16
|
|
@@ -60,6 +60,11 @@ from langroid.language_models.prompt_formatter.hf_formatter import (
|
|
60
60
|
HFFormatter,
|
61
61
|
find_hf_formatter,
|
62
62
|
)
|
63
|
+
from langroid.language_models.provider_params import (
|
64
|
+
DUMMY_API_KEY,
|
65
|
+
LangDBParams,
|
66
|
+
PortkeyParams,
|
67
|
+
)
|
63
68
|
from langroid.language_models.utils import (
|
64
69
|
async_retry_with_exponential_backoff,
|
65
70
|
retry_with_exponential_backoff,
|
@@ -81,9 +86,7 @@ DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
|
|
81
86
|
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
82
87
|
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
83
88
|
GLHF_BASE_URL = "https://glhf.chat/api/openai/v1"
|
84
|
-
LANGDB_BASE_URL = "https://api.us-east-1.langdb.ai"
|
85
89
|
OLLAMA_API_KEY = "ollama"
|
86
|
-
DUMMY_API_KEY = "xxx"
|
87
90
|
|
88
91
|
VLLM_API_KEY = os.environ.get("VLLM_API_KEY", DUMMY_API_KEY)
|
89
92
|
LLAMACPP_API_KEY = os.environ.get("LLAMA_API_KEY", DUMMY_API_KEY)
|
@@ -183,24 +186,6 @@ def noop() -> None:
|
|
183
186
|
return None
|
184
187
|
|
185
188
|
|
186
|
-
class LangDBParams(BaseSettings):
|
187
|
-
"""
|
188
|
-
Parameters specific to LangDB integration.
|
189
|
-
"""
|
190
|
-
|
191
|
-
api_key: str = DUMMY_API_KEY
|
192
|
-
project_id: str = ""
|
193
|
-
label: Optional[str] = None
|
194
|
-
run_id: Optional[str] = None
|
195
|
-
thread_id: Optional[str] = None
|
196
|
-
base_url: str = LANGDB_BASE_URL
|
197
|
-
|
198
|
-
class Config:
|
199
|
-
# allow setting of fields via env vars,
|
200
|
-
# e.g. LANGDB_PROJECT_ID=1234
|
201
|
-
env_prefix = "LANGDB_"
|
202
|
-
|
203
|
-
|
204
189
|
class OpenAICallParams(BaseModel):
|
205
190
|
"""
|
206
191
|
Various params that can be sent to an OpenAI API chat-completion call.
|
@@ -289,6 +274,7 @@ class OpenAIGPTConfig(LLMConfig):
|
|
289
274
|
formatter: str | None = None
|
290
275
|
hf_formatter: HFFormatter | None = None
|
291
276
|
langdb_params: LangDBParams = LangDBParams()
|
277
|
+
portkey_params: PortkeyParams = PortkeyParams()
|
292
278
|
headers: Dict[str, str] = {}
|
293
279
|
|
294
280
|
def __init__(self, **kwargs) -> None: # type: ignore
|
@@ -535,6 +521,7 @@ class OpenAIGPT(LanguageModel):
|
|
535
521
|
self.is_glhf = self.config.chat_model.startswith("glhf/")
|
536
522
|
self.is_openrouter = self.config.chat_model.startswith("openrouter/")
|
537
523
|
self.is_langdb = self.config.chat_model.startswith("langdb/")
|
524
|
+
self.is_portkey = self.config.chat_model.startswith("portkey/")
|
538
525
|
self.is_litellm_proxy = self.config.chat_model.startswith("litellm-proxy/")
|
539
526
|
|
540
527
|
if self.is_groq:
|
@@ -610,6 +597,26 @@ class OpenAIGPT(LanguageModel):
|
|
610
597
|
self.config.headers["x-run-id"] = params.run_id
|
611
598
|
if params.thread_id:
|
612
599
|
self.config.headers["x-thread-id"] = params.thread_id
|
600
|
+
elif self.is_portkey:
|
601
|
+
# Parse the model string and extract provider/model
|
602
|
+
provider, model = self.config.portkey_params.parse_model_string(
|
603
|
+
self.config.chat_model
|
604
|
+
)
|
605
|
+
self.config.chat_model = model
|
606
|
+
if provider:
|
607
|
+
self.config.portkey_params.provider = provider
|
608
|
+
|
609
|
+
# Set Portkey base URL
|
610
|
+
self.api_base = self.config.portkey_params.base_url + "/v1"
|
611
|
+
|
612
|
+
# Set API key - use provider's API key from env if available
|
613
|
+
if self.api_key == OPENAI_API_KEY:
|
614
|
+
self.api_key = self.config.portkey_params.get_provider_api_key(
|
615
|
+
self.config.portkey_params.provider, DUMMY_API_KEY
|
616
|
+
)
|
617
|
+
|
618
|
+
# Add Portkey-specific headers
|
619
|
+
self.config.headers.update(self.config.portkey_params.get_headers())
|
613
620
|
|
614
621
|
self.client = OpenAI(
|
615
622
|
api_key=self.api_key,
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""
|
2
|
+
Provider-specific parameter configurations for various LLM providers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any, Dict, Optional
|
6
|
+
|
7
|
+
from langroid.pydantic_v1 import BaseSettings
|
8
|
+
|
9
|
+
# Constants
|
10
|
+
LANGDB_BASE_URL = "https://api.us-east-1.langdb.ai"
|
11
|
+
PORTKEY_BASE_URL = "https://api.portkey.ai"
|
12
|
+
DUMMY_API_KEY = "xxx"
|
13
|
+
|
14
|
+
|
15
|
+
class LangDBParams(BaseSettings):
|
16
|
+
"""
|
17
|
+
Parameters specific to LangDB integration.
|
18
|
+
"""
|
19
|
+
|
20
|
+
api_key: str = DUMMY_API_KEY
|
21
|
+
project_id: str = ""
|
22
|
+
label: Optional[str] = None
|
23
|
+
run_id: Optional[str] = None
|
24
|
+
thread_id: Optional[str] = None
|
25
|
+
base_url: str = LANGDB_BASE_URL
|
26
|
+
|
27
|
+
class Config:
|
28
|
+
# allow setting of fields via env vars,
|
29
|
+
# e.g. LANGDB_PROJECT_ID=1234
|
30
|
+
env_prefix = "LANGDB_"
|
31
|
+
|
32
|
+
|
33
|
+
class PortkeyParams(BaseSettings):
|
34
|
+
"""
|
35
|
+
Parameters specific to Portkey integration.
|
36
|
+
|
37
|
+
Portkey is an AI gateway that provides a unified API for multiple LLM providers,
|
38
|
+
with features like automatic retries, fallbacks, load balancing, and observability.
|
39
|
+
|
40
|
+
Example usage:
|
41
|
+
# Use Portkey with Anthropic
|
42
|
+
config = OpenAIGPTConfig(
|
43
|
+
chat_model="portkey/anthropic/claude-3-sonnet-20240229",
|
44
|
+
portkey_params=PortkeyParams(
|
45
|
+
api_key="your-portkey-api-key",
|
46
|
+
provider="anthropic"
|
47
|
+
)
|
48
|
+
)
|
49
|
+
"""
|
50
|
+
|
51
|
+
api_key: str = DUMMY_API_KEY # Portkey API key
|
52
|
+
provider: str = "" # Required: e.g., "openai", "anthropic", "cohere", etc.
|
53
|
+
virtual_key: Optional[str] = None # Optional: virtual key for the provider
|
54
|
+
trace_id: Optional[str] = None # Optional: trace ID for request tracking
|
55
|
+
metadata: Optional[Dict[str, Any]] = None # Optional: metadata for logging
|
56
|
+
retry: Optional[Dict[str, Any]] = None # Optional: retry configuration
|
57
|
+
cache: Optional[Dict[str, Any]] = None # Optional: cache configuration
|
58
|
+
cache_force_refresh: Optional[bool] = None # Optional: force cache refresh
|
59
|
+
user: Optional[str] = None # Optional: user identifier
|
60
|
+
organization: Optional[str] = None # Optional: organization identifier
|
61
|
+
custom_headers: Optional[Dict[str, str]] = None # Optional: additional headers
|
62
|
+
base_url: str = PORTKEY_BASE_URL
|
63
|
+
|
64
|
+
class Config:
|
65
|
+
# allow setting of fields via env vars,
|
66
|
+
# e.g. PORTKEY_API_KEY=xxx, PORTKEY_PROVIDER=anthropic
|
67
|
+
env_prefix = "PORTKEY_"
|
68
|
+
|
69
|
+
def get_headers(self) -> Dict[str, str]:
|
70
|
+
"""Generate Portkey-specific headers from parameters."""
|
71
|
+
import json
|
72
|
+
import os
|
73
|
+
|
74
|
+
headers = {}
|
75
|
+
|
76
|
+
# API key - from params or environment
|
77
|
+
if self.api_key and self.api_key != DUMMY_API_KEY:
|
78
|
+
headers["x-portkey-api-key"] = self.api_key
|
79
|
+
else:
|
80
|
+
portkey_key = os.getenv("PORTKEY_API_KEY", "")
|
81
|
+
if portkey_key:
|
82
|
+
headers["x-portkey-api-key"] = portkey_key
|
83
|
+
|
84
|
+
# Provider
|
85
|
+
if self.provider:
|
86
|
+
headers["x-portkey-provider"] = self.provider
|
87
|
+
|
88
|
+
# Virtual key
|
89
|
+
if self.virtual_key:
|
90
|
+
headers["x-portkey-virtual-key"] = self.virtual_key
|
91
|
+
|
92
|
+
# Trace ID
|
93
|
+
if self.trace_id:
|
94
|
+
headers["x-portkey-trace-id"] = self.trace_id
|
95
|
+
|
96
|
+
# Metadata
|
97
|
+
if self.metadata:
|
98
|
+
headers["x-portkey-metadata"] = json.dumps(self.metadata)
|
99
|
+
|
100
|
+
# Retry configuration
|
101
|
+
if self.retry:
|
102
|
+
headers["x-portkey-retry"] = json.dumps(self.retry)
|
103
|
+
|
104
|
+
# Cache configuration
|
105
|
+
if self.cache:
|
106
|
+
headers["x-portkey-cache"] = json.dumps(self.cache)
|
107
|
+
|
108
|
+
# Cache force refresh
|
109
|
+
if self.cache_force_refresh is not None:
|
110
|
+
headers["x-portkey-cache-force-refresh"] = str(
|
111
|
+
self.cache_force_refresh
|
112
|
+
).lower()
|
113
|
+
|
114
|
+
# User identifier
|
115
|
+
if self.user:
|
116
|
+
headers["x-portkey-user"] = self.user
|
117
|
+
|
118
|
+
# Organization identifier
|
119
|
+
if self.organization:
|
120
|
+
headers["x-portkey-organization"] = self.organization
|
121
|
+
|
122
|
+
# Add any custom headers
|
123
|
+
if self.custom_headers:
|
124
|
+
headers.update(self.custom_headers)
|
125
|
+
|
126
|
+
return headers
|
127
|
+
|
128
|
+
def parse_model_string(self, model_string: str) -> tuple[str, str]:
|
129
|
+
"""
|
130
|
+
Parse a model string like "portkey/anthropic/claude-3-sonnet"
|
131
|
+
and extract provider and model name.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
tuple: (provider, model_name)
|
135
|
+
"""
|
136
|
+
parts = model_string.split("/", 2)
|
137
|
+
if len(parts) >= 3 and parts[0] == "portkey":
|
138
|
+
_, provider, model = parts
|
139
|
+
return provider, model
|
140
|
+
else:
|
141
|
+
# Fallback: just remove "portkey/" prefix and return empty provider
|
142
|
+
model = model_string.replace("portkey/", "")
|
143
|
+
return "", model
|
144
|
+
|
145
|
+
def get_provider_api_key(
|
146
|
+
self, provider: str, default_key: str = DUMMY_API_KEY
|
147
|
+
) -> str:
|
148
|
+
"""
|
149
|
+
Get the API key for the provider from environment variables.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
provider: The provider name (e.g., "anthropic", "openai")
|
153
|
+
default_key: Default key to return if not found
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
The API key for the provider
|
157
|
+
"""
|
158
|
+
import os
|
159
|
+
|
160
|
+
# Common environment variable patterns for different providers
|
161
|
+
env_patterns = [
|
162
|
+
f"{provider.upper()}_API_KEY",
|
163
|
+
f"{provider.upper()}_KEY",
|
164
|
+
]
|
165
|
+
|
166
|
+
for pattern in env_patterns:
|
167
|
+
key = os.getenv(pattern, "")
|
168
|
+
if key:
|
169
|
+
return key
|
170
|
+
|
171
|
+
return default_key
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: langroid
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.54.0
|
4
4
|
Summary: Harness LLMs with Multi-Agent Programming
|
5
5
|
Author-email: Prasad Chalasani <pchalasani@gmail.com>
|
6
6
|
License: MIT
|
@@ -254,6 +254,10 @@ and works with [practically any LLM](https://langroid.github.io/langroid/tutoria
|
|
254
254
|
:fire: Read the (WIP) [overview of the langroid architecture](https://langroid.github.io/langroid/blog/2024/08/15/overview-of-langroids-multi-agent-architecture-prelim/),
|
255
255
|
and a [quick tour of Langroid](https://langroid.github.io/langroid/tutorials/langroid-tour/).
|
256
256
|
|
257
|
+
:fire: MCP Support: Allow any LLM-Agent to leverage MCP Servers via Langroid's simple
|
258
|
+
[MCP tool adapter](https://langroid.github.io/langroid/notes/mcp-tools/) that converts
|
259
|
+
the server's tools into Langroid's `ToolMessage` instances.
|
260
|
+
|
257
261
|
📢 Companies are using/adapting Langroid in **production**. Here is a quote:
|
258
262
|
|
259
263
|
>[Nullify](https://www.nullify.ai) uses AI Agents for secure software development.
|
@@ -56,13 +56,13 @@ langroid/agent/tools/segment_extract_tool.py,sha256=__srZ_VGYLVOdPrITUM8S0HpmX4q
|
|
56
56
|
langroid/agent/tools/tavily_search_tool.py,sha256=soI-j0HdgVQLf09wRQScaEK4b5RpAX9C4cwOivRFWWI,1903
|
57
57
|
langroid/agent/tools/mcp/__init__.py,sha256=DJNM0VeFnFS3pJKCyFGggT8JVjVu0rBzrGzasT1HaSM,387
|
58
58
|
langroid/agent/tools/mcp/decorators.py,sha256=h7dterhsmvWJ8q4mp_OopmuG2DF71ty8cZwOyzdDZuk,1127
|
59
|
-
langroid/agent/tools/mcp/fastmcp_client.py,sha256=
|
59
|
+
langroid/agent/tools/mcp/fastmcp_client.py,sha256=WF3MhksDH2MzwXZF8cilMhux0hUmj6Z0dDdBYQMZwRs,18008
|
60
60
|
langroid/cachedb/__init__.py,sha256=G2KyNnk3Qkhv7OKyxTOnpsxfDycx3NY0O_wXkJlalNY,96
|
61
61
|
langroid/cachedb/base.py,sha256=ztVjB1DtN6pLCujCWnR6xruHxwVj3XkYniRTYAKKqk0,1354
|
62
62
|
langroid/cachedb/redis_cachedb.py,sha256=7kgnbf4b5CKsCrlL97mHWKvdvlLt8zgn7lc528jEpiE,5141
|
63
63
|
langroid/embedding_models/__init__.py,sha256=KyYxR3jDFUCfYjSuCL86qjAmrq6mXXjOT4lFNOKVj6Y,955
|
64
64
|
langroid/embedding_models/base.py,sha256=Ml7oA6PzQm0wZmIYn3fhF7dvZCi-amviWUwOeBegH3A,2562
|
65
|
-
langroid/embedding_models/models.py,sha256=
|
65
|
+
langroid/embedding_models/models.py,sha256=52S7pZOXtocnAz50vL6MKBFxhJZZurdxSEWoqs7Qfow,20792
|
66
66
|
langroid/embedding_models/remote_embeds.py,sha256=6_kjXByVbqhY9cGwl9R83ZcYC2km-nGieNNAo1McHaY,5151
|
67
67
|
langroid/embedding_models/protoc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
68
68
|
langroid/embedding_models/protoc/embeddings.proto,sha256=_O-SgFpTaylQeOTgSpxhEJ7CUw7PeCQQJLaPqpPYKJg,321
|
@@ -73,10 +73,10 @@ langroid/language_models/__init__.py,sha256=3aD2qC1lz8v12HX4B-dilv27gNxYdGdeu1Qv
|
|
73
73
|
langroid/language_models/azure_openai.py,sha256=SW0Fp_y6HpERr9l6TtF6CYsKgKwjUf_hSL_2mhTV4wI,5034
|
74
74
|
langroid/language_models/base.py,sha256=253xcwXZ0yxSQ1W4SR50tAPZKCDc35yyU1o35EqB9b8,28484
|
75
75
|
langroid/language_models/config.py,sha256=9Q8wk5a7RQr8LGMT_0WkpjY8S4ywK06SalVRjXlfCiI,378
|
76
|
-
langroid/language_models/mcp_client_lm.py,sha256=wyDvlc26E_En5u_ZNZxajCHm8KBNi4jzG-dL76QCdt4,4098
|
77
76
|
langroid/language_models/mock_lm.py,sha256=tA9JpURznsMZ59iRhFYMmaYQzAc0D0BT-PiJIV58sAk,4079
|
78
77
|
langroid/language_models/model_info.py,sha256=0e011vJZMi7XU9OkKT6doxlybrNJfMlP54klLDDNgFg,14939
|
79
|
-
langroid/language_models/openai_gpt.py,sha256=
|
78
|
+
langroid/language_models/openai_gpt.py,sha256=Xyg2VHGmA3VgPIS5ppLZeeU2Aai0qMKF9ia-oIjqRNM,86616
|
79
|
+
langroid/language_models/provider_params.py,sha256=fX25NAmYUIc1-nliMKpmTGZO6D6RpyTXtSDdZCZdb5w,5464
|
80
80
|
langroid/language_models/utils.py,sha256=n55Oe2_V_4VNGhytvPWLYC-0tFS07RTjN83KWl-p_MI,6032
|
81
81
|
langroid/language_models/prompt_formatter/__init__.py,sha256=2-5cdE24XoFDhifOLl8yiscohil1ogbP1ECkYdBlBsk,372
|
82
82
|
langroid/language_models/prompt_formatter/base.py,sha256=eDS1sgRNZVnoajwV_ZIha6cba5Dt8xjgzdRbPITwx3Q,1221
|
@@ -133,7 +133,7 @@ langroid/vector_store/pineconedb.py,sha256=otxXZNaBKb9f_H75HTaU3lMHiaR2NUp5MqwLZ
|
|
133
133
|
langroid/vector_store/postgres.py,sha256=wHPtIi2qM4fhO4pMQr95pz1ZCe7dTb2hxl4VYspGZoA,16104
|
134
134
|
langroid/vector_store/qdrantdb.py,sha256=O6dSBoDZ0jzfeVBd7LLvsXu083xs2fxXtPa9gGX3JX4,18443
|
135
135
|
langroid/vector_store/weaviatedb.py,sha256=Yn8pg139gOy3zkaPfoTbMXEEBCiLiYa1MU5d_3UA1K4,11847
|
136
|
-
langroid-0.
|
137
|
-
langroid-0.
|
138
|
-
langroid-0.
|
139
|
-
langroid-0.
|
136
|
+
langroid-0.54.0.dist-info/METADATA,sha256=QlVhCDwV-tQE6I1leUqSUmXeMjT0vozzoYn4l1GcVWs,65180
|
137
|
+
langroid-0.54.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
138
|
+
langroid-0.54.0.dist-info/licenses/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
|
139
|
+
langroid-0.54.0.dist-info/RECORD,,
|
@@ -1,128 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
An API for an Agent in an MCP Server to use for chat-completions
|
3
|
-
"""
|
4
|
-
|
5
|
-
from typing import Awaitable, Callable, Dict, List, Optional, Union
|
6
|
-
|
7
|
-
from fastmcp.server import Context
|
8
|
-
|
9
|
-
import langroid.language_models as lm
|
10
|
-
from langroid.language_models import LLMResponse
|
11
|
-
from langroid.language_models.base import (
|
12
|
-
LanguageModel,
|
13
|
-
LLMConfig,
|
14
|
-
OpenAIJsonSchemaSpec,
|
15
|
-
OpenAIToolSpec,
|
16
|
-
ToolChoiceTypes,
|
17
|
-
)
|
18
|
-
from langroid.utils.types import to_string
|
19
|
-
|
20
|
-
|
21
|
-
def none_fn(x: str) -> None | str:
|
22
|
-
return None
|
23
|
-
|
24
|
-
|
25
|
-
class MCPClientLMConfig(LLMConfig):
|
26
|
-
"""
|
27
|
-
Mock Language Model Configuration.
|
28
|
-
|
29
|
-
Attributes:
|
30
|
-
response_dict (Dict[str, str]): A "response rule-book", in the form of a
|
31
|
-
dictionary; if last msg in dialog is x,then respond with response_dict[x]
|
32
|
-
"""
|
33
|
-
|
34
|
-
response_dict: Dict[str, str] = {}
|
35
|
-
response_fn: Callable[[str], None | str] = none_fn
|
36
|
-
response_fn_async: Optional[Callable[[str], Awaitable[Optional[str]]]] = None
|
37
|
-
default_response: str = "Mock response"
|
38
|
-
|
39
|
-
type: str = "mock"
|
40
|
-
|
41
|
-
|
42
|
-
class MockLM(LanguageModel):
|
43
|
-
|
44
|
-
def __init__(self, config: MockLMConfig = MockLMConfig()):
|
45
|
-
super().__init__(config)
|
46
|
-
self.config: MockLMConfig = config
|
47
|
-
|
48
|
-
def _response(self, msg: str) -> LLMResponse:
|
49
|
-
# response is based on this fallback order:
|
50
|
-
# - response_dict
|
51
|
-
# - response_fn
|
52
|
-
# - default_response
|
53
|
-
mapped_response = self.config.response_dict.get(
|
54
|
-
msg, self.config.response_fn(msg) or self.config.default_response
|
55
|
-
)
|
56
|
-
return lm.LLMResponse(
|
57
|
-
message=to_string(mapped_response),
|
58
|
-
cached=False,
|
59
|
-
)
|
60
|
-
|
61
|
-
async def _response_async(self, msg: str) -> LLMResponse:
|
62
|
-
# response is based on this fallback order:
|
63
|
-
# - response_dict
|
64
|
-
# - response_fn_async
|
65
|
-
# - response_fn
|
66
|
-
# - default_response
|
67
|
-
if self.config.response_fn_async is not None:
|
68
|
-
response = await self.config.response_fn_async(msg)
|
69
|
-
else:
|
70
|
-
response = self.config.response_fn(msg)
|
71
|
-
|
72
|
-
mapped_response = self.config.response_dict.get(
|
73
|
-
msg, response or self.config.default_response
|
74
|
-
)
|
75
|
-
return lm.LLMResponse(
|
76
|
-
message=to_string(mapped_response),
|
77
|
-
cached=False,
|
78
|
-
)
|
79
|
-
|
80
|
-
def chat(
|
81
|
-
self,
|
82
|
-
messages: Union[str, List[lm.LLMMessage]],
|
83
|
-
max_tokens: int = 200,
|
84
|
-
tools: Optional[List[OpenAIToolSpec]] = None,
|
85
|
-
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
86
|
-
functions: Optional[List[lm.LLMFunctionSpec]] = None,
|
87
|
-
function_call: str | Dict[str, str] = "auto",
|
88
|
-
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
89
|
-
) -> lm.LLMResponse:
|
90
|
-
"""
|
91
|
-
Mock chat function for testing
|
92
|
-
"""
|
93
|
-
last_msg = messages[-1].content if isinstance(messages, list) else messages
|
94
|
-
return self._response(last_msg)
|
95
|
-
|
96
|
-
async def achat(
|
97
|
-
self,
|
98
|
-
messages: Union[str, List[lm.LLMMessage]],
|
99
|
-
max_tokens: int = 200,
|
100
|
-
tools: Optional[List[OpenAIToolSpec]] = None,
|
101
|
-
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
102
|
-
functions: Optional[List[lm.LLMFunctionSpec]] = None,
|
103
|
-
function_call: str | Dict[str, str] = "auto",
|
104
|
-
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
105
|
-
) -> lm.LLMResponse:
|
106
|
-
"""
|
107
|
-
Mock chat function for testing
|
108
|
-
"""
|
109
|
-
last_msg = messages[-1].content if isinstance(messages, list) else messages
|
110
|
-
return await self._response_async(last_msg)
|
111
|
-
|
112
|
-
def generate(self, prompt: str, max_tokens: int = 200) -> lm.LLMResponse:
|
113
|
-
"""
|
114
|
-
Mock generate function for testing
|
115
|
-
"""
|
116
|
-
return self._response(prompt)
|
117
|
-
|
118
|
-
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
119
|
-
"""
|
120
|
-
Mock generate function for testing
|
121
|
-
"""
|
122
|
-
return await self._response_async(prompt)
|
123
|
-
|
124
|
-
def get_stream(self) -> bool:
|
125
|
-
return False
|
126
|
-
|
127
|
-
def set_stream(self, stream: bool) -> bool:
|
128
|
-
return False
|
File without changes
|
File without changes
|