langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for MCP client."""
15
+
16
+ import unittest
17
+ from langfun.core import async_support
18
+ from langfun.core import mcp as lf_mcp
19
+ from langfun.core import message as lf_message
20
+ from mcp.server import fastmcp as fastmcp_lib
21
+
22
+ mcp = fastmcp_lib.FastMCP(host='0.0.0.0', port=1234)
23
+
24
+
25
+ @mcp.tool()
26
+ async def add(a: int, b: int) -> int:
27
+ """Adds two integers and returns their sum.
28
+
29
+ Args:
30
+ a: The first integer.
31
+ b: The second integer.
32
+
33
+ Returns:
34
+ The sum of the two integers.
35
+ """
36
+ return a + b
37
+
38
+
39
+ class McpTest(unittest.TestCase):
40
+
41
+ def test_sync_usages(self):
42
+ client = lf_mcp.McpClient.from_fastmcp(mcp)
43
+ tools = client.list_tools()
44
+ self.assertEqual(len(tools), 1)
45
+ with client.session() as session:
46
+ self.assertEqual(
47
+ # Test `session.call_tool` method as `tool.__call__` is already tested
48
+ # in `tool_test.py`.
49
+ session.call_tool(tools['add'](a=1, b=2)), 3
50
+ )
51
+
52
+ def test_async_usages(self):
53
+ async def _test():
54
+ client = lf_mcp.McpClient.from_fastmcp(mcp)
55
+ tools = client.list_tools()
56
+ self.assertEqual(len(tools), 1)
57
+ tool_cls = tools['add']
58
+ self.assertEqual(tool_cls.__name__, 'Add')
59
+ self.assertEqual(tool_cls.TOOL_NAME, 'add')
60
+ async with client.session() as session:
61
+ self.assertEqual(
62
+ # Test `session.acall_tool` method as `tool.acall` is already
63
+ # tested in `tool_test.py`.
64
+ await session.acall_tool(tool_cls(a=1, b=2), returns_message=True),
65
+ lf_message.ToolMessage(text='3', result=3)
66
+ )
67
+ async_support.invoke_sync(_test)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ unittest.main()
@@ -0,0 +1,241 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MCP session."""
15
+
16
+ import contextlib
17
+ from typing import Any, Type
18
+ import anyio
19
+ from langfun.core import async_support
20
+ from langfun.core.mcp import tool as mcp_tool
21
+ import mcp
22
+ from mcp.client import sse
23
+ from mcp.client import streamable_http
24
+ from mcp.server import fastmcp as fastmcp_lib
25
+ from mcp.shared import memory
26
+
27
+
28
+ class McpSession:
29
+ """Represents a session for interacting with an MCP server.
30
+
31
+ `McpSession` provides the context for making calls to tools hosted on an
32
+ MCP server. It wraps the standard `mcp.ClientSession` to offer both
33
+ synchronous and asynchronous usage patterns.
34
+
35
+ Sessions are created using `lf.mcp.McpClient.session()` and should be used
36
+ as context managers (either sync or async) to ensure proper initialization
37
+ and teardown of the connection to the server.
38
+
39
+ **Example Sync Usage:**
40
+
41
+ ```python
42
+ import langfun as lf
43
+
44
+ client = lf.mcp.McpClient.from_command(...)
45
+ with client.session() as session:
46
+ tools = session.list_tools()
47
+ result = tools['my_tool'](x=1)(session)
48
+ ```
49
+
50
+ **Example Async Usage:**
51
+
52
+ ```python
53
+ import langfun as lf
54
+
55
+ client = lf.mcp.McpClient.from_url(...)
56
+ async with client.session() as session:
57
+ tools = await session.alist_tools()
58
+ result = await tools['my_tool'](x=1).acall(session)
59
+ ```
60
+ """
61
+
62
+ def __init__(self, stream) -> None:
63
+ self._stream = stream
64
+ self._session = None
65
+ self._session_exit_stack = None
66
+ self._in_session = False
67
+
68
+ # For supporting sync context manager.
69
+ self._sync_context_manager_exit_stack = None
70
+
71
+ def __enter__(self) -> 'McpSession':
72
+ exit_stack = contextlib.ExitStack()
73
+ exit_stack.enter_context(async_support.sync_context_manager(self))
74
+ self._sync_context_manager_exit_stack = exit_stack
75
+ return self
76
+
77
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
78
+ assert self._sync_context_manager_exit_stack is not None
79
+ self._sync_context_manager_exit_stack.close()
80
+
81
+ async def __aenter__(self) -> 'McpSession':
82
+ assert self._session_exit_stack is None, 'Session cannot be re-entered.'
83
+
84
+ self._session_exit_stack = contextlib.AsyncExitStack()
85
+ stream_output = await self._session_exit_stack.enter_async_context(
86
+ self._stream
87
+ )
88
+ assert isinstance(stream_output, tuple) and len(stream_output) in (2, 3)
89
+ read, write = stream_output[:2]
90
+ self._session = mcp.ClientSession(read, write)
91
+ await self._session_exit_stack.enter_async_context(self._session)
92
+ await self._session.initialize()
93
+ return self
94
+
95
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
96
+ del exc_type, exc_val, exc_tb
97
+ if self._session is None:
98
+ return
99
+ assert self._session_exit_stack is not None
100
+ await self._session_exit_stack.aclose()
101
+ self._session = None
102
+
103
+ def list_tools(self) -> dict[str, Type[mcp_tool.McpTool]]:
104
+ """Lists all available tools on the MCP server synchronously.
105
+
106
+ Returns:
107
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
108
+ """
109
+ return async_support.invoke_sync(self.alist_tools)
110
+
111
+ async def alist_tools(self) -> dict[str, Type[mcp_tool.McpTool]]:
112
+ """Lists all available tools on the MCP server asynchronously.
113
+
114
+ Returns:
115
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
116
+ """
117
+ assert self._session is not None, 'MCP session is not entered.'
118
+ return {
119
+ t.name: mcp_tool.McpTool.make_class(t)
120
+ for t in (await self._session.list_tools()).tools
121
+ }
122
+
123
+ def call_tool(
124
+ self,
125
+ tool: mcp_tool.McpTool,
126
+ *,
127
+ returns_message: bool = False
128
+ ) -> Any:
129
+ """Calls an MCP tool synchronously.
130
+
131
+ Args:
132
+ tool: The `McpTool` instance to call.
133
+ returns_message: If True, the tool call will return an `mcp.Message`
134
+ object; otherwise, it returns the tool's direct result.
135
+
136
+ Returns:
137
+ The result of the tool call.
138
+ """
139
+ return tool(self, returns_message=returns_message)
140
+
141
+ async def acall_tool(
142
+ self,
143
+ tool: mcp_tool.McpTool,
144
+ *,
145
+ returns_message: bool = False
146
+ ) -> Any:
147
+ """Calls an MCP tool asynchronously.
148
+
149
+ Args:
150
+ tool: The `McpTool` instance to call.
151
+ returns_message: If True, the tool call will return an `mcp.Message`
152
+ object; otherwise, it returns the tool's direct result.
153
+
154
+ Returns:
155
+ The result of the tool call.
156
+ """
157
+ return await tool.acall(self, returns_message=returns_message)
158
+
159
+ @classmethod
160
+ def from_command(
161
+ cls,
162
+ command: str,
163
+ args: list[str] | None = None
164
+ ) -> 'McpSession':
165
+ """Creates an MCP session from a command-line executable.
166
+
167
+ Args:
168
+ command: The command to execute.
169
+ args: An optional list of arguments to pass to the command.
170
+
171
+ Returns:
172
+ An `McpSession` instance.
173
+ """
174
+ return cls(
175
+ mcp.stdio_client(
176
+ mcp.StdioServerParameters(command=command, args=args or [])
177
+ )
178
+ )
179
+
180
+ @classmethod
181
+ def from_url(
182
+ cls,
183
+ url: str,
184
+ headers: dict[str, str] | None = None
185
+ ) -> 'McpSession':
186
+ """Creates an MCP session from an HTTP URL.
187
+
188
+ The transport protocol (e.g., 'mcp' or 'sse') is inferred from the
189
+ last part of the URL path.
190
+
191
+ Args:
192
+ url: The URL of the MCP server.
193
+ headers: An optional dictionary of HTTP headers to include in requests.
194
+
195
+ Returns:
196
+ An `McpSession` instance.
197
+ """
198
+ transport = url.removesuffix('/').split('/')[-1].lower()
199
+ if transport == 'mcp':
200
+ return cls(streamable_http.streamablehttp_client(url, headers or {}))
201
+ elif transport == 'sse':
202
+ return cls(sse.sse_client(url, headers or {}))
203
+ else:
204
+ raise ValueError(f'Unsupported transport: {transport}')
205
+
206
+ @classmethod
207
+ def from_fastmcp(
208
+ cls,
209
+ fastmcp: fastmcp_lib.FastMCP
210
+ ):
211
+ """Creates an MCP session from an in-memory FastMCP instance.
212
+
213
+ Args:
214
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
215
+
216
+ Returns:
217
+ An `McpSession` instance.
218
+ """
219
+ return cls(_client_streams_from_fastmcp(fastmcp))
220
+
221
+
222
+ @contextlib.asynccontextmanager
223
+ async def _client_streams_from_fastmcp(fastmcp: fastmcp_lib.FastMCP):
224
+ """Creates client streams from an in-memory FastMCP instance."""
225
+ server = fastmcp._mcp_server # pylint: disable=protected-access
226
+ async with memory.create_client_server_memory_streams(
227
+ ) as (client_streams, server_streams):
228
+ client_read, client_write = client_streams
229
+ server_read, server_write = server_streams
230
+
231
+ # Create a cancel scope for the server task
232
+ async with anyio.create_task_group() as tg:
233
+ tg.start_soon(
234
+ lambda: server.run(
235
+ server_read,
236
+ server_write,
237
+ server.create_initialization_options(),
238
+ raise_exceptions=True,
239
+ )
240
+ )
241
+ yield client_read, client_write
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for MCP session."""
15
+
16
+ import unittest
17
+ from unittest import mock
18
+
19
+ from langfun.core.mcp import session as mcp_session
20
+ import mcp
21
+ from mcp.client import sse
22
+ from mcp.client import streamable_http
23
+
24
+
25
+ class McpSessionTest(unittest.TestCase):
26
+
27
+ @mock.patch.object(mcp, 'stdio_client', autospec=True)
28
+ def test_from_command(self, mock_stdio_client):
29
+ mcp_session.McpSession.from_command('my-command', ['--foo'])
30
+ mock_stdio_client.assert_called_once_with(
31
+ mcp.StdioServerParameters(command='my-command', args=['--foo'])
32
+ )
33
+
34
+ @mock.patch.object(streamable_http, 'streamablehttp_client', autospec=True)
35
+ def test_from_url_mcp(self, mock_streamablehttp_client):
36
+ mcp_session.McpSession.from_url(
37
+ 'http://localhost/mcp', headers={'k': 'v'}
38
+ )
39
+ mock_streamablehttp_client.assert_called_once_with(
40
+ 'http://localhost/mcp', {'k': 'v'}
41
+ )
42
+
43
+ @mock.patch.object(sse, 'sse_client', autospec=True)
44
+ def test_from_url_sse(self, mock_sse_client):
45
+ mcp_session.McpSession.from_url('http://localhost/sse', headers={'k': 'v'})
46
+ mock_sse_client.assert_called_once_with('http://localhost/sse', {'k': 'v'})
47
+
48
+ def test_from_url_unsupported(self):
49
+ with self.assertRaisesRegex(ValueError, 'Unsupported transport: foo'):
50
+ mcp_session.McpSession.from_url('http://localhost/foo')
51
+
52
+
53
+ if __name__ == '__main__':
54
+ unittest.main()
@@ -0,0 +1,33 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """A simple MCP client for testing."""
15
+
16
+ from absl import app
17
+ from absl import flags
18
+ from langfun.core import mcp
19
+
20
+
21
+ _URL = flags.DEFINE_string(
22
+ 'url',
23
+ 'http://localhost:8000/mcp',
24
+ 'URL of the MCP server.',
25
+ )
26
+
27
+
28
+ def main(_):
29
+ print(mcp.McpClient.from_url(url=_URL.value).list_tools())
30
+
31
+
32
+ if __name__ == '__main__':
33
+ app.run(main)
@@ -0,0 +1,33 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Simple MCP server for testing."""
15
+
16
+ from absl import app as absl_app
17
+ from mcp.server import fastmcp as fastmcp_lib
18
+
19
+ mcp = fastmcp_lib.FastMCP(host='0.0.0.0', port=8000)
20
+
21
+
22
+ @mcp.tool()
23
+ async def add(a: int, b: int) -> int:
24
+ """Adds two integers and returns their sum."""
25
+ return a + b
26
+
27
+
28
+ def main(_):
29
+ mcp.run(transport='streamable-http', mount_path='/mcp')
30
+
31
+
32
+ if __name__ == '__main__':
33
+ absl_app.run(main)
@@ -0,0 +1,254 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MCP tool."""
15
+
16
+ import base64
17
+ from typing import Annotated, Any, ClassVar
18
+
19
+ from langfun.core import async_support
20
+ from langfun.core import message as lf_message
21
+ from langfun.core import modalities as lf_modalities
22
+ from langfun.core.structured import schema as lf_schema
23
+ import mcp
24
+ import pyglove as pg
25
+
26
+
27
+ class _McpToolMeta(pg.symbolic.ObjectMeta):
28
+
29
+ def __repr__(self) -> str:
30
+ return f'<tool-class \'{self.__name__}\'>'
31
+
32
+
33
+ class McpTool(pg.Object, metaclass=_McpToolMeta):
34
+ """Represents a tool available on an MCP server.
35
+
36
+ `McpTool` is the base class for all tool proxies generated from an MCP
37
+ server's tool definitions. Users do not typically subclass `McpTool` directly.
38
+ Instead, tool classes are obtained by calling `lf.mcp.McpClient.list_tools()`
39
+ or `lf.mcp.McpSession.list_tools()`.
40
+
41
+ Once a tool class is obtained, it can be instantiated with input parameters
42
+ and called via an `McpSession` to execute the tool on the server.
43
+
44
+ **Example Usage:**
45
+
46
+ ```python
47
+ import langfun as lf
48
+
49
+ client = lf.mcp.McpClient.from_command(...)
50
+ with client.session() as session:
51
+ # List tools and get the 'math' tool class.
52
+ math_tool_cls = session.list_tools()['math']
53
+
54
+ # Instantiate the tool with parameters and call it.
55
+ result = math_tool_cls(x=1, y=2, op='+')(session)
56
+ print(result)
57
+ ```
58
+ """
59
+
60
+ TOOL_NAME: Annotated[
61
+ ClassVar[str],
62
+ 'Tool name.'
63
+ ]
64
+
65
+ @classmethod
66
+ def python_definition(cls, markdown: bool = True) -> str:
67
+ """Returns the Python definition of this tool's input schema.
68
+
69
+ This is useful for generating prompts that instruct a language model
70
+ on how to use the tool.
71
+
72
+ Args:
73
+ markdown: If True, formats the output as a Markdown code block.
74
+
75
+ Returns:
76
+ A string containing the Python definition of the tool's input schema.
77
+ """
78
+ return lf_schema.Schema.from_value(cls).schema_repr(markdown=markdown)
79
+
80
+ @classmethod
81
+ def result_to_message(
82
+ cls, result: mcp.types.CallToolResult
83
+ ) -> lf_message.ToolMessage:
84
+ """Converts an `mcp.types.CallToolResult` to an `lf.ToolMessage`.
85
+
86
+ This method translates results from the MCP protocol, including text,
87
+ image, and audio content, into a Langfun `ToolMessage`, making it easy
88
+ to integrate tool results into Langfun workflows.
89
+
90
+ Args:
91
+ result: The `mcp.types.CallToolResult` object to convert.
92
+
93
+ Returns:
94
+ An `lf.ToolMessage` instance representing the tool result.
95
+ """
96
+ chunks = []
97
+ for item in result.content:
98
+ if isinstance(item, mcp.types.TextContent):
99
+ chunk = item.text
100
+ elif isinstance(item, mcp.types.ImageContent):
101
+ chunk = lf_modalities.Image.from_bytes(_base64_decode(item.data))
102
+ elif isinstance(item, mcp.types.AudioContent):
103
+ chunk = lf_modalities.Audio.from_bytes(_base64_decode(item.data))
104
+ else:
105
+ raise ValueError(f'Unsupported item type: {type(item)}')
106
+ chunks.append(chunk)
107
+ message = lf_message.ToolMessage.from_chunks(chunks)
108
+ if result.structuredContent:
109
+ message.metadata.update(result.structuredContent)
110
+ return message
111
+
112
+ def __call__(
113
+ self,
114
+ session,
115
+ *,
116
+ returns_message: bool = False) -> Any:
117
+ """Calls the MCP tool synchronously within a given session.
118
+
119
+ Args:
120
+ session: An `McpSession` object.
121
+ returns_message: If True, the raw `lf.ToolMessage` is returned.
122
+ If False(default), the method attempts to return a more specific result:
123
+ - The `result` field of the `ToolMessage` if it's populated.
124
+ - The `ToolMessage` itself if it contains multi-modal content.
125
+ - The `text` field of the `ToolMessage` otherwise.
126
+
127
+ Returns:
128
+ The result of the tool call, processed according to `returns_message`.
129
+ """
130
+ return async_support.invoke_sync(
131
+ self.acall,
132
+ session,
133
+ returns_message=returns_message
134
+ )
135
+
136
+ async def acall(
137
+ self,
138
+ session,
139
+ *,
140
+ returns_message: bool = False
141
+ ) -> Any:
142
+ """Calls the MCP tool asynchronously within a given session.
143
+
144
+ Args:
145
+ session: An `McpSession` object or an `mcp.ClientSession`.
146
+ returns_message: If True, the raw `lf.ToolMessage` is returned.
147
+ If False(default), the method attempts to return a more specific result:
148
+ - The `result` field of the `ToolMessage` if it's populated.
149
+ - The `ToolMessage` itself if it contains multi-modal content.
150
+ - The `text` field of the `ToolMessage` otherwise.
151
+
152
+ Returns:
153
+ The result of the tool call, processed according to `returns_message`.
154
+ """
155
+ if not isinstance(session, mcp.ClientSession):
156
+ session = getattr(session, '_session', None)
157
+ assert session is not None, 'MCP session is not entered.'
158
+ tool_call_result = await session.call_tool(
159
+ self.TOOL_NAME, self.input_parameters()
160
+ )
161
+ message = self.result_to_message(tool_call_result)
162
+ if returns_message:
163
+ return message
164
+ if message.result:
165
+ return message.result
166
+ if message.referred_modalities:
167
+ return message
168
+ return message.text
169
+
170
+ def input_parameters(self) -> dict[str, Any]:
171
+ """Returns the input parameters for the tool call.
172
+
173
+ Returns:
174
+ A dictionary containing the input parameters, formatted for an
175
+ MCP `call_tool` request.
176
+ """
177
+ # Optional fields are represented as fields with default values. Therefore,
178
+ # we need to remove the default values from the JSON representation of the
179
+ # tool.
180
+ json = self.to_json(hide_default_values=True)
181
+
182
+ # Remove the type name key from the JSON representation of the tool.
183
+ def _transform(path: pg.KeyPath, x: Any) -> Any:
184
+ del path
185
+ if isinstance(x, dict):
186
+ x.pop(pg.JSONConvertible.TYPE_NAME_KEY, None)
187
+ return x
188
+ return pg.utils.transform(json, _transform)
189
+
190
+ @classmethod
191
+ def make_class(cls, tool_definition: mcp.Tool) -> type['McpTool']:
192
+ """Creates an `McpTool` subclass from an MCP tool definition.
193
+
194
+ Args:
195
+ tool_definition: An `mcp.Tool` object containing the tool's metadata.
196
+
197
+ Returns:
198
+ A dynamically generated class that inherits from `McpTool` and
199
+ represents the defined tool.
200
+ """
201
+
202
+ class _McpTool(cls):
203
+ auto_schema = False
204
+
205
+ tool_cls = _McpTool
206
+ tool_cls.TOOL_NAME = tool_definition.name
207
+ tool_cls.__name__ = _snake_to_camel(tool_definition.name)
208
+ tool_cls.__doc__ = tool_definition.description
209
+ schema = pg.Schema.from_json_schema(
210
+ tool_definition.inputSchema, class_fn=McpToolInput.make_class
211
+ )
212
+ tool_cls.apply_schema(schema)
213
+ return tool_cls
214
+
215
+
216
+ class _McpToolInputMeta(pg.symbolic.ObjectMeta):
217
+
218
+ def __repr__(self) -> str:
219
+ return f'<input-class \'{self.__name__}\'>'
220
+
221
+
222
+ class McpToolInput(pg.Object, metaclass=_McpToolInputMeta):
223
+ """Base class for generated MCP tool input schemas."""
224
+
225
+ @classmethod
226
+ def make_class(cls, name: str, schema: pg.Schema):
227
+ """Creates an `McpToolInput` subclass from a schema.
228
+
229
+ Args:
230
+ name: The name of the input class to generate.
231
+ schema: A `pg.Schema` object defining the input fields.
232
+
233
+ Returns:
234
+ A dynamically generated class that inherits from `McpToolInput`.
235
+ """
236
+
237
+ class _McpToolInput(cls):
238
+ pass
239
+
240
+ input_cls = _McpToolInput
241
+ input_cls.__name__ = _snake_to_camel(name)
242
+ input_cls.__doc__ = schema.description
243
+ input_cls.apply_schema(schema)
244
+ return input_cls
245
+
246
+
247
+ def _snake_to_camel(name: str) -> str:
248
+ """Converts a snake_case name to a CamelCase name."""
249
+ return ''.join(x.capitalize() for x in name.split('_'))
250
+
251
+
252
+ def _base64_decode(data: str) -> bytes:
253
+ """Decodes a base64 string."""
254
+ return base64.b64decode(data.encode('utf-8'))