langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -23,8 +23,18 @@ import pyglove as pg
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@lf.use_init_args(['api_endpoint', 'model'])
|
|
26
|
-
class
|
|
27
|
-
"""Base for
|
|
26
|
+
class OpenAIChatCompletionAPI(rest.REST):
|
|
27
|
+
"""Base class for models compatible with OpenAI's Chat Completion API.
|
|
28
|
+
|
|
29
|
+
This class provides a common interface for language models that adhere to
|
|
30
|
+
the OpenAI Chat Completion API format, which is used by providers like
|
|
31
|
+
Groq, DeepSeek, and others. It standardizes request formatting and
|
|
32
|
+
response parsing for these models.
|
|
33
|
+
|
|
34
|
+
**References:**
|
|
35
|
+
|
|
36
|
+
* https://platform.openai.com/docs/api-reference/chat
|
|
37
|
+
"""
|
|
28
38
|
|
|
29
39
|
model: Annotated[
|
|
30
40
|
str, 'The name of the model to use.',
|
|
@@ -42,12 +52,14 @@ class OpenAICompatible(rest.REST):
|
|
|
42
52
|
# Reference:
|
|
43
53
|
# https://platform.openai.com/docs/api-reference/completions/create
|
|
44
54
|
# NOTE(daiyip): options.top_k is not applicable.
|
|
45
|
-
args =
|
|
46
|
-
|
|
47
|
-
top_logprobs=options.top_logprobs,
|
|
48
|
-
)
|
|
55
|
+
args = {}
|
|
56
|
+
|
|
49
57
|
if self.model:
|
|
50
58
|
args['model'] = self.model
|
|
59
|
+
if options.n != 1:
|
|
60
|
+
args['n'] = options.n
|
|
61
|
+
if options.top_logprobs is not None:
|
|
62
|
+
args['top_logprobs'] = options.top_logprobs
|
|
51
63
|
if options.logprobs:
|
|
52
64
|
args['logprobs'] = options.logprobs
|
|
53
65
|
if options.temperature is not None:
|
|
@@ -62,6 +74,8 @@ class OpenAICompatible(rest.REST):
|
|
|
62
74
|
args['seed'] = options.random_seed
|
|
63
75
|
if options.reasoning_effort is not None:
|
|
64
76
|
args['reasoning_effort'] = options.reasoning_effort
|
|
77
|
+
if options.extras:
|
|
78
|
+
args.update(options.extras)
|
|
65
79
|
return args
|
|
66
80
|
|
|
67
81
|
def request(
|
|
@@ -72,27 +86,13 @@ class OpenAICompatible(rest.REST):
|
|
|
72
86
|
"""Returns the JSON input for a message."""
|
|
73
87
|
request_args = self._request_args(sampling_options)
|
|
74
88
|
|
|
75
|
-
#
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if json_schema is not None:
|
|
79
|
-
if not isinstance(json_schema, dict):
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
82
|
-
)
|
|
83
|
-
if 'title' not in json_schema:
|
|
84
|
-
raise ValueError(
|
|
85
|
-
f'The root of `json_schema` must have a `title` field, '
|
|
86
|
-
f'got {json_schema!r}.'
|
|
87
|
-
)
|
|
89
|
+
# Handle structured output.
|
|
90
|
+
output_schema = self._structure_output_schema(prompt)
|
|
91
|
+
if output_schema is not None:
|
|
88
92
|
request_args.update(
|
|
89
93
|
response_format=dict(
|
|
90
94
|
type='json_schema',
|
|
91
|
-
json_schema=
|
|
92
|
-
schema=json_schema,
|
|
93
|
-
name=json_schema['title'],
|
|
94
|
-
strict=True,
|
|
95
|
-
)
|
|
95
|
+
json_schema=output_schema,
|
|
96
96
|
)
|
|
97
97
|
)
|
|
98
98
|
prompt.metadata.formatted_text = (
|
|
@@ -118,17 +118,43 @@ class OpenAICompatible(rest.REST):
|
|
|
118
118
|
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
119
119
|
messages.append(
|
|
120
120
|
system_message.as_format(
|
|
121
|
-
'
|
|
121
|
+
'openai_chat_completion_api', chunk_preprocessor=modality_check
|
|
122
122
|
)
|
|
123
123
|
)
|
|
124
124
|
messages.append(
|
|
125
|
-
prompt.as_format(
|
|
125
|
+
prompt.as_format(
|
|
126
|
+
'openai_chat_completion_api',
|
|
127
|
+
chunk_preprocessor=modality_check
|
|
128
|
+
)
|
|
126
129
|
)
|
|
127
130
|
request = dict()
|
|
128
131
|
request.update(request_args)
|
|
129
132
|
request['messages'] = messages
|
|
130
133
|
return request
|
|
131
134
|
|
|
135
|
+
def _structure_output_schema(
|
|
136
|
+
self, prompt: lf.Message
|
|
137
|
+
) -> dict[str, Any] | None:
|
|
138
|
+
# Users could use `metadata_json_schema` to pass additional
|
|
139
|
+
# request arguments.
|
|
140
|
+
json_schema = prompt.metadata.get('json_schema')
|
|
141
|
+
if json_schema is not None:
|
|
142
|
+
if not isinstance(json_schema, dict):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
145
|
+
)
|
|
146
|
+
if 'title' not in json_schema:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f'The root of `json_schema` must have a `title` field, '
|
|
149
|
+
f'got {json_schema!r}.'
|
|
150
|
+
)
|
|
151
|
+
return dict(
|
|
152
|
+
schema=json_schema,
|
|
153
|
+
name=json_schema['title'],
|
|
154
|
+
strict=True,
|
|
155
|
+
)
|
|
156
|
+
return None
|
|
157
|
+
|
|
132
158
|
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
|
133
159
|
# Reference:
|
|
134
160
|
# https://platform.openai.com/docs/api-reference/chat/object
|
|
@@ -144,7 +170,10 @@ class OpenAICompatible(rest.REST):
|
|
|
144
170
|
for t in choice_logprobs['content']
|
|
145
171
|
]
|
|
146
172
|
return lf.LMSample(
|
|
147
|
-
lf.Message.from_value(
|
|
173
|
+
lf.Message.from_value(
|
|
174
|
+
choice['message'],
|
|
175
|
+
format='openai_chat_completion_api'
|
|
176
|
+
),
|
|
148
177
|
score=0.0,
|
|
149
178
|
logprobs=logprobs,
|
|
150
179
|
)
|
|
@@ -169,3 +198,95 @@ class OpenAICompatible(rest.REST):
|
|
|
169
198
|
or (status_code == 400 and b'string_above_max_length' in content)):
|
|
170
199
|
return lf.ContextLimitError(f'{status_code}: {content}')
|
|
171
200
|
return super()._error(status_code, content)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
|
|
204
|
+
"""Base class for models compatible with OpenAI's Responses API.
|
|
205
|
+
|
|
206
|
+
This class provides a common interface for language models that adhere to
|
|
207
|
+
the new OpenAI Responses API format. It standardizes request formatting
|
|
208
|
+
and response parsing for these models, including handling instructions
|
|
209
|
+
(system messages) and structured outputs.
|
|
210
|
+
|
|
211
|
+
**References:**
|
|
212
|
+
|
|
213
|
+
* https://platform.openai.com/docs/api-reference/responses
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def _request_args(
|
|
217
|
+
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
|
218
|
+
"""Returns a dict as request arguments."""
|
|
219
|
+
if options.logprobs:
|
|
220
|
+
raise ValueError('logprobs is not supported on Responses API.')
|
|
221
|
+
if options.n != 1:
|
|
222
|
+
raise ValueError('n must be 1 for Responses API.')
|
|
223
|
+
return super()._request_args(options)
|
|
224
|
+
|
|
225
|
+
def request(
|
|
226
|
+
self,
|
|
227
|
+
prompt: lf.Message,
|
|
228
|
+
sampling_options: lf.LMSamplingOptions
|
|
229
|
+
) -> dict[str, Any]:
|
|
230
|
+
"""Returns the JSON input for a message."""
|
|
231
|
+
request_args = self._request_args(sampling_options)
|
|
232
|
+
|
|
233
|
+
# Handle structured output.
|
|
234
|
+
output_schema = self._structure_output_schema(prompt)
|
|
235
|
+
if output_schema is not None:
|
|
236
|
+
output_schema['type'] = 'json_schema'
|
|
237
|
+
request_args.update(text=dict(format=output_schema))
|
|
238
|
+
prompt.metadata.formatted_text = (
|
|
239
|
+
prompt.text
|
|
240
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
|
241
|
+
+ pg.to_json_str(request_args['text'], json_indent=2)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
request = dict()
|
|
245
|
+
request.update(request_args)
|
|
246
|
+
|
|
247
|
+
# Users could use `metadata_system_message` to pass system message.
|
|
248
|
+
system_message = prompt.metadata.get('system_message')
|
|
249
|
+
if system_message:
|
|
250
|
+
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
251
|
+
request['instructions'] = system_message.text
|
|
252
|
+
|
|
253
|
+
# Prepare input.
|
|
254
|
+
def modality_check(chunk: str | lf.Modality) -> Any:
|
|
255
|
+
if (isinstance(chunk, lf_modalities.Mime)
|
|
256
|
+
and not self.supports_input(chunk.mime_type)):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f'Unsupported modality: {chunk!r}.'
|
|
259
|
+
)
|
|
260
|
+
return chunk
|
|
261
|
+
|
|
262
|
+
request['input'] = [
|
|
263
|
+
prompt.as_format(
|
|
264
|
+
'openai_responses_api',
|
|
265
|
+
chunk_preprocessor=modality_check
|
|
266
|
+
)
|
|
267
|
+
]
|
|
268
|
+
return request
|
|
269
|
+
|
|
270
|
+
def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
|
|
271
|
+
for item in output:
|
|
272
|
+
if isinstance(item, dict) and item.get('type') == 'message':
|
|
273
|
+
return lf.LMSample(
|
|
274
|
+
lf.Message.from_value(item, format='openai_responses_api'),
|
|
275
|
+
score=0.0,
|
|
276
|
+
)
|
|
277
|
+
raise ValueError('No message found in output.')
|
|
278
|
+
|
|
279
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
|
280
|
+
"""Returns a LMSamplingResult from a JSON response."""
|
|
281
|
+
usage = json['usage']
|
|
282
|
+
return lf.LMSamplingResult(
|
|
283
|
+
samples=[self._parse_output(json['output'])],
|
|
284
|
+
usage=lf.LMSamplingUsage(
|
|
285
|
+
prompt_tokens=usage['input_tokens'],
|
|
286
|
+
completion_tokens=usage['output_tokens'],
|
|
287
|
+
total_tokens=usage['total_tokens'],
|
|
288
|
+
completion_tokens_details=usage.get(
|
|
289
|
+
'output_tokens_details', None
|
|
290
|
+
),
|
|
291
|
+
),
|
|
292
|
+
)
|
|
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
|
|
38
38
|
response_format = ''
|
|
39
39
|
|
|
40
40
|
choices = []
|
|
41
|
-
for k in range(json
|
|
41
|
+
for k in range(json.get('n', 1)):
|
|
42
42
|
if json.get('logprobs'):
|
|
43
43
|
logprobs = dict(
|
|
44
44
|
content=[
|
|
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
|
|
|
89
89
|
c['image_url']['url']
|
|
90
90
|
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
|
91
91
|
]
|
|
92
|
-
for k in range(json
|
|
92
|
+
for k in range(json.get('n', 1)):
|
|
93
93
|
choices.append(pg.Dict(
|
|
94
94
|
message=pg.Dict(
|
|
95
95
|
content=f'Sample {k} for message: {"".join(urls)}'
|
|
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
|
|
|
111
111
|
return response
|
|
112
112
|
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
|
|
115
|
+
del url, kwargs
|
|
116
|
+
_ = json['input']
|
|
117
|
+
|
|
118
|
+
system_message = ''
|
|
119
|
+
if 'instructions' in json:
|
|
120
|
+
system_message = f' system={json["instructions"]}'
|
|
121
|
+
|
|
122
|
+
response_format = ''
|
|
123
|
+
if 'text' in json and 'format' in json['text']:
|
|
124
|
+
response_format = f' format={json["text"]["format"]["type"]}'
|
|
125
|
+
|
|
126
|
+
output = [
|
|
127
|
+
dict(
|
|
128
|
+
type='message',
|
|
129
|
+
content=[
|
|
130
|
+
dict(
|
|
131
|
+
type='output_text',
|
|
132
|
+
text=(
|
|
133
|
+
f'Sample 0 for message.{system_message}{response_format}'
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
response = requests.Response()
|
|
141
|
+
response.status_code = 200
|
|
142
|
+
response._content = pg.to_json_str(
|
|
143
|
+
dict(
|
|
144
|
+
output=output,
|
|
145
|
+
usage=dict(
|
|
146
|
+
input_tokens=100,
|
|
147
|
+
output_tokens=100,
|
|
148
|
+
total_tokens=200,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
).encode()
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mock_responses_request_vision(
|
|
156
|
+
url: str, json: dict[str, Any], **kwargs
|
|
157
|
+
):
|
|
158
|
+
del url, kwargs
|
|
159
|
+
urls = [
|
|
160
|
+
c['image_url']
|
|
161
|
+
for c in json['input'][0]['content']
|
|
162
|
+
if c['type'] == 'input_image'
|
|
163
|
+
]
|
|
164
|
+
output = [
|
|
165
|
+
pg.Dict(
|
|
166
|
+
type='message',
|
|
167
|
+
content=[
|
|
168
|
+
pg.Dict(
|
|
169
|
+
type='output_text',
|
|
170
|
+
text=f'Sample 0 for message: {"".join(urls)}',
|
|
171
|
+
)
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
response = requests.Response()
|
|
176
|
+
response.status_code = 200
|
|
177
|
+
response._content = pg.to_json_str(
|
|
178
|
+
dict(
|
|
179
|
+
output=output,
|
|
180
|
+
usage=dict(
|
|
181
|
+
input_tokens=100,
|
|
182
|
+
output_tokens=100,
|
|
183
|
+
total_tokens=200,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
).encode()
|
|
187
|
+
return response
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class OpenAIChatCompletionAPITest(unittest.TestCase):
|
|
115
191
|
"""Tests for OpenAI compatible language model."""
|
|
116
192
|
|
|
117
193
|
def test_request_args(self):
|
|
118
194
|
self.assertEqual(
|
|
119
|
-
openai_compatible.
|
|
195
|
+
openai_compatible.OpenAIChatCompletionAPI(
|
|
120
196
|
api_endpoint='https://test-server',
|
|
121
197
|
model='test-model'
|
|
122
198
|
)._request_args(
|
|
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
126
202
|
),
|
|
127
203
|
dict(
|
|
128
204
|
model='test-model',
|
|
129
|
-
top_logprobs=None,
|
|
130
|
-
n=1,
|
|
131
205
|
temperature=1.0,
|
|
132
206
|
stop=['\n'],
|
|
133
207
|
seed=123,
|
|
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
137
211
|
def test_call_chat_completion(self):
|
|
138
212
|
with mock.patch('requests.Session.post') as mock_request:
|
|
139
213
|
mock_request.side_effect = mock_chat_completion_request
|
|
140
|
-
lm = openai_compatible.
|
|
214
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
141
215
|
api_endpoint='https://test-server', model='test-model',
|
|
142
216
|
)
|
|
143
217
|
self.assertEqual(
|
|
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
148
222
|
def test_call_chat_completion_with_logprobs(self):
|
|
149
223
|
with mock.patch('requests.Session.post') as mock_request:
|
|
150
224
|
mock_request.side_effect = mock_chat_completion_request
|
|
151
|
-
lm = openai_compatible.
|
|
225
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
152
226
|
api_endpoint='https://test-server', model='test-model',
|
|
153
227
|
)
|
|
154
228
|
results = lm.sample(['hello'], logprobs=True)
|
|
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
214
288
|
def mime_type(self) -> str:
|
|
215
289
|
return 'image/png'
|
|
216
290
|
|
|
291
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
217
292
|
with mock.patch('requests.Session.post') as mock_request:
|
|
218
293
|
mock_request.side_effect = mock_chat_completion_request_vision
|
|
219
|
-
lm_1 = openai_compatible.
|
|
294
|
+
lm_1 = openai_compatible.OpenAIChatCompletionAPI(
|
|
220
295
|
api_endpoint='https://test-server',
|
|
221
296
|
model='test-model1',
|
|
222
297
|
)
|
|
223
|
-
lm_2 = openai_compatible.
|
|
298
|
+
lm_2 = openai_compatible.OpenAIChatCompletionAPI(
|
|
224
299
|
api_endpoint='https://test-server',
|
|
225
300
|
model='test-model2',
|
|
226
301
|
)
|
|
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
228
303
|
self.assertEqual(
|
|
229
304
|
lm(
|
|
230
305
|
lf.UserMessage(
|
|
231
|
-
'hello <<[[image]]>>',
|
|
232
|
-
|
|
306
|
+
f'hello <<[[{image.id}]]>>',
|
|
307
|
+
referred_modalities=[image],
|
|
233
308
|
),
|
|
234
309
|
sampling_options=lf.LMSamplingOptions(n=2)
|
|
235
310
|
),
|
|
236
311
|
'Sample 0 for message: https://fake/image',
|
|
237
312
|
)
|
|
238
313
|
|
|
239
|
-
class TextOnlyModel(openai_compatible.
|
|
314
|
+
class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
|
|
240
315
|
|
|
241
316
|
class ModelInfo(lf.ModelInfo):
|
|
242
317
|
input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
|
|
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
251
326
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
|
252
327
|
lm_3(
|
|
253
328
|
lf.UserMessage(
|
|
254
|
-
'hello <<[[image]]>>',
|
|
255
|
-
|
|
329
|
+
f'hello <<[[{image.id}]]>>',
|
|
330
|
+
referred_modalities=[image],
|
|
256
331
|
),
|
|
257
332
|
)
|
|
258
333
|
|
|
259
334
|
def test_sample_chat_completion(self):
|
|
260
335
|
with mock.patch('requests.Session.post') as mock_request:
|
|
261
336
|
mock_request.side_effect = mock_chat_completion_request
|
|
262
|
-
lm = openai_compatible.
|
|
337
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
263
338
|
api_endpoint='https://test-server', model='test-model'
|
|
264
339
|
)
|
|
265
340
|
results = lm.sample(
|
|
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
400
475
|
def test_sample_with_contextual_options(self):
|
|
401
476
|
with mock.patch('requests.Session.post') as mock_request:
|
|
402
477
|
mock_request.side_effect = mock_chat_completion_request
|
|
403
|
-
lm = openai_compatible.
|
|
478
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
404
479
|
api_endpoint='https://test-server', model='test-model'
|
|
405
480
|
)
|
|
406
481
|
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
|
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
458
533
|
def test_call_with_system_message(self):
|
|
459
534
|
with mock.patch('requests.Session.post') as mock_request:
|
|
460
535
|
mock_request.side_effect = mock_chat_completion_request
|
|
461
|
-
lm = openai_compatible.
|
|
536
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
462
537
|
api_endpoint='https://test-server', model='test-model'
|
|
463
538
|
)
|
|
464
539
|
self.assertEqual(
|
|
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
475
550
|
def test_call_with_json_schema(self):
|
|
476
551
|
with mock.patch('requests.Session.post') as mock_request:
|
|
477
552
|
mock_request.side_effect = mock_chat_completion_request
|
|
478
|
-
lm = openai_compatible.
|
|
553
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
479
554
|
api_endpoint='https://test-server', model='test-model'
|
|
480
555
|
)
|
|
481
556
|
self.assertEqual(
|
|
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
515
590
|
|
|
516
591
|
with mock.patch('requests.Session.post') as mock_request:
|
|
517
592
|
mock_request.side_effect = mock_context_limit_error
|
|
518
|
-
lm = openai_compatible.
|
|
593
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
519
594
|
api_endpoint='https://test-server', model='test-model'
|
|
520
595
|
)
|
|
521
596
|
with self.assertRaisesRegex(
|
|
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
524
599
|
lm(lf.UserMessage('hello'))
|
|
525
600
|
|
|
526
601
|
|
|
602
|
+
class OpenAIResponsesAPITest(unittest.TestCase):
|
|
603
|
+
"""Tests for OpenAI compatible language model on Responses API."""
|
|
604
|
+
|
|
605
|
+
def test_request_args(self):
|
|
606
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
607
|
+
api_endpoint='https://test-server', model='test-model'
|
|
608
|
+
)
|
|
609
|
+
# Test valid args.
|
|
610
|
+
self.assertEqual(
|
|
611
|
+
lm._request_args(
|
|
612
|
+
lf.LMSamplingOptions(
|
|
613
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
|
614
|
+
)
|
|
615
|
+
),
|
|
616
|
+
dict(
|
|
617
|
+
model='test-model',
|
|
618
|
+
temperature=1.0,
|
|
619
|
+
stop=['\n'],
|
|
620
|
+
seed=123,
|
|
621
|
+
),
|
|
622
|
+
)
|
|
623
|
+
# Test unsupported n.
|
|
624
|
+
with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
|
|
625
|
+
lm._request_args(lf.LMSamplingOptions(n=2))
|
|
626
|
+
|
|
627
|
+
# Test unsupported logprobs.
|
|
628
|
+
with self.assertRaisesRegex(
|
|
629
|
+
ValueError, 'logprobs is not supported on Responses API.'
|
|
630
|
+
):
|
|
631
|
+
lm._request_args(lf.LMSamplingOptions(logprobs=True))
|
|
632
|
+
|
|
633
|
+
def test_call_responses(self):
|
|
634
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
635
|
+
mock_request.side_effect = mock_responses_request
|
|
636
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
637
|
+
api_endpoint='https://test-server',
|
|
638
|
+
model='test-model',
|
|
639
|
+
)
|
|
640
|
+
self.assertEqual(lm('hello'), 'Sample 0 for message.')
|
|
641
|
+
|
|
642
|
+
def test_call_responses_vision(self):
|
|
643
|
+
class FakeImage(lf_modalities.Image):
|
|
644
|
+
@property
|
|
645
|
+
def mime_type(self) -> str:
|
|
646
|
+
return 'image/png'
|
|
647
|
+
|
|
648
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
649
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
650
|
+
mock_request.side_effect = mock_responses_request_vision
|
|
651
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
652
|
+
api_endpoint='https://test-server',
|
|
653
|
+
model='test-model1',
|
|
654
|
+
)
|
|
655
|
+
self.assertEqual(
|
|
656
|
+
lm(
|
|
657
|
+
lf.UserMessage(
|
|
658
|
+
f'hello <<[[{image.id}]]>>',
|
|
659
|
+
referred_modalities=[image],
|
|
660
|
+
)
|
|
661
|
+
),
|
|
662
|
+
'Sample 0 for message: https://fake/image',
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
def test_call_with_system_message(self):
|
|
666
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
667
|
+
mock_request.side_effect = mock_responses_request
|
|
668
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
669
|
+
api_endpoint='https://test-server', model='test-model'
|
|
670
|
+
)
|
|
671
|
+
self.assertEqual(
|
|
672
|
+
lm(
|
|
673
|
+
lf.UserMessage(
|
|
674
|
+
'hello',
|
|
675
|
+
system_message=lf.SystemMessage('hi'),
|
|
676
|
+
)
|
|
677
|
+
),
|
|
678
|
+
'Sample 0 for message. system=hi',
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
def test_call_with_json_schema(self):
|
|
682
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
683
|
+
mock_request.side_effect = mock_responses_request
|
|
684
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
685
|
+
api_endpoint='https://test-server', model='test-model'
|
|
686
|
+
)
|
|
687
|
+
self.assertEqual(
|
|
688
|
+
lm(
|
|
689
|
+
lf.UserMessage(
|
|
690
|
+
'hello',
|
|
691
|
+
json_schema={
|
|
692
|
+
'type': 'object',
|
|
693
|
+
'properties': {
|
|
694
|
+
'name': {'type': 'string'},
|
|
695
|
+
},
|
|
696
|
+
'required': ['name'],
|
|
697
|
+
'title': 'Person',
|
|
698
|
+
},
|
|
699
|
+
)
|
|
700
|
+
),
|
|
701
|
+
'Sample 0 for message. format=json_schema',
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Test bad json schema.
|
|
705
|
+
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
|
706
|
+
lm(lf.UserMessage('hello', json_schema='foo'))
|
|
707
|
+
|
|
708
|
+
with self.assertRaisesRegex(
|
|
709
|
+
ValueError, 'The root of `json_schema` must have a `title` field'
|
|
710
|
+
):
|
|
711
|
+
lm(lf.UserMessage('hello', json_schema={}))
|
|
712
|
+
|
|
713
|
+
|
|
527
714
|
if __name__ == '__main__':
|
|
528
715
|
unittest.main()
|
langfun/core/llms/openai_test.py
CHANGED
langfun/core/llms/rest.py
CHANGED
|
@@ -22,7 +22,18 @@ import requests
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class REST(lf.LanguageModel):
|
|
25
|
-
"""
|
|
25
|
+
"""Base class for language models accessed via REST APIs.
|
|
26
|
+
|
|
27
|
+
The `REST` class provides a foundation for implementing language models
|
|
28
|
+
that are accessed through RESTful endpoints. It handles the details of
|
|
29
|
+
making HTTP requests, managing sessions, and handling common errors like
|
|
30
|
+
timeouts and connection issues.
|
|
31
|
+
|
|
32
|
+
Subclasses need to implement the `request` and `result` methods to
|
|
33
|
+
convert Langfun messages to API-specific request formats and to parse
|
|
34
|
+
API responses back into `LMSamplingResult` objects. They also need to
|
|
35
|
+
provide the `api_endpoint` and can override `headers` for authentication.
|
|
36
|
+
"""
|
|
26
37
|
|
|
27
38
|
api_endpoint: Annotated[
|
|
28
39
|
str,
|
|
@@ -98,7 +109,9 @@ class REST(lf.LanguageModel):
|
|
|
98
109
|
raise lf.TemporaryLMError(str(e)) from e
|
|
99
110
|
except (
|
|
100
111
|
requests.exceptions.ConnectionError,
|
|
112
|
+
requests.exceptions.ChunkedEncodingError,
|
|
101
113
|
ConnectionError,
|
|
114
|
+
ConnectionResetError,
|
|
102
115
|
) as e:
|
|
103
116
|
error_message = str(e)
|
|
104
117
|
if 'REJECTED_CLIENT_THROTTLED' in error_message:
|
|
@@ -107,6 +120,8 @@ class REST(lf.LanguageModel):
|
|
|
107
120
|
raise lf.TemporaryLMError(error_message) from e
|
|
108
121
|
if 'UNREACHABLE_ERROR' in error_message:
|
|
109
122
|
raise lf.TemporaryLMError(error_message) from e
|
|
123
|
+
if 'Connection reset by peer' in error_message:
|
|
124
|
+
raise lf.TemporaryLMError(error_message) from e
|
|
110
125
|
raise lf.LMError(error_message) from e
|
|
111
126
|
|
|
112
127
|
def _error(self, status_code: int, content: str) -> lf.LMError:
|