langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
|
|
38
38
|
response_format = ''
|
|
39
39
|
|
|
40
40
|
choices = []
|
|
41
|
-
for k in range(json
|
|
41
|
+
for k in range(json.get('n', 1)):
|
|
42
42
|
if json.get('logprobs'):
|
|
43
43
|
logprobs = dict(
|
|
44
44
|
content=[
|
|
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
|
|
|
89
89
|
c['image_url']['url']
|
|
90
90
|
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
|
91
91
|
]
|
|
92
|
-
for k in range(json
|
|
92
|
+
for k in range(json.get('n', 1)):
|
|
93
93
|
choices.append(pg.Dict(
|
|
94
94
|
message=pg.Dict(
|
|
95
95
|
content=f'Sample {k} for message: {"".join(urls)}'
|
|
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
|
|
|
111
111
|
return response
|
|
112
112
|
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
|
|
115
|
+
del url, kwargs
|
|
116
|
+
_ = json['input']
|
|
117
|
+
|
|
118
|
+
system_message = ''
|
|
119
|
+
if 'instructions' in json:
|
|
120
|
+
system_message = f' system={json["instructions"]}'
|
|
121
|
+
|
|
122
|
+
response_format = ''
|
|
123
|
+
if 'text' in json and 'format' in json['text']:
|
|
124
|
+
response_format = f' format={json["text"]["format"]["type"]}'
|
|
125
|
+
|
|
126
|
+
output = [
|
|
127
|
+
dict(
|
|
128
|
+
type='message',
|
|
129
|
+
content=[
|
|
130
|
+
dict(
|
|
131
|
+
type='output_text',
|
|
132
|
+
text=(
|
|
133
|
+
f'Sample 0 for message.{system_message}{response_format}'
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
response = requests.Response()
|
|
141
|
+
response.status_code = 200
|
|
142
|
+
response._content = pg.to_json_str(
|
|
143
|
+
dict(
|
|
144
|
+
output=output,
|
|
145
|
+
usage=dict(
|
|
146
|
+
input_tokens=100,
|
|
147
|
+
output_tokens=100,
|
|
148
|
+
total_tokens=200,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
).encode()
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mock_responses_request_vision(
|
|
156
|
+
url: str, json: dict[str, Any], **kwargs
|
|
157
|
+
):
|
|
158
|
+
del url, kwargs
|
|
159
|
+
urls = [
|
|
160
|
+
c['image_url']
|
|
161
|
+
for c in json['input'][0]['content']
|
|
162
|
+
if c['type'] == 'input_image'
|
|
163
|
+
]
|
|
164
|
+
output = [
|
|
165
|
+
pg.Dict(
|
|
166
|
+
type='message',
|
|
167
|
+
content=[
|
|
168
|
+
pg.Dict(
|
|
169
|
+
type='output_text',
|
|
170
|
+
text=f'Sample 0 for message: {"".join(urls)}',
|
|
171
|
+
)
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
response = requests.Response()
|
|
176
|
+
response.status_code = 200
|
|
177
|
+
response._content = pg.to_json_str(
|
|
178
|
+
dict(
|
|
179
|
+
output=output,
|
|
180
|
+
usage=dict(
|
|
181
|
+
input_tokens=100,
|
|
182
|
+
output_tokens=100,
|
|
183
|
+
total_tokens=200,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
).encode()
|
|
187
|
+
return response
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class OpenAIChatCompletionAPITest(unittest.TestCase):
|
|
115
191
|
"""Tests for OpenAI compatible language model."""
|
|
116
192
|
|
|
117
193
|
def test_request_args(self):
|
|
118
194
|
self.assertEqual(
|
|
119
|
-
openai_compatible.
|
|
195
|
+
openai_compatible.OpenAIChatCompletionAPI(
|
|
120
196
|
api_endpoint='https://test-server',
|
|
121
197
|
model='test-model'
|
|
122
198
|
)._request_args(
|
|
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
126
202
|
),
|
|
127
203
|
dict(
|
|
128
204
|
model='test-model',
|
|
129
|
-
top_logprobs=None,
|
|
130
|
-
n=1,
|
|
131
205
|
temperature=1.0,
|
|
132
206
|
stop=['\n'],
|
|
133
207
|
seed=123,
|
|
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
137
211
|
def test_call_chat_completion(self):
|
|
138
212
|
with mock.patch('requests.Session.post') as mock_request:
|
|
139
213
|
mock_request.side_effect = mock_chat_completion_request
|
|
140
|
-
lm = openai_compatible.
|
|
214
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
141
215
|
api_endpoint='https://test-server', model='test-model',
|
|
142
216
|
)
|
|
143
217
|
self.assertEqual(
|
|
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
148
222
|
def test_call_chat_completion_with_logprobs(self):
|
|
149
223
|
with mock.patch('requests.Session.post') as mock_request:
|
|
150
224
|
mock_request.side_effect = mock_chat_completion_request
|
|
151
|
-
lm = openai_compatible.
|
|
225
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
152
226
|
api_endpoint='https://test-server', model='test-model',
|
|
153
227
|
)
|
|
154
228
|
results = lm.sample(['hello'], logprobs=True)
|
|
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
214
288
|
def mime_type(self) -> str:
|
|
215
289
|
return 'image/png'
|
|
216
290
|
|
|
291
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
217
292
|
with mock.patch('requests.Session.post') as mock_request:
|
|
218
293
|
mock_request.side_effect = mock_chat_completion_request_vision
|
|
219
|
-
lm_1 = openai_compatible.
|
|
294
|
+
lm_1 = openai_compatible.OpenAIChatCompletionAPI(
|
|
220
295
|
api_endpoint='https://test-server',
|
|
221
296
|
model='test-model1',
|
|
222
297
|
)
|
|
223
|
-
lm_2 = openai_compatible.
|
|
298
|
+
lm_2 = openai_compatible.OpenAIChatCompletionAPI(
|
|
224
299
|
api_endpoint='https://test-server',
|
|
225
300
|
model='test-model2',
|
|
226
301
|
)
|
|
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
228
303
|
self.assertEqual(
|
|
229
304
|
lm(
|
|
230
305
|
lf.UserMessage(
|
|
231
|
-
'hello <<[[image]]>>',
|
|
232
|
-
|
|
306
|
+
f'hello <<[[{image.id}]]>>',
|
|
307
|
+
referred_modalities=[image],
|
|
233
308
|
),
|
|
234
309
|
sampling_options=lf.LMSamplingOptions(n=2)
|
|
235
310
|
),
|
|
236
311
|
'Sample 0 for message: https://fake/image',
|
|
237
312
|
)
|
|
238
313
|
|
|
239
|
-
class TextOnlyModel(openai_compatible.
|
|
314
|
+
class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
|
|
240
315
|
|
|
241
316
|
class ModelInfo(lf.ModelInfo):
|
|
242
317
|
input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
|
|
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
251
326
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
|
252
327
|
lm_3(
|
|
253
328
|
lf.UserMessage(
|
|
254
|
-
'hello <<[[image]]>>',
|
|
255
|
-
|
|
329
|
+
f'hello <<[[{image.id}]]>>',
|
|
330
|
+
referred_modalities=[image],
|
|
256
331
|
),
|
|
257
332
|
)
|
|
258
333
|
|
|
259
334
|
def test_sample_chat_completion(self):
|
|
260
335
|
with mock.patch('requests.Session.post') as mock_request:
|
|
261
336
|
mock_request.side_effect = mock_chat_completion_request
|
|
262
|
-
lm = openai_compatible.
|
|
337
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
263
338
|
api_endpoint='https://test-server', model='test-model'
|
|
264
339
|
)
|
|
265
340
|
results = lm.sample(
|
|
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
400
475
|
def test_sample_with_contextual_options(self):
|
|
401
476
|
with mock.patch('requests.Session.post') as mock_request:
|
|
402
477
|
mock_request.side_effect = mock_chat_completion_request
|
|
403
|
-
lm = openai_compatible.
|
|
478
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
404
479
|
api_endpoint='https://test-server', model='test-model'
|
|
405
480
|
)
|
|
406
481
|
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
|
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
458
533
|
def test_call_with_system_message(self):
|
|
459
534
|
with mock.patch('requests.Session.post') as mock_request:
|
|
460
535
|
mock_request.side_effect = mock_chat_completion_request
|
|
461
|
-
lm = openai_compatible.
|
|
536
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
462
537
|
api_endpoint='https://test-server', model='test-model'
|
|
463
538
|
)
|
|
464
539
|
self.assertEqual(
|
|
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
475
550
|
def test_call_with_json_schema(self):
|
|
476
551
|
with mock.patch('requests.Session.post') as mock_request:
|
|
477
552
|
mock_request.side_effect = mock_chat_completion_request
|
|
478
|
-
lm = openai_compatible.
|
|
553
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
479
554
|
api_endpoint='https://test-server', model='test-model'
|
|
480
555
|
)
|
|
481
556
|
self.assertEqual(
|
|
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
515
590
|
|
|
516
591
|
with mock.patch('requests.Session.post') as mock_request:
|
|
517
592
|
mock_request.side_effect = mock_context_limit_error
|
|
518
|
-
lm = openai_compatible.
|
|
593
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
519
594
|
api_endpoint='https://test-server', model='test-model'
|
|
520
595
|
)
|
|
521
596
|
with self.assertRaisesRegex(
|
|
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
524
599
|
lm(lf.UserMessage('hello'))
|
|
525
600
|
|
|
526
601
|
|
|
602
|
+
class OpenAIResponsesAPITest(unittest.TestCase):
|
|
603
|
+
"""Tests for OpenAI compatible language model on Responses API."""
|
|
604
|
+
|
|
605
|
+
def test_request_args(self):
|
|
606
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
607
|
+
api_endpoint='https://test-server', model='test-model'
|
|
608
|
+
)
|
|
609
|
+
# Test valid args.
|
|
610
|
+
self.assertEqual(
|
|
611
|
+
lm._request_args(
|
|
612
|
+
lf.LMSamplingOptions(
|
|
613
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
|
614
|
+
)
|
|
615
|
+
),
|
|
616
|
+
dict(
|
|
617
|
+
model='test-model',
|
|
618
|
+
temperature=1.0,
|
|
619
|
+
stop=['\n'],
|
|
620
|
+
seed=123,
|
|
621
|
+
),
|
|
622
|
+
)
|
|
623
|
+
# Test unsupported n.
|
|
624
|
+
with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
|
|
625
|
+
lm._request_args(lf.LMSamplingOptions(n=2))
|
|
626
|
+
|
|
627
|
+
# Test unsupported logprobs.
|
|
628
|
+
with self.assertRaisesRegex(
|
|
629
|
+
ValueError, 'logprobs is not supported on Responses API.'
|
|
630
|
+
):
|
|
631
|
+
lm._request_args(lf.LMSamplingOptions(logprobs=True))
|
|
632
|
+
|
|
633
|
+
def test_call_responses(self):
|
|
634
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
635
|
+
mock_request.side_effect = mock_responses_request
|
|
636
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
637
|
+
api_endpoint='https://test-server',
|
|
638
|
+
model='test-model',
|
|
639
|
+
)
|
|
640
|
+
self.assertEqual(lm('hello'), 'Sample 0 for message.')
|
|
641
|
+
|
|
642
|
+
def test_call_responses_vision(self):
|
|
643
|
+
class FakeImage(lf_modalities.Image):
|
|
644
|
+
@property
|
|
645
|
+
def mime_type(self) -> str:
|
|
646
|
+
return 'image/png'
|
|
647
|
+
|
|
648
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
649
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
650
|
+
mock_request.side_effect = mock_responses_request_vision
|
|
651
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
652
|
+
api_endpoint='https://test-server',
|
|
653
|
+
model='test-model1',
|
|
654
|
+
)
|
|
655
|
+
self.assertEqual(
|
|
656
|
+
lm(
|
|
657
|
+
lf.UserMessage(
|
|
658
|
+
f'hello <<[[{image.id}]]>>',
|
|
659
|
+
referred_modalities=[image],
|
|
660
|
+
)
|
|
661
|
+
),
|
|
662
|
+
'Sample 0 for message: https://fake/image',
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
def test_call_with_system_message(self):
|
|
666
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
667
|
+
mock_request.side_effect = mock_responses_request
|
|
668
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
669
|
+
api_endpoint='https://test-server', model='test-model'
|
|
670
|
+
)
|
|
671
|
+
self.assertEqual(
|
|
672
|
+
lm(
|
|
673
|
+
lf.UserMessage(
|
|
674
|
+
'hello',
|
|
675
|
+
system_message=lf.SystemMessage('hi'),
|
|
676
|
+
)
|
|
677
|
+
),
|
|
678
|
+
'Sample 0 for message. system=hi',
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
def test_call_with_json_schema(self):
|
|
682
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
683
|
+
mock_request.side_effect = mock_responses_request
|
|
684
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
685
|
+
api_endpoint='https://test-server', model='test-model'
|
|
686
|
+
)
|
|
687
|
+
self.assertEqual(
|
|
688
|
+
lm(
|
|
689
|
+
lf.UserMessage(
|
|
690
|
+
'hello',
|
|
691
|
+
json_schema={
|
|
692
|
+
'type': 'object',
|
|
693
|
+
'properties': {
|
|
694
|
+
'name': {'type': 'string'},
|
|
695
|
+
},
|
|
696
|
+
'required': ['name'],
|
|
697
|
+
'title': 'Person',
|
|
698
|
+
},
|
|
699
|
+
)
|
|
700
|
+
),
|
|
701
|
+
'Sample 0 for message. format=json_schema',
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Test bad json schema.
|
|
705
|
+
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
|
706
|
+
lm(lf.UserMessage('hello', json_schema='foo'))
|
|
707
|
+
|
|
708
|
+
with self.assertRaisesRegex(
|
|
709
|
+
ValueError, 'The root of `json_schema` must have a `title` field'
|
|
710
|
+
):
|
|
711
|
+
lm(lf.UserMessage('hello', json_schema={}))
|
|
712
|
+
|
|
713
|
+
|
|
527
714
|
if __name__ == '__main__':
|
|
528
715
|
unittest.main()
|
langfun/core/llms/openai_test.py
CHANGED
langfun/core/llms/rest.py
CHANGED
|
@@ -22,7 +22,18 @@ import requests
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class REST(lf.LanguageModel):
|
|
25
|
-
"""
|
|
25
|
+
"""Base class for language models accessed via REST APIs.
|
|
26
|
+
|
|
27
|
+
The `REST` class provides a foundation for implementing language models
|
|
28
|
+
that are accessed through RESTful endpoints. It handles the details of
|
|
29
|
+
making HTTP requests, managing sessions, and handling common errors like
|
|
30
|
+
timeouts and connection issues.
|
|
31
|
+
|
|
32
|
+
Subclasses need to implement the `request` and `result` methods to
|
|
33
|
+
convert Langfun messages to API-specific request formats and to parse
|
|
34
|
+
API responses back into `LMSamplingResult` objects. They also need to
|
|
35
|
+
provide the `api_endpoint` and can override `headers` for authentication.
|
|
36
|
+
"""
|
|
26
37
|
|
|
27
38
|
api_endpoint: Annotated[
|
|
28
39
|
str,
|
|
@@ -98,7 +109,9 @@ class REST(lf.LanguageModel):
|
|
|
98
109
|
raise lf.TemporaryLMError(str(e)) from e
|
|
99
110
|
except (
|
|
100
111
|
requests.exceptions.ConnectionError,
|
|
112
|
+
requests.exceptions.ChunkedEncodingError,
|
|
101
113
|
ConnectionError,
|
|
114
|
+
ConnectionResetError,
|
|
102
115
|
) as e:
|
|
103
116
|
error_message = str(e)
|
|
104
117
|
if 'REJECTED_CLIENT_THROTTLED' in error_message:
|
|
@@ -107,6 +120,8 @@ class REST(lf.LanguageModel):
|
|
|
107
120
|
raise lf.TemporaryLMError(error_message) from e
|
|
108
121
|
if 'UNREACHABLE_ERROR' in error_message:
|
|
109
122
|
raise lf.TemporaryLMError(error_message) from e
|
|
123
|
+
if 'Connection reset by peer' in error_message:
|
|
124
|
+
raise lf.TemporaryLMError(error_message) from e
|
|
110
125
|
raise lf.LMError(error_message) from e
|
|
111
126
|
|
|
112
127
|
def _error(self, status_code: int, content: str) -> lf.LMError:
|
langfun/core/llms/vertexai.py
CHANGED
|
@@ -44,13 +44,32 @@ except ImportError:
|
|
|
44
44
|
|
|
45
45
|
@pg.use_init_args(['api_endpoint'])
|
|
46
46
|
class VertexAI(rest.REST):
|
|
47
|
-
"""Base class for
|
|
47
|
+
"""Base class for models served on Vertex AI.
|
|
48
48
|
|
|
49
|
-
This class handles
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
This class handles authentication for Vertex AI models. Subclasses,
|
|
50
|
+
such as `VertexAIGemini`, `VertexAIAnthropic`, and `VertexAILlama`,
|
|
51
|
+
provide specific implementations for different model families hosted
|
|
52
|
+
on Vertex AI.
|
|
52
53
|
|
|
53
|
-
|
|
54
|
+
**Quick Start:**
|
|
55
|
+
|
|
56
|
+
If you are using Langfun from a Google Cloud environment (e.g., GCE, GKE)
|
|
57
|
+
that has service account credentials, authentication is handled automatically.
|
|
58
|
+
Otherwise, you might need to set up credentials:
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
gcloud auth application-default login
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Then you can use a Vertex AI model:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
import langfun as lf
|
|
68
|
+
|
|
69
|
+
lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
|
|
70
|
+
r = lm('Who are you?')
|
|
71
|
+
print(r)
|
|
72
|
+
```
|
|
54
73
|
"""
|
|
55
74
|
|
|
56
75
|
model: pg.typing.Annotated[
|
|
@@ -158,7 +177,21 @@ class VertexAI(rest.REST):
|
|
|
158
177
|
@pg.use_init_args(['model'])
|
|
159
178
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
160
179
|
class VertexAIGemini(VertexAI, gemini.Gemini):
|
|
161
|
-
"""Gemini models served
|
|
180
|
+
"""Gemini models served on Vertex AI.
|
|
181
|
+
|
|
182
|
+
**Quick Start:**
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
import langfun as lf
|
|
186
|
+
|
|
187
|
+
# Call Gemini 1.5 Flash on Vertex AI.
|
|
188
|
+
# If project and location are not specified, they will be read from
|
|
189
|
+
# environment variables 'VERTEXAI_PROJECT' and 'VERTEXAI_LOCATION'.
|
|
190
|
+
lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
|
|
191
|
+
r = lm('Who are you?')
|
|
192
|
+
print(r)
|
|
193
|
+
```
|
|
194
|
+
"""
|
|
162
195
|
|
|
163
196
|
# Set default location to us-central1.
|
|
164
197
|
location = 'us-central1'
|
|
@@ -180,6 +213,13 @@ class VertexAIGemini(VertexAI, gemini.Gemini):
|
|
|
180
213
|
#
|
|
181
214
|
# Production models.
|
|
182
215
|
#
|
|
216
|
+
class VertexAIGemini3ProPreview(VertexAIGemini): # pylint: disable=invalid-name
|
|
217
|
+
"""Gemini 3 Pro Preview model launched on 11/18/2025."""
|
|
218
|
+
|
|
219
|
+
model = 'gemini-3-pro-preview'
|
|
220
|
+
location = 'global'
|
|
221
|
+
|
|
222
|
+
|
|
183
223
|
class VertexAIGemini25Pro(VertexAIGemini): # pylint: disable=invalid-name
|
|
184
224
|
"""Gemini 2.5 Pro GA model launched on 06/17/2025."""
|
|
185
225
|
|
|
@@ -369,6 +409,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
|
|
|
369
409
|
# pylint: disable=invalid-name
|
|
370
410
|
|
|
371
411
|
|
|
412
|
+
class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
|
|
413
|
+
"""Anthropic's Claude 4.5 Haiku model on VertexAI."""
|
|
414
|
+
model = 'claude-haiku-4-5@20251001'
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
|
|
418
|
+
"""Anthropic's Claude 4.5 Sonnet model on VertexAI."""
|
|
419
|
+
model = 'claude-sonnet-4-5@20250929'
|
|
420
|
+
|
|
421
|
+
|
|
372
422
|
class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
|
|
373
423
|
"""Anthropic's Claude 4 Opus model on VertexAI."""
|
|
374
424
|
model = 'claude-opus-4@20250514'
|
|
@@ -487,7 +537,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
|
|
|
487
537
|
|
|
488
538
|
@pg.use_init_args(['model'])
|
|
489
539
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
490
|
-
class VertexAILlama(VertexAI, openai_compatible.
|
|
540
|
+
class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
491
541
|
"""Llama models on VertexAI."""
|
|
492
542
|
|
|
493
543
|
model: pg.typing.Annotated[
|
|
@@ -600,7 +650,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
|
|
|
600
650
|
|
|
601
651
|
@pg.use_init_args(['model'])
|
|
602
652
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
603
|
-
class VertexAIMistral(VertexAI, openai_compatible.
|
|
653
|
+
class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
604
654
|
"""Mistral AI models on VertexAI."""
|
|
605
655
|
|
|
606
656
|
model: pg.typing.Annotated[
|
langfun/core/logging.py
CHANGED
|
@@ -310,7 +310,7 @@ def warning(
|
|
|
310
310
|
console: bool = False,
|
|
311
311
|
**kwargs
|
|
312
312
|
) -> LogEntry:
|
|
313
|
-
"""Logs
|
|
313
|
+
"""Logs a warning message to the session."""
|
|
314
314
|
return log('warning', message, indent=indent, console=console, **kwargs)
|
|
315
315
|
|
|
316
316
|
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Langfun MCP support."""
|
|
2
|
+
|
|
3
|
+
# pylint: disable=g-importing-member
|
|
4
|
+
|
|
5
|
+
from langfun.core.mcp.client import McpClient
|
|
6
|
+
from langfun.core.mcp.session import McpSession
|
|
7
|
+
from langfun.core.mcp.tool import McpTool
|
|
8
|
+
from langfun.core.mcp.tool import McpToolInput
|
|
9
|
+
|
|
10
|
+
# pylint: enable=g-importing-member
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""MCP client."""
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
from typing import Annotated, Type
|
|
18
|
+
|
|
19
|
+
from langfun.core.mcp import session as mcp_session
|
|
20
|
+
from langfun.core.mcp import tool as mcp_tool
|
|
21
|
+
from mcp.server import fastmcp as fastmcp_lib
|
|
22
|
+
import pyglove as pg
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class McpClient(pg.Object):
|
|
26
|
+
"""Interface for Model Context Protocol (MCP) client.
|
|
27
|
+
|
|
28
|
+
An MCP client serves as a bridge to an MCP server, enabling users to interact
|
|
29
|
+
with tools hosted on the server. It provides methods for listing available
|
|
30
|
+
tools and creating sessions for tool interaction.
|
|
31
|
+
|
|
32
|
+
There are three types of MCP clients:
|
|
33
|
+
|
|
34
|
+
* **Stdio-based client**: Ideal for interacting with tools exposed as
|
|
35
|
+
command-line executables through stdin/stdout.
|
|
36
|
+
Created by `lf.mcp.McpClient.from_command`.
|
|
37
|
+
* **HTTP-based client**: Designed for tools accessible via HTTP,
|
|
38
|
+
supporting Server-Sent Events (SSE) for streaming.
|
|
39
|
+
Created by `lf.mcp.McpClient.from_url`.
|
|
40
|
+
* **In-memory client**: Useful for testing or embedding MCP servers
|
|
41
|
+
within the same process.
|
|
42
|
+
Created by `lf.mcp.McpClient.from_fastmcp`.
|
|
43
|
+
|
|
44
|
+
**Example Usage:**
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import langfun as lf
|
|
48
|
+
|
|
49
|
+
# Example 1: Stdio-based client
|
|
50
|
+
client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
|
|
51
|
+
tools = client.list_tools()
|
|
52
|
+
tool_cls = tools['<TOOL_NAME>']
|
|
53
|
+
|
|
54
|
+
# Print the Python definition of the tool.
|
|
55
|
+
print(tool_cls.python_definition())
|
|
56
|
+
|
|
57
|
+
with client.session() as session:
|
|
58
|
+
result = tool_cls(x=1, y=2)(session)
|
|
59
|
+
print(result)
|
|
60
|
+
|
|
61
|
+
# Example 2: HTTP-based client (async)
|
|
62
|
+
async def main():
|
|
63
|
+
client = lf.mcp.McpClient.from_url('http://localhost:8000/mcp')
|
|
64
|
+
tools = client.list_tools()
|
|
65
|
+
tool_cls = tools['<TOOL_NAME>']
|
|
66
|
+
|
|
67
|
+
# Print the Python definition of the tool.
|
|
68
|
+
print(tool_cls.python_definition())
|
|
69
|
+
|
|
70
|
+
async with client.session() as session:
|
|
71
|
+
result = await tool_cls(x=1, y=2).acall(session)
|
|
72
|
+
print(result)
|
|
73
|
+
```
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def _on_bound(self):
|
|
77
|
+
super()._on_bound()
|
|
78
|
+
self._tools = None
|
|
79
|
+
|
|
80
|
+
def list_tools(
|
|
81
|
+
self, refresh: bool = False
|
|
82
|
+
) -> dict[str, Type[mcp_tool.McpTool]]:
|
|
83
|
+
"""Lists all available tools on the MCP server.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
refresh: If True, forces a refresh of the tool list from the server.
|
|
87
|
+
Otherwise, a cached list may be returned.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A dictionary mapping tool names to their corresponding `McpTool` classes.
|
|
91
|
+
"""
|
|
92
|
+
if self._tools is None or refresh:
|
|
93
|
+
with self.session() as session:
|
|
94
|
+
self._tools = session.list_tools()
|
|
95
|
+
return self._tools
|
|
96
|
+
|
|
97
|
+
@abc.abstractmethod
|
|
98
|
+
def session(self) -> mcp_session.McpSession:
|
|
99
|
+
"""Creates a new session for interacting with MCP tools.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
An `McpSession` object.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def from_command(cls, command: str, args: list[str]) -> 'McpClient':
|
|
107
|
+
"""Creates an MCP client from a command-line executable.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
command: The command to execute.
|
|
111
|
+
args: A list of arguments to pass to the command.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
A `McpClient` instance that communicates via stdin/stdout.
|
|
115
|
+
"""
|
|
116
|
+
return _StdioMcpClient(command=command, args=args)
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def from_url(
|
|
120
|
+
cls,
|
|
121
|
+
url: str,
|
|
122
|
+
headers: dict[str, str] | None = None
|
|
123
|
+
) -> 'McpClient':
|
|
124
|
+
"""Creates an MCP client from an HTTP URL.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
url: The URL of the MCP server.
|
|
128
|
+
headers: An optional dictionary of HTTP headers to include in requests.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A `McpClient` instance that communicates via HTTP.
|
|
132
|
+
"""
|
|
133
|
+
return _HttpMcpClient(url=url, headers=headers or {})
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def from_fastmcp(cls, fastmcp: fastmcp_lib.FastMCP) -> 'McpClient':
|
|
137
|
+
"""Creates an MCP client from an in-memory FastMCP instance.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
fastmcp: An instance of `fastmcp_lib.FastMCP`.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
A `McpClient` instance that communicates with the in-memory server.
|
|
144
|
+
"""
|
|
145
|
+
return _InMemoryFastMcpClient(fastmcp=fastmcp)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class _StdioMcpClient(McpClient):
|
|
149
|
+
"""Stdio-based MCP client."""
|
|
150
|
+
|
|
151
|
+
command: Annotated[str, 'Command to execute.']
|
|
152
|
+
args: Annotated[list[str], 'Arguments to pass to the command.']
|
|
153
|
+
|
|
154
|
+
def session(self) -> mcp_session.McpSession:
|
|
155
|
+
"""Creates an McpSession from command."""
|
|
156
|
+
return mcp_session.McpSession.from_command(self.command, self.args)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class _HttpMcpClient(McpClient):
|
|
160
|
+
"""HTTP-based MCP client."""
|
|
161
|
+
|
|
162
|
+
url: Annotated[str, 'URL to connect to.']
|
|
163
|
+
headers: Annotated[dict[str, str], 'Headers to send with the request.'] = {}
|
|
164
|
+
|
|
165
|
+
def session(self) -> mcp_session.McpSession:
|
|
166
|
+
"""Creates an McpSession from URL."""
|
|
167
|
+
return mcp_session.McpSession.from_url(self.url, self.headers)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class _InMemoryFastMcpClient(McpClient):
|
|
171
|
+
"""In-memory MCP client."""
|
|
172
|
+
|
|
173
|
+
fastmcp: Annotated[fastmcp_lib.FastMCP, 'MCP server to connect to.']
|
|
174
|
+
|
|
175
|
+
def session(self) -> mcp_session.McpSession:
|
|
176
|
+
"""Creates an McpSession from an in-memory FastMCP instance."""
|
|
177
|
+
return mcp_session.McpSession.from_fastmcp(self.fastmcp)
|