langfun 0.1.2.dev202510240805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (41) hide show
  1. langfun/core/concurrent_test.py +1 -0
  2. langfun/core/data/conversion/anthropic_test.py +8 -6
  3. langfun/core/data/conversion/gemini_test.py +12 -9
  4. langfun/core/data/conversion/openai.py +134 -30
  5. langfun/core/data/conversion/openai_test.py +161 -17
  6. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  7. langfun/core/langfunc_test.py +4 -2
  8. langfun/core/language_model.py +6 -6
  9. langfun/core/language_model_test.py +9 -3
  10. langfun/core/llms/__init__.py +2 -1
  11. langfun/core/llms/cache/base.py +3 -1
  12. langfun/core/llms/cache/in_memory_test.py +14 -4
  13. langfun/core/llms/deepseek.py +1 -1
  14. langfun/core/llms/groq.py +1 -1
  15. langfun/core/llms/llama_cpp.py +1 -1
  16. langfun/core/llms/openai.py +7 -2
  17. langfun/core/llms/openai_compatible.py +134 -27
  18. langfun/core/llms/openai_compatible_test.py +207 -20
  19. langfun/core/llms/openai_test.py +0 -2
  20. langfun/core/llms/vertexai.py +2 -2
  21. langfun/core/message.py +78 -44
  22. langfun/core/message_test.py +56 -81
  23. langfun/core/modalities/__init__.py +8 -0
  24. langfun/core/modalities/mime.py +9 -0
  25. langfun/core/modality.py +104 -27
  26. langfun/core/modality_test.py +42 -12
  27. langfun/core/sampling_test.py +20 -4
  28. langfun/core/structured/completion.py +2 -7
  29. langfun/core/structured/completion_test.py +23 -43
  30. langfun/core/structured/mapping.py +4 -13
  31. langfun/core/structured/querying.py +13 -11
  32. langfun/core/structured/querying_test.py +65 -29
  33. langfun/core/template.py +39 -13
  34. langfun/core/template_test.py +83 -17
  35. langfun/env/event_handlers/metric_writer_test.py +3 -3
  36. langfun/env/load_balancers_test.py +2 -2
  37. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
  38. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +41 -41
  39. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
  40. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
  41. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
@@ -175,18 +175,28 @@ class InMemoryLMCacheTest(unittest.TestCase):
175
175
 
176
176
  cache = in_memory.InMemory()
177
177
  lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
178
- lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
179
- lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
178
+ image_foo = CustomModality('foo')
179
+ image_bar = CustomModality('bar')
180
+ lm(
181
+ lf.UserMessage(
182
+ f'hi <<[[{image_foo.id}]]>>', referred_modalities=[image_foo]
183
+ )
184
+ )
185
+ lm(
186
+ lf.UserMessage(
187
+ f'hi <<[[{image_bar.id}]]>>', referred_modalities=[image_bar]
188
+ )
189
+ )
180
190
  self.assertEqual(
181
191
  list(cache.keys()),
182
192
  [
183
193
  (
184
- 'hi <<[[image]]>><image>acbd18db</image>',
194
+ f'hi <<[[{image_foo.id}]]>>',
185
195
  (None, None, 1, 40, None, None),
186
196
  0,
187
197
  ),
188
198
  (
189
- 'hi <<[[image]]>><image>37b51d19</image>',
199
+ f'hi <<[[{image_bar.id}]]>>',
190
200
  (None, None, 1, 40, None, None),
191
201
  0,
192
202
  ),
@@ -93,7 +93,7 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
93
93
  # DeepSeek API uses an API format compatible with OpenAI.
94
94
  # Reference: https://api-docs.deepseek.com/
95
95
  @lf.use_init_args(['model'])
96
- class DeepSeek(openai_compatible.OpenAICompatible):
96
+ class DeepSeek(openai_compatible.OpenAIChatCompletionAPI):
97
97
  """DeepSeek model."""
98
98
 
99
99
  model: pg.typing.Annotated[
langfun/core/llms/groq.py CHANGED
@@ -259,7 +259,7 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
259
259
 
260
260
 
261
261
  @lf.use_init_args(['model'])
262
- class Groq(openai_compatible.OpenAICompatible):
262
+ class Groq(openai_compatible.OpenAIChatCompletionAPI):
263
263
  """Groq LLMs through REST APIs (OpenAI compatible).
264
264
 
265
265
  See https://platform.openai.com/docs/api-reference/chat
@@ -20,7 +20,7 @@ import pyglove as pg
20
20
 
21
21
  @pg.use_init_args(['url', 'model'])
22
22
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
23
- class LlamaCppRemote(openai_compatible.OpenAICompatible):
23
+ class LlamaCppRemote(openai_compatible.OpenAIChatCompletionAPI):
24
24
  """The remote LLaMA C++ model.
25
25
 
26
26
  The Remote LLaMA C++ models can be launched via
@@ -1031,7 +1031,7 @@ _SUPPORTED_MODELS_BY_MODEL_ID = {m.model_id: m for m in SUPPORTED_MODELS}
1031
1031
 
1032
1032
 
1033
1033
  @lf.use_init_args(['model'])
1034
- class OpenAI(openai_compatible.OpenAICompatible):
1034
+ class OpenAI(openai_compatible.OpenAIResponsesAPI):
1035
1035
  """OpenAI model."""
1036
1036
 
1037
1037
  model: pg.typing.Annotated[
@@ -1041,7 +1041,12 @@ class OpenAI(openai_compatible.OpenAICompatible):
1041
1041
  'The name of the model to use.',
1042
1042
  ]
1043
1043
 
1044
- api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
1044
+ # Disable message storage by default.
1045
+ sampling_options = lf.LMSamplingOptions(
1046
+ extras={'store': False}
1047
+ )
1048
+
1049
+ api_endpoint: str = 'https://api.openai.com/v1/responses'
1045
1050
 
1046
1051
  api_key: Annotated[
1047
1052
  str | None,
@@ -23,8 +23,13 @@ import pyglove as pg
23
23
 
24
24
 
25
25
  @lf.use_init_args(['api_endpoint', 'model'])
26
- class OpenAICompatible(rest.REST):
27
- """Base for OpenAI compatible models."""
26
+ class OpenAIChatCompletionAPI(rest.REST):
27
+ """Base for OpenAI compatible models based on ChatCompletion API.
28
+
29
+ See https://platform.openai.com/docs/api-reference/chat
30
+ As of 2025-10-23, OpenAI is migrating from ChatCompletion API to Responses
31
+ API.
32
+ """
28
33
 
29
34
  model: Annotated[
30
35
  str, 'The name of the model to use.',
@@ -42,12 +47,14 @@ class OpenAICompatible(rest.REST):
42
47
  # Reference:
43
48
  # https://platform.openai.com/docs/api-reference/completions/create
44
49
  # NOTE(daiyip): options.top_k is not applicable.
45
- args = dict(
46
- n=options.n,
47
- top_logprobs=options.top_logprobs,
48
- )
50
+ args = {}
51
+
49
52
  if self.model:
50
53
  args['model'] = self.model
54
+ if options.n != 1:
55
+ args['n'] = options.n
56
+ if options.top_logprobs is not None:
57
+ args['top_logprobs'] = options.top_logprobs
51
58
  if options.logprobs:
52
59
  args['logprobs'] = options.logprobs
53
60
  if options.temperature is not None:
@@ -74,27 +81,13 @@ class OpenAICompatible(rest.REST):
74
81
  """Returns the JSON input for a message."""
75
82
  request_args = self._request_args(sampling_options)
76
83
 
77
- # Users could use `metadata_json_schema` to pass additional
78
- # request arguments.
79
- json_schema = prompt.metadata.get('json_schema')
80
- if json_schema is not None:
81
- if not isinstance(json_schema, dict):
82
- raise ValueError(
83
- f'`json_schema` must be a dict, got {json_schema!r}.'
84
- )
85
- if 'title' not in json_schema:
86
- raise ValueError(
87
- f'The root of `json_schema` must have a `title` field, '
88
- f'got {json_schema!r}.'
89
- )
84
+ # Handle structured output.
85
+ output_schema = self._structure_output_schema(prompt)
86
+ if output_schema is not None:
90
87
  request_args.update(
91
88
  response_format=dict(
92
89
  type='json_schema',
93
- json_schema=dict(
94
- schema=json_schema,
95
- name=json_schema['title'],
96
- strict=True,
97
- )
90
+ json_schema=output_schema,
98
91
  )
99
92
  )
100
93
  prompt.metadata.formatted_text = (
@@ -120,17 +113,43 @@ class OpenAICompatible(rest.REST):
120
113
  assert isinstance(system_message, lf.SystemMessage), type(system_message)
121
114
  messages.append(
122
115
  system_message.as_format(
123
- 'openai', chunk_preprocessor=modality_check
116
+ 'openai_chat_completion_api', chunk_preprocessor=modality_check
124
117
  )
125
118
  )
126
119
  messages.append(
127
- prompt.as_format('openai', chunk_preprocessor=modality_check)
120
+ prompt.as_format(
121
+ 'openai_chat_completion_api',
122
+ chunk_preprocessor=modality_check
123
+ )
128
124
  )
129
125
  request = dict()
130
126
  request.update(request_args)
131
127
  request['messages'] = messages
132
128
  return request
133
129
 
130
+ def _structure_output_schema(
131
+ self, prompt: lf.Message
132
+ ) -> dict[str, Any] | None:
133
+ # Users could use `metadata_json_schema` to pass additional
134
+ # request arguments.
135
+ json_schema = prompt.metadata.get('json_schema')
136
+ if json_schema is not None:
137
+ if not isinstance(json_schema, dict):
138
+ raise ValueError(
139
+ f'`json_schema` must be a dict, got {json_schema!r}.'
140
+ )
141
+ if 'title' not in json_schema:
142
+ raise ValueError(
143
+ f'The root of `json_schema` must have a `title` field, '
144
+ f'got {json_schema!r}.'
145
+ )
146
+ return dict(
147
+ schema=json_schema,
148
+ name=json_schema['title'],
149
+ strict=True,
150
+ )
151
+ return None
152
+
134
153
  def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
135
154
  # Reference:
136
155
  # https://platform.openai.com/docs/api-reference/chat/object
@@ -146,7 +165,10 @@ class OpenAICompatible(rest.REST):
146
165
  for t in choice_logprobs['content']
147
166
  ]
148
167
  return lf.LMSample(
149
- lf.Message.from_value(choice['message'], format='openai'),
168
+ lf.Message.from_value(
169
+ choice['message'],
170
+ format='openai_chat_completion_api'
171
+ ),
150
172
  score=0.0,
151
173
  logprobs=logprobs,
152
174
  )
@@ -171,3 +193,88 @@ class OpenAICompatible(rest.REST):
171
193
  or (status_code == 400 and b'string_above_max_length' in content)):
172
194
  return lf.ContextLimitError(f'{status_code}: {content}')
173
195
  return super()._error(status_code, content)
196
+
197
+
198
+ class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
199
+ """Base for OpenAI compatible models based on Responses API.
200
+
201
+ https://platform.openai.com/docs/api-reference/responses/create
202
+ """
203
+
204
+ def _request_args(
205
+ self, options: lf.LMSamplingOptions) -> dict[str, Any]:
206
+ """Returns a dict as request arguments."""
207
+ if options.logprobs:
208
+ raise ValueError('logprobs is not supported on Responses API.')
209
+ if options.n != 1:
210
+ raise ValueError('n must be 1 for Responses API.')
211
+ return super()._request_args(options)
212
+
213
+ def request(
214
+ self,
215
+ prompt: lf.Message,
216
+ sampling_options: lf.LMSamplingOptions
217
+ ) -> dict[str, Any]:
218
+ """Returns the JSON input for a message."""
219
+ request_args = self._request_args(sampling_options)
220
+
221
+ # Handle structured output.
222
+ output_schema = self._structure_output_schema(prompt)
223
+ if output_schema is not None:
224
+ output_schema['type'] = 'json_schema'
225
+ request_args.update(text=dict(format=output_schema))
226
+ prompt.metadata.formatted_text = (
227
+ prompt.text
228
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
229
+ + pg.to_json_str(request_args['text'], json_indent=2)
230
+ )
231
+
232
+ request = dict()
233
+ request.update(request_args)
234
+
235
+ # Users could use `metadata_system_message` to pass system message.
236
+ system_message = prompt.metadata.get('system_message')
237
+ if system_message:
238
+ assert isinstance(system_message, lf.SystemMessage), type(system_message)
239
+ request['instructions'] = system_message.text
240
+
241
+ # Prepare input.
242
+ def modality_check(chunk: str | lf.Modality) -> Any:
243
+ if (isinstance(chunk, lf_modalities.Mime)
244
+ and not self.supports_input(chunk.mime_type)):
245
+ raise ValueError(
246
+ f'Unsupported modality: {chunk!r}.'
247
+ )
248
+ return chunk
249
+
250
+ request['input'] = [
251
+ prompt.as_format(
252
+ 'openai_responses_api',
253
+ chunk_preprocessor=modality_check
254
+ )
255
+ ]
256
+ return request
257
+
258
+ def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
259
+ for item in output:
260
+ if isinstance(item, dict) and item.get('type') == 'message':
261
+ return lf.LMSample(
262
+ lf.Message.from_value(item, format='openai_responses_api'),
263
+ score=0.0,
264
+ )
265
+ raise ValueError('No message found in output.')
266
+
267
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
268
+ """Returns a LMSamplingResult from a JSON response."""
269
+ usage = json['usage']
270
+ return lf.LMSamplingResult(
271
+ samples=[self._parse_output(json['output'])],
272
+ usage=lf.LMSamplingUsage(
273
+ prompt_tokens=usage['input_tokens'],
274
+ completion_tokens=usage['output_tokens'],
275
+ total_tokens=usage['total_tokens'],
276
+ completion_tokens_details=usage.get(
277
+ 'output_tokens_details', None
278
+ ),
279
+ ),
280
+ )
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
38
38
  response_format = ''
39
39
 
40
40
  choices = []
41
- for k in range(json['n']):
41
+ for k in range(json.get('n', 1)):
42
42
  if json.get('logprobs'):
43
43
  logprobs = dict(
44
44
  content=[
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
89
89
  c['image_url']['url']
90
90
  for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
91
  ]
92
- for k in range(json['n']):
92
+ for k in range(json.get('n', 1)):
93
93
  choices.append(pg.Dict(
94
94
  message=pg.Dict(
95
95
  content=f'Sample {k} for message: {"".join(urls)}'
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
111
111
  return response
112
112
 
113
113
 
114
- class OpenAIComptibleTest(unittest.TestCase):
114
+ def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
115
+ del url, kwargs
116
+ _ = json['input']
117
+
118
+ system_message = ''
119
+ if 'instructions' in json:
120
+ system_message = f' system={json["instructions"]}'
121
+
122
+ response_format = ''
123
+ if 'text' in json and 'format' in json['text']:
124
+ response_format = f' format={json["text"]["format"]["type"]}'
125
+
126
+ output = [
127
+ dict(
128
+ type='message',
129
+ content=[
130
+ dict(
131
+ type='output_text',
132
+ text=(
133
+ f'Sample 0 for message.{system_message}{response_format}'
134
+ )
135
+ )
136
+ ],
137
+ )
138
+ ]
139
+
140
+ response = requests.Response()
141
+ response.status_code = 200
142
+ response._content = pg.to_json_str(
143
+ dict(
144
+ output=output,
145
+ usage=dict(
146
+ input_tokens=100,
147
+ output_tokens=100,
148
+ total_tokens=200,
149
+ ),
150
+ )
151
+ ).encode()
152
+ return response
153
+
154
+
155
+ def mock_responses_request_vision(
156
+ url: str, json: dict[str, Any], **kwargs
157
+ ):
158
+ del url, kwargs
159
+ urls = [
160
+ c['image_url']
161
+ for c in json['input'][0]['content']
162
+ if c['type'] == 'input_image'
163
+ ]
164
+ output = [
165
+ pg.Dict(
166
+ type='message',
167
+ content=[
168
+ pg.Dict(
169
+ type='output_text',
170
+ text=f'Sample 0 for message: {"".join(urls)}',
171
+ )
172
+ ],
173
+ )
174
+ ]
175
+ response = requests.Response()
176
+ response.status_code = 200
177
+ response._content = pg.to_json_str(
178
+ dict(
179
+ output=output,
180
+ usage=dict(
181
+ input_tokens=100,
182
+ output_tokens=100,
183
+ total_tokens=200,
184
+ ),
185
+ )
186
+ ).encode()
187
+ return response
188
+
189
+
190
+ class OpenAIChatCompletionAPITest(unittest.TestCase):
115
191
  """Tests for OpenAI compatible language model."""
116
192
 
117
193
  def test_request_args(self):
118
194
  self.assertEqual(
119
- openai_compatible.OpenAICompatible(
195
+ openai_compatible.OpenAIChatCompletionAPI(
120
196
  api_endpoint='https://test-server',
121
197
  model='test-model'
122
198
  )._request_args(
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
126
202
  ),
127
203
  dict(
128
204
  model='test-model',
129
- top_logprobs=None,
130
- n=1,
131
205
  temperature=1.0,
132
206
  stop=['\n'],
133
207
  seed=123,
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
137
211
  def test_call_chat_completion(self):
138
212
  with mock.patch('requests.Session.post') as mock_request:
139
213
  mock_request.side_effect = mock_chat_completion_request
140
- lm = openai_compatible.OpenAICompatible(
214
+ lm = openai_compatible.OpenAIChatCompletionAPI(
141
215
  api_endpoint='https://test-server', model='test-model',
142
216
  )
143
217
  self.assertEqual(
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
148
222
  def test_call_chat_completion_with_logprobs(self):
149
223
  with mock.patch('requests.Session.post') as mock_request:
150
224
  mock_request.side_effect = mock_chat_completion_request
151
- lm = openai_compatible.OpenAICompatible(
225
+ lm = openai_compatible.OpenAIChatCompletionAPI(
152
226
  api_endpoint='https://test-server', model='test-model',
153
227
  )
154
228
  results = lm.sample(['hello'], logprobs=True)
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
214
288
  def mime_type(self) -> str:
215
289
  return 'image/png'
216
290
 
291
+ image = FakeImage.from_uri('https://fake/image')
217
292
  with mock.patch('requests.Session.post') as mock_request:
218
293
  mock_request.side_effect = mock_chat_completion_request_vision
219
- lm_1 = openai_compatible.OpenAICompatible(
294
+ lm_1 = openai_compatible.OpenAIChatCompletionAPI(
220
295
  api_endpoint='https://test-server',
221
296
  model='test-model1',
222
297
  )
223
- lm_2 = openai_compatible.OpenAICompatible(
298
+ lm_2 = openai_compatible.OpenAIChatCompletionAPI(
224
299
  api_endpoint='https://test-server',
225
300
  model='test-model2',
226
301
  )
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
228
303
  self.assertEqual(
229
304
  lm(
230
305
  lf.UserMessage(
231
- 'hello <<[[image]]>>',
232
- image=FakeImage.from_uri('https://fake/image')
306
+ f'hello <<[[{image.id}]]>>',
307
+ referred_modalities=[image],
233
308
  ),
234
309
  sampling_options=lf.LMSamplingOptions(n=2)
235
310
  ),
236
311
  'Sample 0 for message: https://fake/image',
237
312
  )
238
313
 
239
- class TextOnlyModel(openai_compatible.OpenAICompatible):
314
+ class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
240
315
 
241
316
  class ModelInfo(lf.ModelInfo):
242
317
  input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
251
326
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
252
327
  lm_3(
253
328
  lf.UserMessage(
254
- 'hello <<[[image]]>>',
255
- image=FakeImage.from_uri('https://fake/image')
329
+ f'hello <<[[{image.id}]]>>',
330
+ referred_modalities=[image],
256
331
  ),
257
332
  )
258
333
 
259
334
  def test_sample_chat_completion(self):
260
335
  with mock.patch('requests.Session.post') as mock_request:
261
336
  mock_request.side_effect = mock_chat_completion_request
262
- lm = openai_compatible.OpenAICompatible(
337
+ lm = openai_compatible.OpenAIChatCompletionAPI(
263
338
  api_endpoint='https://test-server', model='test-model'
264
339
  )
265
340
  results = lm.sample(
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
400
475
  def test_sample_with_contextual_options(self):
401
476
  with mock.patch('requests.Session.post') as mock_request:
402
477
  mock_request.side_effect = mock_chat_completion_request
403
- lm = openai_compatible.OpenAICompatible(
478
+ lm = openai_compatible.OpenAIChatCompletionAPI(
404
479
  api_endpoint='https://test-server', model='test-model'
405
480
  )
406
481
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
458
533
  def test_call_with_system_message(self):
459
534
  with mock.patch('requests.Session.post') as mock_request:
460
535
  mock_request.side_effect = mock_chat_completion_request
461
- lm = openai_compatible.OpenAICompatible(
536
+ lm = openai_compatible.OpenAIChatCompletionAPI(
462
537
  api_endpoint='https://test-server', model='test-model'
463
538
  )
464
539
  self.assertEqual(
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
475
550
  def test_call_with_json_schema(self):
476
551
  with mock.patch('requests.Session.post') as mock_request:
477
552
  mock_request.side_effect = mock_chat_completion_request
478
- lm = openai_compatible.OpenAICompatible(
553
+ lm = openai_compatible.OpenAIChatCompletionAPI(
479
554
  api_endpoint='https://test-server', model='test-model'
480
555
  )
481
556
  self.assertEqual(
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
515
590
 
516
591
  with mock.patch('requests.Session.post') as mock_request:
517
592
  mock_request.side_effect = mock_context_limit_error
518
- lm = openai_compatible.OpenAICompatible(
593
+ lm = openai_compatible.OpenAIChatCompletionAPI(
519
594
  api_endpoint='https://test-server', model='test-model'
520
595
  )
521
596
  with self.assertRaisesRegex(
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
524
599
  lm(lf.UserMessage('hello'))
525
600
 
526
601
 
602
+ class OpenAIResponsesAPITest(unittest.TestCase):
603
+ """Tests for OpenAI compatible language model on Responses API."""
604
+
605
+ def test_request_args(self):
606
+ lm = openai_compatible.OpenAIResponsesAPI(
607
+ api_endpoint='https://test-server', model='test-model'
608
+ )
609
+ # Test valid args.
610
+ self.assertEqual(
611
+ lm._request_args(
612
+ lf.LMSamplingOptions(
613
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
614
+ )
615
+ ),
616
+ dict(
617
+ model='test-model',
618
+ temperature=1.0,
619
+ stop=['\n'],
620
+ seed=123,
621
+ ),
622
+ )
623
+ # Test unsupported n.
624
+ with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
625
+ lm._request_args(lf.LMSamplingOptions(n=2))
626
+
627
+ # Test unsupported logprobs.
628
+ with self.assertRaisesRegex(
629
+ ValueError, 'logprobs is not supported on Responses API.'
630
+ ):
631
+ lm._request_args(lf.LMSamplingOptions(logprobs=True))
632
+
633
+ def test_call_responses(self):
634
+ with mock.patch('requests.Session.post') as mock_request:
635
+ mock_request.side_effect = mock_responses_request
636
+ lm = openai_compatible.OpenAIResponsesAPI(
637
+ api_endpoint='https://test-server',
638
+ model='test-model',
639
+ )
640
+ self.assertEqual(lm('hello'), 'Sample 0 for message.')
641
+
642
+ def test_call_responses_vision(self):
643
+ class FakeImage(lf_modalities.Image):
644
+ @property
645
+ def mime_type(self) -> str:
646
+ return 'image/png'
647
+
648
+ image = FakeImage.from_uri('https://fake/image')
649
+ with mock.patch('requests.Session.post') as mock_request:
650
+ mock_request.side_effect = mock_responses_request_vision
651
+ lm = openai_compatible.OpenAIResponsesAPI(
652
+ api_endpoint='https://test-server',
653
+ model='test-model1',
654
+ )
655
+ self.assertEqual(
656
+ lm(
657
+ lf.UserMessage(
658
+ f'hello <<[[{image.id}]]>>',
659
+ referred_modalities=[image],
660
+ )
661
+ ),
662
+ 'Sample 0 for message: https://fake/image',
663
+ )
664
+
665
+ def test_call_with_system_message(self):
666
+ with mock.patch('requests.Session.post') as mock_request:
667
+ mock_request.side_effect = mock_responses_request
668
+ lm = openai_compatible.OpenAIResponsesAPI(
669
+ api_endpoint='https://test-server', model='test-model'
670
+ )
671
+ self.assertEqual(
672
+ lm(
673
+ lf.UserMessage(
674
+ 'hello',
675
+ system_message=lf.SystemMessage('hi'),
676
+ )
677
+ ),
678
+ 'Sample 0 for message. system=hi',
679
+ )
680
+
681
+ def test_call_with_json_schema(self):
682
+ with mock.patch('requests.Session.post') as mock_request:
683
+ mock_request.side_effect = mock_responses_request
684
+ lm = openai_compatible.OpenAIResponsesAPI(
685
+ api_endpoint='https://test-server', model='test-model'
686
+ )
687
+ self.assertEqual(
688
+ lm(
689
+ lf.UserMessage(
690
+ 'hello',
691
+ json_schema={
692
+ 'type': 'object',
693
+ 'properties': {
694
+ 'name': {'type': 'string'},
695
+ },
696
+ 'required': ['name'],
697
+ 'title': 'Person',
698
+ },
699
+ )
700
+ ),
701
+ 'Sample 0 for message. format=json_schema',
702
+ )
703
+
704
+ # Test bad json schema.
705
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
706
+ lm(lf.UserMessage('hello', json_schema='foo'))
707
+
708
+ with self.assertRaisesRegex(
709
+ ValueError, 'The root of `json_schema` must have a `title` field'
710
+ ):
711
+ lm(lf.UserMessage('hello', json_schema={}))
712
+
713
+
527
714
  if __name__ == '__main__':
528
715
  unittest.main()
@@ -61,8 +61,6 @@ class OpenAITest(unittest.TestCase):
61
61
  ),
62
62
  dict(
63
63
  model='gpt-4',
64
- top_logprobs=None,
65
- n=1,
66
64
  temperature=1.0,
67
65
  stop=['\n'],
68
66
  seed=123,
@@ -497,7 +497,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
497
497
 
498
498
  @pg.use_init_args(['model'])
499
499
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
500
- class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
500
+ class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
501
501
  """Llama models on VertexAI."""
502
502
 
503
503
  model: pg.typing.Annotated[
@@ -610,7 +610,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
610
610
 
611
611
  @pg.use_init_args(['model'])
612
612
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
613
- class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
613
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
614
614
  """Mistral AI models on VertexAI."""
615
615
 
616
616
  model: pg.typing.Annotated[