langfun 0.1.2.dev202510200805__py3-none-any.whl → 0.1.2.dev202511160804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +1 -0
- langfun/core/agentic/action.py +107 -12
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +25 -0
- langfun/core/async_support.py +32 -3
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +1 -0
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +9 -2
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +1 -0
- langfun/core/eval/v2/checkpointing.py +39 -5
- langfun/core/eval/v2/checkpointing_test.py +1 -1
- langfun/core/eval/v2/eval_test_helper.py +97 -1
- langfun/core/eval/v2/evaluation.py +88 -16
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +45 -39
- langfun/core/eval/v2/example_test.py +3 -3
- langfun/core/eval/v2/experiment.py +51 -8
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +30 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +20 -6
- langfun/core/eval/v2/runners/__init__.py +26 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +22 -124
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +79 -0
- langfun/core/eval/v2/runners/parallel.py +100 -0
- langfun/core/eval/v2/runners/parallel_test.py +98 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +175 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +103 -16
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +7 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +14 -9
- langfun/core/llms/google_genai.py +29 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +36 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +12 -1
- langfun/core/llms/vertexai.py +51 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/client.py +77 -22
- langfun/core/mcp/client_test.py +8 -35
- langfun/core/mcp/session.py +94 -29
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/tool.py +151 -22
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +19 -1
- langfun/core/modalities/mime.py +62 -3
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +215 -142
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/structured/schema/__init__.py +48 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +8 -2
- langfun/env/base_environment.py +320 -128
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +92 -15
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +84 -361
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +1 -1
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +95 -98
- langfun/env/event_handlers/event_logger_test.py +21 -21
- langfun/env/event_handlers/metric_writer.py +225 -140
- langfun/env/event_handlers/metric_writer_test.py +23 -6
- langfun/env/interface.py +854 -40
- langfun/env/interface_test.py +112 -2
- langfun/env/load_balancers_test.py +23 -2
- langfun/env/test_utils.py +126 -84
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/METADATA +1 -1
- langfun-0.1.2.dev202511160804.dist-info/RECORD +211 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun/env/base_test.py +0 -1481
- langfun/env/event_handlers/base.py +0 -350
- langfun-0.1.2.dev202510200805.dist-info/RECORD +0 -195
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/top_level.txt +0 -0
|
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
|
|
38
38
|
response_format = ''
|
|
39
39
|
|
|
40
40
|
choices = []
|
|
41
|
-
for k in range(json
|
|
41
|
+
for k in range(json.get('n', 1)):
|
|
42
42
|
if json.get('logprobs'):
|
|
43
43
|
logprobs = dict(
|
|
44
44
|
content=[
|
|
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
|
|
|
89
89
|
c['image_url']['url']
|
|
90
90
|
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
|
91
91
|
]
|
|
92
|
-
for k in range(json
|
|
92
|
+
for k in range(json.get('n', 1)):
|
|
93
93
|
choices.append(pg.Dict(
|
|
94
94
|
message=pg.Dict(
|
|
95
95
|
content=f'Sample {k} for message: {"".join(urls)}'
|
|
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
|
|
|
111
111
|
return response
|
|
112
112
|
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
|
|
115
|
+
del url, kwargs
|
|
116
|
+
_ = json['input']
|
|
117
|
+
|
|
118
|
+
system_message = ''
|
|
119
|
+
if 'instructions' in json:
|
|
120
|
+
system_message = f' system={json["instructions"]}'
|
|
121
|
+
|
|
122
|
+
response_format = ''
|
|
123
|
+
if 'text' in json and 'format' in json['text']:
|
|
124
|
+
response_format = f' format={json["text"]["format"]["type"]}'
|
|
125
|
+
|
|
126
|
+
output = [
|
|
127
|
+
dict(
|
|
128
|
+
type='message',
|
|
129
|
+
content=[
|
|
130
|
+
dict(
|
|
131
|
+
type='output_text',
|
|
132
|
+
text=(
|
|
133
|
+
f'Sample 0 for message.{system_message}{response_format}'
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
response = requests.Response()
|
|
141
|
+
response.status_code = 200
|
|
142
|
+
response._content = pg.to_json_str(
|
|
143
|
+
dict(
|
|
144
|
+
output=output,
|
|
145
|
+
usage=dict(
|
|
146
|
+
input_tokens=100,
|
|
147
|
+
output_tokens=100,
|
|
148
|
+
total_tokens=200,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
).encode()
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mock_responses_request_vision(
|
|
156
|
+
url: str, json: dict[str, Any], **kwargs
|
|
157
|
+
):
|
|
158
|
+
del url, kwargs
|
|
159
|
+
urls = [
|
|
160
|
+
c['image_url']
|
|
161
|
+
for c in json['input'][0]['content']
|
|
162
|
+
if c['type'] == 'input_image'
|
|
163
|
+
]
|
|
164
|
+
output = [
|
|
165
|
+
pg.Dict(
|
|
166
|
+
type='message',
|
|
167
|
+
content=[
|
|
168
|
+
pg.Dict(
|
|
169
|
+
type='output_text',
|
|
170
|
+
text=f'Sample 0 for message: {"".join(urls)}',
|
|
171
|
+
)
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
response = requests.Response()
|
|
176
|
+
response.status_code = 200
|
|
177
|
+
response._content = pg.to_json_str(
|
|
178
|
+
dict(
|
|
179
|
+
output=output,
|
|
180
|
+
usage=dict(
|
|
181
|
+
input_tokens=100,
|
|
182
|
+
output_tokens=100,
|
|
183
|
+
total_tokens=200,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
).encode()
|
|
187
|
+
return response
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class OpenAIChatCompletionAPITest(unittest.TestCase):
|
|
115
191
|
"""Tests for OpenAI compatible language model."""
|
|
116
192
|
|
|
117
193
|
def test_request_args(self):
|
|
118
194
|
self.assertEqual(
|
|
119
|
-
openai_compatible.
|
|
195
|
+
openai_compatible.OpenAIChatCompletionAPI(
|
|
120
196
|
api_endpoint='https://test-server',
|
|
121
197
|
model='test-model'
|
|
122
198
|
)._request_args(
|
|
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
126
202
|
),
|
|
127
203
|
dict(
|
|
128
204
|
model='test-model',
|
|
129
|
-
top_logprobs=None,
|
|
130
|
-
n=1,
|
|
131
205
|
temperature=1.0,
|
|
132
206
|
stop=['\n'],
|
|
133
207
|
seed=123,
|
|
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
137
211
|
def test_call_chat_completion(self):
|
|
138
212
|
with mock.patch('requests.Session.post') as mock_request:
|
|
139
213
|
mock_request.side_effect = mock_chat_completion_request
|
|
140
|
-
lm = openai_compatible.
|
|
214
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
141
215
|
api_endpoint='https://test-server', model='test-model',
|
|
142
216
|
)
|
|
143
217
|
self.assertEqual(
|
|
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
148
222
|
def test_call_chat_completion_with_logprobs(self):
|
|
149
223
|
with mock.patch('requests.Session.post') as mock_request:
|
|
150
224
|
mock_request.side_effect = mock_chat_completion_request
|
|
151
|
-
lm = openai_compatible.
|
|
225
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
152
226
|
api_endpoint='https://test-server', model='test-model',
|
|
153
227
|
)
|
|
154
228
|
results = lm.sample(['hello'], logprobs=True)
|
|
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
214
288
|
def mime_type(self) -> str:
|
|
215
289
|
return 'image/png'
|
|
216
290
|
|
|
291
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
217
292
|
with mock.patch('requests.Session.post') as mock_request:
|
|
218
293
|
mock_request.side_effect = mock_chat_completion_request_vision
|
|
219
|
-
lm_1 = openai_compatible.
|
|
294
|
+
lm_1 = openai_compatible.OpenAIChatCompletionAPI(
|
|
220
295
|
api_endpoint='https://test-server',
|
|
221
296
|
model='test-model1',
|
|
222
297
|
)
|
|
223
|
-
lm_2 = openai_compatible.
|
|
298
|
+
lm_2 = openai_compatible.OpenAIChatCompletionAPI(
|
|
224
299
|
api_endpoint='https://test-server',
|
|
225
300
|
model='test-model2',
|
|
226
301
|
)
|
|
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
228
303
|
self.assertEqual(
|
|
229
304
|
lm(
|
|
230
305
|
lf.UserMessage(
|
|
231
|
-
'hello <<[[image]]>>',
|
|
232
|
-
|
|
306
|
+
f'hello <<[[{image.id}]]>>',
|
|
307
|
+
referred_modalities=[image],
|
|
233
308
|
),
|
|
234
309
|
sampling_options=lf.LMSamplingOptions(n=2)
|
|
235
310
|
),
|
|
236
311
|
'Sample 0 for message: https://fake/image',
|
|
237
312
|
)
|
|
238
313
|
|
|
239
|
-
class TextOnlyModel(openai_compatible.
|
|
314
|
+
class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
|
|
240
315
|
|
|
241
316
|
class ModelInfo(lf.ModelInfo):
|
|
242
317
|
input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
|
|
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
251
326
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
|
252
327
|
lm_3(
|
|
253
328
|
lf.UserMessage(
|
|
254
|
-
'hello <<[[image]]>>',
|
|
255
|
-
|
|
329
|
+
f'hello <<[[{image.id}]]>>',
|
|
330
|
+
referred_modalities=[image],
|
|
256
331
|
),
|
|
257
332
|
)
|
|
258
333
|
|
|
259
334
|
def test_sample_chat_completion(self):
|
|
260
335
|
with mock.patch('requests.Session.post') as mock_request:
|
|
261
336
|
mock_request.side_effect = mock_chat_completion_request
|
|
262
|
-
lm = openai_compatible.
|
|
337
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
263
338
|
api_endpoint='https://test-server', model='test-model'
|
|
264
339
|
)
|
|
265
340
|
results = lm.sample(
|
|
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
400
475
|
def test_sample_with_contextual_options(self):
|
|
401
476
|
with mock.patch('requests.Session.post') as mock_request:
|
|
402
477
|
mock_request.side_effect = mock_chat_completion_request
|
|
403
|
-
lm = openai_compatible.
|
|
478
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
404
479
|
api_endpoint='https://test-server', model='test-model'
|
|
405
480
|
)
|
|
406
481
|
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
|
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
458
533
|
def test_call_with_system_message(self):
|
|
459
534
|
with mock.patch('requests.Session.post') as mock_request:
|
|
460
535
|
mock_request.side_effect = mock_chat_completion_request
|
|
461
|
-
lm = openai_compatible.
|
|
536
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
462
537
|
api_endpoint='https://test-server', model='test-model'
|
|
463
538
|
)
|
|
464
539
|
self.assertEqual(
|
|
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
475
550
|
def test_call_with_json_schema(self):
|
|
476
551
|
with mock.patch('requests.Session.post') as mock_request:
|
|
477
552
|
mock_request.side_effect = mock_chat_completion_request
|
|
478
|
-
lm = openai_compatible.
|
|
553
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
479
554
|
api_endpoint='https://test-server', model='test-model'
|
|
480
555
|
)
|
|
481
556
|
self.assertEqual(
|
|
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
515
590
|
|
|
516
591
|
with mock.patch('requests.Session.post') as mock_request:
|
|
517
592
|
mock_request.side_effect = mock_context_limit_error
|
|
518
|
-
lm = openai_compatible.
|
|
593
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
519
594
|
api_endpoint='https://test-server', model='test-model'
|
|
520
595
|
)
|
|
521
596
|
with self.assertRaisesRegex(
|
|
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
524
599
|
lm(lf.UserMessage('hello'))
|
|
525
600
|
|
|
526
601
|
|
|
602
|
+
class OpenAIResponsesAPITest(unittest.TestCase):
|
|
603
|
+
"""Tests for OpenAI compatible language model on Responses API."""
|
|
604
|
+
|
|
605
|
+
def test_request_args(self):
|
|
606
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
607
|
+
api_endpoint='https://test-server', model='test-model'
|
|
608
|
+
)
|
|
609
|
+
# Test valid args.
|
|
610
|
+
self.assertEqual(
|
|
611
|
+
lm._request_args(
|
|
612
|
+
lf.LMSamplingOptions(
|
|
613
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
|
614
|
+
)
|
|
615
|
+
),
|
|
616
|
+
dict(
|
|
617
|
+
model='test-model',
|
|
618
|
+
temperature=1.0,
|
|
619
|
+
stop=['\n'],
|
|
620
|
+
seed=123,
|
|
621
|
+
),
|
|
622
|
+
)
|
|
623
|
+
# Test unsupported n.
|
|
624
|
+
with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
|
|
625
|
+
lm._request_args(lf.LMSamplingOptions(n=2))
|
|
626
|
+
|
|
627
|
+
# Test unsupported logprobs.
|
|
628
|
+
with self.assertRaisesRegex(
|
|
629
|
+
ValueError, 'logprobs is not supported on Responses API.'
|
|
630
|
+
):
|
|
631
|
+
lm._request_args(lf.LMSamplingOptions(logprobs=True))
|
|
632
|
+
|
|
633
|
+
def test_call_responses(self):
|
|
634
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
635
|
+
mock_request.side_effect = mock_responses_request
|
|
636
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
637
|
+
api_endpoint='https://test-server',
|
|
638
|
+
model='test-model',
|
|
639
|
+
)
|
|
640
|
+
self.assertEqual(lm('hello'), 'Sample 0 for message.')
|
|
641
|
+
|
|
642
|
+
def test_call_responses_vision(self):
|
|
643
|
+
class FakeImage(lf_modalities.Image):
|
|
644
|
+
@property
|
|
645
|
+
def mime_type(self) -> str:
|
|
646
|
+
return 'image/png'
|
|
647
|
+
|
|
648
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
649
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
650
|
+
mock_request.side_effect = mock_responses_request_vision
|
|
651
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
652
|
+
api_endpoint='https://test-server',
|
|
653
|
+
model='test-model1',
|
|
654
|
+
)
|
|
655
|
+
self.assertEqual(
|
|
656
|
+
lm(
|
|
657
|
+
lf.UserMessage(
|
|
658
|
+
f'hello <<[[{image.id}]]>>',
|
|
659
|
+
referred_modalities=[image],
|
|
660
|
+
)
|
|
661
|
+
),
|
|
662
|
+
'Sample 0 for message: https://fake/image',
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
def test_call_with_system_message(self):
|
|
666
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
667
|
+
mock_request.side_effect = mock_responses_request
|
|
668
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
669
|
+
api_endpoint='https://test-server', model='test-model'
|
|
670
|
+
)
|
|
671
|
+
self.assertEqual(
|
|
672
|
+
lm(
|
|
673
|
+
lf.UserMessage(
|
|
674
|
+
'hello',
|
|
675
|
+
system_message=lf.SystemMessage('hi'),
|
|
676
|
+
)
|
|
677
|
+
),
|
|
678
|
+
'Sample 0 for message. system=hi',
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
def test_call_with_json_schema(self):
|
|
682
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
683
|
+
mock_request.side_effect = mock_responses_request
|
|
684
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
685
|
+
api_endpoint='https://test-server', model='test-model'
|
|
686
|
+
)
|
|
687
|
+
self.assertEqual(
|
|
688
|
+
lm(
|
|
689
|
+
lf.UserMessage(
|
|
690
|
+
'hello',
|
|
691
|
+
json_schema={
|
|
692
|
+
'type': 'object',
|
|
693
|
+
'properties': {
|
|
694
|
+
'name': {'type': 'string'},
|
|
695
|
+
},
|
|
696
|
+
'required': ['name'],
|
|
697
|
+
'title': 'Person',
|
|
698
|
+
},
|
|
699
|
+
)
|
|
700
|
+
),
|
|
701
|
+
'Sample 0 for message. format=json_schema',
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Test bad json schema.
|
|
705
|
+
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
|
706
|
+
lm(lf.UserMessage('hello', json_schema='foo'))
|
|
707
|
+
|
|
708
|
+
with self.assertRaisesRegex(
|
|
709
|
+
ValueError, 'The root of `json_schema` must have a `title` field'
|
|
710
|
+
):
|
|
711
|
+
lm(lf.UserMessage('hello', json_schema={}))
|
|
712
|
+
|
|
713
|
+
|
|
527
714
|
if __name__ == '__main__':
|
|
528
715
|
unittest.main()
|
langfun/core/llms/openai_test.py
CHANGED
langfun/core/llms/rest.py
CHANGED
|
@@ -22,7 +22,18 @@ import requests
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class REST(lf.LanguageModel):
|
|
25
|
-
"""
|
|
25
|
+
"""Base class for language models accessed via REST APIs.
|
|
26
|
+
|
|
27
|
+
The `REST` class provides a foundation for implementing language models
|
|
28
|
+
that are accessed through RESTful endpoints. It handles the details of
|
|
29
|
+
making HTTP requests, managing sessions, and handling common errors like
|
|
30
|
+
timeouts and connection issues.
|
|
31
|
+
|
|
32
|
+
Subclasses need to implement the `request` and `result` methods to
|
|
33
|
+
convert Langfun messages to API-specific request formats and to parse
|
|
34
|
+
API responses back into `LMSamplingResult` objects. They also need to
|
|
35
|
+
provide the `api_endpoint` and can override `headers` for authentication.
|
|
36
|
+
"""
|
|
26
37
|
|
|
27
38
|
api_endpoint: Annotated[
|
|
28
39
|
str,
|
langfun/core/llms/vertexai.py
CHANGED
|
@@ -44,13 +44,32 @@ except ImportError:
|
|
|
44
44
|
|
|
45
45
|
@pg.use_init_args(['api_endpoint'])
|
|
46
46
|
class VertexAI(rest.REST):
|
|
47
|
-
"""Base class for
|
|
47
|
+
"""Base class for models served on Vertex AI.
|
|
48
48
|
|
|
49
|
-
This class handles
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
This class handles authentication for Vertex AI models. Subclasses,
|
|
50
|
+
such as `VertexAIGemini`, `VertexAIAnthropic`, and `VertexAILlama`,
|
|
51
|
+
provide specific implementations for different model families hosted
|
|
52
|
+
on Vertex AI.
|
|
52
53
|
|
|
53
|
-
|
|
54
|
+
**Quick Start:**
|
|
55
|
+
|
|
56
|
+
If you are using Langfun from a Google Cloud environment (e.g., GCE, GKE)
|
|
57
|
+
that has service account credentials, authentication is handled automatically.
|
|
58
|
+
Otherwise, you might need to set up credentials:
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
gcloud auth application-default login
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
Then you can use a Vertex AI model:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
import langfun as lf
|
|
68
|
+
|
|
69
|
+
lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
|
|
70
|
+
r = lm('Who are you?')
|
|
71
|
+
print(r)
|
|
72
|
+
```
|
|
54
73
|
"""
|
|
55
74
|
|
|
56
75
|
model: pg.typing.Annotated[
|
|
@@ -158,7 +177,21 @@ class VertexAI(rest.REST):
|
|
|
158
177
|
@pg.use_init_args(['model'])
|
|
159
178
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
160
179
|
class VertexAIGemini(VertexAI, gemini.Gemini):
|
|
161
|
-
"""Gemini models served
|
|
180
|
+
"""Gemini models served on Vertex AI.
|
|
181
|
+
|
|
182
|
+
**Quick Start:**
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
import langfun as lf
|
|
186
|
+
|
|
187
|
+
# Call Gemini 1.5 Flash on Vertex AI.
|
|
188
|
+
# If project and location are not specified, they will be read from
|
|
189
|
+
# environment variables 'VERTEXAI_PROJECT' and 'VERTEXAI_LOCATION'.
|
|
190
|
+
lm = lf.llms.VertexAIGemini25Flash(project='my-project', location='global')
|
|
191
|
+
r = lm('Who are you?')
|
|
192
|
+
print(r)
|
|
193
|
+
```
|
|
194
|
+
"""
|
|
162
195
|
|
|
163
196
|
# Set default location to us-central1.
|
|
164
197
|
location = 'us-central1'
|
|
@@ -369,6 +402,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
|
|
|
369
402
|
# pylint: disable=invalid-name
|
|
370
403
|
|
|
371
404
|
|
|
405
|
+
class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
|
|
406
|
+
"""Anthropic's Claude 4.5 Haiku model on VertexAI."""
|
|
407
|
+
model = 'claude-haiku-4-5@20251001'
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
|
|
411
|
+
"""Anthropic's Claude 4.5 Sonnet model on VertexAI."""
|
|
412
|
+
model = 'claude-sonnet-4-5@20250929'
|
|
413
|
+
|
|
414
|
+
|
|
372
415
|
class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
|
|
373
416
|
"""Anthropic's Claude 4 Opus model on VertexAI."""
|
|
374
417
|
model = 'claude-opus-4@20250514'
|
|
@@ -487,7 +530,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
|
|
|
487
530
|
|
|
488
531
|
@pg.use_init_args(['model'])
|
|
489
532
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
490
|
-
class VertexAILlama(VertexAI, openai_compatible.
|
|
533
|
+
class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
491
534
|
"""Llama models on VertexAI."""
|
|
492
535
|
|
|
493
536
|
model: pg.typing.Annotated[
|
|
@@ -600,7 +643,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
|
|
|
600
643
|
|
|
601
644
|
@pg.use_init_args(['model'])
|
|
602
645
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
603
|
-
class VertexAIMistral(VertexAI, openai_compatible.
|
|
646
|
+
class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
604
647
|
"""Mistral AI models on VertexAI."""
|
|
605
648
|
|
|
606
649
|
model: pg.typing.Annotated[
|
langfun/core/logging.py
CHANGED
|
@@ -310,7 +310,7 @@ def warning(
|
|
|
310
310
|
console: bool = False,
|
|
311
311
|
**kwargs
|
|
312
312
|
) -> LogEntry:
|
|
313
|
-
"""Logs
|
|
313
|
+
"""Logs a warning message to the session."""
|
|
314
314
|
return log('warning', message, indent=indent, console=console, **kwargs)
|
|
315
315
|
|
|
316
316
|
|
langfun/core/mcp/client.py
CHANGED
|
@@ -23,33 +23,53 @@ import pyglove as pg
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class McpClient(pg.Object):
|
|
26
|
-
"""
|
|
26
|
+
"""Interface for Model Context Protocol (MCP) client.
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
An MCP client serves as a bridge to an MCP server, enabling users to interact
|
|
29
|
+
with tools hosted on the server. It provides methods for listing available
|
|
30
|
+
tools and creating sessions for tool interaction.
|
|
31
|
+
|
|
32
|
+
There are three types of MCP clients:
|
|
33
|
+
|
|
34
|
+
* **Stdio-based client**: Ideal for interacting with tools exposed as
|
|
35
|
+
command-line executables through stdin/stdout.
|
|
36
|
+
Created by `lf.mcp.McpClient.from_command`.
|
|
37
|
+
* **HTTP-based client**: Designed for tools accessible via HTTP,
|
|
38
|
+
supporting Server-Sent Events (SSE) for streaming.
|
|
39
|
+
Created by `lf.mcp.McpClient.from_url`.
|
|
40
|
+
* **In-memory client**: Useful for testing or embedding MCP servers
|
|
41
|
+
within the same process.
|
|
42
|
+
Created by `lf.mcp.McpClient.from_fastmcp`.
|
|
43
|
+
|
|
44
|
+
**Example Usage:**
|
|
29
45
|
|
|
30
46
|
```python
|
|
47
|
+
import langfun as lf
|
|
31
48
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
49
|
+
# Example 1: Stdio-based client
|
|
50
|
+
client = lf.mcp.McpClient.from_command('<MCP_CMD>', ['<ARG1>', 'ARG2'])
|
|
51
|
+
tools = client.list_tools()
|
|
52
|
+
tool_cls = tools['<TOOL_NAME>']
|
|
36
53
|
|
|
37
|
-
|
|
38
|
-
|
|
54
|
+
# Print the Python definition of the tool.
|
|
55
|
+
print(tool_cls.python_definition())
|
|
39
56
|
|
|
40
|
-
|
|
41
|
-
|
|
57
|
+
with client.session() as session:
|
|
58
|
+
result = tool_cls(x=1, y=2)(session)
|
|
59
|
+
print(result)
|
|
42
60
|
|
|
43
|
-
|
|
61
|
+
# Example 2: HTTP-based client (async)
|
|
62
|
+
async def main():
|
|
44
63
|
client = lf.mcp.McpClient.from_url('http://localhost:8000/mcp')
|
|
45
64
|
tools = client.list_tools()
|
|
46
65
|
tool_cls = tools['<TOOL_NAME>']
|
|
47
66
|
|
|
48
|
-
# Print the
|
|
67
|
+
# Print the Python definition of the tool.
|
|
49
68
|
print(tool_cls.python_definition())
|
|
50
69
|
|
|
51
70
|
async with client.session() as session:
|
|
52
|
-
|
|
71
|
+
result = await tool_cls(x=1, y=2).acall(session)
|
|
72
|
+
print(result)
|
|
53
73
|
```
|
|
54
74
|
"""
|
|
55
75
|
|
|
@@ -60,7 +80,15 @@ class McpClient(pg.Object):
|
|
|
60
80
|
def list_tools(
|
|
61
81
|
self, refresh: bool = False
|
|
62
82
|
) -> dict[str, Type[mcp_tool.McpTool]]:
|
|
63
|
-
"""Lists all MCP
|
|
83
|
+
"""Lists all available tools on the MCP server.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
refresh: If True, forces a refresh of the tool list from the server.
|
|
87
|
+
Otherwise, a cached list may be returned.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A dictionary mapping tool names to their corresponding `McpTool` classes.
|
|
91
|
+
"""
|
|
64
92
|
if self._tools is None or refresh:
|
|
65
93
|
with self.session() as session:
|
|
66
94
|
self._tools = session.list_tools()
|
|
@@ -68,11 +96,23 @@ class McpClient(pg.Object):
|
|
|
68
96
|
|
|
69
97
|
@abc.abstractmethod
|
|
70
98
|
def session(self) -> mcp_session.McpSession:
|
|
71
|
-
"""Creates a MCP
|
|
99
|
+
"""Creates a new session for interacting with MCP tools.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
An `McpSession` object.
|
|
103
|
+
"""
|
|
72
104
|
|
|
73
105
|
@classmethod
|
|
74
106
|
def from_command(cls, command: str, args: list[str]) -> 'McpClient':
|
|
75
|
-
"""Creates
|
|
107
|
+
"""Creates an MCP client from a command-line executable.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
command: The command to execute.
|
|
111
|
+
args: A list of arguments to pass to the command.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
A `McpClient` instance that communicates via stdin/stdout.
|
|
115
|
+
"""
|
|
76
116
|
return _StdioMcpClient(command=command, args=args)
|
|
77
117
|
|
|
78
118
|
@classmethod
|
|
@@ -81,12 +121,27 @@ class McpClient(pg.Object):
|
|
|
81
121
|
url: str,
|
|
82
122
|
headers: dict[str, str] | None = None
|
|
83
123
|
) -> 'McpClient':
|
|
84
|
-
"""Creates
|
|
124
|
+
"""Creates an MCP client from an HTTP URL.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
url: The URL of the MCP server.
|
|
128
|
+
headers: An optional dictionary of HTTP headers to include in requests.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A `McpClient` instance that communicates via HTTP.
|
|
132
|
+
"""
|
|
85
133
|
return _HttpMcpClient(url=url, headers=headers or {})
|
|
86
134
|
|
|
87
135
|
@classmethod
|
|
88
136
|
def from_fastmcp(cls, fastmcp: fastmcp_lib.FastMCP) -> 'McpClient':
|
|
89
|
-
"""Creates
|
|
137
|
+
"""Creates an MCP client from an in-memory FastMCP instance.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
fastmcp: An instance of `fastmcp_lib.FastMCP`.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
A `McpClient` instance that communicates with the in-memory server.
|
|
144
|
+
"""
|
|
90
145
|
return _InMemoryFastMcpClient(fastmcp=fastmcp)
|
|
91
146
|
|
|
92
147
|
|
|
@@ -97,18 +152,18 @@ class _StdioMcpClient(McpClient):
|
|
|
97
152
|
args: Annotated[list[str], 'Arguments to pass to the command.']
|
|
98
153
|
|
|
99
154
|
def session(self) -> mcp_session.McpSession:
|
|
100
|
-
"""Creates
|
|
155
|
+
"""Creates an McpSession from command."""
|
|
101
156
|
return mcp_session.McpSession.from_command(self.command, self.args)
|
|
102
157
|
|
|
103
158
|
|
|
104
159
|
class _HttpMcpClient(McpClient):
|
|
105
|
-
"""
|
|
160
|
+
"""HTTP-based MCP client."""
|
|
106
161
|
|
|
107
162
|
url: Annotated[str, 'URL to connect to.']
|
|
108
163
|
headers: Annotated[dict[str, str], 'Headers to send with the request.'] = {}
|
|
109
164
|
|
|
110
165
|
def session(self) -> mcp_session.McpSession:
|
|
111
|
-
"""Creates
|
|
166
|
+
"""Creates an McpSession from URL."""
|
|
112
167
|
return mcp_session.McpSession.from_url(self.url, self.headers)
|
|
113
168
|
|
|
114
169
|
|
|
@@ -118,5 +173,5 @@ class _InMemoryFastMcpClient(McpClient):
|
|
|
118
173
|
fastmcp: Annotated[fastmcp_lib.FastMCP, 'MCP server to connect to.']
|
|
119
174
|
|
|
120
175
|
def session(self) -> mcp_session.McpSession:
|
|
121
|
-
"""Creates
|
|
176
|
+
"""Creates an McpSession from an in-memory FastMCP instance."""
|
|
122
177
|
return mcp_session.McpSession.from_fastmcp(self.fastmcp)
|