langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -44,13 +44,32 @@ except ImportError:
44
44
 
45
45
  @pg.use_init_args(['api_endpoint'])
46
46
  class VertexAI(rest.REST):
47
- """Base class for VertexAI models.
47
+ """Base class for models served on Vertex AI.
48
48
 
49
- This class handles the authentication of vertex AI models. Subclasses
50
- should implement `request` and `result` methods, as well as the `api_endpoint`
51
- property. Or let users to provide them as __init__ arguments.
49
+ This class handles authentication for Vertex AI models. Subclasses,
50
+ such as `VertexAIGemini`, `VertexAIAnthropic`, and `VertexAILlama`,
51
+ provide specific implementations for different model families hosted
52
+ on Vertex AI.
52
53
 
53
- Please check out VertexAIGemini in `gemini.py` as an example.
54
+ **Quick Start:**
55
+
56
+ If you are using Langfun from a Google Cloud environment (e.g., GCE, GKE)
57
+ that has service account credentials, authentication is handled automatically.
58
+ Otherwise, you might need to set up credentials:
59
+
60
+ ```bash
61
+ gcloud auth application-default login
62
+ ```
63
+
64
+ Then you can use a Vertex AI model:
65
+
66
+ ```python
67
+ import langfun as lf
68
+
69
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
70
+ r = lm('Who are you?')
71
+ print(r)
72
+ ```
54
73
  """
55
74
 
56
75
  model: pg.typing.Annotated[
@@ -158,7 +177,21 @@ class VertexAI(rest.REST):
158
177
  @pg.use_init_args(['model'])
159
178
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
160
179
  class VertexAIGemini(VertexAI, gemini.Gemini):
161
- """Gemini models served by Vertex AI.."""
180
+ """Gemini models served on Vertex AI.
181
+
182
+ **Quick Start:**
183
+
184
+ ```python
185
+ import langfun as lf
186
+
187
+ # Call Gemini 1.5 Flash on Vertex AI.
188
+ # If project and location are not specified, they will be read from
189
+ # environment variables 'VERTEXAI_PROJECT' and 'VERTEXAI_LOCATION'.
190
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
191
+ r = lm('Who are you?')
192
+ print(r)
193
+ ```
194
+ """
162
195
 
163
196
  # Set default location to us-central1.
164
197
  location = 'us-central1'
@@ -180,6 +213,33 @@ class VertexAIGemini(VertexAI, gemini.Gemini):
180
213
  #
181
214
  # Production models.
182
215
  #
216
+ class VertexAIGemini3ProPreview(VertexAIGemini): # pylint: disable=invalid-name
217
+ """Gemini 3 Pro Preview model launched on 11/18/2025."""
218
+
219
+ model = 'gemini-3-pro-preview'
220
+ location = 'global'
221
+
222
+
223
+ class VertexAIGemini3ProImagePreview(VertexAIGemini): # pylint: disable=invalid-name
224
+ """Gemini 3 Pro Image Preview model for high-fidelity image generation.
225
+
226
+ This model supports:
227
+ - Text-to-image generation
228
+ - Image editing (multimodal input)
229
+ - Visual reasoning
230
+
231
+ Key Requirements:
232
+ - Uses v1beta1 API endpoint
233
+ - responseModalities must include 'IMAGE'
234
+ - Supported aspect ratios: 1:1, 16:9, 9:16, 4:3, 3:4
235
+ - Image sizes: 1K (default), 2K, 4K
236
+ """
237
+
238
+ model = 'gemini-3-pro-image-preview'
239
+ location = 'global'
240
+ response_modalities = ['TEXT', 'IMAGE']
241
+
242
+
183
243
  class VertexAIGemini25Pro(VertexAIGemini): # pylint: disable=invalid-name
184
244
  """Gemini 2.5 Pro GA model launched on 06/17/2025."""
185
245
 
@@ -369,6 +429,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
369
429
  # pylint: disable=invalid-name
370
430
 
371
431
 
432
+ class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
433
+ """Anthropic's Claude 4.5 Haiku model on VertexAI."""
434
+ model = 'claude-haiku-4-5@20251001'
435
+
436
+
437
+ class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
438
+ """Anthropic's Claude 4.5 Sonnet model on VertexAI."""
439
+ model = 'claude-sonnet-4-5@20250929'
440
+
441
+
372
442
  class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
373
443
  """Anthropic's Claude 4 Opus model on VertexAI."""
374
444
  model = 'claude-opus-4@20250514'
@@ -487,7 +557,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
487
557
 
488
558
  @pg.use_init_args(['model'])
489
559
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
490
- class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
560
+ class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
491
561
  """Llama models on VertexAI."""
492
562
 
493
563
  model: pg.typing.Annotated[
@@ -600,7 +670,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
600
670
 
601
671
  @pg.use_init_args(['model'])
602
672
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
603
- class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
673
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
604
674
  """Mistral AI models on VertexAI."""
605
675
 
606
676
  model: pg.typing.Annotated[
langfun/core/logging.py CHANGED
@@ -310,7 +310,7 @@ def warning(
310
310
  console: bool = False,
311
311
  **kwargs
312
312
  ) -> LogEntry:
313
- """Logs an info message to the session."""
313
+ """Logs a warning message to the session."""
314
314
  return log('warning', message, indent=indent, console=console, **kwargs)
315
315
 
316
316
 
@@ -0,0 +1,10 @@
1
+ """Langfun MCP support."""
2
+
3
+ # pylint: disable=g-importing-member
4
+
5
+ from langfun.core.mcp.client import McpClient
6
+ from langfun.core.mcp.session import McpSession
7
+ from langfun.core.mcp.tool import McpTool
8
+ from langfun.core.mcp.tool import McpToolInput
9
+
10
+ # pylint: enable=g-importing-member
@@ -0,0 +1,177 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MCP client."""
15
+
16
+ import abc
17
+ from typing import Annotated, Type
18
+
19
+ from langfun.core.mcp import session as mcp_session
20
+ from langfun.core.mcp import tool as mcp_tool
21
+ from mcp.server import fastmcp as fastmcp_lib
22
+ import pyglove as pg
23
+
24
+
25
+ class McpClient(pg.Object):
26
+ """Interface for Model Context Protocol (MCP) client.
27
+
28
+ An MCP client serves as a bridge to an MCP server, enabling users to interact
29
+ with tools hosted on the server. It provides methods for listing available
30
+ tools and creating sessions for tool interaction.
31
+
32
+ There are three types of MCP clients:
33
+
34
+ * **Stdio-based client**: Ideal for interacting with tools exposed as
35
+ command-line executables through stdin/stdout.
36
+ Created by `lf.mcp.McpClient.from_command`.
37
+ * **HTTP-based client**: Designed for tools accessible via HTTP,
38
+ supporting Server-Sent Events (SSE) for streaming.
39
+ Created by `lf.mcp.McpClient.from_url`.
40
+ * **In-memory client**: Useful for testing or embedding MCP servers
41
+ within the same process.
42
+ Created by `lf.mcp.McpClient.from_fastmcp`.
43
+
44
+ **Example Usage:**
45
+
46
+ ```python
47
+ import langfun as lf
48
+
49
+ # Example 1: Stdio-based client
50
+ client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
51
+ tools = client.list_tools()
52
+ tool_cls = tools['<TOOL_NAME>']
53
+
54
+ # Print the Python definition of the tool.
55
+ print(tool_cls.python_definition())
56
+
57
+ with client.session() as session:
58
+ result = tool_cls(x=1, y=2)(session)
59
+ print(result)
60
+
61
+ # Example 2: HTTP-based client (async)
62
+ async def main():
63
+ client = lf.mcp.McpClient.from_url('http://localhost:8000/mcp')
64
+ tools = client.list_tools()
65
+ tool_cls = tools['<TOOL_NAME>']
66
+
67
+ # Print the Python definition of the tool.
68
+ print(tool_cls.python_definition())
69
+
70
+ async with client.session() as session:
71
+ result = await tool_cls(x=1, y=2).acall(session)
72
+ print(result)
73
+ ```
74
+ """
75
+
76
+ def _on_bound(self):
77
+ super()._on_bound()
78
+ self._tools = None
79
+
80
+ def list_tools(
81
+ self, refresh: bool = False
82
+ ) -> dict[str, Type[mcp_tool.McpTool]]:
83
+ """Lists all available tools on the MCP server.
84
+
85
+ Args:
86
+ refresh: If True, forces a refresh of the tool list from the server.
87
+ Otherwise, a cached list may be returned.
88
+
89
+ Returns:
90
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
91
+ """
92
+ if self._tools is None or refresh:
93
+ with self.session() as session:
94
+ self._tools = session.list_tools()
95
+ return self._tools
96
+
97
+ @abc.abstractmethod
98
+ def session(self) -> mcp_session.McpSession:
99
+ """Creates a new session for interacting with MCP tools.
100
+
101
+ Returns:
102
+ An `McpSession` object.
103
+ """
104
+
105
+ @classmethod
106
+ def from_command(cls, command: str, args: list[str]) -> 'McpClient':
107
+ """Creates an MCP client from a command-line executable.
108
+
109
+ Args:
110
+ command: The command to execute.
111
+ args: A list of arguments to pass to the command.
112
+
113
+ Returns:
114
+ A `McpClient` instance that communicates via stdin/stdout.
115
+ """
116
+ return _StdioMcpClient(command=command, args=args)
117
+
118
+ @classmethod
119
+ def from_url(
120
+ cls,
121
+ url: str,
122
+ headers: dict[str, str] | None = None
123
+ ) -> 'McpClient':
124
+ """Creates an MCP client from an HTTP URL.
125
+
126
+ Args:
127
+ url: The URL of the MCP server.
128
+ headers: An optional dictionary of HTTP headers to include in requests.
129
+
130
+ Returns:
131
+ A `McpClient` instance that communicates via HTTP.
132
+ """
133
+ return _HttpMcpClient(url=url, headers=headers or {})
134
+
135
+ @classmethod
136
+ def from_fastmcp(cls, fastmcp: fastmcp_lib.FastMCP) -> 'McpClient':
137
+ """Creates an MCP client from an in-memory FastMCP instance.
138
+
139
+ Args:
140
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
141
+
142
+ Returns:
143
+ A `McpClient` instance that communicates with the in-memory server.
144
+ """
145
+ return _InMemoryFastMcpClient(fastmcp=fastmcp)
146
+
147
+
148
+ class _StdioMcpClient(McpClient):
149
+ """Stdio-based MCP client."""
150
+
151
+ command: Annotated[str, 'Command to execute.']
152
+ args: Annotated[list[str], 'Arguments to pass to the command.']
153
+
154
+ def session(self) -> mcp_session.McpSession:
155
+ """Creates an McpSession from command."""
156
+ return mcp_session.McpSession.from_command(self.command, self.args)
157
+
158
+
159
+ class _HttpMcpClient(McpClient):
160
+ """HTTP-based MCP client."""
161
+
162
+ url: Annotated[str, 'URL to connect to.']
163
+ headers: Annotated[dict[str, str], 'Headers to send with the request.'] = {}
164
+
165
+ def session(self) -> mcp_session.McpSession:
166
+ """Creates an McpSession from URL."""
167
+ return mcp_session.McpSession.from_url(self.url, self.headers)
168
+
169
+
170
+ class _InMemoryFastMcpClient(McpClient):
171
+ """In-memory MCP client."""
172
+
173
+ fastmcp: Annotated[fastmcp_lib.FastMCP, 'MCP server to connect to.']
174
+
175
+ def session(self) -> mcp_session.McpSession:
176
+ """Creates an McpSession from an in-memory FastMCP instance."""
177
+ return mcp_session.McpSession.from_fastmcp(self.fastmcp)
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for MCP client."""
15
+
16
+ import unittest
17
+ from langfun.core import async_support
18
+ from langfun.core import mcp as lf_mcp
19
+ from langfun.core import message as lf_message
20
+ from mcp.server import fastmcp as fastmcp_lib
21
+
22
+ mcp = fastmcp_lib.FastMCP(host='0.0.0.0', port=1234)
23
+
24
+
25
+ @mcp.tool()
26
+ async def add(a: int, b: int) -> int:
27
+ """Adds two integers and returns their sum.
28
+
29
+ Args:
30
+ a: The first integer.
31
+ b: The second integer.
32
+
33
+ Returns:
34
+ The sum of the two integers.
35
+ """
36
+ return a + b
37
+
38
+
39
+ class McpTest(unittest.TestCase):
40
+
41
+ def test_sync_usages(self):
42
+ client = lf_mcp.McpClient.from_fastmcp(mcp)
43
+ tools = client.list_tools()
44
+ self.assertEqual(len(tools), 1)
45
+ with client.session() as session:
46
+ self.assertEqual(
47
+ # Test `session.call_tool` method as `tool.__call__` is already tested
48
+ # in `tool_test.py`.
49
+ session.call_tool(tools['add'](a=1, b=2)), 3
50
+ )
51
+
52
+ def test_async_usages(self):
53
+ async def _test():
54
+ client = lf_mcp.McpClient.from_fastmcp(mcp)
55
+ tools = client.list_tools()
56
+ self.assertEqual(len(tools), 1)
57
+ tool_cls = tools['add']
58
+ self.assertEqual(tool_cls.__name__, 'Add')
59
+ self.assertEqual(tool_cls.TOOL_NAME, 'add')
60
+ async with client.session() as session:
61
+ self.assertEqual(
62
+ # Test `session.acall_tool` method as `tool.acall` is already
63
+ # tested in `tool_test.py`.
64
+ await session.acall_tool(tool_cls(a=1, b=2), returns_message=True),
65
+ lf_message.ToolMessage(text='3', result=3)
66
+ )
67
+ async_support.invoke_sync(_test)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ unittest.main()
@@ -0,0 +1,241 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MCP session."""
15
+
16
+ import contextlib
17
+ from typing import Any, Type
18
+ import anyio
19
+ from langfun.core import async_support
20
+ from langfun.core.mcp import tool as mcp_tool
21
+ import mcp
22
+ from mcp.client import sse
23
+ from mcp.client import streamable_http
24
+ from mcp.server import fastmcp as fastmcp_lib
25
+ from mcp.shared import memory
26
+
27
+
28
+ class McpSession:
29
+ """Represents a session for interacting with an MCP server.
30
+
31
+ `McpSession` provides the context for making calls to tools hosted on an
32
+ MCP server. It wraps the standard `mcp.ClientSession` to offer both
33
+ synchronous and asynchronous usage patterns.
34
+
35
+ Sessions are created using `lf.mcp.McpClient.session()` and should be used
36
+ as context managers (either sync or async) to ensure proper initialization
37
+ and teardown of the connection to the server.
38
+
39
+ **Example Sync Usage:**
40
+
41
+ ```python
42
+ import langfun as lf
43
+
44
+ client = lf.mcp.McpClient.from_command(...)
45
+ with client.session() as session:
46
+ tools = session.list_tools()
47
+ result = tools['my_tool'](x=1)(session)
48
+ ```
49
+
50
+ **Example Async Usage:**
51
+
52
+ ```python
53
+ import langfun as lf
54
+
55
+ client = lf.mcp.McpClient.from_url(...)
56
+ async with client.session() as session:
57
+ tools = await session.alist_tools()
58
+ result = await tools['my_tool'](x=1).acall(session)
59
+ ```
60
+ """
61
+
62
+ def __init__(self, stream) -> None:
63
+ self._stream = stream
64
+ self._session = None
65
+ self._session_exit_stack = None
66
+ self._in_session = False
67
+
68
+ # For supporting sync context manager.
69
+ self._sync_context_manager_exit_stack = None
70
+
71
+ def __enter__(self) -> 'McpSession':
72
+ exit_stack = contextlib.ExitStack()
73
+ exit_stack.enter_context(async_support.sync_context_manager(self))
74
+ self._sync_context_manager_exit_stack = exit_stack
75
+ return self
76
+
77
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
78
+ assert self._sync_context_manager_exit_stack is not None
79
+ self._sync_context_manager_exit_stack.close()
80
+
81
+ async def __aenter__(self) -> 'McpSession':
82
+ assert self._session_exit_stack is None, 'Session cannot be re-entered.'
83
+
84
+ self._session_exit_stack = contextlib.AsyncExitStack()
85
+ stream_output = await self._session_exit_stack.enter_async_context(
86
+ self._stream
87
+ )
88
+ assert isinstance(stream_output, tuple) and len(stream_output) in (2, 3)
89
+ read, write = stream_output[:2]
90
+ self._session = mcp.ClientSession(read, write)
91
+ await self._session_exit_stack.enter_async_context(self._session)
92
+ await self._session.initialize()
93
+ return self
94
+
95
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
96
+ del exc_type, exc_val, exc_tb
97
+ if self._session is None:
98
+ return
99
+ assert self._session_exit_stack is not None
100
+ await self._session_exit_stack.aclose()
101
+ self._session = None
102
+
103
+ def list_tools(self) -> dict[str, Type[mcp_tool.McpTool]]:
104
+ """Lists all available tools on the MCP server synchronously.
105
+
106
+ Returns:
107
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
108
+ """
109
+ return async_support.invoke_sync(self.alist_tools)
110
+
111
+ async def alist_tools(self) -> dict[str, Type[mcp_tool.McpTool]]:
112
+ """Lists all available tools on the MCP server asynchronously.
113
+
114
+ Returns:
115
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
116
+ """
117
+ assert self._session is not None, 'MCP session is not entered.'
118
+ return {
119
+ t.name: mcp_tool.McpTool.make_class(t)
120
+ for t in (await self._session.list_tools()).tools
121
+ }
122
+
123
+ def call_tool(
124
+ self,
125
+ tool: mcp_tool.McpTool,
126
+ *,
127
+ returns_message: bool = False
128
+ ) -> Any:
129
+ """Calls an MCP tool synchronously.
130
+
131
+ Args:
132
+ tool: The `McpTool` instance to call.
133
+ returns_message: If True, the tool call will return an `mcp.Message`
134
+ object; otherwise, it returns the tool's direct result.
135
+
136
+ Returns:
137
+ The result of the tool call.
138
+ """
139
+ return tool(self, returns_message=returns_message)
140
+
141
+ async def acall_tool(
142
+ self,
143
+ tool: mcp_tool.McpTool,
144
+ *,
145
+ returns_message: bool = False
146
+ ) -> Any:
147
+ """Calls an MCP tool asynchronously.
148
+
149
+ Args:
150
+ tool: The `McpTool` instance to call.
151
+ returns_message: If True, the tool call will return an `mcp.Message`
152
+ object; otherwise, it returns the tool's direct result.
153
+
154
+ Returns:
155
+ The result of the tool call.
156
+ """
157
+ return await tool.acall(self, returns_message=returns_message)
158
+
159
+ @classmethod
160
+ def from_command(
161
+ cls,
162
+ command: str,
163
+ args: list[str] | None = None
164
+ ) -> 'McpSession':
165
+ """Creates an MCP session from a command-line executable.
166
+
167
+ Args:
168
+ command: The command to execute.
169
+ args: An optional list of arguments to pass to the command.
170
+
171
+ Returns:
172
+ An `McpSession` instance.
173
+ """
174
+ return cls(
175
+ mcp.stdio_client(
176
+ mcp.StdioServerParameters(command=command, args=args or [])
177
+ )
178
+ )
179
+
180
+ @classmethod
181
+ def from_url(
182
+ cls,
183
+ url: str,
184
+ headers: dict[str, str] | None = None
185
+ ) -> 'McpSession':
186
+ """Creates an MCP session from an HTTP URL.
187
+
188
+ The transport protocol (e.g., 'mcp' or 'sse') is inferred from the
189
+ last part of the URL path.
190
+
191
+ Args:
192
+ url: The URL of the MCP server.
193
+ headers: An optional dictionary of HTTP headers to include in requests.
194
+
195
+ Returns:
196
+ An `McpSession` instance.
197
+ """
198
+ transport = url.removesuffix('/').split('/')[-1].lower()
199
+ if transport == 'mcp':
200
+ return cls(streamable_http.streamablehttp_client(url, headers or {}))
201
+ elif transport == 'sse':
202
+ return cls(sse.sse_client(url, headers or {}))
203
+ else:
204
+ raise ValueError(f'Unsupported transport: {transport}')
205
+
206
+ @classmethod
207
+ def from_fastmcp(
208
+ cls,
209
+ fastmcp: fastmcp_lib.FastMCP
210
+ ):
211
+ """Creates an MCP session from an in-memory FastMCP instance.
212
+
213
+ Args:
214
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
215
+
216
+ Returns:
217
+ An `McpSession` instance.
218
+ """
219
+ return cls(_client_streams_from_fastmcp(fastmcp))
220
+
221
+
222
+ @contextlib.asynccontextmanager
223
+ async def _client_streams_from_fastmcp(fastmcp: fastmcp_lib.FastMCP):
224
+ """Creates client streams from an in-memory FastMCP instance."""
225
+ server = fastmcp._mcp_server # pylint: disable=protected-access
226
+ async with memory.create_client_server_memory_streams(
227
+ ) as (client_streams, server_streams):
228
+ client_read, client_write = client_streams
229
+ server_read, server_write = server_streams
230
+
231
+ # Create a cancel scope for the server task
232
+ async with anyio.create_task_group() as tg:
233
+ tg.start_soon(
234
+ lambda: server.run(
235
+ server_read,
236
+ server_write,
237
+ server.create_initialization_options(),
238
+ raise_exceptions=True,
239
+ )
240
+ )
241
+ yield client_read, client_write
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for MCP session."""
15
+
16
+ import unittest
17
+ from unittest import mock
18
+
19
+ from langfun.core.mcp import session as mcp_session
20
+ import mcp
21
+ from mcp.client import sse
22
+ from mcp.client import streamable_http
23
+
24
+
25
+ class McpSessionTest(unittest.TestCase):
26
+
27
+ @mock.patch.object(mcp, 'stdio_client', autospec=True)
28
+ def test_from_command(self, mock_stdio_client):
29
+ mcp_session.McpSession.from_command('my-command', ['--foo'])
30
+ mock_stdio_client.assert_called_once_with(
31
+ mcp.StdioServerParameters(command='my-command', args=['--foo'])
32
+ )
33
+
34
+ @mock.patch.object(streamable_http, 'streamablehttp_client', autospec=True)
35
+ def test_from_url_mcp(self, mock_streamablehttp_client):
36
+ mcp_session.McpSession.from_url(
37
+ 'http://localhost/mcp', headers={'k': 'v'}
38
+ )
39
+ mock_streamablehttp_client.assert_called_once_with(
40
+ 'http://localhost/mcp', {'k': 'v'}
41
+ )
42
+
43
+ @mock.patch.object(sse, 'sse_client', autospec=True)
44
+ def test_from_url_sse(self, mock_sse_client):
45
+ mcp_session.McpSession.from_url('http://localhost/sse', headers={'k': 'v'})
46
+ mock_sse_client.assert_called_once_with('http://localhost/sse', {'k': 'v'})
47
+
48
+ def test_from_url_unsupported(self):
49
+ with self.assertRaisesRegex(ValueError, 'Unsupported transport: foo'):
50
+ mcp_session.McpSession.from_url('http://localhost/foo')
51
+
52
+
53
+ if __name__ == '__main__':
54
+ unittest.main()