langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
38
38
  response_format = ''
39
39
 
40
40
  choices = []
41
- for k in range(json['n']):
41
+ for k in range(json.get('n', 1)):
42
42
  if json.get('logprobs'):
43
43
  logprobs = dict(
44
44
  content=[
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
89
89
  c['image_url']['url']
90
90
  for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
91
  ]
92
- for k in range(json['n']):
92
+ for k in range(json.get('n', 1)):
93
93
  choices.append(pg.Dict(
94
94
  message=pg.Dict(
95
95
  content=f'Sample {k} for message: {"".join(urls)}'
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
111
111
  return response
112
112
 
113
113
 
114
- class OpenAIComptibleTest(unittest.TestCase):
114
+ def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
115
+ del url, kwargs
116
+ _ = json['input']
117
+
118
+ system_message = ''
119
+ if 'instructions' in json:
120
+ system_message = f' system={json["instructions"]}'
121
+
122
+ response_format = ''
123
+ if 'text' in json and 'format' in json['text']:
124
+ response_format = f' format={json["text"]["format"]["type"]}'
125
+
126
+ output = [
127
+ dict(
128
+ type='message',
129
+ content=[
130
+ dict(
131
+ type='output_text',
132
+ text=(
133
+ f'Sample 0 for message.{system_message}{response_format}'
134
+ )
135
+ )
136
+ ],
137
+ )
138
+ ]
139
+
140
+ response = requests.Response()
141
+ response.status_code = 200
142
+ response._content = pg.to_json_str(
143
+ dict(
144
+ output=output,
145
+ usage=dict(
146
+ input_tokens=100,
147
+ output_tokens=100,
148
+ total_tokens=200,
149
+ ),
150
+ )
151
+ ).encode()
152
+ return response
153
+
154
+
155
+ def mock_responses_request_vision(
156
+ url: str, json: dict[str, Any], **kwargs
157
+ ):
158
+ del url, kwargs
159
+ urls = [
160
+ c['image_url']
161
+ for c in json['input'][0]['content']
162
+ if c['type'] == 'input_image'
163
+ ]
164
+ output = [
165
+ pg.Dict(
166
+ type='message',
167
+ content=[
168
+ pg.Dict(
169
+ type='output_text',
170
+ text=f'Sample 0 for message: {"".join(urls)}',
171
+ )
172
+ ],
173
+ )
174
+ ]
175
+ response = requests.Response()
176
+ response.status_code = 200
177
+ response._content = pg.to_json_str(
178
+ dict(
179
+ output=output,
180
+ usage=dict(
181
+ input_tokens=100,
182
+ output_tokens=100,
183
+ total_tokens=200,
184
+ ),
185
+ )
186
+ ).encode()
187
+ return response
188
+
189
+
190
+ class OpenAIChatCompletionAPITest(unittest.TestCase):
115
191
  """Tests for OpenAI compatible language model."""
116
192
 
117
193
  def test_request_args(self):
118
194
  self.assertEqual(
119
- openai_compatible.OpenAICompatible(
195
+ openai_compatible.OpenAIChatCompletionAPI(
120
196
  api_endpoint='https://test-server',
121
197
  model='test-model'
122
198
  )._request_args(
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
126
202
  ),
127
203
  dict(
128
204
  model='test-model',
129
- top_logprobs=None,
130
- n=1,
131
205
  temperature=1.0,
132
206
  stop=['\n'],
133
207
  seed=123,
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
137
211
  def test_call_chat_completion(self):
138
212
  with mock.patch('requests.Session.post') as mock_request:
139
213
  mock_request.side_effect = mock_chat_completion_request
140
- lm = openai_compatible.OpenAICompatible(
214
+ lm = openai_compatible.OpenAIChatCompletionAPI(
141
215
  api_endpoint='https://test-server', model='test-model',
142
216
  )
143
217
  self.assertEqual(
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
148
222
  def test_call_chat_completion_with_logprobs(self):
149
223
  with mock.patch('requests.Session.post') as mock_request:
150
224
  mock_request.side_effect = mock_chat_completion_request
151
- lm = openai_compatible.OpenAICompatible(
225
+ lm = openai_compatible.OpenAIChatCompletionAPI(
152
226
  api_endpoint='https://test-server', model='test-model',
153
227
  )
154
228
  results = lm.sample(['hello'], logprobs=True)
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
214
288
  def mime_type(self) -> str:
215
289
  return 'image/png'
216
290
 
291
+ image = FakeImage.from_uri('https://fake/image')
217
292
  with mock.patch('requests.Session.post') as mock_request:
218
293
  mock_request.side_effect = mock_chat_completion_request_vision
219
- lm_1 = openai_compatible.OpenAICompatible(
294
+ lm_1 = openai_compatible.OpenAIChatCompletionAPI(
220
295
  api_endpoint='https://test-server',
221
296
  model='test-model1',
222
297
  )
223
- lm_2 = openai_compatible.OpenAICompatible(
298
+ lm_2 = openai_compatible.OpenAIChatCompletionAPI(
224
299
  api_endpoint='https://test-server',
225
300
  model='test-model2',
226
301
  )
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
228
303
  self.assertEqual(
229
304
  lm(
230
305
  lf.UserMessage(
231
- 'hello <<[[image]]>>',
232
- image=FakeImage.from_uri('https://fake/image')
306
+ f'hello <<[[{image.id}]]>>',
307
+ referred_modalities=[image],
233
308
  ),
234
309
  sampling_options=lf.LMSamplingOptions(n=2)
235
310
  ),
236
311
  'Sample 0 for message: https://fake/image',
237
312
  )
238
313
 
239
- class TextOnlyModel(openai_compatible.OpenAICompatible):
314
+ class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
240
315
 
241
316
  class ModelInfo(lf.ModelInfo):
242
317
  input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
251
326
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
252
327
  lm_3(
253
328
  lf.UserMessage(
254
- 'hello <<[[image]]>>',
255
- image=FakeImage.from_uri('https://fake/image')
329
+ f'hello <<[[{image.id}]]>>',
330
+ referred_modalities=[image],
256
331
  ),
257
332
  )
258
333
 
259
334
  def test_sample_chat_completion(self):
260
335
  with mock.patch('requests.Session.post') as mock_request:
261
336
  mock_request.side_effect = mock_chat_completion_request
262
- lm = openai_compatible.OpenAICompatible(
337
+ lm = openai_compatible.OpenAIChatCompletionAPI(
263
338
  api_endpoint='https://test-server', model='test-model'
264
339
  )
265
340
  results = lm.sample(
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
400
475
  def test_sample_with_contextual_options(self):
401
476
  with mock.patch('requests.Session.post') as mock_request:
402
477
  mock_request.side_effect = mock_chat_completion_request
403
- lm = openai_compatible.OpenAICompatible(
478
+ lm = openai_compatible.OpenAIChatCompletionAPI(
404
479
  api_endpoint='https://test-server', model='test-model'
405
480
  )
406
481
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
458
533
  def test_call_with_system_message(self):
459
534
  with mock.patch('requests.Session.post') as mock_request:
460
535
  mock_request.side_effect = mock_chat_completion_request
461
- lm = openai_compatible.OpenAICompatible(
536
+ lm = openai_compatible.OpenAIChatCompletionAPI(
462
537
  api_endpoint='https://test-server', model='test-model'
463
538
  )
464
539
  self.assertEqual(
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
475
550
  def test_call_with_json_schema(self):
476
551
  with mock.patch('requests.Session.post') as mock_request:
477
552
  mock_request.side_effect = mock_chat_completion_request
478
- lm = openai_compatible.OpenAICompatible(
553
+ lm = openai_compatible.OpenAIChatCompletionAPI(
479
554
  api_endpoint='https://test-server', model='test-model'
480
555
  )
481
556
  self.assertEqual(
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
515
590
 
516
591
  with mock.patch('requests.Session.post') as mock_request:
517
592
  mock_request.side_effect = mock_context_limit_error
518
- lm = openai_compatible.OpenAICompatible(
593
+ lm = openai_compatible.OpenAIChatCompletionAPI(
519
594
  api_endpoint='https://test-server', model='test-model'
520
595
  )
521
596
  with self.assertRaisesRegex(
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
524
599
  lm(lf.UserMessage('hello'))
525
600
 
526
601
 
602
+ class OpenAIResponsesAPITest(unittest.TestCase):
603
+ """Tests for OpenAI compatible language model on Responses API."""
604
+
605
+ def test_request_args(self):
606
+ lm = openai_compatible.OpenAIResponsesAPI(
607
+ api_endpoint='https://test-server', model='test-model'
608
+ )
609
+ # Test valid args.
610
+ self.assertEqual(
611
+ lm._request_args(
612
+ lf.LMSamplingOptions(
613
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
614
+ )
615
+ ),
616
+ dict(
617
+ model='test-model',
618
+ temperature=1.0,
619
+ stop=['\n'],
620
+ seed=123,
621
+ ),
622
+ )
623
+ # Test unsupported n.
624
+ with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
625
+ lm._request_args(lf.LMSamplingOptions(n=2))
626
+
627
+ # Test unsupported logprobs.
628
+ with self.assertRaisesRegex(
629
+ ValueError, 'logprobs is not supported on Responses API.'
630
+ ):
631
+ lm._request_args(lf.LMSamplingOptions(logprobs=True))
632
+
633
+ def test_call_responses(self):
634
+ with mock.patch('requests.Session.post') as mock_request:
635
+ mock_request.side_effect = mock_responses_request
636
+ lm = openai_compatible.OpenAIResponsesAPI(
637
+ api_endpoint='https://test-server',
638
+ model='test-model',
639
+ )
640
+ self.assertEqual(lm('hello'), 'Sample 0 for message.')
641
+
642
+ def test_call_responses_vision(self):
643
+ class FakeImage(lf_modalities.Image):
644
+ @property
645
+ def mime_type(self) -> str:
646
+ return 'image/png'
647
+
648
+ image = FakeImage.from_uri('https://fake/image')
649
+ with mock.patch('requests.Session.post') as mock_request:
650
+ mock_request.side_effect = mock_responses_request_vision
651
+ lm = openai_compatible.OpenAIResponsesAPI(
652
+ api_endpoint='https://test-server',
653
+ model='test-model1',
654
+ )
655
+ self.assertEqual(
656
+ lm(
657
+ lf.UserMessage(
658
+ f'hello <<[[{image.id}]]>>',
659
+ referred_modalities=[image],
660
+ )
661
+ ),
662
+ 'Sample 0 for message: https://fake/image',
663
+ )
664
+
665
+ def test_call_with_system_message(self):
666
+ with mock.patch('requests.Session.post') as mock_request:
667
+ mock_request.side_effect = mock_responses_request
668
+ lm = openai_compatible.OpenAIResponsesAPI(
669
+ api_endpoint='https://test-server', model='test-model'
670
+ )
671
+ self.assertEqual(
672
+ lm(
673
+ lf.UserMessage(
674
+ 'hello',
675
+ system_message=lf.SystemMessage('hi'),
676
+ )
677
+ ),
678
+ 'Sample 0 for message. system=hi',
679
+ )
680
+
681
+ def test_call_with_json_schema(self):
682
+ with mock.patch('requests.Session.post') as mock_request:
683
+ mock_request.side_effect = mock_responses_request
684
+ lm = openai_compatible.OpenAIResponsesAPI(
685
+ api_endpoint='https://test-server', model='test-model'
686
+ )
687
+ self.assertEqual(
688
+ lm(
689
+ lf.UserMessage(
690
+ 'hello',
691
+ json_schema={
692
+ 'type': 'object',
693
+ 'properties': {
694
+ 'name': {'type': 'string'},
695
+ },
696
+ 'required': ['name'],
697
+ 'title': 'Person',
698
+ },
699
+ )
700
+ ),
701
+ 'Sample 0 for message. format=json_schema',
702
+ )
703
+
704
+ # Test bad json schema.
705
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
706
+ lm(lf.UserMessage('hello', json_schema='foo'))
707
+
708
+ with self.assertRaisesRegex(
709
+ ValueError, 'The root of `json_schema` must have a `title` field'
710
+ ):
711
+ lm(lf.UserMessage('hello', json_schema={}))
712
+
713
+
527
714
  if __name__ == '__main__':
528
715
  unittest.main()
@@ -61,8 +61,6 @@ class OpenAITest(unittest.TestCase):
61
61
  ),
62
62
  dict(
63
63
  model='gpt-4',
64
- top_logprobs=None,
65
- n=1,
66
64
  temperature=1.0,
67
65
  stop=['\n'],
68
66
  seed=123,
langfun/core/llms/rest.py CHANGED
@@ -22,7 +22,18 @@ import requests
22
22
 
23
23
 
24
24
  class REST(lf.LanguageModel):
25
- """REST-based language model."""
25
+ """Base class for language models accessed via REST APIs.
26
+
27
+ The `REST` class provides a foundation for implementing language models
28
+ that are accessed through RESTful endpoints. It handles the details of
29
+ making HTTP requests, managing sessions, and handling common errors like
30
+ timeouts and connection issues.
31
+
32
+ Subclasses need to implement the `request` and `result` methods to
33
+ convert Langfun messages to API-specific request formats and to parse
34
+ API responses back into `LMSamplingResult` objects. They also need to
35
+ provide the `api_endpoint` and can override `headers` for authentication.
36
+ """
26
37
 
27
38
  api_endpoint: Annotated[
28
39
  str,
@@ -98,7 +109,9 @@ class REST(lf.LanguageModel):
98
109
  raise lf.TemporaryLMError(str(e)) from e
99
110
  except (
100
111
  requests.exceptions.ConnectionError,
112
+ requests.exceptions.ChunkedEncodingError,
101
113
  ConnectionError,
114
+ ConnectionResetError,
102
115
  ) as e:
103
116
  error_message = str(e)
104
117
  if 'REJECTED_CLIENT_THROTTLED' in error_message:
@@ -107,6 +120,8 @@ class REST(lf.LanguageModel):
107
120
  raise lf.TemporaryLMError(error_message) from e
108
121
  if 'UNREACHABLE_ERROR' in error_message:
109
122
  raise lf.TemporaryLMError(error_message) from e
123
+ if 'Connection reset by peer' in error_message:
124
+ raise lf.TemporaryLMError(error_message) from e
110
125
  raise lf.LMError(error_message) from e
111
126
 
112
127
  def _error(self, status_code: int, content: str) -> lf.LMError:
@@ -44,13 +44,32 @@ except ImportError:
44
44
 
45
45
  @pg.use_init_args(['api_endpoint'])
46
46
  class VertexAI(rest.REST):
47
- """Base class for VertexAI models.
47
+ """Base class for models served on Vertex AI.
48
48
 
49
- This class handles the authentication of vertex AI models. Subclasses
50
- should implement `request` and `result` methods, as well as the `api_endpoint`
51
- property. Or let users to provide them as __init__ arguments.
49
+ This class handles authentication for Vertex AI models. Subclasses,
50
+ such as `VertexAIGemini`, `VertexAIAnthropic`, and `VertexAILlama`,
51
+ provide specific implementations for different model families hosted
52
+ on Vertex AI.
52
53
 
53
- Please check out VertexAIGemini in `gemini.py` as an example.
54
+ **Quick Start:**
55
+
56
+ If you are using Langfun from a Google Cloud environment (e.g., GCE, GKE)
57
+ that has service account credentials, authentication is handled automatically.
58
+ Otherwise, you might need to set up credentials:
59
+
60
+ ```bash
61
+ gcloud auth application-default login
62
+ ```
63
+
64
+ Then you can use a Vertex AI model:
65
+
66
+ ```python
67
+ import langfun as lf
68
+
69
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
70
+ r = lm('Who are you?')
71
+ print(r)
72
+ ```
54
73
  """
55
74
 
56
75
  model: pg.typing.Annotated[
@@ -158,7 +177,21 @@ class VertexAI(rest.REST):
158
177
  @pg.use_init_args(['model'])
159
178
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
160
179
  class VertexAIGemini(VertexAI, gemini.Gemini):
161
- """Gemini models served by Vertex AI.."""
180
+ """Gemini models served on Vertex AI.
181
+
182
+ **Quick Start:**
183
+
184
+ ```python
185
+ import langfun as lf
186
+
187
+ # Call Gemini 1.5 Flash on Vertex AI.
188
+ # If project and location are not specified, they will be read from
189
+ # environment variables 'VERTEXAI_PROJECT' and 'VERTEXAI_LOCATION'.
190
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
191
+ r = lm('Who are you?')
192
+ print(r)
193
+ ```
194
+ """
162
195
 
163
196
  # Set default location to us-central1.
164
197
  location = 'us-central1'
@@ -180,6 +213,13 @@ class VertexAIGemini(VertexAI, gemini.Gemini):
180
213
  #
181
214
  # Production models.
182
215
  #
216
+ class VertexAIGemini3ProPreview(VertexAIGemini): # pylint: disable=invalid-name
217
+ """Gemini 3 Pro Preview model launched on 11/18/2025."""
218
+
219
+ model = 'gemini-3-pro-preview'
220
+ location = 'global'
221
+
222
+
183
223
  class VertexAIGemini25Pro(VertexAIGemini): # pylint: disable=invalid-name
184
224
  """Gemini 2.5 Pro GA model launched on 06/17/2025."""
185
225
 
@@ -369,6 +409,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
369
409
  # pylint: disable=invalid-name
370
410
 
371
411
 
412
+ class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
413
+ """Anthropic's Claude 4.5 Haiku model on VertexAI."""
414
+ model = 'claude-haiku-4-5@20251001'
415
+
416
+
417
+ class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
418
+ """Anthropic's Claude 4.5 Sonnet model on VertexAI."""
419
+ model = 'claude-sonnet-4-5@20250929'
420
+
421
+
372
422
  class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
373
423
  """Anthropic's Claude 4 Opus model on VertexAI."""
374
424
  model = 'claude-opus-4@20250514'
@@ -487,7 +537,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
487
537
 
488
538
  @pg.use_init_args(['model'])
489
539
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
490
- class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
540
+ class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
491
541
  """Llama models on VertexAI."""
492
542
 
493
543
  model: pg.typing.Annotated[
@@ -600,7 +650,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
600
650
 
601
651
  @pg.use_init_args(['model'])
602
652
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
603
- class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
653
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
604
654
  """Mistral AI models on VertexAI."""
605
655
 
606
656
  model: pg.typing.Annotated[
langfun/core/logging.py CHANGED
@@ -310,7 +310,7 @@ def warning(
310
310
  console: bool = False,
311
311
  **kwargs
312
312
  ) -> LogEntry:
313
- """Logs an info message to the session."""
313
+ """Logs a warning message to the session."""
314
314
  return log('warning', message, indent=indent, console=console, **kwargs)
315
315
 
316
316
 
@@ -0,0 +1,10 @@
1
+ """Langfun MCP support."""
2
+
3
+ # pylint: disable=g-importing-member
4
+
5
+ from langfun.core.mcp.client import McpClient
6
+ from langfun.core.mcp.session import McpSession
7
+ from langfun.core.mcp.tool import McpTool
8
+ from langfun.core.mcp.tool import McpToolInput
9
+
10
+ # pylint: enable=g-importing-member
@@ -0,0 +1,177 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MCP client."""
15
+
16
+ import abc
17
+ from typing import Annotated, Type
18
+
19
+ from langfun.core.mcp import session as mcp_session
20
+ from langfun.core.mcp import tool as mcp_tool
21
+ from mcp.server import fastmcp as fastmcp_lib
22
+ import pyglove as pg
23
+
24
+
25
+ class McpClient(pg.Object):
26
+ """Interface for Model Context Protocol (MCP) client.
27
+
28
+ An MCP client serves as a bridge to an MCP server, enabling users to interact
29
+ with tools hosted on the server. It provides methods for listing available
30
+ tools and creating sessions for tool interaction.
31
+
32
+ There are three types of MCP clients:
33
+
34
+ * **Stdio-based client**: Ideal for interacting with tools exposed as
35
+ command-line executables through stdin/stdout.
36
+ Created by `lf.mcp.McpClient.from_command`.
37
+ * **HTTP-based client**: Designed for tools accessible via HTTP,
38
+ supporting Server-Sent Events (SSE) for streaming.
39
+ Created by `lf.mcp.McpClient.from_url`.
40
+ * **In-memory client**: Useful for testing or embedding MCP servers
41
+ within the same process.
42
+ Created by `lf.mcp.McpClient.from_fastmcp`.
43
+
44
+ **Example Usage:**
45
+
46
+ ```python
47
+ import langfun as lf
48
+
49
+ # Example 1: Stdio-based client
50
+ client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
51
+ tools = client.list_tools()
52
+ tool_cls = tools['<TOOL_NAME>']
53
+
54
+ # Print the Python definition of the tool.
55
+ print(tool_cls.python_definition())
56
+
57
+ with client.session() as session:
58
+ result = tool_cls(x=1, y=2)(session)
59
+ print(result)
60
+
61
+ # Example 2: HTTP-based client (async)
62
+ async def main():
63
+ client = lf.mcp.McpClient.from_url('http://localhost:8000/mcp')
64
+ tools = client.list_tools()
65
+ tool_cls = tools['<TOOL_NAME>']
66
+
67
+ # Print the Python definition of the tool.
68
+ print(tool_cls.python_definition())
69
+
70
+ async with client.session() as session:
71
+ result = await tool_cls(x=1, y=2).acall(session)
72
+ print(result)
73
+ ```
74
+ """
75
+
76
+ def _on_bound(self):
77
+ super()._on_bound()
78
+ self._tools = None
79
+
80
+ def list_tools(
81
+ self, refresh: bool = False
82
+ ) -> dict[str, Type[mcp_tool.McpTool]]:
83
+ """Lists all available tools on the MCP server.
84
+
85
+ Args:
86
+ refresh: If True, forces a refresh of the tool list from the server.
87
+ Otherwise, a cached list may be returned.
88
+
89
+ Returns:
90
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
91
+ """
92
+ if self._tools is None or refresh:
93
+ with self.session() as session:
94
+ self._tools = session.list_tools()
95
+ return self._tools
96
+
97
+ @abc.abstractmethod
98
+ def session(self) -> mcp_session.McpSession:
99
+ """Creates a new session for interacting with MCP tools.
100
+
101
+ Returns:
102
+ An `McpSession` object.
103
+ """
104
+
105
+ @classmethod
106
+ def from_command(cls, command: str, args: list[str]) -> 'McpClient':
107
+ """Creates an MCP client from a command-line executable.
108
+
109
+ Args:
110
+ command: The command to execute.
111
+ args: A list of arguments to pass to the command.
112
+
113
+ Returns:
114
+ A `McpClient` instance that communicates via stdin/stdout.
115
+ """
116
+ return _StdioMcpClient(command=command, args=args)
117
+
118
+ @classmethod
119
+ def from_url(
120
+ cls,
121
+ url: str,
122
+ headers: dict[str, str] | None = None
123
+ ) -> 'McpClient':
124
+ """Creates an MCP client from an HTTP URL.
125
+
126
+ Args:
127
+ url: The URL of the MCP server.
128
+ headers: An optional dictionary of HTTP headers to include in requests.
129
+
130
+ Returns:
131
+ A `McpClient` instance that communicates via HTTP.
132
+ """
133
+ return _HttpMcpClient(url=url, headers=headers or {})
134
+
135
+ @classmethod
136
+ def from_fastmcp(cls, fastmcp: fastmcp_lib.FastMCP) -> 'McpClient':
137
+ """Creates an MCP client from an in-memory FastMCP instance.
138
+
139
+ Args:
140
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
141
+
142
+ Returns:
143
+ A `McpClient` instance that communicates with the in-memory server.
144
+ """
145
+ return _InMemoryFastMcpClient(fastmcp=fastmcp)
146
+
147
+
148
+ class _StdioMcpClient(McpClient):
149
+ """Stdio-based MCP client."""
150
+
151
+ command: Annotated[str, 'Command to execute.']
152
+ args: Annotated[list[str], 'Arguments to pass to the command.']
153
+
154
+ def session(self) -> mcp_session.McpSession:
155
+ """Creates an McpSession from command."""
156
+ return mcp_session.McpSession.from_command(self.command, self.args)
157
+
158
+
159
+ class _HttpMcpClient(McpClient):
160
+ """HTTP-based MCP client."""
161
+
162
+ url: Annotated[str, 'URL to connect to.']
163
+ headers: Annotated[dict[str, str], 'Headers to send with the request.'] = {}
164
+
165
+ def session(self) -> mcp_session.McpSession:
166
+ """Creates an McpSession from URL."""
167
+ return mcp_session.McpSession.from_url(self.url, self.headers)
168
+
169
+
170
+ class _InMemoryFastMcpClient(McpClient):
171
+ """In-memory MCP client."""
172
+
173
+ fastmcp: Annotated[fastmcp_lib.FastMCP, 'MCP server to connect to.']
174
+
175
+ def session(self) -> mcp_session.McpSession:
176
+ """Creates an McpSession from an in-memory FastMCP instance."""
177
+ return mcp_session.McpSession.from_fastmcp(self.fastmcp)