afm-langchain 0.1.0.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- afm_langchain-0.1.0.dev1/PKG-INFO +17 -0
- afm_langchain-0.1.0.dev1/README.md +0 -0
- afm_langchain-0.1.0.dev1/pyproject.toml +39 -0
- afm_langchain-0.1.0.dev1/src/afm_langchain/__init__.py +8 -0
- afm_langchain-0.1.0.dev1/src/afm_langchain/backend.py +295 -0
- afm_langchain-0.1.0.dev1/src/afm_langchain/providers.py +136 -0
- afm_langchain-0.1.0.dev1/src/afm_langchain/tools/__init__.py +2 -0
- afm_langchain-0.1.0.dev1/src/afm_langchain/tools/mcp.py +244 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: afm-langchain
|
|
3
|
+
Version: 0.1.0.dev1
|
|
4
|
+
Summary: AFM LangChain execution backend
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
Classifier: Development Status :: 3 - Alpha
|
|
7
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
8
|
+
Requires-Dist: afm-core
|
|
9
|
+
Requires-Dist: langchain>=1.2.8
|
|
10
|
+
Requires-Dist: langchain-openai>=1.1.7
|
|
11
|
+
Requires-Dist: langchain-anthropic>=1.3.1
|
|
12
|
+
Requires-Dist: mcp>=1.26.0
|
|
13
|
+
Requires-Dist: langchain-mcp-adapters>=0.2.1
|
|
14
|
+
Requires-Python: >=3.11
|
|
15
|
+
Project-URL: Repository, https://github.com/wso2/reference-implementations-afm
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "afm-langchain"
|
|
3
|
+
version = "0.1.0.dev1"
|
|
4
|
+
description = "AFM LangChain execution backend"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
classifiers = [
|
|
7
|
+
"Development Status :: 3 - Alpha",
|
|
8
|
+
"Programming Language :: Python :: 3.11",
|
|
9
|
+
]
|
|
10
|
+
license = "Apache-2.0"
|
|
11
|
+
requires-python = ">=3.11"
|
|
12
|
+
urls = { Repository = "https://github.com/wso2/reference-implementations-afm" }
|
|
13
|
+
dependencies = [
|
|
14
|
+
"afm-core",
|
|
15
|
+
"langchain>=1.2.8",
|
|
16
|
+
"langchain-openai>=1.1.7",
|
|
17
|
+
"langchain-anthropic>=1.3.1",
|
|
18
|
+
"mcp>=1.26.0",
|
|
19
|
+
"langchain-mcp-adapters>=0.2.1",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[project.entry-points."afm.runner"]
|
|
23
|
+
langchain = "afm_langchain.backend:LangChainRunner"
|
|
24
|
+
|
|
25
|
+
[build-system]
|
|
26
|
+
requires = ["uv_build>=0.9.28,<0.10.0"]
|
|
27
|
+
build-backend = "uv_build"
|
|
28
|
+
|
|
29
|
+
[tool.uv.build-backend]
|
|
30
|
+
module-name = "afm_langchain"
|
|
31
|
+
|
|
32
|
+
[tool.uv.sources]
|
|
33
|
+
afm-core = { workspace = true }
|
|
34
|
+
|
|
35
|
+
[tool.pytest.ini_options]
|
|
36
|
+
testpaths = ["tests"]
|
|
37
|
+
pythonpath = ["src"]
|
|
38
|
+
asyncio_mode = "auto"
|
|
39
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright (c) 2025
|
|
2
|
+
# Licensed under the Apache License, Version 2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from types import TracebackType
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from langchain_core.language_models import BaseChatModel
|
|
12
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
13
|
+
from langchain_core.tools import BaseTool
|
|
14
|
+
|
|
15
|
+
from afm.exceptions import AgentError, InputValidationError, OutputValidationError
|
|
16
|
+
from afm.models import (
|
|
17
|
+
AFMRecord,
|
|
18
|
+
Interface,
|
|
19
|
+
Signature,
|
|
20
|
+
)
|
|
21
|
+
from afm.schema_validator import (
|
|
22
|
+
build_output_schema_instruction,
|
|
23
|
+
coerce_output_to_schema,
|
|
24
|
+
validate_input,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from .providers import create_model_provider
|
|
28
|
+
from .tools.mcp import MCPManager
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LangChainRunner:
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
afm: AFMRecord,
|
|
37
|
+
*,
|
|
38
|
+
model: BaseChatModel | None = None,
|
|
39
|
+
tools: list[BaseTool] | None = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
self._afm = afm
|
|
42
|
+
self._base_model = model or create_model_provider(afm.metadata.model)
|
|
43
|
+
self._model = self._base_model # Will be updated with tools when connected
|
|
44
|
+
self._sessions: dict[str, list[HumanMessage | AIMessage]] = {}
|
|
45
|
+
|
|
46
|
+
# MCP management
|
|
47
|
+
self._mcp_manager = MCPManager.from_afm(afm)
|
|
48
|
+
self._external_tools = tools or []
|
|
49
|
+
self._mcp_tools: list[BaseTool] = []
|
|
50
|
+
self._connected = False
|
|
51
|
+
|
|
52
|
+
# Cache the active interface for signature validation
|
|
53
|
+
self._interface = self._get_primary_interface()
|
|
54
|
+
self._signature = self._get_signature()
|
|
55
|
+
|
|
56
|
+
async def __aenter__(self) -> "LangChainRunner":
|
|
57
|
+
await self.connect()
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
async def __aexit__(
|
|
61
|
+
self,
|
|
62
|
+
exc_type: type[BaseException] | None,
|
|
63
|
+
exc_val: BaseException | None,
|
|
64
|
+
exc_tb: TracebackType | None,
|
|
65
|
+
) -> None:
|
|
66
|
+
await self.disconnect()
|
|
67
|
+
|
|
68
|
+
async def connect(self) -> None:
|
|
69
|
+
if self._connected:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
if self._mcp_manager is not None:
|
|
73
|
+
logger.info(f"Connecting to MCP servers: {self._mcp_manager.server_names}")
|
|
74
|
+
self._mcp_tools = await self._mcp_manager.get_tools()
|
|
75
|
+
logger.info(f"Loaded {len(self._mcp_tools)} MCP tools")
|
|
76
|
+
|
|
77
|
+
# Bind tools to model if any are available
|
|
78
|
+
all_tools = self._get_all_tools()
|
|
79
|
+
if all_tools:
|
|
80
|
+
self._model = self._base_model.bind_tools(all_tools)
|
|
81
|
+
logger.info(f"Bound {len(all_tools)} tools to model")
|
|
82
|
+
|
|
83
|
+
self._connected = True
|
|
84
|
+
|
|
85
|
+
async def disconnect(self) -> None:
|
|
86
|
+
if not self._connected:
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
# Clear MCP tools and reset model
|
|
90
|
+
self._mcp_tools = []
|
|
91
|
+
self._model = self._base_model
|
|
92
|
+
self._connected = False
|
|
93
|
+
|
|
94
|
+
if self._mcp_manager is not None:
|
|
95
|
+
self._mcp_manager.clear_cache()
|
|
96
|
+
logger.info("Disconnected from MCP servers")
|
|
97
|
+
|
|
98
|
+
def _get_all_tools(self) -> list[BaseTool]:
|
|
99
|
+
|
|
100
|
+
return self._external_tools + self._mcp_tools
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def afm(self) -> AFMRecord:
|
|
104
|
+
return self._afm
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def name(self) -> str:
|
|
108
|
+
return self._afm.metadata.name or "AFM Agent"
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def description(self) -> str | None:
|
|
112
|
+
return self._afm.metadata.description
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def system_prompt(self) -> str:
|
|
116
|
+
return f"""# Role
|
|
117
|
+
|
|
118
|
+
{self._afm.role}
|
|
119
|
+
|
|
120
|
+
# Instructions
|
|
121
|
+
|
|
122
|
+
{self._afm.instructions}"""
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def max_iterations(self) -> int | None:
|
|
126
|
+
return self._afm.metadata.max_iterations
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def tools(self) -> list[BaseTool]:
|
|
130
|
+
return self._get_all_tools()
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def signature(self) -> Signature:
|
|
134
|
+
return self._signature
|
|
135
|
+
|
|
136
|
+
def _get_primary_interface(self) -> Interface | None:
|
|
137
|
+
interfaces = self._afm.metadata.interfaces
|
|
138
|
+
if interfaces and len(interfaces) > 0:
|
|
139
|
+
return interfaces[0]
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
def _get_signature(self) -> Signature:
|
|
143
|
+
if self._interface is not None:
|
|
144
|
+
return self._interface.signature
|
|
145
|
+
# Default: string input/output
|
|
146
|
+
return Signature()
|
|
147
|
+
|
|
148
|
+
def _get_session_history(self, session_id: str) -> list[HumanMessage | AIMessage]:
|
|
149
|
+
if session_id not in self._sessions:
|
|
150
|
+
self._sessions[session_id] = []
|
|
151
|
+
return self._sessions[session_id]
|
|
152
|
+
|
|
153
|
+
def _prepare_input(self, input_data: str | dict[str, Any]) -> str:
|
|
154
|
+
input_schema = self._signature.input
|
|
155
|
+
|
|
156
|
+
# Validate input
|
|
157
|
+
validate_input(input_data, input_schema)
|
|
158
|
+
|
|
159
|
+
# Convert to string for the LLM
|
|
160
|
+
if isinstance(input_data, str):
|
|
161
|
+
return input_data
|
|
162
|
+
return json.dumps(input_data)
|
|
163
|
+
|
|
164
|
+
def _build_messages(
|
|
165
|
+
self,
|
|
166
|
+
user_input: str,
|
|
167
|
+
session_history: list[HumanMessage | AIMessage],
|
|
168
|
+
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
|
169
|
+
messages: list[SystemMessage | HumanMessage | AIMessage | ToolMessage] = [
|
|
170
|
+
SystemMessage(content=self.system_prompt)
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
# Add conversation history
|
|
174
|
+
messages.extend(session_history)
|
|
175
|
+
|
|
176
|
+
# Add the current user message
|
|
177
|
+
# If output schema is not string, append schema instructions
|
|
178
|
+
output_schema = self._signature.output
|
|
179
|
+
if output_schema.type != "string":
|
|
180
|
+
schema_instruction = build_output_schema_instruction(output_schema)
|
|
181
|
+
user_input = user_input + schema_instruction
|
|
182
|
+
|
|
183
|
+
messages.append(HumanMessage(content=user_input))
|
|
184
|
+
|
|
185
|
+
return messages
|
|
186
|
+
|
|
187
|
+
def _extract_response_content(self, response: Any) -> str:
|
|
188
|
+
if isinstance(response, AIMessage):
|
|
189
|
+
content = response.content
|
|
190
|
+
else:
|
|
191
|
+
content = str(response)
|
|
192
|
+
|
|
193
|
+
if not isinstance(content, str):
|
|
194
|
+
content = str(content)
|
|
195
|
+
|
|
196
|
+
return content
|
|
197
|
+
|
|
198
|
+
async def arun(
|
|
199
|
+
self,
|
|
200
|
+
input_data: str | dict[str, Any],
|
|
201
|
+
*,
|
|
202
|
+
session_id: str = "default",
|
|
203
|
+
) -> str | dict[str, Any]:
|
|
204
|
+
try:
|
|
205
|
+
# Prepare and validate input
|
|
206
|
+
user_input = self._prepare_input(input_data)
|
|
207
|
+
|
|
208
|
+
# Get session history
|
|
209
|
+
session_history = self._get_session_history(session_id)
|
|
210
|
+
|
|
211
|
+
# Save the original input for history before schema augmentation
|
|
212
|
+
original_input = user_input
|
|
213
|
+
|
|
214
|
+
# Build messages
|
|
215
|
+
messages: list[Any] = self._build_messages(user_input, session_history)
|
|
216
|
+
|
|
217
|
+
# Max iterations for tool use
|
|
218
|
+
max_iterations = (
|
|
219
|
+
self.max_iterations if self.max_iterations is not None else 10
|
|
220
|
+
)
|
|
221
|
+
iterations = 0
|
|
222
|
+
response = None
|
|
223
|
+
|
|
224
|
+
# Main agent loop to handle tool calls
|
|
225
|
+
while iterations < max_iterations:
|
|
226
|
+
# Invoke the LLM asynchronously
|
|
227
|
+
response = await self._model.ainvoke(messages)
|
|
228
|
+
|
|
229
|
+
# If no tool calls, we're done
|
|
230
|
+
if not response.tool_calls:
|
|
231
|
+
break
|
|
232
|
+
|
|
233
|
+
# Add the assistant message (containing tool calls) to the conversation
|
|
234
|
+
messages.append(response)
|
|
235
|
+
|
|
236
|
+
# Execute tool calls
|
|
237
|
+
for tool_call in response.tool_calls:
|
|
238
|
+
# Find the tool
|
|
239
|
+
tool_name = tool_call["name"]
|
|
240
|
+
tool = next((t for t in self.tools if t.name == tool_name), None)
|
|
241
|
+
|
|
242
|
+
if tool is None:
|
|
243
|
+
tool_output = f"Error: Tool '{tool_name}' not found."
|
|
244
|
+
else:
|
|
245
|
+
try:
|
|
246
|
+
# Run the tool
|
|
247
|
+
tool_output = await tool.ainvoke(tool_call["args"])
|
|
248
|
+
except Exception as e:
|
|
249
|
+
tool_output = f"Error executing tool '{tool_name}': {e}"
|
|
250
|
+
|
|
251
|
+
# Add tool response to messages
|
|
252
|
+
messages.append(
|
|
253
|
+
ToolMessage(
|
|
254
|
+
content=str(tool_output),
|
|
255
|
+
tool_call_id=tool_call["id"],
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
iterations += 1
|
|
260
|
+
|
|
261
|
+
if iterations >= max_iterations and response and response.tool_calls:
|
|
262
|
+
logger.warning(
|
|
263
|
+
f"Max iterations ({max_iterations}) reached with "
|
|
264
|
+
f"{len(response.tool_calls)} pending tool calls: "
|
|
265
|
+
f"{[tc['name'] for tc in response.tool_calls]}"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if response is None:
|
|
269
|
+
raise AgentError("No response from LLM")
|
|
270
|
+
|
|
271
|
+
# Extract content from response
|
|
272
|
+
response_content = self._extract_response_content(response)
|
|
273
|
+
|
|
274
|
+
# Validate and coerce output
|
|
275
|
+
output_schema = self._signature.output
|
|
276
|
+
result = coerce_output_to_schema(response_content, output_schema)
|
|
277
|
+
|
|
278
|
+
# Update session history
|
|
279
|
+
session_history.append(HumanMessage(content=original_input))
|
|
280
|
+
session_history.append(AIMessage(content=response_content))
|
|
281
|
+
|
|
282
|
+
return result
|
|
283
|
+
|
|
284
|
+
except InputValidationError:
|
|
285
|
+
raise
|
|
286
|
+
except OutputValidationError:
|
|
287
|
+
raise
|
|
288
|
+
except Exception as e:
|
|
289
|
+
if isinstance(e, AgentError):
|
|
290
|
+
raise
|
|
291
|
+
raise AgentError(f"Agent execution failed: {e}") from e
|
|
292
|
+
|
|
293
|
+
def clear_history(self, session_id: str = "default") -> None:
|
|
294
|
+
if session_id in self._sessions:
|
|
295
|
+
del self._sessions[session_id]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright (c) 2025
|
|
2
|
+
# Licensed under the Apache License, Version 2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from langchain_core.language_models import BaseChatModel
|
|
10
|
+
|
|
11
|
+
from afm.exceptions import ProviderError
|
|
12
|
+
from afm.models import ClientAuthentication, Model
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from langchain_anthropic import ChatAnthropic
|
|
16
|
+
from langchain_openai import ChatOpenAI
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
DEFAULT_OPENAI_MODEL = "gpt-4o"
|
|
20
|
+
DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-5"
|
|
21
|
+
|
|
22
|
+
DEFAULT_OPENAI_URL = "https://api.openai.com/v1"
|
|
23
|
+
DEFAULT_ANTHROPIC_URL = "https://api.anthropic.com"
|
|
24
|
+
|
|
25
|
+
# Environment variable names for API keys
|
|
26
|
+
OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
|
|
27
|
+
ANTHROPIC_API_KEY_ENV = "ANTHROPIC_API_KEY"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_model_provider(afm_model: Model | None = None) -> BaseChatModel:
|
|
31
|
+
if afm_model is None:
|
|
32
|
+
return _create_openai_model(None)
|
|
33
|
+
|
|
34
|
+
provider = (afm_model.provider or "openai").lower()
|
|
35
|
+
|
|
36
|
+
match provider:
|
|
37
|
+
case "openai":
|
|
38
|
+
return _create_openai_model(afm_model)
|
|
39
|
+
case "anthropic":
|
|
40
|
+
return _create_anthropic_model(afm_model)
|
|
41
|
+
case _:
|
|
42
|
+
raise ProviderError(f"Unsupported provider: {provider}", provider=provider)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _create_openai_model(afm_model: Model | None) -> ChatOpenAI:
|
|
46
|
+
try:
|
|
47
|
+
from langchain_openai import ChatOpenAI
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
raise ProviderError(
|
|
50
|
+
"langchain-openai package is required for OpenAI models. "
|
|
51
|
+
"Install it with: pip install langchain-openai",
|
|
52
|
+
provider="openai",
|
|
53
|
+
) from e
|
|
54
|
+
|
|
55
|
+
api_key = _get_api_key(
|
|
56
|
+
afm_model.authentication if afm_model else None,
|
|
57
|
+
OPENAI_API_KEY_ENV,
|
|
58
|
+
"openai",
|
|
59
|
+
)
|
|
60
|
+
model_name = (
|
|
61
|
+
afm_model.name if afm_model and afm_model.name else DEFAULT_OPENAI_MODEL
|
|
62
|
+
)
|
|
63
|
+
base_url = afm_model.url if afm_model and afm_model.url else None
|
|
64
|
+
|
|
65
|
+
kwargs: dict = {
|
|
66
|
+
"api_key": api_key,
|
|
67
|
+
"model": model_name,
|
|
68
|
+
}
|
|
69
|
+
if base_url:
|
|
70
|
+
kwargs["base_url"] = base_url
|
|
71
|
+
|
|
72
|
+
return ChatOpenAI(**kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _create_anthropic_model(afm_model: Model) -> ChatAnthropic:
|
|
76
|
+
try:
|
|
77
|
+
from langchain_anthropic import ChatAnthropic
|
|
78
|
+
except ImportError as e:
|
|
79
|
+
raise ProviderError(
|
|
80
|
+
"langchain-anthropic package is required for Anthropic models. "
|
|
81
|
+
"Install it with: pip install langchain-anthropic",
|
|
82
|
+
provider="anthropic",
|
|
83
|
+
) from e
|
|
84
|
+
|
|
85
|
+
api_key = _get_api_key(
|
|
86
|
+
afm_model.authentication,
|
|
87
|
+
ANTHROPIC_API_KEY_ENV,
|
|
88
|
+
"anthropic",
|
|
89
|
+
)
|
|
90
|
+
model_name = afm_model.name if afm_model.name else DEFAULT_ANTHROPIC_MODEL
|
|
91
|
+
base_url = afm_model.url if afm_model.url else None
|
|
92
|
+
|
|
93
|
+
kwargs: dict = {
|
|
94
|
+
"api_key": api_key,
|
|
95
|
+
"model": model_name,
|
|
96
|
+
}
|
|
97
|
+
if base_url:
|
|
98
|
+
kwargs["base_url"] = base_url
|
|
99
|
+
|
|
100
|
+
return ChatAnthropic(**kwargs)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _get_api_key(
|
|
104
|
+
auth: ClientAuthentication | None,
|
|
105
|
+
env_var: str,
|
|
106
|
+
provider: str,
|
|
107
|
+
) -> str:
|
|
108
|
+
if auth is not None:
|
|
109
|
+
auth_type = auth.type.lower()
|
|
110
|
+
|
|
111
|
+
if auth_type == "bearer" and auth.token:
|
|
112
|
+
return auth.token
|
|
113
|
+
elif auth_type == "api-key" and auth.api_key:
|
|
114
|
+
return auth.api_key
|
|
115
|
+
elif auth_type == "basic":
|
|
116
|
+
raise ProviderError(
|
|
117
|
+
"Basic authentication is not supported for LLM providers. "
|
|
118
|
+
"Use 'bearer' or 'api-key' authentication type.",
|
|
119
|
+
provider=provider,
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
auth_dict = auth.model_dump(exclude_none=True)
|
|
123
|
+
for key in ["token", "api_key", "key", "apiKey"]:
|
|
124
|
+
if key in auth_dict and auth_dict[key]:
|
|
125
|
+
return auth_dict[key]
|
|
126
|
+
|
|
127
|
+
# Fall back to environment variable
|
|
128
|
+
api_key = os.environ.get(env_var)
|
|
129
|
+
if api_key:
|
|
130
|
+
return api_key
|
|
131
|
+
|
|
132
|
+
raise ProviderError(
|
|
133
|
+
f"No API key found. Provide authentication in the model config "
|
|
134
|
+
f"or set the {env_var} environment variable.",
|
|
135
|
+
provider=provider,
|
|
136
|
+
)
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# Copyright (c) 2025
|
|
2
|
+
# Licensed under the Apache License, Version 2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from langchain_core.tools import BaseTool
|
|
10
|
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
11
|
+
from langchain_mcp_adapters.sessions import StreamableHttpConnection
|
|
12
|
+
|
|
13
|
+
from afm.exceptions import (
|
|
14
|
+
MCPAuthenticationError,
|
|
15
|
+
MCPConnectionError,
|
|
16
|
+
MCPError,
|
|
17
|
+
)
|
|
18
|
+
from afm.models import (
|
|
19
|
+
AFMRecord,
|
|
20
|
+
ClientAuthentication,
|
|
21
|
+
MCPServer,
|
|
22
|
+
ToolFilter,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BearerAuth(httpx.Auth):
|
|
29
|
+
def __init__(self, token: str) -> None:
|
|
30
|
+
self.token = token
|
|
31
|
+
|
|
32
|
+
def auth_flow(self, request: httpx.Request):
|
|
33
|
+
request.headers["Authorization"] = f"Bearer {self.token}"
|
|
34
|
+
yield request
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ApiKeyAuth(httpx.Auth):
|
|
38
|
+
def __init__(self, api_key: str, header_name: str = "Authorization") -> None:
|
|
39
|
+
self.api_key = api_key
|
|
40
|
+
self.header_name = header_name
|
|
41
|
+
|
|
42
|
+
def auth_flow(self, request: httpx.Request):
|
|
43
|
+
request.headers[self.header_name] = self.api_key
|
|
44
|
+
yield request
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None:
|
|
48
|
+
if auth is None:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
auth_type = auth.type.lower()
|
|
52
|
+
|
|
53
|
+
if auth_type == "bearer":
|
|
54
|
+
if auth.token is None:
|
|
55
|
+
raise MCPAuthenticationError("Bearer auth requires 'token' field")
|
|
56
|
+
return BearerAuth(auth.token)
|
|
57
|
+
|
|
58
|
+
elif auth_type == "basic":
|
|
59
|
+
if auth.username is None or auth.password is None:
|
|
60
|
+
raise MCPAuthenticationError(
|
|
61
|
+
"Basic auth requires 'username' and 'password' fields"
|
|
62
|
+
)
|
|
63
|
+
return httpx.BasicAuth(auth.username, auth.password)
|
|
64
|
+
|
|
65
|
+
elif auth_type == "api-key":
|
|
66
|
+
if auth.api_key is None:
|
|
67
|
+
raise MCPAuthenticationError("API key auth requires 'api_key' field")
|
|
68
|
+
return ApiKeyAuth(auth.api_key)
|
|
69
|
+
|
|
70
|
+
elif auth_type in ("oauth2", "jwt"):
|
|
71
|
+
raise MCPAuthenticationError(
|
|
72
|
+
f"Authentication type '{auth_type}' not yet supported"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
else:
|
|
76
|
+
raise MCPAuthenticationError(f"Unsupported authentication type: {auth_type}")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def filter_tools(
|
|
80
|
+
tools: list[BaseTool],
|
|
81
|
+
tool_filter: ToolFilter | None,
|
|
82
|
+
) -> list[BaseTool]:
|
|
83
|
+
if tool_filter is None:
|
|
84
|
+
return tools
|
|
85
|
+
|
|
86
|
+
allow = tool_filter.allow
|
|
87
|
+
deny = tool_filter.deny
|
|
88
|
+
|
|
89
|
+
# No filters specified
|
|
90
|
+
if allow is None and deny is None:
|
|
91
|
+
return tools
|
|
92
|
+
|
|
93
|
+
# Build a set of tool names for efficient lookup
|
|
94
|
+
tool_names = {tool.name for tool in tools}
|
|
95
|
+
|
|
96
|
+
if allow is not None:
|
|
97
|
+
# Start with allowed tools only
|
|
98
|
+
allowed_set = set(allow) & tool_names
|
|
99
|
+
else:
|
|
100
|
+
# Start with all tools
|
|
101
|
+
allowed_set = tool_names
|
|
102
|
+
|
|
103
|
+
if deny is not None:
|
|
104
|
+
# Remove denied tools
|
|
105
|
+
allowed_set -= set(deny)
|
|
106
|
+
|
|
107
|
+
# Filter the tools list maintaining order
|
|
108
|
+
return [tool for tool in tools if tool.name in allowed_set]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class MCPClient:
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
name: str,
|
|
115
|
+
url: str,
|
|
116
|
+
authentication: ClientAuthentication | None = None,
|
|
117
|
+
tool_filter: ToolFilter | None = None,
|
|
118
|
+
) -> None:
|
|
119
|
+
"""Initialize an MCP client."""
|
|
120
|
+
self.name = name
|
|
121
|
+
self.url = url
|
|
122
|
+
self.authentication = authentication
|
|
123
|
+
self.tool_filter = tool_filter
|
|
124
|
+
self._tools: list[BaseTool] | None = None
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def from_mcp_server(cls, server: MCPServer) -> "MCPClient":
|
|
128
|
+
transport = server.transport
|
|
129
|
+
|
|
130
|
+
if transport.type != "http":
|
|
131
|
+
raise MCPError(
|
|
132
|
+
f"Unsupported transport type: {transport.type}. Only 'http' is supported for now.",
|
|
133
|
+
server_name=server.name,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return cls(
|
|
137
|
+
name=server.name,
|
|
138
|
+
url=transport.url,
|
|
139
|
+
authentication=transport.authentication,
|
|
140
|
+
tool_filter=server.tool_filter,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def _build_connection_config(self) -> StreamableHttpConnection:
|
|
144
|
+
config: StreamableHttpConnection = {
|
|
145
|
+
"transport": "streamable_http",
|
|
146
|
+
"url": self.url,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# Add authentication if configured
|
|
150
|
+
auth = build_httpx_auth(self.authentication)
|
|
151
|
+
if auth is not None:
|
|
152
|
+
config["auth"] = auth
|
|
153
|
+
|
|
154
|
+
return config
|
|
155
|
+
|
|
156
|
+
async def get_tools(self) -> list[BaseTool]:
|
|
157
|
+
try:
|
|
158
|
+
# Create a client for just this server
|
|
159
|
+
client = MultiServerMCPClient({self.name: self._build_connection_config()})
|
|
160
|
+
|
|
161
|
+
# Get tools from the server
|
|
162
|
+
tools = await client.get_tools(server_name=self.name)
|
|
163
|
+
|
|
164
|
+
# Apply filtering
|
|
165
|
+
filtered_tools = filter_tools(tools, self.tool_filter)
|
|
166
|
+
|
|
167
|
+
logger.info(
|
|
168
|
+
f"MCP server '{self.name}': loaded {len(filtered_tools)} tools "
|
|
169
|
+
f"(filtered from {len(tools)})"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return filtered_tools
|
|
173
|
+
|
|
174
|
+
except MCPError:
|
|
175
|
+
# Re-raise MCPError subclasses (like MCPAuthenticationError) to preserve diagnostics
|
|
176
|
+
raise
|
|
177
|
+
except Exception as e:
|
|
178
|
+
raise MCPConnectionError(
|
|
179
|
+
f"Failed to connect: {e}",
|
|
180
|
+
server_name=self.name,
|
|
181
|
+
) from e
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class MCPManager:
|
|
185
|
+
def __init__(self, servers: list[MCPServer]) -> None:
|
|
186
|
+
"""Initialize the MCP manager."""
|
|
187
|
+
self._servers = servers
|
|
188
|
+
self._clients: list[MCPClient] = []
|
|
189
|
+
self._tools: list[BaseTool] | None = None
|
|
190
|
+
|
|
191
|
+
# Create clients for each server
|
|
192
|
+
for server in servers:
|
|
193
|
+
try:
|
|
194
|
+
client = MCPClient.from_mcp_server(server)
|
|
195
|
+
self._clients.append(client)
|
|
196
|
+
except MCPError as e:
|
|
197
|
+
logger.warning(f"Skipping MCP server: {e}")
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def from_afm(cls, afm: AFMRecord) -> "MCPManager | None":
|
|
201
|
+
tools_config = afm.metadata.tools
|
|
202
|
+
if tools_config is None:
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
mcp_servers = tools_config.mcp
|
|
206
|
+
if mcp_servers is None or len(mcp_servers) == 0:
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
return cls(mcp_servers)
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def server_names(self) -> list[str]:
|
|
213
|
+
return [client.name for client in self._clients]
|
|
214
|
+
|
|
215
|
+
async def get_tools(self) -> list[BaseTool]:
|
|
216
|
+
if self._tools is not None:
|
|
217
|
+
return self._tools
|
|
218
|
+
|
|
219
|
+
all_tools: list[BaseTool] = []
|
|
220
|
+
errors: list[str] = []
|
|
221
|
+
|
|
222
|
+
# Get tools from each client individually to handle per-server filtering
|
|
223
|
+
for client in self._clients:
|
|
224
|
+
try:
|
|
225
|
+
tools = await client.get_tools()
|
|
226
|
+
all_tools.extend(tools)
|
|
227
|
+
except MCPConnectionError as e:
|
|
228
|
+
errors.append(str(e))
|
|
229
|
+
logger.error(f"Failed to get tools from server '{client.name}': {e}")
|
|
230
|
+
|
|
231
|
+
if errors and not all_tools:
|
|
232
|
+
raise MCPConnectionError(
|
|
233
|
+
f"Failed to connect to any MCP server: {'; '.join(errors)}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Only cache if all servers succeeded; partial results are not
|
|
237
|
+
# cached so that failed servers can be retried on the next call.
|
|
238
|
+
if not errors:
|
|
239
|
+
self._tools = all_tools
|
|
240
|
+
|
|
241
|
+
return all_tools
|
|
242
|
+
|
|
243
|
+
def clear_cache(self) -> None:
|
|
244
|
+
self._tools = None
|