langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (155) hide show
  1. langfun/core/__init__.py +2 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +447 -29
  4. langfun/core/agentic/action_eval.py +9 -2
  5. langfun/core/agentic/action_test.py +149 -21
  6. langfun/core/async_support.py +32 -3
  7. langfun/core/coding/python/correction.py +19 -9
  8. langfun/core/coding/python/execution.py +14 -12
  9. langfun/core/coding/python/generation.py +21 -16
  10. langfun/core/coding/python/sandboxing.py +23 -3
  11. langfun/core/component.py +42 -3
  12. langfun/core/concurrent.py +70 -6
  13. langfun/core/concurrent_test.py +1 -0
  14. langfun/core/console.py +1 -1
  15. langfun/core/data/conversion/anthropic.py +12 -3
  16. langfun/core/data/conversion/anthropic_test.py +8 -6
  17. langfun/core/data/conversion/gemini.py +9 -2
  18. langfun/core/data/conversion/gemini_test.py +12 -9
  19. langfun/core/data/conversion/openai.py +145 -31
  20. langfun/core/data/conversion/openai_test.py +161 -17
  21. langfun/core/eval/base.py +47 -43
  22. langfun/core/eval/base_test.py +5 -5
  23. langfun/core/eval/matching.py +5 -2
  24. langfun/core/eval/patching.py +3 -3
  25. langfun/core/eval/scoring.py +4 -3
  26. langfun/core/eval/v2/__init__.py +1 -0
  27. langfun/core/eval/v2/checkpointing.py +64 -6
  28. langfun/core/eval/v2/checkpointing_test.py +9 -2
  29. langfun/core/eval/v2/eval_test_helper.py +103 -2
  30. langfun/core/eval/v2/evaluation.py +91 -16
  31. langfun/core/eval/v2/evaluation_test.py +9 -3
  32. langfun/core/eval/v2/example.py +50 -40
  33. langfun/core/eval/v2/example_test.py +16 -8
  34. langfun/core/eval/v2/experiment.py +74 -8
  35. langfun/core/eval/v2/experiment_test.py +19 -0
  36. langfun/core/eval/v2/metric_values.py +31 -3
  37. langfun/core/eval/v2/metric_values_test.py +32 -0
  38. langfun/core/eval/v2/metrics.py +157 -44
  39. langfun/core/eval/v2/metrics_test.py +39 -18
  40. langfun/core/eval/v2/progress.py +30 -1
  41. langfun/core/eval/v2/progress_test.py +27 -0
  42. langfun/core/eval/v2/progress_tracking.py +12 -3
  43. langfun/core/eval/v2/progress_tracking_test.py +6 -1
  44. langfun/core/eval/v2/reporting.py +90 -71
  45. langfun/core/eval/v2/reporting_test.py +24 -6
  46. langfun/core/eval/v2/runners/__init__.py +30 -0
  47. langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
  48. langfun/core/eval/v2/runners/beam.py +341 -0
  49. langfun/core/eval/v2/runners/beam_test.py +131 -0
  50. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  51. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  52. langfun/core/eval/v2/runners/debug.py +40 -0
  53. langfun/core/eval/v2/runners/debug_test.py +76 -0
  54. langfun/core/eval/v2/runners/parallel.py +100 -0
  55. langfun/core/eval/v2/runners/parallel_test.py +95 -0
  56. langfun/core/eval/v2/runners/sequential.py +47 -0
  57. langfun/core/eval/v2/runners/sequential_test.py +172 -0
  58. langfun/core/langfunc.py +45 -130
  59. langfun/core/langfunc_test.py +7 -5
  60. langfun/core/language_model.py +141 -21
  61. langfun/core/language_model_test.py +54 -3
  62. langfun/core/llms/__init__.py +9 -1
  63. langfun/core/llms/anthropic.py +157 -2
  64. langfun/core/llms/azure_openai.py +29 -17
  65. langfun/core/llms/cache/base.py +25 -3
  66. langfun/core/llms/cache/in_memory.py +48 -7
  67. langfun/core/llms/cache/in_memory_test.py +14 -4
  68. langfun/core/llms/compositional.py +25 -1
  69. langfun/core/llms/deepseek.py +30 -2
  70. langfun/core/llms/fake.py +32 -1
  71. langfun/core/llms/gemini.py +55 -17
  72. langfun/core/llms/gemini_test.py +84 -0
  73. langfun/core/llms/google_genai.py +34 -1
  74. langfun/core/llms/groq.py +28 -3
  75. langfun/core/llms/llama_cpp.py +23 -4
  76. langfun/core/llms/openai.py +36 -3
  77. langfun/core/llms/openai_compatible.py +148 -27
  78. langfun/core/llms/openai_compatible_test.py +207 -20
  79. langfun/core/llms/openai_test.py +0 -2
  80. langfun/core/llms/rest.py +12 -1
  81. langfun/core/llms/vertexai.py +58 -8
  82. langfun/core/logging.py +1 -1
  83. langfun/core/mcp/client.py +77 -22
  84. langfun/core/mcp/client_test.py +8 -35
  85. langfun/core/mcp/session.py +94 -29
  86. langfun/core/mcp/session_test.py +54 -0
  87. langfun/core/mcp/tool.py +151 -22
  88. langfun/core/mcp/tool_test.py +197 -0
  89. langfun/core/memory.py +1 -0
  90. langfun/core/message.py +160 -55
  91. langfun/core/message_test.py +65 -81
  92. langfun/core/modalities/__init__.py +8 -0
  93. langfun/core/modalities/audio.py +21 -1
  94. langfun/core/modalities/image.py +19 -1
  95. langfun/core/modalities/mime.py +64 -3
  96. langfun/core/modalities/mime_test.py +11 -0
  97. langfun/core/modalities/pdf.py +19 -1
  98. langfun/core/modalities/video.py +21 -1
  99. langfun/core/modality.py +167 -29
  100. langfun/core/modality_test.py +42 -12
  101. langfun/core/natural_language.py +1 -1
  102. langfun/core/sampling.py +4 -4
  103. langfun/core/sampling_test.py +20 -4
  104. langfun/core/structured/__init__.py +2 -24
  105. langfun/core/structured/completion.py +34 -44
  106. langfun/core/structured/completion_test.py +23 -43
  107. langfun/core/structured/description.py +54 -50
  108. langfun/core/structured/function_generation.py +29 -12
  109. langfun/core/structured/mapping.py +81 -37
  110. langfun/core/structured/parsing.py +95 -79
  111. langfun/core/structured/parsing_test.py +0 -3
  112. langfun/core/structured/querying.py +215 -142
  113. langfun/core/structured/querying_test.py +65 -29
  114. langfun/core/structured/schema/__init__.py +49 -0
  115. langfun/core/structured/schema/base.py +664 -0
  116. langfun/core/structured/schema/base_test.py +531 -0
  117. langfun/core/structured/schema/json.py +174 -0
  118. langfun/core/structured/schema/json_test.py +121 -0
  119. langfun/core/structured/schema/python.py +316 -0
  120. langfun/core/structured/schema/python_test.py +410 -0
  121. langfun/core/structured/schema_generation.py +33 -14
  122. langfun/core/structured/scoring.py +47 -36
  123. langfun/core/structured/tokenization.py +26 -11
  124. langfun/core/subscription.py +2 -2
  125. langfun/core/template.py +174 -49
  126. langfun/core/template_test.py +123 -17
  127. langfun/env/__init__.py +8 -2
  128. langfun/env/base_environment.py +320 -128
  129. langfun/env/base_environment_test.py +473 -0
  130. langfun/env/base_feature.py +92 -15
  131. langfun/env/base_feature_test.py +228 -0
  132. langfun/env/base_sandbox.py +84 -361
  133. langfun/env/base_sandbox_test.py +1235 -0
  134. langfun/env/event_handlers/__init__.py +1 -1
  135. langfun/env/event_handlers/chain.py +233 -0
  136. langfun/env/event_handlers/chain_test.py +253 -0
  137. langfun/env/event_handlers/event_logger.py +95 -98
  138. langfun/env/event_handlers/event_logger_test.py +21 -21
  139. langfun/env/event_handlers/metric_writer.py +225 -140
  140. langfun/env/event_handlers/metric_writer_test.py +23 -6
  141. langfun/env/interface.py +854 -40
  142. langfun/env/interface_test.py +112 -2
  143. langfun/env/load_balancers_test.py +23 -2
  144. langfun/env/test_utils.py +126 -84
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  146. langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
  147. langfun/core/eval/v2/runners_test.py +0 -343
  148. langfun/core/structured/schema.py +0 -987
  149. langfun/core/structured/schema_test.py +0 -982
  150. langfun/env/base_test.py +0 -1481
  151. langfun/env/event_handlers/base.py +0 -350
  152. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  153. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  154. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  155. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -44,13 +44,32 @@ except ImportError:
44
44
 
45
45
  @pg.use_init_args(['api_endpoint'])
46
46
  class VertexAI(rest.REST):
47
- """Base class for VertexAI models.
47
+ """Base class for models served on Vertex AI.
48
48
 
49
- This class handles the authentication of vertex AI models. Subclasses
50
- should implement `request` and `result` methods, as well as the `api_endpoint`
51
- property. Or let users to provide them as __init__ arguments.
49
+ This class handles authentication for Vertex AI models. Subclasses,
50
+ such as `VertexAIGemini`, `VertexAIAnthropic`, and `VertexAILlama`,
51
+ provide specific implementations for different model families hosted
52
+ on Vertex AI.
52
53
 
53
- Please check out VertexAIGemini in `gemini.py` as an example.
54
+ **Quick Start:**
55
+
56
+ If you are using Langfun from a Google Cloud environment (e.g., GCE, GKE)
57
+ that has service account credentials, authentication is handled automatically.
58
+ Otherwise, you might need to set up credentials:
59
+
60
+ ```bash
61
+ gcloud auth application-default login
62
+ ```
63
+
64
+ Then you can use a Vertex AI model:
65
+
66
+ ```python
67
+ import langfun as lf
68
+
69
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
70
+ r = lm('Who are you?')
71
+ print(r)
72
+ ```
54
73
  """
55
74
 
56
75
  model: pg.typing.Annotated[
@@ -158,7 +177,21 @@ class VertexAI(rest.REST):
158
177
  @pg.use_init_args(['model'])
159
178
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
160
179
  class VertexAIGemini(VertexAI, gemini.Gemini):
161
- """Gemini models served by Vertex AI.."""
180
+ """Gemini models served on Vertex AI.
181
+
182
+ **Quick Start:**
183
+
184
+ ```python
185
+ import langfun as lf
186
+
187
+ # Call Gemini 1.5 Flash on Vertex AI.
188
+ # If project and location are not specified, they will be read from
189
+ # environment variables 'VERTEXAI_PROJECT' and 'VERTEXAI_LOCATION'.
190
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
191
+ r = lm('Who are you?')
192
+ print(r)
193
+ ```
194
+ """
162
195
 
163
196
  # Set default location to us-central1.
164
197
  location = 'us-central1'
@@ -180,6 +213,13 @@ class VertexAIGemini(VertexAI, gemini.Gemini):
180
213
  #
181
214
  # Production models.
182
215
  #
216
+ class VertexAIGemini3ProPreview(VertexAIGemini): # pylint: disable=invalid-name
217
+ """Gemini 3 Pro Preview model launched on 11/18/2025."""
218
+
219
+ model = 'gemini-3-pro-preview'
220
+ location = 'global'
221
+
222
+
183
223
  class VertexAIGemini25Pro(VertexAIGemini): # pylint: disable=invalid-name
184
224
  """Gemini 2.5 Pro GA model launched on 06/17/2025."""
185
225
 
@@ -369,6 +409,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
369
409
  # pylint: disable=invalid-name
370
410
 
371
411
 
412
+ class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
413
+ """Anthropic's Claude 4.5 Haiku model on VertexAI."""
414
+ model = 'claude-haiku-4-5@20251001'
415
+
416
+
417
+ class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
418
+ """Anthropic's Claude 4.5 Sonnet model on VertexAI."""
419
+ model = 'claude-sonnet-4-5@20250929'
420
+
421
+
372
422
  class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
373
423
  """Anthropic's Claude 4 Opus model on VertexAI."""
374
424
  model = 'claude-opus-4@20250514'
@@ -487,7 +537,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
487
537
 
488
538
  @pg.use_init_args(['model'])
489
539
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
490
- class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
540
+ class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
491
541
  """Llama models on VertexAI."""
492
542
 
493
543
  model: pg.typing.Annotated[
@@ -600,7 +650,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
600
650
 
601
651
  @pg.use_init_args(['model'])
602
652
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
603
- class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
653
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
604
654
  """Mistral AI models on VertexAI."""
605
655
 
606
656
  model: pg.typing.Annotated[
langfun/core/logging.py CHANGED
@@ -310,7 +310,7 @@ def warning(
310
310
  console: bool = False,
311
311
  **kwargs
312
312
  ) -> LogEntry:
313
- """Logs an info message to the session."""
313
+ """Logs a warning message to the session."""
314
314
  return log('warning', message, indent=indent, console=console, **kwargs)
315
315
 
316
316
 
@@ -23,33 +23,53 @@ import pyglove as pg
23
23
 
24
24
 
25
25
  class McpClient(pg.Object):
26
- """Base class for MCP client.
26
+ """Interface for Model Context Protocol (MCP) client.
27
27
 
28
- Usage:
28
+ An MCP client serves as a bridge to an MCP server, enabling users to interact
29
+ with tools hosted on the server. It provides methods for listing available
30
+ tools and creating sessions for tool interaction.
31
+
32
+ There are three types of MCP clients:
33
+
34
+ * **Stdio-based client**: Ideal for interacting with tools exposed as
35
+ command-line executables through stdin/stdout.
36
+ Created by `lf.mcp.McpClient.from_command`.
37
+ * **HTTP-based client**: Designed for tools accessible via HTTP,
38
+ supporting Server-Sent Events (SSE) for streaming.
39
+ Created by `lf.mcp.McpClient.from_url`.
40
+ * **In-memory client**: Useful for testing or embedding MCP servers
41
+ within the same process.
42
+ Created by `lf.mcp.McpClient.from_fastmcp`.
43
+
44
+ **Example Usage:**
29
45
 
30
46
  ```python
47
+ import langfun as lf
31
48
 
32
- def tool_use():
33
- client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
34
- tools = client.list_tools()
35
- tool_cls = tools['<TOOL_NAME>']
49
+ # Example 1: Stdio-based client
50
+ client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
51
+ tools = client.list_tools()
52
+ tool_cls = tools['<TOOL_NAME>']
36
53
 
37
- # Print the python definition of the tool.
38
- print(tool_cls.python_definition())
54
+ # Print the Python definition of the tool.
55
+ print(tool_cls.python_definition())
39
56
 
40
- with client.session() as session:
41
- return tool_cls(x=1, y=2)(session)
57
+ with client.session() as session:
58
+ result = tool_cls(x=1, y=2)(session)
59
+ print(result)
42
60
 
43
- async def tool_use_async_version():
61
+ # Example 2: HTTP-based client (async)
62
+ async def main():
44
63
  client = lf.mcp.McpClient.from_url('http://localhost:8000/mcp')
45
64
  tools = client.list_tools()
46
65
  tool_cls = tools['<TOOL_NAME>']
47
66
 
48
- # Print the python definition of the tool.
67
+ # Print the Python definition of the tool.
49
68
  print(tool_cls.python_definition())
50
69
 
51
70
  async with client.session() as session:
52
- return await tool_cls(x=1, y=2).acall(session)
71
+ result = await tool_cls(x=1, y=2).acall(session)
72
+ print(result)
53
73
  ```
54
74
  """
55
75
 
@@ -60,7 +80,15 @@ class McpClient(pg.Object):
60
80
  def list_tools(
61
81
  self, refresh: bool = False
62
82
  ) -> dict[str, Type[mcp_tool.McpTool]]:
63
- """Lists all MCP tools."""
83
+ """Lists all available tools on the MCP server.
84
+
85
+ Args:
86
+ refresh: If True, forces a refresh of the tool list from the server.
87
+ Otherwise, a cached list may be returned.
88
+
89
+ Returns:
90
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
91
+ """
64
92
  if self._tools is None or refresh:
65
93
  with self.session() as session:
66
94
  self._tools = session.list_tools()
@@ -68,11 +96,23 @@ class McpClient(pg.Object):
68
96
 
69
97
  @abc.abstractmethod
70
98
  def session(self) -> mcp_session.McpSession:
71
- """Creates a MCP session."""
99
+ """Creates a new session for interacting with MCP tools.
100
+
101
+ Returns:
102
+ An `McpSession` object.
103
+ """
72
104
 
73
105
  @classmethod
74
106
  def from_command(cls, command: str, args: list[str]) -> 'McpClient':
75
- """Creates a MCP client from a tool."""
107
+ """Creates an MCP client from a command-line executable.
108
+
109
+ Args:
110
+ command: The command to execute.
111
+ args: A list of arguments to pass to the command.
112
+
113
+ Returns:
114
+ A `McpClient` instance that communicates via stdin/stdout.
115
+ """
76
116
  return _StdioMcpClient(command=command, args=args)
77
117
 
78
118
  @classmethod
@@ -81,12 +121,27 @@ class McpClient(pg.Object):
81
121
  url: str,
82
122
  headers: dict[str, str] | None = None
83
123
  ) -> 'McpClient':
84
- """Creates a MCP client from a URL."""
124
+ """Creates an MCP client from an HTTP URL.
125
+
126
+ Args:
127
+ url: The URL of the MCP server.
128
+ headers: An optional dictionary of HTTP headers to include in requests.
129
+
130
+ Returns:
131
+ A `McpClient` instance that communicates via HTTP.
132
+ """
85
133
  return _HttpMcpClient(url=url, headers=headers or {})
86
134
 
87
135
  @classmethod
88
136
  def from_fastmcp(cls, fastmcp: fastmcp_lib.FastMCP) -> 'McpClient':
89
- """Creates a MCP client from a MCP server."""
137
+ """Creates an MCP client from an in-memory FastMCP instance.
138
+
139
+ Args:
140
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
141
+
142
+ Returns:
143
+ A `McpClient` instance that communicates with the in-memory server.
144
+ """
90
145
  return _InMemoryFastMcpClient(fastmcp=fastmcp)
91
146
 
92
147
 
@@ -97,18 +152,18 @@ class _StdioMcpClient(McpClient):
97
152
  args: Annotated[list[str], 'Arguments to pass to the command.']
98
153
 
99
154
  def session(self) -> mcp_session.McpSession:
100
- """Creates a MCP session."""
155
+ """Creates an McpSession from command."""
101
156
  return mcp_session.McpSession.from_command(self.command, self.args)
102
157
 
103
158
 
104
159
  class _HttpMcpClient(McpClient):
105
- """Server-Sent Events (SSE)/Streamable HTTP-based MCP client."""
160
+ """HTTP-based MCP client."""
106
161
 
107
162
  url: Annotated[str, 'URL to connect to.']
108
163
  headers: Annotated[dict[str, str], 'Headers to send with the request.'] = {}
109
164
 
110
165
  def session(self) -> mcp_session.McpSession:
111
- """Creates a MCP session."""
166
+ """Creates an McpSession from URL."""
112
167
  return mcp_session.McpSession.from_url(self.url, self.headers)
113
168
 
114
169
 
@@ -118,5 +173,5 @@ class _InMemoryFastMcpClient(McpClient):
118
173
  fastmcp: Annotated[fastmcp_lib.FastMCP, 'MCP server to connect to.']
119
174
 
120
175
  def session(self) -> mcp_session.McpSession:
121
- """Creates a MCP session."""
176
+ """Creates an McpSession from an in-memory FastMCP instance."""
122
177
  return mcp_session.McpSession.from_fastmcp(self.fastmcp)
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
  """Tests for MCP client."""
15
15
 
16
- import inspect
17
16
  import unittest
18
17
  from langfun.core import async_support
19
18
  from langfun.core import mcp as lf_mcp
19
+ from langfun.core import message as lf_message
20
20
  from mcp.server import fastmcp as fastmcp_lib
21
21
 
22
22
  mcp = fastmcp_lib.FastMCP(host='0.0.0.0', port=1234)
@@ -42,38 +42,11 @@ class McpTest(unittest.TestCase):
42
42
  client = lf_mcp.McpClient.from_fastmcp(mcp)
43
43
  tools = client.list_tools()
44
44
  self.assertEqual(len(tools), 1)
45
- tool_cls = tools['add']
46
- print(tool_cls.python_definition())
47
- self.assertEqual(
48
- tool_cls.python_definition(),
49
- inspect.cleandoc(
50
- '''
51
- Add
52
-
53
- ```python
54
- class Add:
55
- """Adds two integers and returns their sum.
56
-
57
- Args:
58
- a: The first integer.
59
- b: The second integer.
60
-
61
- Returns:
62
- The sum of the two integers.
63
- """
64
- a: int
65
- b: int
66
- ```
67
- '''
68
- )
69
- )
70
- self.assertEqual(repr(tool_cls), '<tool-class \'Add\'>')
71
- self.assertEqual(tool_cls.__name__, 'Add')
72
- self.assertEqual(tool_cls.TOOL_NAME, 'add')
73
- self.assertEqual(tool_cls(a=1, b=2).input_parameters(), {'a': 1, 'b': 2})
74
45
  with client.session() as session:
75
46
  self.assertEqual(
76
- tool_cls(a=1, b=2)(session), 3
47
+ # Test `session.call_tool` method as `tool.__call__` is already tested
48
+ # in `tool_test.py`.
49
+ session.call_tool(tools['add'](a=1, b=2)), 3
77
50
  )
78
51
 
79
52
  def test_async_usages(self):
@@ -86,10 +59,10 @@ class McpTest(unittest.TestCase):
86
59
  self.assertEqual(tool_cls.TOOL_NAME, 'add')
87
60
  async with client.session() as session:
88
61
  self.assertEqual(
89
- (await tool_cls(a=1, b=2).acall(
90
- session, returns_call_result=True))
91
- .structuredContent['result'],
92
- 3
62
+ # Test `session.acall_tool` method as `tool.acall` is already
63
+ # tested in `tool_test.py`.
64
+ await session.acall_tool(tool_cls(a=1, b=2), returns_message=True),
65
+ lf_message.ToolMessage(text='3', result=3)
93
66
  )
94
67
  async_support.invoke_sync(_test)
95
68
 
@@ -26,10 +26,37 @@ from mcp.shared import memory
26
26
 
27
27
 
28
28
  class McpSession:
29
- """Langfun's MCP session.
29
+ """Represents a session for interacting with an MCP server.
30
30
 
31
- Compared to the standard mcp.ClientSession, Langfun's MCP session could be
32
- used both synchronously and asynchronously.
31
+ `McpSession` provides the context for making calls to tools hosted on an
32
+ MCP server. It wraps the standard `mcp.ClientSession` to offer both
33
+ synchronous and asynchronous usage patterns.
34
+
35
+ Sessions are created using `lf.mcp.McpClient.session()` and should be used
36
+ as context managers (either sync or async) to ensure proper initialization
37
+ and teardown of the connection to the server.
38
+
39
+ **Example Sync Usage:**
40
+
41
+ ```python
42
+ import langfun as lf
43
+
44
+ client = lf.mcp.McpClient.from_command(...)
45
+ with client.session() as session:
46
+ tools = session.list_tools()
47
+ result = tools['my_tool'](x=1)(session)
48
+ ```
49
+
50
+ **Example Async Usage:**
51
+
52
+ ```python
53
+ import langfun as lf
54
+
55
+ client = lf.mcp.McpClient.from_url(...)
56
+ async with client.session() as session:
57
+ tools = await session.alist_tools()
58
+ result = await tools['my_tool'](x=1).acall(session)
59
+ ```
33
60
  """
34
61
 
35
62
  def __init__(self, stream) -> None:
@@ -74,11 +101,19 @@ class McpSession:
74
101
  self._session = None
75
102
 
76
103
  def list_tools(self) -> dict[str, Type[mcp_tool.McpTool]]:
77
- """Lists all MCP tools synchronously."""
104
+ """Lists all available tools on the MCP server synchronously.
105
+
106
+ Returns:
107
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
108
+ """
78
109
  return async_support.invoke_sync(self.alist_tools)
79
110
 
80
111
  async def alist_tools(self) -> dict[str, Type[mcp_tool.McpTool]]:
81
- """Lists all MCP tools asynchronously."""
112
+ """Lists all available tools on the MCP server asynchronously.
113
+
114
+ Returns:
115
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
116
+ """
82
117
  assert self._session is not None, 'MCP session is not entered.'
83
118
  return {
84
119
  t.name: mcp_tool.McpTool.make_class(t)
@@ -89,34 +124,37 @@ class McpSession:
89
124
  self,
90
125
  tool: mcp_tool.McpTool,
91
126
  *,
92
- returns_call_result: bool = False
127
+ returns_message: bool = False
93
128
  ) -> Any:
94
- """Calls a MCP tool synchronously."""
95
- return async_support.invoke_sync(
96
- self.acall_tool,
97
- tool,
98
- returns_call_result=returns_call_result
99
- )
129
+ """Calls an MCP tool synchronously.
130
+
131
+ Args:
132
+ tool: The `McpTool` instance to call.
133
+ returns_message: If True, the tool call will return an `mcp.Message`
134
+ object; otherwise, it returns the tool's direct result.
135
+
136
+ Returns:
137
+ The result of the tool call.
138
+ """
139
+ return tool(self, returns_message=returns_message)
100
140
 
101
141
  async def acall_tool(
102
142
  self,
103
143
  tool: mcp_tool.McpTool,
104
144
  *,
105
- returns_call_result: bool = False
145
+ returns_message: bool = False
106
146
  ) -> Any:
107
- """Calls a MCP tool asynchronously."""
108
- assert self._session is not None, 'MCP session is not entered.'
109
- tool_call_result = await self._session.call_tool(
110
- tool.TOOL_NAME, tool.input_parameters()
111
- )
112
- if returns_call_result:
113
- return tool_call_result
114
- if (
115
- tool_call_result.structuredContent
116
- and 'result' in tool_call_result.structuredContent
117
- ):
118
- return tool_call_result.structuredContent['result']
119
- return tool_call_result.content
147
+ """Calls an MCP tool asynchronously.
148
+
149
+ Args:
150
+ tool: The `McpTool` instance to call.
151
+ returns_message: If True, the tool call will return an `mcp.Message`
152
+ object; otherwise, it returns the tool's direct result.
153
+
154
+ Returns:
155
+ The result of the tool call.
156
+ """
157
+ return await tool.acall(self, returns_message=returns_message)
120
158
 
121
159
  @classmethod
122
160
  def from_command(
@@ -124,7 +162,15 @@ class McpSession:
124
162
  command: str,
125
163
  args: list[str] | None = None
126
164
  ) -> 'McpSession':
127
- """Creates a MCP session from a command."""
165
+ """Creates an MCP session from a command-line executable.
166
+
167
+ Args:
168
+ command: The command to execute.
169
+ args: An optional list of arguments to pass to the command.
170
+
171
+ Returns:
172
+ An `McpSession` instance.
173
+ """
128
174
  return cls(
129
175
  mcp.stdio_client(
130
176
  mcp.StdioServerParameters(command=command, args=args or [])
@@ -137,7 +183,18 @@ class McpSession:
137
183
  url: str,
138
184
  headers: dict[str, str] | None = None
139
185
  ) -> 'McpSession':
140
- """Creates a MCP session from a URL."""
186
+ """Creates an MCP session from an HTTP URL.
187
+
188
+ The transport protocol (e.g., 'mcp' or 'sse') is inferred from the
189
+ last part of the URL path.
190
+
191
+ Args:
192
+ url: The URL of the MCP server.
193
+ headers: An optional dictionary of HTTP headers to include in requests.
194
+
195
+ Returns:
196
+ An `McpSession` instance.
197
+ """
141
198
  transport = url.removesuffix('/').split('/')[-1].lower()
142
199
  if transport == 'mcp':
143
200
  return cls(streamable_http.streamablehttp_client(url, headers or {}))
@@ -151,12 +208,20 @@ class McpSession:
151
208
  cls,
152
209
  fastmcp: fastmcp_lib.FastMCP
153
210
  ):
211
+ """Creates an MCP session from an in-memory FastMCP instance.
212
+
213
+ Args:
214
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
215
+
216
+ Returns:
217
+ An `McpSession` instance.
218
+ """
154
219
  return cls(_client_streams_from_fastmcp(fastmcp))
155
220
 
156
221
 
157
222
  @contextlib.asynccontextmanager
158
223
  async def _client_streams_from_fastmcp(fastmcp: fastmcp_lib.FastMCP):
159
- """Creates client streams from a MCP server."""
224
+ """Creates client streams from an in-memory FastMCP instance."""
160
225
  server = fastmcp._mcp_server # pylint: disable=protected-access
161
226
  async with memory.create_client_server_memory_streams(
162
227
  ) as (client_streams, server_streams):
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for MCP session."""
15
+
16
+ import unittest
17
+ from unittest import mock
18
+
19
+ from langfun.core.mcp import session as mcp_session
20
+ import mcp
21
+ from mcp.client import sse
22
+ from mcp.client import streamable_http
23
+
24
+
25
+ class McpSessionTest(unittest.TestCase):
26
+
27
+ @mock.patch.object(mcp, 'stdio_client', autospec=True)
28
+ def test_from_command(self, mock_stdio_client):
29
+ mcp_session.McpSession.from_command('my-command', ['--foo'])
30
+ mock_stdio_client.assert_called_once_with(
31
+ mcp.StdioServerParameters(command='my-command', args=['--foo'])
32
+ )
33
+
34
+ @mock.patch.object(streamable_http, 'streamablehttp_client', autospec=True)
35
+ def test_from_url_mcp(self, mock_streamablehttp_client):
36
+ mcp_session.McpSession.from_url(
37
+ 'http://localhost/mcp', headers={'k': 'v'}
38
+ )
39
+ mock_streamablehttp_client.assert_called_once_with(
40
+ 'http://localhost/mcp', {'k': 'v'}
41
+ )
42
+
43
+ @mock.patch.object(sse, 'sse_client', autospec=True)
44
+ def test_from_url_sse(self, mock_sse_client):
45
+ mcp_session.McpSession.from_url('http://localhost/sse', headers={'k': 'v'})
46
+ mock_sse_client.assert_called_once_with('http://localhost/sse', {'k': 'v'})
47
+
48
+ def test_from_url_unsupported(self):
49
+ with self.assertRaisesRegex(ValueError, 'Unsupported transport: foo'):
50
+ mcp_session.McpSession.from_url('http://localhost/foo')
51
+
52
+
53
+ if __name__ == '__main__':
54
+ unittest.main()