langfun 0.1.2.dev202510240805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/concurrent_test.py +1 -0
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +134 -30
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/langfunc_test.py +4 -2
- langfun/core/language_model.py +6 -6
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +2 -1
- langfun/core/llms/cache/base.py +3 -1
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/deepseek.py +1 -1
- langfun/core/llms/groq.py +1 -1
- langfun/core/llms/llama_cpp.py +1 -1
- langfun/core/llms/openai.py +7 -2
- langfun/core/llms/openai_compatible.py +134 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/vertexai.py +2 -2
- langfun/core/message.py +78 -44
- langfun/core/message_test.py +56 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/mime.py +9 -0
- langfun/core/modality.py +104 -27
- langfun/core/modality_test.py +42 -12
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/completion.py +2 -7
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/mapping.py +4 -13
- langfun/core/structured/querying.py +13 -11
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/template.py +39 -13
- langfun/core/template_test.py +83 -17
- langfun/env/event_handlers/metric_writer_test.py +3 -3
- langfun/env/load_balancers_test.py +2 -2
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +41 -41
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
|
@@ -175,18 +175,28 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
|
175
175
|
|
|
176
176
|
cache = in_memory.InMemory()
|
|
177
177
|
lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
|
|
178
|
-
|
|
179
|
-
|
|
178
|
+
image_foo = CustomModality('foo')
|
|
179
|
+
image_bar = CustomModality('bar')
|
|
180
|
+
lm(
|
|
181
|
+
lf.UserMessage(
|
|
182
|
+
f'hi <<[[{image_foo.id}]]>>', referred_modalities=[image_foo]
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
lm(
|
|
186
|
+
lf.UserMessage(
|
|
187
|
+
f'hi <<[[{image_bar.id}]]>>', referred_modalities=[image_bar]
|
|
188
|
+
)
|
|
189
|
+
)
|
|
180
190
|
self.assertEqual(
|
|
181
191
|
list(cache.keys()),
|
|
182
192
|
[
|
|
183
193
|
(
|
|
184
|
-
'hi <<[[
|
|
194
|
+
f'hi <<[[{image_foo.id}]]>>',
|
|
185
195
|
(None, None, 1, 40, None, None),
|
|
186
196
|
0,
|
|
187
197
|
),
|
|
188
198
|
(
|
|
189
|
-
'hi <<[[
|
|
199
|
+
f'hi <<[[{image_bar.id}]]>>',
|
|
190
200
|
(None, None, 1, 40, None, None),
|
|
191
201
|
0,
|
|
192
202
|
),
|
langfun/core/llms/deepseek.py
CHANGED
|
@@ -93,7 +93,7 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
93
93
|
# DeepSeek API uses an API format compatible with OpenAI.
|
|
94
94
|
# Reference: https://api-docs.deepseek.com/
|
|
95
95
|
@lf.use_init_args(['model'])
|
|
96
|
-
class DeepSeek(openai_compatible.
|
|
96
|
+
class DeepSeek(openai_compatible.OpenAIChatCompletionAPI):
|
|
97
97
|
"""DeepSeek model."""
|
|
98
98
|
|
|
99
99
|
model: pg.typing.Annotated[
|
langfun/core/llms/groq.py
CHANGED
|
@@ -259,7 +259,7 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
@lf.use_init_args(['model'])
|
|
262
|
-
class Groq(openai_compatible.
|
|
262
|
+
class Groq(openai_compatible.OpenAIChatCompletionAPI):
|
|
263
263
|
"""Groq LLMs through REST APIs (OpenAI compatible).
|
|
264
264
|
|
|
265
265
|
See https://platform.openai.com/docs/api-reference/chat
|
langfun/core/llms/llama_cpp.py
CHANGED
|
@@ -20,7 +20,7 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
@pg.use_init_args(['url', 'model'])
|
|
22
22
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
23
|
-
class LlamaCppRemote(openai_compatible.
|
|
23
|
+
class LlamaCppRemote(openai_compatible.OpenAIChatCompletionAPI):
|
|
24
24
|
"""The remote LLaMA C++ model.
|
|
25
25
|
|
|
26
26
|
The Remote LLaMA C++ models can be launched via
|
langfun/core/llms/openai.py
CHANGED
|
@@ -1031,7 +1031,7 @@ _SUPPORTED_MODELS_BY_MODEL_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
1031
1031
|
|
|
1032
1032
|
|
|
1033
1033
|
@lf.use_init_args(['model'])
|
|
1034
|
-
class OpenAI(openai_compatible.
|
|
1034
|
+
class OpenAI(openai_compatible.OpenAIResponsesAPI):
|
|
1035
1035
|
"""OpenAI model."""
|
|
1036
1036
|
|
|
1037
1037
|
model: pg.typing.Annotated[
|
|
@@ -1041,7 +1041,12 @@ class OpenAI(openai_compatible.OpenAICompatible):
|
|
|
1041
1041
|
'The name of the model to use.',
|
|
1042
1042
|
]
|
|
1043
1043
|
|
|
1044
|
-
|
|
1044
|
+
# Disable message storage by default.
|
|
1045
|
+
sampling_options = lf.LMSamplingOptions(
|
|
1046
|
+
extras={'store': False}
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
api_endpoint: str = 'https://api.openai.com/v1/responses'
|
|
1045
1050
|
|
|
1046
1051
|
api_key: Annotated[
|
|
1047
1052
|
str | None,
|
|
@@ -23,8 +23,13 @@ import pyglove as pg
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@lf.use_init_args(['api_endpoint', 'model'])
|
|
26
|
-
class
|
|
27
|
-
"""Base for OpenAI compatible models.
|
|
26
|
+
class OpenAIChatCompletionAPI(rest.REST):
|
|
27
|
+
"""Base for OpenAI compatible models based on ChatCompletion API.
|
|
28
|
+
|
|
29
|
+
See https://platform.openai.com/docs/api-reference/chat
|
|
30
|
+
As of 2025-10-23, OpenAI is migrating from ChatCompletion API to Responses
|
|
31
|
+
API.
|
|
32
|
+
"""
|
|
28
33
|
|
|
29
34
|
model: Annotated[
|
|
30
35
|
str, 'The name of the model to use.',
|
|
@@ -42,12 +47,14 @@ class OpenAICompatible(rest.REST):
|
|
|
42
47
|
# Reference:
|
|
43
48
|
# https://platform.openai.com/docs/api-reference/completions/create
|
|
44
49
|
# NOTE(daiyip): options.top_k is not applicable.
|
|
45
|
-
args =
|
|
46
|
-
|
|
47
|
-
top_logprobs=options.top_logprobs,
|
|
48
|
-
)
|
|
50
|
+
args = {}
|
|
51
|
+
|
|
49
52
|
if self.model:
|
|
50
53
|
args['model'] = self.model
|
|
54
|
+
if options.n != 1:
|
|
55
|
+
args['n'] = options.n
|
|
56
|
+
if options.top_logprobs is not None:
|
|
57
|
+
args['top_logprobs'] = options.top_logprobs
|
|
51
58
|
if options.logprobs:
|
|
52
59
|
args['logprobs'] = options.logprobs
|
|
53
60
|
if options.temperature is not None:
|
|
@@ -74,27 +81,13 @@ class OpenAICompatible(rest.REST):
|
|
|
74
81
|
"""Returns the JSON input for a message."""
|
|
75
82
|
request_args = self._request_args(sampling_options)
|
|
76
83
|
|
|
77
|
-
#
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
if json_schema is not None:
|
|
81
|
-
if not isinstance(json_schema, dict):
|
|
82
|
-
raise ValueError(
|
|
83
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
84
|
-
)
|
|
85
|
-
if 'title' not in json_schema:
|
|
86
|
-
raise ValueError(
|
|
87
|
-
f'The root of `json_schema` must have a `title` field, '
|
|
88
|
-
f'got {json_schema!r}.'
|
|
89
|
-
)
|
|
84
|
+
# Handle structured output.
|
|
85
|
+
output_schema = self._structure_output_schema(prompt)
|
|
86
|
+
if output_schema is not None:
|
|
90
87
|
request_args.update(
|
|
91
88
|
response_format=dict(
|
|
92
89
|
type='json_schema',
|
|
93
|
-
json_schema=
|
|
94
|
-
schema=json_schema,
|
|
95
|
-
name=json_schema['title'],
|
|
96
|
-
strict=True,
|
|
97
|
-
)
|
|
90
|
+
json_schema=output_schema,
|
|
98
91
|
)
|
|
99
92
|
)
|
|
100
93
|
prompt.metadata.formatted_text = (
|
|
@@ -120,17 +113,43 @@ class OpenAICompatible(rest.REST):
|
|
|
120
113
|
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
121
114
|
messages.append(
|
|
122
115
|
system_message.as_format(
|
|
123
|
-
'
|
|
116
|
+
'openai_chat_completion_api', chunk_preprocessor=modality_check
|
|
124
117
|
)
|
|
125
118
|
)
|
|
126
119
|
messages.append(
|
|
127
|
-
prompt.as_format(
|
|
120
|
+
prompt.as_format(
|
|
121
|
+
'openai_chat_completion_api',
|
|
122
|
+
chunk_preprocessor=modality_check
|
|
123
|
+
)
|
|
128
124
|
)
|
|
129
125
|
request = dict()
|
|
130
126
|
request.update(request_args)
|
|
131
127
|
request['messages'] = messages
|
|
132
128
|
return request
|
|
133
129
|
|
|
130
|
+
def _structure_output_schema(
|
|
131
|
+
self, prompt: lf.Message
|
|
132
|
+
) -> dict[str, Any] | None:
|
|
133
|
+
# Users could use `metadata_json_schema` to pass additional
|
|
134
|
+
# request arguments.
|
|
135
|
+
json_schema = prompt.metadata.get('json_schema')
|
|
136
|
+
if json_schema is not None:
|
|
137
|
+
if not isinstance(json_schema, dict):
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
140
|
+
)
|
|
141
|
+
if 'title' not in json_schema:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f'The root of `json_schema` must have a `title` field, '
|
|
144
|
+
f'got {json_schema!r}.'
|
|
145
|
+
)
|
|
146
|
+
return dict(
|
|
147
|
+
schema=json_schema,
|
|
148
|
+
name=json_schema['title'],
|
|
149
|
+
strict=True,
|
|
150
|
+
)
|
|
151
|
+
return None
|
|
152
|
+
|
|
134
153
|
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
|
135
154
|
# Reference:
|
|
136
155
|
# https://platform.openai.com/docs/api-reference/chat/object
|
|
@@ -146,7 +165,10 @@ class OpenAICompatible(rest.REST):
|
|
|
146
165
|
for t in choice_logprobs['content']
|
|
147
166
|
]
|
|
148
167
|
return lf.LMSample(
|
|
149
|
-
lf.Message.from_value(
|
|
168
|
+
lf.Message.from_value(
|
|
169
|
+
choice['message'],
|
|
170
|
+
format='openai_chat_completion_api'
|
|
171
|
+
),
|
|
150
172
|
score=0.0,
|
|
151
173
|
logprobs=logprobs,
|
|
152
174
|
)
|
|
@@ -171,3 +193,88 @@ class OpenAICompatible(rest.REST):
|
|
|
171
193
|
or (status_code == 400 and b'string_above_max_length' in content)):
|
|
172
194
|
return lf.ContextLimitError(f'{status_code}: {content}')
|
|
173
195
|
return super()._error(status_code, content)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
|
|
199
|
+
"""Base for OpenAI compatible models based on Responses API.
|
|
200
|
+
|
|
201
|
+
https://platform.openai.com/docs/api-reference/responses/create
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def _request_args(
|
|
205
|
+
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
|
206
|
+
"""Returns a dict as request arguments."""
|
|
207
|
+
if options.logprobs:
|
|
208
|
+
raise ValueError('logprobs is not supported on Responses API.')
|
|
209
|
+
if options.n != 1:
|
|
210
|
+
raise ValueError('n must be 1 for Responses API.')
|
|
211
|
+
return super()._request_args(options)
|
|
212
|
+
|
|
213
|
+
def request(
|
|
214
|
+
self,
|
|
215
|
+
prompt: lf.Message,
|
|
216
|
+
sampling_options: lf.LMSamplingOptions
|
|
217
|
+
) -> dict[str, Any]:
|
|
218
|
+
"""Returns the JSON input for a message."""
|
|
219
|
+
request_args = self._request_args(sampling_options)
|
|
220
|
+
|
|
221
|
+
# Handle structured output.
|
|
222
|
+
output_schema = self._structure_output_schema(prompt)
|
|
223
|
+
if output_schema is not None:
|
|
224
|
+
output_schema['type'] = 'json_schema'
|
|
225
|
+
request_args.update(text=dict(format=output_schema))
|
|
226
|
+
prompt.metadata.formatted_text = (
|
|
227
|
+
prompt.text
|
|
228
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
|
229
|
+
+ pg.to_json_str(request_args['text'], json_indent=2)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
request = dict()
|
|
233
|
+
request.update(request_args)
|
|
234
|
+
|
|
235
|
+
# Users could use `metadata_system_message` to pass system message.
|
|
236
|
+
system_message = prompt.metadata.get('system_message')
|
|
237
|
+
if system_message:
|
|
238
|
+
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
239
|
+
request['instructions'] = system_message.text
|
|
240
|
+
|
|
241
|
+
# Prepare input.
|
|
242
|
+
def modality_check(chunk: str | lf.Modality) -> Any:
|
|
243
|
+
if (isinstance(chunk, lf_modalities.Mime)
|
|
244
|
+
and not self.supports_input(chunk.mime_type)):
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f'Unsupported modality: {chunk!r}.'
|
|
247
|
+
)
|
|
248
|
+
return chunk
|
|
249
|
+
|
|
250
|
+
request['input'] = [
|
|
251
|
+
prompt.as_format(
|
|
252
|
+
'openai_responses_api',
|
|
253
|
+
chunk_preprocessor=modality_check
|
|
254
|
+
)
|
|
255
|
+
]
|
|
256
|
+
return request
|
|
257
|
+
|
|
258
|
+
def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
|
|
259
|
+
for item in output:
|
|
260
|
+
if isinstance(item, dict) and item.get('type') == 'message':
|
|
261
|
+
return lf.LMSample(
|
|
262
|
+
lf.Message.from_value(item, format='openai_responses_api'),
|
|
263
|
+
score=0.0,
|
|
264
|
+
)
|
|
265
|
+
raise ValueError('No message found in output.')
|
|
266
|
+
|
|
267
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
|
268
|
+
"""Returns a LMSamplingResult from a JSON response."""
|
|
269
|
+
usage = json['usage']
|
|
270
|
+
return lf.LMSamplingResult(
|
|
271
|
+
samples=[self._parse_output(json['output'])],
|
|
272
|
+
usage=lf.LMSamplingUsage(
|
|
273
|
+
prompt_tokens=usage['input_tokens'],
|
|
274
|
+
completion_tokens=usage['output_tokens'],
|
|
275
|
+
total_tokens=usage['total_tokens'],
|
|
276
|
+
completion_tokens_details=usage.get(
|
|
277
|
+
'output_tokens_details', None
|
|
278
|
+
),
|
|
279
|
+
),
|
|
280
|
+
)
|
|
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
|
|
38
38
|
response_format = ''
|
|
39
39
|
|
|
40
40
|
choices = []
|
|
41
|
-
for k in range(json
|
|
41
|
+
for k in range(json.get('n', 1)):
|
|
42
42
|
if json.get('logprobs'):
|
|
43
43
|
logprobs = dict(
|
|
44
44
|
content=[
|
|
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
|
|
|
89
89
|
c['image_url']['url']
|
|
90
90
|
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
|
91
91
|
]
|
|
92
|
-
for k in range(json
|
|
92
|
+
for k in range(json.get('n', 1)):
|
|
93
93
|
choices.append(pg.Dict(
|
|
94
94
|
message=pg.Dict(
|
|
95
95
|
content=f'Sample {k} for message: {"".join(urls)}'
|
|
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
|
|
|
111
111
|
return response
|
|
112
112
|
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
|
|
115
|
+
del url, kwargs
|
|
116
|
+
_ = json['input']
|
|
117
|
+
|
|
118
|
+
system_message = ''
|
|
119
|
+
if 'instructions' in json:
|
|
120
|
+
system_message = f' system={json["instructions"]}'
|
|
121
|
+
|
|
122
|
+
response_format = ''
|
|
123
|
+
if 'text' in json and 'format' in json['text']:
|
|
124
|
+
response_format = f' format={json["text"]["format"]["type"]}'
|
|
125
|
+
|
|
126
|
+
output = [
|
|
127
|
+
dict(
|
|
128
|
+
type='message',
|
|
129
|
+
content=[
|
|
130
|
+
dict(
|
|
131
|
+
type='output_text',
|
|
132
|
+
text=(
|
|
133
|
+
f'Sample 0 for message.{system_message}{response_format}'
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
response = requests.Response()
|
|
141
|
+
response.status_code = 200
|
|
142
|
+
response._content = pg.to_json_str(
|
|
143
|
+
dict(
|
|
144
|
+
output=output,
|
|
145
|
+
usage=dict(
|
|
146
|
+
input_tokens=100,
|
|
147
|
+
output_tokens=100,
|
|
148
|
+
total_tokens=200,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
).encode()
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mock_responses_request_vision(
|
|
156
|
+
url: str, json: dict[str, Any], **kwargs
|
|
157
|
+
):
|
|
158
|
+
del url, kwargs
|
|
159
|
+
urls = [
|
|
160
|
+
c['image_url']
|
|
161
|
+
for c in json['input'][0]['content']
|
|
162
|
+
if c['type'] == 'input_image'
|
|
163
|
+
]
|
|
164
|
+
output = [
|
|
165
|
+
pg.Dict(
|
|
166
|
+
type='message',
|
|
167
|
+
content=[
|
|
168
|
+
pg.Dict(
|
|
169
|
+
type='output_text',
|
|
170
|
+
text=f'Sample 0 for message: {"".join(urls)}',
|
|
171
|
+
)
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
response = requests.Response()
|
|
176
|
+
response.status_code = 200
|
|
177
|
+
response._content = pg.to_json_str(
|
|
178
|
+
dict(
|
|
179
|
+
output=output,
|
|
180
|
+
usage=dict(
|
|
181
|
+
input_tokens=100,
|
|
182
|
+
output_tokens=100,
|
|
183
|
+
total_tokens=200,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
).encode()
|
|
187
|
+
return response
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class OpenAIChatCompletionAPITest(unittest.TestCase):
|
|
115
191
|
"""Tests for OpenAI compatible language model."""
|
|
116
192
|
|
|
117
193
|
def test_request_args(self):
|
|
118
194
|
self.assertEqual(
|
|
119
|
-
openai_compatible.
|
|
195
|
+
openai_compatible.OpenAIChatCompletionAPI(
|
|
120
196
|
api_endpoint='https://test-server',
|
|
121
197
|
model='test-model'
|
|
122
198
|
)._request_args(
|
|
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
126
202
|
),
|
|
127
203
|
dict(
|
|
128
204
|
model='test-model',
|
|
129
|
-
top_logprobs=None,
|
|
130
|
-
n=1,
|
|
131
205
|
temperature=1.0,
|
|
132
206
|
stop=['\n'],
|
|
133
207
|
seed=123,
|
|
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
137
211
|
def test_call_chat_completion(self):
|
|
138
212
|
with mock.patch('requests.Session.post') as mock_request:
|
|
139
213
|
mock_request.side_effect = mock_chat_completion_request
|
|
140
|
-
lm = openai_compatible.
|
|
214
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
141
215
|
api_endpoint='https://test-server', model='test-model',
|
|
142
216
|
)
|
|
143
217
|
self.assertEqual(
|
|
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
148
222
|
def test_call_chat_completion_with_logprobs(self):
|
|
149
223
|
with mock.patch('requests.Session.post') as mock_request:
|
|
150
224
|
mock_request.side_effect = mock_chat_completion_request
|
|
151
|
-
lm = openai_compatible.
|
|
225
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
152
226
|
api_endpoint='https://test-server', model='test-model',
|
|
153
227
|
)
|
|
154
228
|
results = lm.sample(['hello'], logprobs=True)
|
|
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
214
288
|
def mime_type(self) -> str:
|
|
215
289
|
return 'image/png'
|
|
216
290
|
|
|
291
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
217
292
|
with mock.patch('requests.Session.post') as mock_request:
|
|
218
293
|
mock_request.side_effect = mock_chat_completion_request_vision
|
|
219
|
-
lm_1 = openai_compatible.
|
|
294
|
+
lm_1 = openai_compatible.OpenAIChatCompletionAPI(
|
|
220
295
|
api_endpoint='https://test-server',
|
|
221
296
|
model='test-model1',
|
|
222
297
|
)
|
|
223
|
-
lm_2 = openai_compatible.
|
|
298
|
+
lm_2 = openai_compatible.OpenAIChatCompletionAPI(
|
|
224
299
|
api_endpoint='https://test-server',
|
|
225
300
|
model='test-model2',
|
|
226
301
|
)
|
|
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
228
303
|
self.assertEqual(
|
|
229
304
|
lm(
|
|
230
305
|
lf.UserMessage(
|
|
231
|
-
'hello <<[[image]]>>',
|
|
232
|
-
|
|
306
|
+
f'hello <<[[{image.id}]]>>',
|
|
307
|
+
referred_modalities=[image],
|
|
233
308
|
),
|
|
234
309
|
sampling_options=lf.LMSamplingOptions(n=2)
|
|
235
310
|
),
|
|
236
311
|
'Sample 0 for message: https://fake/image',
|
|
237
312
|
)
|
|
238
313
|
|
|
239
|
-
class TextOnlyModel(openai_compatible.
|
|
314
|
+
class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
|
|
240
315
|
|
|
241
316
|
class ModelInfo(lf.ModelInfo):
|
|
242
317
|
input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
|
|
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
251
326
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
|
252
327
|
lm_3(
|
|
253
328
|
lf.UserMessage(
|
|
254
|
-
'hello <<[[image]]>>',
|
|
255
|
-
|
|
329
|
+
f'hello <<[[{image.id}]]>>',
|
|
330
|
+
referred_modalities=[image],
|
|
256
331
|
),
|
|
257
332
|
)
|
|
258
333
|
|
|
259
334
|
def test_sample_chat_completion(self):
|
|
260
335
|
with mock.patch('requests.Session.post') as mock_request:
|
|
261
336
|
mock_request.side_effect = mock_chat_completion_request
|
|
262
|
-
lm = openai_compatible.
|
|
337
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
263
338
|
api_endpoint='https://test-server', model='test-model'
|
|
264
339
|
)
|
|
265
340
|
results = lm.sample(
|
|
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
400
475
|
def test_sample_with_contextual_options(self):
|
|
401
476
|
with mock.patch('requests.Session.post') as mock_request:
|
|
402
477
|
mock_request.side_effect = mock_chat_completion_request
|
|
403
|
-
lm = openai_compatible.
|
|
478
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
404
479
|
api_endpoint='https://test-server', model='test-model'
|
|
405
480
|
)
|
|
406
481
|
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
|
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
458
533
|
def test_call_with_system_message(self):
|
|
459
534
|
with mock.patch('requests.Session.post') as mock_request:
|
|
460
535
|
mock_request.side_effect = mock_chat_completion_request
|
|
461
|
-
lm = openai_compatible.
|
|
536
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
462
537
|
api_endpoint='https://test-server', model='test-model'
|
|
463
538
|
)
|
|
464
539
|
self.assertEqual(
|
|
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
475
550
|
def test_call_with_json_schema(self):
|
|
476
551
|
with mock.patch('requests.Session.post') as mock_request:
|
|
477
552
|
mock_request.side_effect = mock_chat_completion_request
|
|
478
|
-
lm = openai_compatible.
|
|
553
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
479
554
|
api_endpoint='https://test-server', model='test-model'
|
|
480
555
|
)
|
|
481
556
|
self.assertEqual(
|
|
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
515
590
|
|
|
516
591
|
with mock.patch('requests.Session.post') as mock_request:
|
|
517
592
|
mock_request.side_effect = mock_context_limit_error
|
|
518
|
-
lm = openai_compatible.
|
|
593
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
519
594
|
api_endpoint='https://test-server', model='test-model'
|
|
520
595
|
)
|
|
521
596
|
with self.assertRaisesRegex(
|
|
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
524
599
|
lm(lf.UserMessage('hello'))
|
|
525
600
|
|
|
526
601
|
|
|
602
|
+
class OpenAIResponsesAPITest(unittest.TestCase):
|
|
603
|
+
"""Tests for OpenAI compatible language model on Responses API."""
|
|
604
|
+
|
|
605
|
+
def test_request_args(self):
|
|
606
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
607
|
+
api_endpoint='https://test-server', model='test-model'
|
|
608
|
+
)
|
|
609
|
+
# Test valid args.
|
|
610
|
+
self.assertEqual(
|
|
611
|
+
lm._request_args(
|
|
612
|
+
lf.LMSamplingOptions(
|
|
613
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
|
614
|
+
)
|
|
615
|
+
),
|
|
616
|
+
dict(
|
|
617
|
+
model='test-model',
|
|
618
|
+
temperature=1.0,
|
|
619
|
+
stop=['\n'],
|
|
620
|
+
seed=123,
|
|
621
|
+
),
|
|
622
|
+
)
|
|
623
|
+
# Test unsupported n.
|
|
624
|
+
with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
|
|
625
|
+
lm._request_args(lf.LMSamplingOptions(n=2))
|
|
626
|
+
|
|
627
|
+
# Test unsupported logprobs.
|
|
628
|
+
with self.assertRaisesRegex(
|
|
629
|
+
ValueError, 'logprobs is not supported on Responses API.'
|
|
630
|
+
):
|
|
631
|
+
lm._request_args(lf.LMSamplingOptions(logprobs=True))
|
|
632
|
+
|
|
633
|
+
def test_call_responses(self):
|
|
634
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
635
|
+
mock_request.side_effect = mock_responses_request
|
|
636
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
637
|
+
api_endpoint='https://test-server',
|
|
638
|
+
model='test-model',
|
|
639
|
+
)
|
|
640
|
+
self.assertEqual(lm('hello'), 'Sample 0 for message.')
|
|
641
|
+
|
|
642
|
+
def test_call_responses_vision(self):
|
|
643
|
+
class FakeImage(lf_modalities.Image):
|
|
644
|
+
@property
|
|
645
|
+
def mime_type(self) -> str:
|
|
646
|
+
return 'image/png'
|
|
647
|
+
|
|
648
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
649
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
650
|
+
mock_request.side_effect = mock_responses_request_vision
|
|
651
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
652
|
+
api_endpoint='https://test-server',
|
|
653
|
+
model='test-model1',
|
|
654
|
+
)
|
|
655
|
+
self.assertEqual(
|
|
656
|
+
lm(
|
|
657
|
+
lf.UserMessage(
|
|
658
|
+
f'hello <<[[{image.id}]]>>',
|
|
659
|
+
referred_modalities=[image],
|
|
660
|
+
)
|
|
661
|
+
),
|
|
662
|
+
'Sample 0 for message: https://fake/image',
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
def test_call_with_system_message(self):
|
|
666
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
667
|
+
mock_request.side_effect = mock_responses_request
|
|
668
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
669
|
+
api_endpoint='https://test-server', model='test-model'
|
|
670
|
+
)
|
|
671
|
+
self.assertEqual(
|
|
672
|
+
lm(
|
|
673
|
+
lf.UserMessage(
|
|
674
|
+
'hello',
|
|
675
|
+
system_message=lf.SystemMessage('hi'),
|
|
676
|
+
)
|
|
677
|
+
),
|
|
678
|
+
'Sample 0 for message. system=hi',
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
def test_call_with_json_schema(self):
|
|
682
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
683
|
+
mock_request.side_effect = mock_responses_request
|
|
684
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
685
|
+
api_endpoint='https://test-server', model='test-model'
|
|
686
|
+
)
|
|
687
|
+
self.assertEqual(
|
|
688
|
+
lm(
|
|
689
|
+
lf.UserMessage(
|
|
690
|
+
'hello',
|
|
691
|
+
json_schema={
|
|
692
|
+
'type': 'object',
|
|
693
|
+
'properties': {
|
|
694
|
+
'name': {'type': 'string'},
|
|
695
|
+
},
|
|
696
|
+
'required': ['name'],
|
|
697
|
+
'title': 'Person',
|
|
698
|
+
},
|
|
699
|
+
)
|
|
700
|
+
),
|
|
701
|
+
'Sample 0 for message. format=json_schema',
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Test bad json schema.
|
|
705
|
+
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
|
706
|
+
lm(lf.UserMessage('hello', json_schema='foo'))
|
|
707
|
+
|
|
708
|
+
with self.assertRaisesRegex(
|
|
709
|
+
ValueError, 'The root of `json_schema` must have a `title` field'
|
|
710
|
+
):
|
|
711
|
+
lm(lf.UserMessage('hello', json_schema={}))
|
|
712
|
+
|
|
713
|
+
|
|
527
714
|
if __name__ == '__main__':
|
|
528
715
|
unittest.main()
|
langfun/core/llms/openai_test.py
CHANGED
langfun/core/llms/vertexai.py
CHANGED
|
@@ -497,7 +497,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
|
|
|
497
497
|
|
|
498
498
|
@pg.use_init_args(['model'])
|
|
499
499
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
500
|
-
class VertexAILlama(VertexAI, openai_compatible.
|
|
500
|
+
class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
501
501
|
"""Llama models on VertexAI."""
|
|
502
502
|
|
|
503
503
|
model: pg.typing.Annotated[
|
|
@@ -610,7 +610,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
|
|
|
610
610
|
|
|
611
611
|
@pg.use_init_args(['model'])
|
|
612
612
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
613
|
-
class VertexAIMistral(VertexAI, openai_compatible.
|
|
613
|
+
class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
614
614
|
"""Mistral AI models on VertexAI."""
|
|
615
615
|
|
|
616
616
|
model: pg.typing.Annotated[
|