langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511160804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (146) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/agentic/action.py +107 -12
  3. langfun/core/agentic/action_eval.py +9 -2
  4. langfun/core/agentic/action_test.py +25 -0
  5. langfun/core/async_support.py +32 -3
  6. langfun/core/coding/python/correction.py +19 -9
  7. langfun/core/coding/python/execution.py +14 -12
  8. langfun/core/coding/python/generation.py +21 -16
  9. langfun/core/coding/python/sandboxing.py +23 -3
  10. langfun/core/component.py +42 -3
  11. langfun/core/concurrent.py +70 -6
  12. langfun/core/concurrent_test.py +1 -0
  13. langfun/core/console.py +1 -1
  14. langfun/core/data/conversion/anthropic.py +12 -3
  15. langfun/core/data/conversion/anthropic_test.py +8 -6
  16. langfun/core/data/conversion/gemini.py +9 -2
  17. langfun/core/data/conversion/gemini_test.py +12 -9
  18. langfun/core/data/conversion/openai.py +145 -31
  19. langfun/core/data/conversion/openai_test.py +161 -17
  20. langfun/core/eval/base.py +47 -43
  21. langfun/core/eval/base_test.py +4 -4
  22. langfun/core/eval/matching.py +5 -2
  23. langfun/core/eval/patching.py +3 -3
  24. langfun/core/eval/scoring.py +4 -3
  25. langfun/core/eval/v2/__init__.py +1 -0
  26. langfun/core/eval/v2/checkpointing.py +39 -5
  27. langfun/core/eval/v2/checkpointing_test.py +1 -1
  28. langfun/core/eval/v2/eval_test_helper.py +96 -0
  29. langfun/core/eval/v2/evaluation.py +87 -15
  30. langfun/core/eval/v2/evaluation_test.py +9 -3
  31. langfun/core/eval/v2/example.py +45 -39
  32. langfun/core/eval/v2/example_test.py +3 -3
  33. langfun/core/eval/v2/experiment.py +51 -8
  34. langfun/core/eval/v2/metric_values.py +31 -3
  35. langfun/core/eval/v2/metric_values_test.py +32 -0
  36. langfun/core/eval/v2/metrics.py +157 -44
  37. langfun/core/eval/v2/metrics_test.py +39 -18
  38. langfun/core/eval/v2/progress.py +30 -1
  39. langfun/core/eval/v2/progress_test.py +27 -0
  40. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  41. langfun/core/eval/v2/reporting.py +90 -71
  42. langfun/core/eval/v2/reporting_test.py +20 -6
  43. langfun/core/eval/v2/runners/__init__.py +26 -0
  44. langfun/core/eval/v2/{runners.py → runners/base.py} +22 -124
  45. langfun/core/eval/v2/runners/debug.py +40 -0
  46. langfun/core/eval/v2/runners/debug_test.py +79 -0
  47. langfun/core/eval/v2/runners/parallel.py +100 -0
  48. langfun/core/eval/v2/runners/parallel_test.py +98 -0
  49. langfun/core/eval/v2/runners/sequential.py +47 -0
  50. langfun/core/eval/v2/runners/sequential_test.py +175 -0
  51. langfun/core/langfunc.py +45 -130
  52. langfun/core/langfunc_test.py +6 -4
  53. langfun/core/language_model.py +103 -16
  54. langfun/core/language_model_test.py +9 -3
  55. langfun/core/llms/__init__.py +7 -1
  56. langfun/core/llms/anthropic.py +157 -2
  57. langfun/core/llms/azure_openai.py +29 -17
  58. langfun/core/llms/cache/base.py +25 -3
  59. langfun/core/llms/cache/in_memory.py +48 -7
  60. langfun/core/llms/cache/in_memory_test.py +14 -4
  61. langfun/core/llms/compositional.py +25 -1
  62. langfun/core/llms/deepseek.py +30 -2
  63. langfun/core/llms/fake.py +32 -1
  64. langfun/core/llms/gemini.py +14 -9
  65. langfun/core/llms/google_genai.py +29 -1
  66. langfun/core/llms/groq.py +28 -3
  67. langfun/core/llms/llama_cpp.py +23 -4
  68. langfun/core/llms/openai.py +36 -3
  69. langfun/core/llms/openai_compatible.py +148 -27
  70. langfun/core/llms/openai_compatible_test.py +207 -20
  71. langfun/core/llms/openai_test.py +0 -2
  72. langfun/core/llms/rest.py +12 -1
  73. langfun/core/llms/vertexai.py +51 -8
  74. langfun/core/logging.py +1 -1
  75. langfun/core/mcp/client.py +77 -22
  76. langfun/core/mcp/client_test.py +8 -35
  77. langfun/core/mcp/session.py +94 -29
  78. langfun/core/mcp/session_test.py +54 -0
  79. langfun/core/mcp/tool.py +151 -22
  80. langfun/core/mcp/tool_test.py +197 -0
  81. langfun/core/memory.py +1 -0
  82. langfun/core/message.py +160 -55
  83. langfun/core/message_test.py +65 -81
  84. langfun/core/modalities/__init__.py +8 -0
  85. langfun/core/modalities/audio.py +21 -1
  86. langfun/core/modalities/image.py +19 -1
  87. langfun/core/modalities/mime.py +62 -3
  88. langfun/core/modalities/pdf.py +19 -1
  89. langfun/core/modalities/video.py +21 -1
  90. langfun/core/modality.py +167 -29
  91. langfun/core/modality_test.py +42 -12
  92. langfun/core/natural_language.py +1 -1
  93. langfun/core/sampling.py +4 -4
  94. langfun/core/sampling_test.py +20 -4
  95. langfun/core/structured/__init__.py +2 -24
  96. langfun/core/structured/completion.py +34 -44
  97. langfun/core/structured/completion_test.py +23 -43
  98. langfun/core/structured/description.py +54 -50
  99. langfun/core/structured/function_generation.py +29 -12
  100. langfun/core/structured/mapping.py +81 -37
  101. langfun/core/structured/parsing.py +95 -79
  102. langfun/core/structured/parsing_test.py +0 -3
  103. langfun/core/structured/querying.py +215 -142
  104. langfun/core/structured/querying_test.py +65 -29
  105. langfun/core/structured/schema/__init__.py +48 -0
  106. langfun/core/structured/schema/base.py +664 -0
  107. langfun/core/structured/schema/base_test.py +531 -0
  108. langfun/core/structured/schema/json.py +174 -0
  109. langfun/core/structured/schema/json_test.py +121 -0
  110. langfun/core/structured/schema/python.py +316 -0
  111. langfun/core/structured/schema/python_test.py +410 -0
  112. langfun/core/structured/schema_generation.py +33 -14
  113. langfun/core/structured/scoring.py +47 -36
  114. langfun/core/structured/tokenization.py +26 -11
  115. langfun/core/subscription.py +2 -2
  116. langfun/core/template.py +174 -49
  117. langfun/core/template_test.py +123 -17
  118. langfun/env/__init__.py +8 -2
  119. langfun/env/base_environment.py +320 -128
  120. langfun/env/base_environment_test.py +473 -0
  121. langfun/env/base_feature.py +92 -15
  122. langfun/env/base_feature_test.py +228 -0
  123. langfun/env/base_sandbox.py +84 -361
  124. langfun/env/base_sandbox_test.py +1235 -0
  125. langfun/env/event_handlers/__init__.py +1 -1
  126. langfun/env/event_handlers/chain.py +233 -0
  127. langfun/env/event_handlers/chain_test.py +253 -0
  128. langfun/env/event_handlers/event_logger.py +95 -98
  129. langfun/env/event_handlers/event_logger_test.py +21 -21
  130. langfun/env/event_handlers/metric_writer.py +225 -140
  131. langfun/env/event_handlers/metric_writer_test.py +23 -6
  132. langfun/env/interface.py +854 -40
  133. langfun/env/interface_test.py +112 -2
  134. langfun/env/load_balancers_test.py +23 -2
  135. langfun/env/test_utils.py +126 -84
  136. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/METADATA +1 -1
  137. langfun-0.1.2.dev202511160804.dist-info/RECORD +211 -0
  138. langfun/core/eval/v2/runners_test.py +0 -343
  139. langfun/core/structured/schema.py +0 -987
  140. langfun/core/structured/schema_test.py +0 -982
  141. langfun/env/base_test.py +0 -1481
  142. langfun/env/event_handlers/base.py +0 -350
  143. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  144. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/WHEEL +0 -0
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/licenses/LICENSE +0 -0
  146. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
38
38
  response_format = ''
39
39
 
40
40
  choices = []
41
- for k in range(json['n']):
41
+ for k in range(json.get('n', 1)):
42
42
  if json.get('logprobs'):
43
43
  logprobs = dict(
44
44
  content=[
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
89
89
  c['image_url']['url']
90
90
  for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
91
  ]
92
- for k in range(json['n']):
92
+ for k in range(json.get('n', 1)):
93
93
  choices.append(pg.Dict(
94
94
  message=pg.Dict(
95
95
  content=f'Sample {k} for message: {"".join(urls)}'
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
111
111
  return response
112
112
 
113
113
 
114
- class OpenAIComptibleTest(unittest.TestCase):
114
+ def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
115
+ del url, kwargs
116
+ _ = json['input']
117
+
118
+ system_message = ''
119
+ if 'instructions' in json:
120
+ system_message = f' system={json["instructions"]}'
121
+
122
+ response_format = ''
123
+ if 'text' in json and 'format' in json['text']:
124
+ response_format = f' format={json["text"]["format"]["type"]}'
125
+
126
+ output = [
127
+ dict(
128
+ type='message',
129
+ content=[
130
+ dict(
131
+ type='output_text',
132
+ text=(
133
+ f'Sample 0 for message.{system_message}{response_format}'
134
+ )
135
+ )
136
+ ],
137
+ )
138
+ ]
139
+
140
+ response = requests.Response()
141
+ response.status_code = 200
142
+ response._content = pg.to_json_str(
143
+ dict(
144
+ output=output,
145
+ usage=dict(
146
+ input_tokens=100,
147
+ output_tokens=100,
148
+ total_tokens=200,
149
+ ),
150
+ )
151
+ ).encode()
152
+ return response
153
+
154
+
155
+ def mock_responses_request_vision(
156
+ url: str, json: dict[str, Any], **kwargs
157
+ ):
158
+ del url, kwargs
159
+ urls = [
160
+ c['image_url']
161
+ for c in json['input'][0]['content']
162
+ if c['type'] == 'input_image'
163
+ ]
164
+ output = [
165
+ pg.Dict(
166
+ type='message',
167
+ content=[
168
+ pg.Dict(
169
+ type='output_text',
170
+ text=f'Sample 0 for message: {"".join(urls)}',
171
+ )
172
+ ],
173
+ )
174
+ ]
175
+ response = requests.Response()
176
+ response.status_code = 200
177
+ response._content = pg.to_json_str(
178
+ dict(
179
+ output=output,
180
+ usage=dict(
181
+ input_tokens=100,
182
+ output_tokens=100,
183
+ total_tokens=200,
184
+ ),
185
+ )
186
+ ).encode()
187
+ return response
188
+
189
+
190
+ class OpenAIChatCompletionAPITest(unittest.TestCase):
115
191
  """Tests for OpenAI compatible language model."""
116
192
 
117
193
  def test_request_args(self):
118
194
  self.assertEqual(
119
- openai_compatible.OpenAICompatible(
195
+ openai_compatible.OpenAIChatCompletionAPI(
120
196
  api_endpoint='https://test-server',
121
197
  model='test-model'
122
198
  )._request_args(
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
126
202
  ),
127
203
  dict(
128
204
  model='test-model',
129
- top_logprobs=None,
130
- n=1,
131
205
  temperature=1.0,
132
206
  stop=['\n'],
133
207
  seed=123,
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
137
211
  def test_call_chat_completion(self):
138
212
  with mock.patch('requests.Session.post') as mock_request:
139
213
  mock_request.side_effect = mock_chat_completion_request
140
- lm = openai_compatible.OpenAICompatible(
214
+ lm = openai_compatible.OpenAIChatCompletionAPI(
141
215
  api_endpoint='https://test-server', model='test-model',
142
216
  )
143
217
  self.assertEqual(
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
148
222
  def test_call_chat_completion_with_logprobs(self):
149
223
  with mock.patch('requests.Session.post') as mock_request:
150
224
  mock_request.side_effect = mock_chat_completion_request
151
- lm = openai_compatible.OpenAICompatible(
225
+ lm = openai_compatible.OpenAIChatCompletionAPI(
152
226
  api_endpoint='https://test-server', model='test-model',
153
227
  )
154
228
  results = lm.sample(['hello'], logprobs=True)
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
214
288
  def mime_type(self) -> str:
215
289
  return 'image/png'
216
290
 
291
+ image = FakeImage.from_uri('https://fake/image')
217
292
  with mock.patch('requests.Session.post') as mock_request:
218
293
  mock_request.side_effect = mock_chat_completion_request_vision
219
- lm_1 = openai_compatible.OpenAICompatible(
294
+ lm_1 = openai_compatible.OpenAIChatCompletionAPI(
220
295
  api_endpoint='https://test-server',
221
296
  model='test-model1',
222
297
  )
223
- lm_2 = openai_compatible.OpenAICompatible(
298
+ lm_2 = openai_compatible.OpenAIChatCompletionAPI(
224
299
  api_endpoint='https://test-server',
225
300
  model='test-model2',
226
301
  )
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
228
303
  self.assertEqual(
229
304
  lm(
230
305
  lf.UserMessage(
231
- 'hello <<[[image]]>>',
232
- image=FakeImage.from_uri('https://fake/image')
306
+ f'hello <<[[{image.id}]]>>',
307
+ referred_modalities=[image],
233
308
  ),
234
309
  sampling_options=lf.LMSamplingOptions(n=2)
235
310
  ),
236
311
  'Sample 0 for message: https://fake/image',
237
312
  )
238
313
 
239
- class TextOnlyModel(openai_compatible.OpenAICompatible):
314
+ class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
240
315
 
241
316
  class ModelInfo(lf.ModelInfo):
242
317
  input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
251
326
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
252
327
  lm_3(
253
328
  lf.UserMessage(
254
- 'hello <<[[image]]>>',
255
- image=FakeImage.from_uri('https://fake/image')
329
+ f'hello <<[[{image.id}]]>>',
330
+ referred_modalities=[image],
256
331
  ),
257
332
  )
258
333
 
259
334
  def test_sample_chat_completion(self):
260
335
  with mock.patch('requests.Session.post') as mock_request:
261
336
  mock_request.side_effect = mock_chat_completion_request
262
- lm = openai_compatible.OpenAICompatible(
337
+ lm = openai_compatible.OpenAIChatCompletionAPI(
263
338
  api_endpoint='https://test-server', model='test-model'
264
339
  )
265
340
  results = lm.sample(
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
400
475
  def test_sample_with_contextual_options(self):
401
476
  with mock.patch('requests.Session.post') as mock_request:
402
477
  mock_request.side_effect = mock_chat_completion_request
403
- lm = openai_compatible.OpenAICompatible(
478
+ lm = openai_compatible.OpenAIChatCompletionAPI(
404
479
  api_endpoint='https://test-server', model='test-model'
405
480
  )
406
481
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
458
533
  def test_call_with_system_message(self):
459
534
  with mock.patch('requests.Session.post') as mock_request:
460
535
  mock_request.side_effect = mock_chat_completion_request
461
- lm = openai_compatible.OpenAICompatible(
536
+ lm = openai_compatible.OpenAIChatCompletionAPI(
462
537
  api_endpoint='https://test-server', model='test-model'
463
538
  )
464
539
  self.assertEqual(
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
475
550
  def test_call_with_json_schema(self):
476
551
  with mock.patch('requests.Session.post') as mock_request:
477
552
  mock_request.side_effect = mock_chat_completion_request
478
- lm = openai_compatible.OpenAICompatible(
553
+ lm = openai_compatible.OpenAIChatCompletionAPI(
479
554
  api_endpoint='https://test-server', model='test-model'
480
555
  )
481
556
  self.assertEqual(
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
515
590
 
516
591
  with mock.patch('requests.Session.post') as mock_request:
517
592
  mock_request.side_effect = mock_context_limit_error
518
- lm = openai_compatible.OpenAICompatible(
593
+ lm = openai_compatible.OpenAIChatCompletionAPI(
519
594
  api_endpoint='https://test-server', model='test-model'
520
595
  )
521
596
  with self.assertRaisesRegex(
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
524
599
  lm(lf.UserMessage('hello'))
525
600
 
526
601
 
602
+ class OpenAIResponsesAPITest(unittest.TestCase):
603
+ """Tests for OpenAI compatible language model on Responses API."""
604
+
605
+ def test_request_args(self):
606
+ lm = openai_compatible.OpenAIResponsesAPI(
607
+ api_endpoint='https://test-server', model='test-model'
608
+ )
609
+ # Test valid args.
610
+ self.assertEqual(
611
+ lm._request_args(
612
+ lf.LMSamplingOptions(
613
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
614
+ )
615
+ ),
616
+ dict(
617
+ model='test-model',
618
+ temperature=1.0,
619
+ stop=['\n'],
620
+ seed=123,
621
+ ),
622
+ )
623
+ # Test unsupported n.
624
+ with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
625
+ lm._request_args(lf.LMSamplingOptions(n=2))
626
+
627
+ # Test unsupported logprobs.
628
+ with self.assertRaisesRegex(
629
+ ValueError, 'logprobs is not supported on Responses API.'
630
+ ):
631
+ lm._request_args(lf.LMSamplingOptions(logprobs=True))
632
+
633
+ def test_call_responses(self):
634
+ with mock.patch('requests.Session.post') as mock_request:
635
+ mock_request.side_effect = mock_responses_request
636
+ lm = openai_compatible.OpenAIResponsesAPI(
637
+ api_endpoint='https://test-server',
638
+ model='test-model',
639
+ )
640
+ self.assertEqual(lm('hello'), 'Sample 0 for message.')
641
+
642
+ def test_call_responses_vision(self):
643
+ class FakeImage(lf_modalities.Image):
644
+ @property
645
+ def mime_type(self) -> str:
646
+ return 'image/png'
647
+
648
+ image = FakeImage.from_uri('https://fake/image')
649
+ with mock.patch('requests.Session.post') as mock_request:
650
+ mock_request.side_effect = mock_responses_request_vision
651
+ lm = openai_compatible.OpenAIResponsesAPI(
652
+ api_endpoint='https://test-server',
653
+ model='test-model1',
654
+ )
655
+ self.assertEqual(
656
+ lm(
657
+ lf.UserMessage(
658
+ f'hello <<[[{image.id}]]>>',
659
+ referred_modalities=[image],
660
+ )
661
+ ),
662
+ 'Sample 0 for message: https://fake/image',
663
+ )
664
+
665
+ def test_call_with_system_message(self):
666
+ with mock.patch('requests.Session.post') as mock_request:
667
+ mock_request.side_effect = mock_responses_request
668
+ lm = openai_compatible.OpenAIResponsesAPI(
669
+ api_endpoint='https://test-server', model='test-model'
670
+ )
671
+ self.assertEqual(
672
+ lm(
673
+ lf.UserMessage(
674
+ 'hello',
675
+ system_message=lf.SystemMessage('hi'),
676
+ )
677
+ ),
678
+ 'Sample 0 for message. system=hi',
679
+ )
680
+
681
+ def test_call_with_json_schema(self):
682
+ with mock.patch('requests.Session.post') as mock_request:
683
+ mock_request.side_effect = mock_responses_request
684
+ lm = openai_compatible.OpenAIResponsesAPI(
685
+ api_endpoint='https://test-server', model='test-model'
686
+ )
687
+ self.assertEqual(
688
+ lm(
689
+ lf.UserMessage(
690
+ 'hello',
691
+ json_schema={
692
+ 'type': 'object',
693
+ 'properties': {
694
+ 'name': {'type': 'string'},
695
+ },
696
+ 'required': ['name'],
697
+ 'title': 'Person',
698
+ },
699
+ )
700
+ ),
701
+ 'Sample 0 for message. format=json_schema',
702
+ )
703
+
704
+ # Test bad json schema.
705
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
706
+ lm(lf.UserMessage('hello', json_schema='foo'))
707
+
708
+ with self.assertRaisesRegex(
709
+ ValueError, 'The root of `json_schema` must have a `title` field'
710
+ ):
711
+ lm(lf.UserMessage('hello', json_schema={}))
712
+
713
+
527
714
  if __name__ == '__main__':
528
715
  unittest.main()
@@ -61,8 +61,6 @@ class OpenAITest(unittest.TestCase):
61
61
  ),
62
62
  dict(
63
63
  model='gpt-4',
64
- top_logprobs=None,
65
- n=1,
66
64
  temperature=1.0,
67
65
  stop=['\n'],
68
66
  seed=123,
langfun/core/llms/rest.py CHANGED
@@ -22,7 +22,18 @@ import requests
22
22
 
23
23
 
24
24
  class REST(lf.LanguageModel):
25
- """REST-based language model."""
25
+ """Base class for language models accessed via REST APIs.
26
+
27
+ The `REST` class provides a foundation for implementing language models
28
+ that are accessed through RESTful endpoints. It handles the details of
29
+ making HTTP requests, managing sessions, and handling common errors like
30
+ timeouts and connection issues.
31
+
32
+ Subclasses need to implement the `request` and `result` methods to
33
+ convert Langfun messages to API-specific request formats and to parse
34
+ API responses back into `LMSamplingResult` objects. They also need to
35
+ provide the `api_endpoint` and can override `headers` for authentication.
36
+ """
26
37
 
27
38
  api_endpoint: Annotated[
28
39
  str,
@@ -44,13 +44,32 @@ except ImportError:
44
44
 
45
45
  @pg.use_init_args(['api_endpoint'])
46
46
  class VertexAI(rest.REST):
47
- """Base class for VertexAI models.
47
+ """Base class for models served on Vertex AI.
48
48
 
49
- This class handles the authentication of vertex AI models. Subclasses
50
- should implement `request` and `result` methods, as well as the `api_endpoint`
51
- property. Or let users to provide them as __init__ arguments.
49
+ This class handles authentication for Vertex AI models. Subclasses,
50
+ such as `VertexAIGemini`, `VertexAIAnthropic`, and `VertexAILlama`,
51
+ provide specific implementations for different model families hosted
52
+ on Vertex AI.
52
53
 
53
- Please check out VertexAIGemini in `gemini.py` as an example.
54
+ **Quick Start:**
55
+
56
+ If you are using Langfun from a Google Cloud environment (e.g., GCE, GKE)
57
+ that has service account credentials, authentication is handled automatically.
58
+ Otherwise, you might need to set up credentials:
59
+
60
+ ```bash
61
+ gcloud auth application-default login
62
+ ```
63
+
64
+ Then you can use a Vertex AI model:
65
+
66
+ ```python
67
+ import langfun as lf
68
+
69
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
70
+ r = lm('Who are you?')
71
+ print(r)
72
+ ```
54
73
  """
55
74
 
56
75
  model: pg.typing.Annotated[
@@ -158,7 +177,21 @@ class VertexAI(rest.REST):
158
177
  @pg.use_init_args(['model'])
159
178
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
160
179
  class VertexAIGemini(VertexAI, gemini.Gemini):
161
- """Gemini models served by Vertex AI.."""
180
+ """Gemini models served on Vertex AI.
181
+
182
+ **Quick Start:**
183
+
184
+ ```python
185
+ import langfun as lf
186
+
187
+ # Call Gemini 1.5 Flash on Vertex AI.
188
+ # If project and location are not specified, they will be read from
189
+ # environment variables 'VERTEXAI_PROJECT' and 'VERTEXAI_LOCATION'.
190
+ lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
191
+ r = lm('Who are you?')
192
+ print(r)
193
+ ```
194
+ """
162
195
 
163
196
  # Set default location to us-central1.
164
197
  location = 'us-central1'
@@ -369,6 +402,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
369
402
  # pylint: disable=invalid-name
370
403
 
371
404
 
405
+ class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
406
+ """Anthropic's Claude 4.5 Haiku model on VertexAI."""
407
+ model = 'claude-haiku-4-5@20251001'
408
+
409
+
410
+ class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
411
+ """Anthropic's Claude 4.5 Sonnet model on VertexAI."""
412
+ model = 'claude-sonnet-4-5@20250929'
413
+
414
+
372
415
  class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
373
416
  """Anthropic's Claude 4 Opus model on VertexAI."""
374
417
  model = 'claude-opus-4@20250514'
@@ -487,7 +530,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
487
530
 
488
531
  @pg.use_init_args(['model'])
489
532
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
490
- class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
533
+ class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
491
534
  """Llama models on VertexAI."""
492
535
 
493
536
  model: pg.typing.Annotated[
@@ -600,7 +643,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
600
643
 
601
644
  @pg.use_init_args(['model'])
602
645
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
603
- class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
646
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
604
647
  """Mistral AI models on VertexAI."""
605
648
 
606
649
  model: pg.typing.Annotated[
langfun/core/logging.py CHANGED
@@ -310,7 +310,7 @@ def warning(
310
310
  console: bool = False,
311
311
  **kwargs
312
312
  ) -> LogEntry:
313
- """Logs an info message to the session."""
313
+ """Logs a warning message to the session."""
314
314
  return log('warning', message, indent=indent, console=console, **kwargs)
315
315
 
316
316
 
@@ -23,33 +23,53 @@ import pyglove as pg
23
23
 
24
24
 
25
25
  class McpClient(pg.Object):
26
- """Base class for MCP client.
26
+ """Interface for Model Context Protocol (MCP) client.
27
27
 
28
- Usage:
28
+ An MCP client serves as a bridge to an MCP server, enabling users to interact
29
+ with tools hosted on the server. It provides methods for listing available
30
+ tools and creating sessions for tool interaction.
31
+
32
+ There are three types of MCP clients:
33
+
34
+ * **Stdio-based client**: Ideal for interacting with tools exposed as
35
+ command-line executables through stdin/stdout.
36
+ Created by `lf.mcp.McpClient.from_command`.
37
+ * **HTTP-based client**: Designed for tools accessible via HTTP,
38
+ supporting Server-Sent Events (SSE) for streaming.
39
+ Created by `lf.mcp.McpClient.from_url`.
40
+ * **In-memory client**: Useful for testing or embedding MCP servers
41
+ within the same process.
42
+ Created by `lf.mcp.McpClient.from_fastmcp`.
43
+
44
+ **Example Usage:**
29
45
 
30
46
  ```python
47
+ import langfun as lf
31
48
 
32
- def tool_use():
33
- client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
34
- tools = client.list_tools()
35
- tool_cls = tools['<TOOL_NAME>']
49
+ # Example 1: Stdio-based client
50
+ client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
51
+ tools = client.list_tools()
52
+ tool_cls = tools['<TOOL_NAME>']
36
53
 
37
- # Print the python definition of the tool.
38
- print(tool_cls.python_definition())
54
+ # Print the Python definition of the tool.
55
+ print(tool_cls.python_definition())
39
56
 
40
- with client.session() as session:
41
- return tool_cls(x=1, y=2)(session)
57
+ with client.session() as session:
58
+ result = tool_cls(x=1, y=2)(session)
59
+ print(result)
42
60
 
43
- async def tool_use_async_version():
61
+ # Example 2: HTTP-based client (async)
62
+ async def main():
44
63
  client = lf.mcp.McpClient.from_url('http://localhost:8000/mcp')
45
64
  tools = client.list_tools()
46
65
  tool_cls = tools['<TOOL_NAME>']
47
66
 
48
- # Print the python definition of the tool.
67
+ # Print the Python definition of the tool.
49
68
  print(tool_cls.python_definition())
50
69
 
51
70
  async with client.session() as session:
52
- return await tool_cls(x=1, y=2).acall(session)
71
+ result = await tool_cls(x=1, y=2).acall(session)
72
+ print(result)
53
73
  ```
54
74
  """
55
75
 
@@ -60,7 +80,15 @@ class McpClient(pg.Object):
60
80
  def list_tools(
61
81
  self, refresh: bool = False
62
82
  ) -> dict[str, Type[mcp_tool.McpTool]]:
63
- """Lists all MCP tools."""
83
+ """Lists all available tools on the MCP server.
84
+
85
+ Args:
86
+ refresh: If True, forces a refresh of the tool list from the server.
87
+ Otherwise, a cached list may be returned.
88
+
89
+ Returns:
90
+ A dictionary mapping tool names to their corresponding `McpTool` classes.
91
+ """
64
92
  if self._tools is None or refresh:
65
93
  with self.session() as session:
66
94
  self._tools = session.list_tools()
@@ -68,11 +96,23 @@ class McpClient(pg.Object):
68
96
 
69
97
  @abc.abstractmethod
70
98
  def session(self) -> mcp_session.McpSession:
71
- """Creates a MCP session."""
99
+ """Creates a new session for interacting with MCP tools.
100
+
101
+ Returns:
102
+ An `McpSession` object.
103
+ """
72
104
 
73
105
  @classmethod
74
106
  def from_command(cls, command: str, args: list[str]) -> 'McpClient':
75
- """Creates a MCP client from a tool."""
107
+ """Creates an MCP client from a command-line executable.
108
+
109
+ Args:
110
+ command: The command to execute.
111
+ args: A list of arguments to pass to the command.
112
+
113
+ Returns:
114
+ A `McpClient` instance that communicates via stdin/stdout.
115
+ """
76
116
  return _StdioMcpClient(command=command, args=args)
77
117
 
78
118
  @classmethod
@@ -81,12 +121,27 @@ class McpClient(pg.Object):
81
121
  url: str,
82
122
  headers: dict[str, str] | None = None
83
123
  ) -> 'McpClient':
84
- """Creates a MCP client from a URL."""
124
+ """Creates an MCP client from an HTTP URL.
125
+
126
+ Args:
127
+ url: The URL of the MCP server.
128
+ headers: An optional dictionary of HTTP headers to include in requests.
129
+
130
+ Returns:
131
+ A `McpClient` instance that communicates via HTTP.
132
+ """
85
133
  return _HttpMcpClient(url=url, headers=headers or {})
86
134
 
87
135
  @classmethod
88
136
  def from_fastmcp(cls, fastmcp: fastmcp_lib.FastMCP) -> 'McpClient':
89
- """Creates a MCP client from a MCP server."""
137
+ """Creates an MCP client from an in-memory FastMCP instance.
138
+
139
+ Args:
140
+ fastmcp: An instance of `fastmcp_lib.FastMCP`.
141
+
142
+ Returns:
143
+ A `McpClient` instance that communicates with the in-memory server.
144
+ """
90
145
  return _InMemoryFastMcpClient(fastmcp=fastmcp)
91
146
 
92
147
 
@@ -97,18 +152,18 @@ class _StdioMcpClient(McpClient):
97
152
  args: Annotated[list[str], 'Arguments to pass to the command.']
98
153
 
99
154
  def session(self) -> mcp_session.McpSession:
100
- """Creates a MCP session."""
155
+ """Creates an McpSession from command."""
101
156
  return mcp_session.McpSession.from_command(self.command, self.args)
102
157
 
103
158
 
104
159
  class _HttpMcpClient(McpClient):
105
- """Server-Sent Events (SSE)/Streamable HTTP-based MCP client."""
160
+ """HTTP-based MCP client."""
106
161
 
107
162
  url: Annotated[str, 'URL to connect to.']
108
163
  headers: Annotated[dict[str, str], 'Headers to send with the request.'] = {}
109
164
 
110
165
  def session(self) -> mcp_session.McpSession:
111
- """Creates a MCP session."""
166
+ """Creates an McpSession from URL."""
112
167
  return mcp_session.McpSession.from_url(self.url, self.headers)
113
168
 
114
169
 
@@ -118,5 +173,5 @@ class _InMemoryFastMcpClient(McpClient):
118
173
  fastmcp: Annotated[fastmcp_lib.FastMCP, 'MCP server to connect to.']
119
174
 
120
175
  def session(self) -> mcp_session.McpSession:
121
- """Creates a MCP session."""
176
+ """Creates an McpSession from an in-memory FastMCP instance."""
122
177
  return mcp_session.McpSession.from_fastmcp(self.fastmcp)