langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""A simple MCP client for testing."""
|
|
15
|
+
|
|
16
|
+
from absl import app
|
|
17
|
+
from absl import flags
|
|
18
|
+
from langfun.core import mcp
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_URL = flags.DEFINE_string(
|
|
22
|
+
'url',
|
|
23
|
+
'http://localhost:8000/mcp',
|
|
24
|
+
'URL of the MCP server.',
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main(_):
|
|
29
|
+
print(mcp.McpClient.from_url(url=_URL.value).list_tools())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if __name__ == '__main__':
|
|
33
|
+
app.run(main)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Simple MCP server for testing."""
|
|
15
|
+
|
|
16
|
+
from absl import app as absl_app
|
|
17
|
+
from mcp.server import fastmcp as fastmcp_lib
|
|
18
|
+
|
|
19
|
+
mcp = fastmcp_lib.FastMCP(host='0.0.0.0', port=8000)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@mcp.tool()
|
|
23
|
+
async def add(a: int, b: int) -> int:
|
|
24
|
+
"""Adds two integers and returns their sum."""
|
|
25
|
+
return a + b
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main(_):
|
|
29
|
+
mcp.run(transport='streamable-http', mount_path='/mcp')
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if __name__ == '__main__':
|
|
33
|
+
absl_app.run(main)
|
langfun/core/mcp/tool.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""MCP tool."""
|
|
15
|
+
|
|
16
|
+
import base64
|
|
17
|
+
from typing import Annotated, Any, ClassVar
|
|
18
|
+
|
|
19
|
+
from langfun.core import async_support
|
|
20
|
+
from langfun.core import message as lf_message
|
|
21
|
+
from langfun.core import modalities as lf_modalities
|
|
22
|
+
from langfun.core.structured import schema as lf_schema
|
|
23
|
+
import mcp
|
|
24
|
+
import pyglove as pg
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _McpToolMeta(pg.symbolic.ObjectMeta):
|
|
28
|
+
|
|
29
|
+
def __repr__(self) -> str:
|
|
30
|
+
return f'<tool-class \'{self.__name__}\'>'
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class McpTool(pg.Object, metaclass=_McpToolMeta):
|
|
34
|
+
"""Represents a tool available on an MCP server.
|
|
35
|
+
|
|
36
|
+
`McpTool` is the base class for all tool proxies generated from an MCP
|
|
37
|
+
server's tool definitions. Users do not typically subclass `McpTool` directly.
|
|
38
|
+
Instead, tool classes are obtained by calling `lf.mcp.McpClient.list_tools()`
|
|
39
|
+
or `lf.mcp.McpSession.list_tools()`.
|
|
40
|
+
|
|
41
|
+
Once a tool class is obtained, it can be instantiated with input parameters
|
|
42
|
+
and called via an `McpSession` to execute the tool on the server.
|
|
43
|
+
|
|
44
|
+
**Example Usage:**
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import langfun as lf
|
|
48
|
+
|
|
49
|
+
client = lf.mcp.McpClient.from_command(...)
|
|
50
|
+
with client.session() as session:
|
|
51
|
+
# List tools and get the 'math' tool class.
|
|
52
|
+
math_tool_cls = session.list_tools()['math']
|
|
53
|
+
|
|
54
|
+
# Instantiate the tool with parameters and call it.
|
|
55
|
+
result = math_tool_cls(x=1, y=2, op='+')(session)
|
|
56
|
+
print(result)
|
|
57
|
+
```
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
TOOL_NAME: Annotated[
|
|
61
|
+
ClassVar[str],
|
|
62
|
+
'Tool name.'
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def python_definition(cls, markdown: bool = True) -> str:
|
|
67
|
+
"""Returns the Python definition of this tool's input schema.
|
|
68
|
+
|
|
69
|
+
This is useful for generating prompts that instruct a language model
|
|
70
|
+
on how to use the tool.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
markdown: If True, formats the output as a Markdown code block.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
A string containing the Python definition of the tool's input schema.
|
|
77
|
+
"""
|
|
78
|
+
return lf_schema.Schema.from_value(cls).schema_repr(markdown=markdown)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def result_to_message(
|
|
82
|
+
cls, result: mcp.types.CallToolResult
|
|
83
|
+
) -> lf_message.ToolMessage:
|
|
84
|
+
"""Converts an `mcp.types.CallToolResult` to an `lf.ToolMessage`.
|
|
85
|
+
|
|
86
|
+
This method translates results from the MCP protocol, including text,
|
|
87
|
+
image, and audio content, into a Langfun `ToolMessage`, making it easy
|
|
88
|
+
to integrate tool results into Langfun workflows.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
result: The `mcp.types.CallToolResult` object to convert.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
An `lf.ToolMessage` instance representing the tool result.
|
|
95
|
+
"""
|
|
96
|
+
chunks = []
|
|
97
|
+
for item in result.content:
|
|
98
|
+
if isinstance(item, mcp.types.TextContent):
|
|
99
|
+
chunk = item.text
|
|
100
|
+
elif isinstance(item, mcp.types.ImageContent):
|
|
101
|
+
chunk = lf_modalities.Image.from_bytes(_base64_decode(item.data))
|
|
102
|
+
elif isinstance(item, mcp.types.AudioContent):
|
|
103
|
+
chunk = lf_modalities.Audio.from_bytes(_base64_decode(item.data))
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f'Unsupported item type: {type(item)}')
|
|
106
|
+
chunks.append(chunk)
|
|
107
|
+
message = lf_message.ToolMessage.from_chunks(chunks)
|
|
108
|
+
if result.structuredContent:
|
|
109
|
+
message.metadata.update(result.structuredContent)
|
|
110
|
+
return message
|
|
111
|
+
|
|
112
|
+
def __call__(
|
|
113
|
+
self,
|
|
114
|
+
session,
|
|
115
|
+
*,
|
|
116
|
+
returns_message: bool = False) -> Any:
|
|
117
|
+
"""Calls the MCP tool synchronously within a given session.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
session: An `McpSession` object.
|
|
121
|
+
returns_message: If True, the raw `lf.ToolMessage` is returned.
|
|
122
|
+
If False(default), the method attempts to return a more specific result:
|
|
123
|
+
- The `result` field of the `ToolMessage` if it's populated.
|
|
124
|
+
- The `ToolMessage` itself if it contains multi-modal content.
|
|
125
|
+
- The `text` field of the `ToolMessage` otherwise.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The result of the tool call, processed according to `returns_message`.
|
|
129
|
+
"""
|
|
130
|
+
return async_support.invoke_sync(
|
|
131
|
+
self.acall,
|
|
132
|
+
session,
|
|
133
|
+
returns_message=returns_message
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def acall(
|
|
137
|
+
self,
|
|
138
|
+
session,
|
|
139
|
+
*,
|
|
140
|
+
returns_message: bool = False
|
|
141
|
+
) -> Any:
|
|
142
|
+
"""Calls the MCP tool asynchronously within a given session.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
session: An `McpSession` object or an `mcp.ClientSession`.
|
|
146
|
+
returns_message: If True, the raw `lf.ToolMessage` is returned.
|
|
147
|
+
If False(default), the method attempts to return a more specific result:
|
|
148
|
+
- The `result` field of the `ToolMessage` if it's populated.
|
|
149
|
+
- The `ToolMessage` itself if it contains multi-modal content.
|
|
150
|
+
- The `text` field of the `ToolMessage` otherwise.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
The result of the tool call, processed according to `returns_message`.
|
|
154
|
+
"""
|
|
155
|
+
if not isinstance(session, mcp.ClientSession):
|
|
156
|
+
session = getattr(session, '_session', None)
|
|
157
|
+
assert session is not None, 'MCP session is not entered.'
|
|
158
|
+
tool_call_result = await session.call_tool(
|
|
159
|
+
self.TOOL_NAME, self.input_parameters()
|
|
160
|
+
)
|
|
161
|
+
message = self.result_to_message(tool_call_result)
|
|
162
|
+
if returns_message:
|
|
163
|
+
return message
|
|
164
|
+
if message.result:
|
|
165
|
+
return message.result
|
|
166
|
+
if message.referred_modalities:
|
|
167
|
+
return message
|
|
168
|
+
return message.text
|
|
169
|
+
|
|
170
|
+
def input_parameters(self) -> dict[str, Any]:
|
|
171
|
+
"""Returns the input parameters for the tool call.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A dictionary containing the input parameters, formatted for an
|
|
175
|
+
MCP `call_tool` request.
|
|
176
|
+
"""
|
|
177
|
+
# Optional fields are represented as fields with default values. Therefore,
|
|
178
|
+
# we need to remove the default values from the JSON representation of the
|
|
179
|
+
# tool.
|
|
180
|
+
json = self.to_json(hide_default_values=True)
|
|
181
|
+
|
|
182
|
+
# Remove the type name key from the JSON representation of the tool.
|
|
183
|
+
def _transform(path: pg.KeyPath, x: Any) -> Any:
|
|
184
|
+
del path
|
|
185
|
+
if isinstance(x, dict):
|
|
186
|
+
x.pop(pg.JSONConvertible.TYPE_NAME_KEY, None)
|
|
187
|
+
return x
|
|
188
|
+
return pg.utils.transform(json, _transform)
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def make_class(cls, tool_definition: mcp.Tool) -> type['McpTool']:
|
|
192
|
+
"""Creates an `McpTool` subclass from an MCP tool definition.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
tool_definition: An `mcp.Tool` object containing the tool's metadata.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
A dynamically generated class that inherits from `McpTool` and
|
|
199
|
+
represents the defined tool.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
class _McpTool(cls):
|
|
203
|
+
auto_schema = False
|
|
204
|
+
|
|
205
|
+
tool_cls = _McpTool
|
|
206
|
+
tool_cls.TOOL_NAME = tool_definition.name
|
|
207
|
+
tool_cls.__name__ = _snake_to_camel(tool_definition.name)
|
|
208
|
+
tool_cls.__doc__ = tool_definition.description
|
|
209
|
+
schema = pg.Schema.from_json_schema(
|
|
210
|
+
tool_definition.inputSchema, class_fn=McpToolInput.make_class
|
|
211
|
+
)
|
|
212
|
+
tool_cls.apply_schema(schema)
|
|
213
|
+
return tool_cls
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class _McpToolInputMeta(pg.symbolic.ObjectMeta):
|
|
217
|
+
|
|
218
|
+
def __repr__(self) -> str:
|
|
219
|
+
return f'<input-class \'{self.__name__}\'>'
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class McpToolInput(pg.Object, metaclass=_McpToolInputMeta):
|
|
223
|
+
"""Base class for generated MCP tool input schemas."""
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def make_class(cls, name: str, schema: pg.Schema):
|
|
227
|
+
"""Creates an `McpToolInput` subclass from a schema.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
name: The name of the input class to generate.
|
|
231
|
+
schema: A `pg.Schema` object defining the input fields.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
A dynamically generated class that inherits from `McpToolInput`.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
class _McpToolInput(cls):
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
input_cls = _McpToolInput
|
|
241
|
+
input_cls.__name__ = _snake_to_camel(name)
|
|
242
|
+
input_cls.__doc__ = schema.description
|
|
243
|
+
input_cls.apply_schema(schema)
|
|
244
|
+
return input_cls
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _snake_to_camel(name: str) -> str:
|
|
248
|
+
"""Converts a snake_case name to a CamelCase name."""
|
|
249
|
+
return ''.join(x.capitalize() for x in name.split('_'))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _base64_decode(data: str) -> bytes:
|
|
253
|
+
"""Decodes a base64 string."""
|
|
254
|
+
return base64.b64decode(data.encode('utf-8'))
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Tests for MCP tool."""
|
|
15
|
+
|
|
16
|
+
import base64
|
|
17
|
+
import inspect
|
|
18
|
+
import unittest
|
|
19
|
+
|
|
20
|
+
from langfun.core import async_support
|
|
21
|
+
from langfun.core import message as lf_message
|
|
22
|
+
from langfun.core import modalities as lf_modalities
|
|
23
|
+
from langfun.core.mcp import client as mcp_client
|
|
24
|
+
from langfun.core.mcp import tool as mcp_tool
|
|
25
|
+
import mcp
|
|
26
|
+
from mcp.server import fastmcp as fastmcp_lib
|
|
27
|
+
import pyglove as pg
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# MCP server setup for testing.
|
|
31
|
+
_mcp_server = fastmcp_lib.FastMCP(host='0.0.0.0', port=1235)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@_mcp_server.tool()
|
|
35
|
+
async def add(a: int, b: int) -> int:
|
|
36
|
+
"""Adds two integers."""
|
|
37
|
+
return a + b
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class McpToolTest(unittest.TestCase):
|
|
41
|
+
|
|
42
|
+
def setUp(self):
|
|
43
|
+
super().setUp()
|
|
44
|
+
self.client = mcp_client.McpClient.from_fastmcp(_mcp_server)
|
|
45
|
+
self.tools = self.client.list_tools()
|
|
46
|
+
|
|
47
|
+
def test_snake_to_camel(self):
|
|
48
|
+
self.assertEqual(mcp_tool._snake_to_camel('foo_bar'), 'FooBar')
|
|
49
|
+
self.assertEqual(mcp_tool._snake_to_camel('foo'), 'Foo')
|
|
50
|
+
|
|
51
|
+
def test_base64_decode(self):
|
|
52
|
+
self.assertEqual(
|
|
53
|
+
mcp_tool._base64_decode(base64.b64encode(b'foo').decode('utf-8')),
|
|
54
|
+
b'foo'
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def test_make_input_class(self):
|
|
58
|
+
schema = pg.Schema(
|
|
59
|
+
description='Foo input.',
|
|
60
|
+
fields=[
|
|
61
|
+
pg.typing.Field('x', pg.typing.Int(), 'Integer x.'),
|
|
62
|
+
pg.typing.Field('y', pg.typing.Str(), 'String y.'),
|
|
63
|
+
],
|
|
64
|
+
)
|
|
65
|
+
input_cls = mcp_tool.McpToolInput.make_class('foo_input', schema)
|
|
66
|
+
self.assertTrue(issubclass(input_cls, mcp_tool.McpToolInput))
|
|
67
|
+
self.assertEqual(input_cls.__name__, 'FooInput')
|
|
68
|
+
self.assertEqual(input_cls.__doc__, 'Foo input.')
|
|
69
|
+
s = input_cls.__schema__
|
|
70
|
+
self.assertEqual(list(s.fields.keys()), ['x', 'y'])
|
|
71
|
+
self.assertEqual(repr(input_cls), "<input-class 'FooInput'>")
|
|
72
|
+
self.assertEqual(
|
|
73
|
+
repr(input_cls(x=1, y='abc')),
|
|
74
|
+
"FooInput(x=1, y='abc')",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def test_make_tool_class(self):
|
|
78
|
+
tool_def = mcp.Tool(
|
|
79
|
+
name='my_tool',
|
|
80
|
+
inputSchema={
|
|
81
|
+
'type': 'object',
|
|
82
|
+
'properties': {
|
|
83
|
+
'a': {'type': 'integer', 'description': 'Integer a.'},
|
|
84
|
+
'b': {'type': 'string', 'description': 'String b.'},
|
|
85
|
+
},
|
|
86
|
+
'required': ['a'],
|
|
87
|
+
},
|
|
88
|
+
description='My tool.',
|
|
89
|
+
)
|
|
90
|
+
tool_cls = mcp_tool.McpTool.make_class(tool_def)
|
|
91
|
+
self.assertTrue(issubclass(tool_cls, mcp_tool.McpTool))
|
|
92
|
+
self.assertEqual(tool_cls.__name__, 'MyTool')
|
|
93
|
+
self.assertEqual(tool_cls.TOOL_NAME, 'my_tool')
|
|
94
|
+
self.assertEqual(tool_cls.__doc__, 'My tool.')
|
|
95
|
+
s = tool_cls.__schema__
|
|
96
|
+
self.assertEqual(list(s.fields.keys()), ['a', 'b'])
|
|
97
|
+
self.assertEqual(repr(tool_cls), "<tool-class 'MyTool'>")
|
|
98
|
+
self.assertEqual(s.fields['a'].description, 'Integer a.')
|
|
99
|
+
self.assertEqual(s.fields['b'].description, 'String b.')
|
|
100
|
+
|
|
101
|
+
self.assertEqual(
|
|
102
|
+
tool_cls.python_definition(markdown=True),
|
|
103
|
+
inspect.cleandoc(
|
|
104
|
+
"""
|
|
105
|
+
MyTool
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
class MyTool:
|
|
109
|
+
\"\"\"My tool.\"\"\"
|
|
110
|
+
# Integer a.
|
|
111
|
+
a: int
|
|
112
|
+
# String b.
|
|
113
|
+
b: str | None
|
|
114
|
+
```
|
|
115
|
+
"""
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
self.assertEqual(
|
|
119
|
+
tool_cls.python_definition(markdown=False),
|
|
120
|
+
inspect.cleandoc(
|
|
121
|
+
"""
|
|
122
|
+
MyTool
|
|
123
|
+
|
|
124
|
+
class MyTool:
|
|
125
|
+
\"\"\"My tool.\"\"\"
|
|
126
|
+
# Integer a.
|
|
127
|
+
a: int
|
|
128
|
+
# String b.
|
|
129
|
+
b: str | None
|
|
130
|
+
"""
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def test_input_parameters(self):
|
|
135
|
+
tool_cls = self.tools['add']
|
|
136
|
+
self.assertEqual(tool_cls(a=1, b=2).input_parameters(), {'a': 1, 'b': 2})
|
|
137
|
+
|
|
138
|
+
def test_result_to_message(self):
|
|
139
|
+
img_data = base64.b64encode(b'image-data').decode('utf-8')
|
|
140
|
+
audio_data = base64.b64encode(b'audio-data').decode('utf-8')
|
|
141
|
+
|
|
142
|
+
tool_def = self.tools['add']
|
|
143
|
+
result = mcp.types.CallToolResult(
|
|
144
|
+
content=[
|
|
145
|
+
mcp.types.TextContent(type='text', text='hello'),
|
|
146
|
+
mcp.types.ImageContent(
|
|
147
|
+
type='image', data=img_data, mimeType='image/png'
|
|
148
|
+
),
|
|
149
|
+
mcp.types.AudioContent(
|
|
150
|
+
type='audio', data=audio_data, mimeType='audio/wav'
|
|
151
|
+
),
|
|
152
|
+
],
|
|
153
|
+
structuredContent={'x': 1},
|
|
154
|
+
)
|
|
155
|
+
message = tool_def.result_to_message(result)
|
|
156
|
+
self.assertIsInstance(message, lf_message.ToolMessage)
|
|
157
|
+
self.assertIn('hello', message.text)
|
|
158
|
+
self.assertIn('<<[[image', message.text)
|
|
159
|
+
self.assertIn('<<[[audio', message.text)
|
|
160
|
+
self.assertEqual(message.metadata, {'x': 1})
|
|
161
|
+
modalities = message.modalities()
|
|
162
|
+
self.assertEqual(len(modalities), 2)
|
|
163
|
+
self.assertIsInstance(modalities[0], lf_modalities.Image)
|
|
164
|
+
self.assertEqual(modalities[0].to_bytes(), b'image-data')
|
|
165
|
+
self.assertIsInstance(modalities[1], lf_modalities.Audio)
|
|
166
|
+
self.assertEqual(modalities[1].to_bytes(), b'audio-data')
|
|
167
|
+
|
|
168
|
+
def test_sync_call(self):
|
|
169
|
+
add_tool_cls = self.tools['add']
|
|
170
|
+
with self.client.session() as session:
|
|
171
|
+
# Test returning structured content.
|
|
172
|
+
self.assertEqual(add_tool_cls(a=1, b=2)(session), 3)
|
|
173
|
+
|
|
174
|
+
# Test returning message.
|
|
175
|
+
self.assertEqual(
|
|
176
|
+
add_tool_cls(a=1, b=2)(session, returns_message=True),
|
|
177
|
+
lf_message.ToolMessage(text='3', result=3),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def test_async_call(self):
|
|
181
|
+
async def _test():
|
|
182
|
+
add_tool_cls = self.tools['add']
|
|
183
|
+
async with self.client.session() as session:
|
|
184
|
+
# Test returning structured content.
|
|
185
|
+
self.assertEqual(await add_tool_cls(a=1, b=2).acall(session), 3)
|
|
186
|
+
|
|
187
|
+
# Test returning message.
|
|
188
|
+
self.assertEqual(
|
|
189
|
+
await add_tool_cls(a=1, b=2).acall(session, returns_message=True),
|
|
190
|
+
lf_message.ToolMessage(text='3', result=3),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
async_support.invoke_sync(_test)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
if __name__ == '__main__':
|
|
197
|
+
unittest.main()
|