langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,18 @@ import pyglove as pg
23
23
 
24
24
 
25
25
  @lf.use_init_args(['api_endpoint', 'model'])
26
- class OpenAICompatible(rest.REST):
27
- """Base for OpenAI compatible models."""
26
+ class OpenAIChatCompletionAPI(rest.REST):
27
+ """Base class for models compatible with OpenAI's Chat Completion API.
28
+
29
+ This class provides a common interface for language models that adhere to
30
+ the OpenAI Chat Completion API format, which is used by providers like
31
+ Groq, DeepSeek, and others. It standardizes request formatting and
32
+ response parsing for these models.
33
+
34
+ **References:**
35
+
36
+ * https://platform.openai.com/docs/api-reference/chat
37
+ """
28
38
 
29
39
  model: Annotated[
30
40
  str, 'The name of the model to use.',
@@ -42,12 +52,14 @@ class OpenAICompatible(rest.REST):
42
52
  # Reference:
43
53
  # https://platform.openai.com/docs/api-reference/completions/create
44
54
  # NOTE(daiyip): options.top_k is not applicable.
45
- args = dict(
46
- n=options.n,
47
- top_logprobs=options.top_logprobs,
48
- )
55
+ args = {}
56
+
49
57
  if self.model:
50
58
  args['model'] = self.model
59
+ if options.n != 1:
60
+ args['n'] = options.n
61
+ if options.top_logprobs is not None:
62
+ args['top_logprobs'] = options.top_logprobs
51
63
  if options.logprobs:
52
64
  args['logprobs'] = options.logprobs
53
65
  if options.temperature is not None:
@@ -62,6 +74,8 @@ class OpenAICompatible(rest.REST):
62
74
  args['seed'] = options.random_seed
63
75
  if options.reasoning_effort is not None:
64
76
  args['reasoning_effort'] = options.reasoning_effort
77
+ if options.extras:
78
+ args.update(options.extras)
65
79
  return args
66
80
 
67
81
  def request(
@@ -72,27 +86,13 @@ class OpenAICompatible(rest.REST):
72
86
  """Returns the JSON input for a message."""
73
87
  request_args = self._request_args(sampling_options)
74
88
 
75
- # Users could use `metadata_json_schema` to pass additional
76
- # request arguments.
77
- json_schema = prompt.metadata.get('json_schema')
78
- if json_schema is not None:
79
- if not isinstance(json_schema, dict):
80
- raise ValueError(
81
- f'`json_schema` must be a dict, got {json_schema!r}.'
82
- )
83
- if 'title' not in json_schema:
84
- raise ValueError(
85
- f'The root of `json_schema` must have a `title` field, '
86
- f'got {json_schema!r}.'
87
- )
89
+ # Handle structured output.
90
+ output_schema = self._structure_output_schema(prompt)
91
+ if output_schema is not None:
88
92
  request_args.update(
89
93
  response_format=dict(
90
94
  type='json_schema',
91
- json_schema=dict(
92
- schema=json_schema,
93
- name=json_schema['title'],
94
- strict=True,
95
- )
95
+ json_schema=output_schema,
96
96
  )
97
97
  )
98
98
  prompt.metadata.formatted_text = (
@@ -118,17 +118,43 @@ class OpenAICompatible(rest.REST):
118
118
  assert isinstance(system_message, lf.SystemMessage), type(system_message)
119
119
  messages.append(
120
120
  system_message.as_format(
121
- 'openai', chunk_preprocessor=modality_check
121
+ 'openai_chat_completion_api', chunk_preprocessor=modality_check
122
122
  )
123
123
  )
124
124
  messages.append(
125
- prompt.as_format('openai', chunk_preprocessor=modality_check)
125
+ prompt.as_format(
126
+ 'openai_chat_completion_api',
127
+ chunk_preprocessor=modality_check
128
+ )
126
129
  )
127
130
  request = dict()
128
131
  request.update(request_args)
129
132
  request['messages'] = messages
130
133
  return request
131
134
 
135
+ def _structure_output_schema(
136
+ self, prompt: lf.Message
137
+ ) -> dict[str, Any] | None:
138
+ # Users could use `metadata_json_schema` to pass additional
139
+ # request arguments.
140
+ json_schema = prompt.metadata.get('json_schema')
141
+ if json_schema is not None:
142
+ if not isinstance(json_schema, dict):
143
+ raise ValueError(
144
+ f'`json_schema` must be a dict, got {json_schema!r}.'
145
+ )
146
+ if 'title' not in json_schema:
147
+ raise ValueError(
148
+ f'The root of `json_schema` must have a `title` field, '
149
+ f'got {json_schema!r}.'
150
+ )
151
+ return dict(
152
+ schema=json_schema,
153
+ name=json_schema['title'],
154
+ strict=True,
155
+ )
156
+ return None
157
+
132
158
  def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
133
159
  # Reference:
134
160
  # https://platform.openai.com/docs/api-reference/chat/object
@@ -144,7 +170,10 @@ class OpenAICompatible(rest.REST):
144
170
  for t in choice_logprobs['content']
145
171
  ]
146
172
  return lf.LMSample(
147
- lf.Message.from_value(choice['message'], format='openai'),
173
+ lf.Message.from_value(
174
+ choice['message'],
175
+ format='openai_chat_completion_api'
176
+ ),
148
177
  score=0.0,
149
178
  logprobs=logprobs,
150
179
  )
@@ -169,3 +198,95 @@ class OpenAICompatible(rest.REST):
169
198
  or (status_code == 400 and b'string_above_max_length' in content)):
170
199
  return lf.ContextLimitError(f'{status_code}: {content}')
171
200
  return super()._error(status_code, content)
201
+
202
+
203
+ class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
204
+ """Base class for models compatible with OpenAI's Responses API.
205
+
206
+ This class provides a common interface for language models that adhere to
207
+ the new OpenAI Responses API format. It standardizes request formatting
208
+ and response parsing for these models, including handling instructions
209
+ (system messages) and structured outputs.
210
+
211
+ **References:**
212
+
213
+ * https://platform.openai.com/docs/api-reference/responses
214
+ """
215
+
216
+ def _request_args(
217
+ self, options: lf.LMSamplingOptions) -> dict[str, Any]:
218
+ """Returns a dict as request arguments."""
219
+ if options.logprobs:
220
+ raise ValueError('logprobs is not supported on Responses API.')
221
+ if options.n != 1:
222
+ raise ValueError('n must be 1 for Responses API.')
223
+ return super()._request_args(options)
224
+
225
+ def request(
226
+ self,
227
+ prompt: lf.Message,
228
+ sampling_options: lf.LMSamplingOptions
229
+ ) -> dict[str, Any]:
230
+ """Returns the JSON input for a message."""
231
+ request_args = self._request_args(sampling_options)
232
+
233
+ # Handle structured output.
234
+ output_schema = self._structure_output_schema(prompt)
235
+ if output_schema is not None:
236
+ output_schema['type'] = 'json_schema'
237
+ request_args.update(text=dict(format=output_schema))
238
+ prompt.metadata.formatted_text = (
239
+ prompt.text
240
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
241
+ + pg.to_json_str(request_args['text'], json_indent=2)
242
+ )
243
+
244
+ request = dict()
245
+ request.update(request_args)
246
+
247
+ # Users could use `metadata_system_message` to pass system message.
248
+ system_message = prompt.metadata.get('system_message')
249
+ if system_message:
250
+ assert isinstance(system_message, lf.SystemMessage), type(system_message)
251
+ request['instructions'] = system_message.text
252
+
253
+ # Prepare input.
254
+ def modality_check(chunk: str | lf.Modality) -> Any:
255
+ if (isinstance(chunk, lf_modalities.Mime)
256
+ and not self.supports_input(chunk.mime_type)):
257
+ raise ValueError(
258
+ f'Unsupported modality: {chunk!r}.'
259
+ )
260
+ return chunk
261
+
262
+ request['input'] = [
263
+ prompt.as_format(
264
+ 'openai_responses_api',
265
+ chunk_preprocessor=modality_check
266
+ )
267
+ ]
268
+ return request
269
+
270
+ def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
271
+ for item in output:
272
+ if isinstance(item, dict) and item.get('type') == 'message':
273
+ return lf.LMSample(
274
+ lf.Message.from_value(item, format='openai_responses_api'),
275
+ score=0.0,
276
+ )
277
+ raise ValueError('No message found in output.')
278
+
279
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
280
+ """Returns a LMSamplingResult from a JSON response."""
281
+ usage = json['usage']
282
+ return lf.LMSamplingResult(
283
+ samples=[self._parse_output(json['output'])],
284
+ usage=lf.LMSamplingUsage(
285
+ prompt_tokens=usage['input_tokens'],
286
+ completion_tokens=usage['output_tokens'],
287
+ total_tokens=usage['total_tokens'],
288
+ completion_tokens_details=usage.get(
289
+ 'output_tokens_details', None
290
+ ),
291
+ ),
292
+ )
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
38
38
  response_format = ''
39
39
 
40
40
  choices = []
41
- for k in range(json['n']):
41
+ for k in range(json.get('n', 1)):
42
42
  if json.get('logprobs'):
43
43
  logprobs = dict(
44
44
  content=[
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
89
89
  c['image_url']['url']
90
90
  for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
91
  ]
92
- for k in range(json['n']):
92
+ for k in range(json.get('n', 1)):
93
93
  choices.append(pg.Dict(
94
94
  message=pg.Dict(
95
95
  content=f'Sample {k} for message: {"".join(urls)}'
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
111
111
  return response
112
112
 
113
113
 
114
- class OpenAIComptibleTest(unittest.TestCase):
114
+ def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
115
+ del url, kwargs
116
+ _ = json['input']
117
+
118
+ system_message = ''
119
+ if 'instructions' in json:
120
+ system_message = f' system={json["instructions"]}'
121
+
122
+ response_format = ''
123
+ if 'text' in json and 'format' in json['text']:
124
+ response_format = f' format={json["text"]["format"]["type"]}'
125
+
126
+ output = [
127
+ dict(
128
+ type='message',
129
+ content=[
130
+ dict(
131
+ type='output_text',
132
+ text=(
133
+ f'Sample 0 for message.{system_message}{response_format}'
134
+ )
135
+ )
136
+ ],
137
+ )
138
+ ]
139
+
140
+ response = requests.Response()
141
+ response.status_code = 200
142
+ response._content = pg.to_json_str(
143
+ dict(
144
+ output=output,
145
+ usage=dict(
146
+ input_tokens=100,
147
+ output_tokens=100,
148
+ total_tokens=200,
149
+ ),
150
+ )
151
+ ).encode()
152
+ return response
153
+
154
+
155
+ def mock_responses_request_vision(
156
+ url: str, json: dict[str, Any], **kwargs
157
+ ):
158
+ del url, kwargs
159
+ urls = [
160
+ c['image_url']
161
+ for c in json['input'][0]['content']
162
+ if c['type'] == 'input_image'
163
+ ]
164
+ output = [
165
+ pg.Dict(
166
+ type='message',
167
+ content=[
168
+ pg.Dict(
169
+ type='output_text',
170
+ text=f'Sample 0 for message: {"".join(urls)}',
171
+ )
172
+ ],
173
+ )
174
+ ]
175
+ response = requests.Response()
176
+ response.status_code = 200
177
+ response._content = pg.to_json_str(
178
+ dict(
179
+ output=output,
180
+ usage=dict(
181
+ input_tokens=100,
182
+ output_tokens=100,
183
+ total_tokens=200,
184
+ ),
185
+ )
186
+ ).encode()
187
+ return response
188
+
189
+
190
+ class OpenAIChatCompletionAPITest(unittest.TestCase):
115
191
  """Tests for OpenAI compatible language model."""
116
192
 
117
193
  def test_request_args(self):
118
194
  self.assertEqual(
119
- openai_compatible.OpenAICompatible(
195
+ openai_compatible.OpenAIChatCompletionAPI(
120
196
  api_endpoint='https://test-server',
121
197
  model='test-model'
122
198
  )._request_args(
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
126
202
  ),
127
203
  dict(
128
204
  model='test-model',
129
- top_logprobs=None,
130
- n=1,
131
205
  temperature=1.0,
132
206
  stop=['\n'],
133
207
  seed=123,
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
137
211
  def test_call_chat_completion(self):
138
212
  with mock.patch('requests.Session.post') as mock_request:
139
213
  mock_request.side_effect = mock_chat_completion_request
140
- lm = openai_compatible.OpenAICompatible(
214
+ lm = openai_compatible.OpenAIChatCompletionAPI(
141
215
  api_endpoint='https://test-server', model='test-model',
142
216
  )
143
217
  self.assertEqual(
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
148
222
  def test_call_chat_completion_with_logprobs(self):
149
223
  with mock.patch('requests.Session.post') as mock_request:
150
224
  mock_request.side_effect = mock_chat_completion_request
151
- lm = openai_compatible.OpenAICompatible(
225
+ lm = openai_compatible.OpenAIChatCompletionAPI(
152
226
  api_endpoint='https://test-server', model='test-model',
153
227
  )
154
228
  results = lm.sample(['hello'], logprobs=True)
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
214
288
  def mime_type(self) -> str:
215
289
  return 'image/png'
216
290
 
291
+ image = FakeImage.from_uri('https://fake/image')
217
292
  with mock.patch('requests.Session.post') as mock_request:
218
293
  mock_request.side_effect = mock_chat_completion_request_vision
219
- lm_1 = openai_compatible.OpenAICompatible(
294
+ lm_1 = openai_compatible.OpenAIChatCompletionAPI(
220
295
  api_endpoint='https://test-server',
221
296
  model='test-model1',
222
297
  )
223
- lm_2 = openai_compatible.OpenAICompatible(
298
+ lm_2 = openai_compatible.OpenAIChatCompletionAPI(
224
299
  api_endpoint='https://test-server',
225
300
  model='test-model2',
226
301
  )
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
228
303
  self.assertEqual(
229
304
  lm(
230
305
  lf.UserMessage(
231
- 'hello <<[[image]]>>',
232
- image=FakeImage.from_uri('https://fake/image')
306
+ f'hello <<[[{image.id}]]>>',
307
+ referred_modalities=[image],
233
308
  ),
234
309
  sampling_options=lf.LMSamplingOptions(n=2)
235
310
  ),
236
311
  'Sample 0 for message: https://fake/image',
237
312
  )
238
313
 
239
- class TextOnlyModel(openai_compatible.OpenAICompatible):
314
+ class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
240
315
 
241
316
  class ModelInfo(lf.ModelInfo):
242
317
  input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
251
326
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
252
327
  lm_3(
253
328
  lf.UserMessage(
254
- 'hello <<[[image]]>>',
255
- image=FakeImage.from_uri('https://fake/image')
329
+ f'hello <<[[{image.id}]]>>',
330
+ referred_modalities=[image],
256
331
  ),
257
332
  )
258
333
 
259
334
  def test_sample_chat_completion(self):
260
335
  with mock.patch('requests.Session.post') as mock_request:
261
336
  mock_request.side_effect = mock_chat_completion_request
262
- lm = openai_compatible.OpenAICompatible(
337
+ lm = openai_compatible.OpenAIChatCompletionAPI(
263
338
  api_endpoint='https://test-server', model='test-model'
264
339
  )
265
340
  results = lm.sample(
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
400
475
  def test_sample_with_contextual_options(self):
401
476
  with mock.patch('requests.Session.post') as mock_request:
402
477
  mock_request.side_effect = mock_chat_completion_request
403
- lm = openai_compatible.OpenAICompatible(
478
+ lm = openai_compatible.OpenAIChatCompletionAPI(
404
479
  api_endpoint='https://test-server', model='test-model'
405
480
  )
406
481
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
458
533
  def test_call_with_system_message(self):
459
534
  with mock.patch('requests.Session.post') as mock_request:
460
535
  mock_request.side_effect = mock_chat_completion_request
461
- lm = openai_compatible.OpenAICompatible(
536
+ lm = openai_compatible.OpenAIChatCompletionAPI(
462
537
  api_endpoint='https://test-server', model='test-model'
463
538
  )
464
539
  self.assertEqual(
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
475
550
  def test_call_with_json_schema(self):
476
551
  with mock.patch('requests.Session.post') as mock_request:
477
552
  mock_request.side_effect = mock_chat_completion_request
478
- lm = openai_compatible.OpenAICompatible(
553
+ lm = openai_compatible.OpenAIChatCompletionAPI(
479
554
  api_endpoint='https://test-server', model='test-model'
480
555
  )
481
556
  self.assertEqual(
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
515
590
 
516
591
  with mock.patch('requests.Session.post') as mock_request:
517
592
  mock_request.side_effect = mock_context_limit_error
518
- lm = openai_compatible.OpenAICompatible(
593
+ lm = openai_compatible.OpenAIChatCompletionAPI(
519
594
  api_endpoint='https://test-server', model='test-model'
520
595
  )
521
596
  with self.assertRaisesRegex(
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
524
599
  lm(lf.UserMessage('hello'))
525
600
 
526
601
 
602
+ class OpenAIResponsesAPITest(unittest.TestCase):
603
+ """Tests for OpenAI compatible language model on Responses API."""
604
+
605
+ def test_request_args(self):
606
+ lm = openai_compatible.OpenAIResponsesAPI(
607
+ api_endpoint='https://test-server', model='test-model'
608
+ )
609
+ # Test valid args.
610
+ self.assertEqual(
611
+ lm._request_args(
612
+ lf.LMSamplingOptions(
613
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
614
+ )
615
+ ),
616
+ dict(
617
+ model='test-model',
618
+ temperature=1.0,
619
+ stop=['\n'],
620
+ seed=123,
621
+ ),
622
+ )
623
+ # Test unsupported n.
624
+ with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
625
+ lm._request_args(lf.LMSamplingOptions(n=2))
626
+
627
+ # Test unsupported logprobs.
628
+ with self.assertRaisesRegex(
629
+ ValueError, 'logprobs is not supported on Responses API.'
630
+ ):
631
+ lm._request_args(lf.LMSamplingOptions(logprobs=True))
632
+
633
+ def test_call_responses(self):
634
+ with mock.patch('requests.Session.post') as mock_request:
635
+ mock_request.side_effect = mock_responses_request
636
+ lm = openai_compatible.OpenAIResponsesAPI(
637
+ api_endpoint='https://test-server',
638
+ model='test-model',
639
+ )
640
+ self.assertEqual(lm('hello'), 'Sample 0 for message.')
641
+
642
+ def test_call_responses_vision(self):
643
+ class FakeImage(lf_modalities.Image):
644
+ @property
645
+ def mime_type(self) -> str:
646
+ return 'image/png'
647
+
648
+ image = FakeImage.from_uri('https://fake/image')
649
+ with mock.patch('requests.Session.post') as mock_request:
650
+ mock_request.side_effect = mock_responses_request_vision
651
+ lm = openai_compatible.OpenAIResponsesAPI(
652
+ api_endpoint='https://test-server',
653
+ model='test-model1',
654
+ )
655
+ self.assertEqual(
656
+ lm(
657
+ lf.UserMessage(
658
+ f'hello <<[[{image.id}]]>>',
659
+ referred_modalities=[image],
660
+ )
661
+ ),
662
+ 'Sample 0 for message: https://fake/image',
663
+ )
664
+
665
+ def test_call_with_system_message(self):
666
+ with mock.patch('requests.Session.post') as mock_request:
667
+ mock_request.side_effect = mock_responses_request
668
+ lm = openai_compatible.OpenAIResponsesAPI(
669
+ api_endpoint='https://test-server', model='test-model'
670
+ )
671
+ self.assertEqual(
672
+ lm(
673
+ lf.UserMessage(
674
+ 'hello',
675
+ system_message=lf.SystemMessage('hi'),
676
+ )
677
+ ),
678
+ 'Sample 0 for message. system=hi',
679
+ )
680
+
681
+ def test_call_with_json_schema(self):
682
+ with mock.patch('requests.Session.post') as mock_request:
683
+ mock_request.side_effect = mock_responses_request
684
+ lm = openai_compatible.OpenAIResponsesAPI(
685
+ api_endpoint='https://test-server', model='test-model'
686
+ )
687
+ self.assertEqual(
688
+ lm(
689
+ lf.UserMessage(
690
+ 'hello',
691
+ json_schema={
692
+ 'type': 'object',
693
+ 'properties': {
694
+ 'name': {'type': 'string'},
695
+ },
696
+ 'required': ['name'],
697
+ 'title': 'Person',
698
+ },
699
+ )
700
+ ),
701
+ 'Sample 0 for message. format=json_schema',
702
+ )
703
+
704
+ # Test bad json schema.
705
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
706
+ lm(lf.UserMessage('hello', json_schema='foo'))
707
+
708
+ with self.assertRaisesRegex(
709
+ ValueError, 'The root of `json_schema` must have a `title` field'
710
+ ):
711
+ lm(lf.UserMessage('hello', json_schema={}))
712
+
713
+
527
714
  if __name__ == '__main__':
528
715
  unittest.main()
@@ -61,8 +61,6 @@ class OpenAITest(unittest.TestCase):
61
61
  ),
62
62
  dict(
63
63
  model='gpt-4',
64
- top_logprobs=None,
65
- n=1,
66
64
  temperature=1.0,
67
65
  stop=['\n'],
68
66
  seed=123,
langfun/core/llms/rest.py CHANGED
@@ -22,7 +22,18 @@ import requests
22
22
 
23
23
 
24
24
  class REST(lf.LanguageModel):
25
- """REST-based language model."""
25
+ """Base class for language models accessed via REST APIs.
26
+
27
+ The `REST` class provides a foundation for implementing language models
28
+ that are accessed through RESTful endpoints. It handles the details of
29
+ making HTTP requests, managing sessions, and handling common errors like
30
+ timeouts and connection issues.
31
+
32
+ Subclasses need to implement the `request` and `result` methods to
33
+ convert Langfun messages to API-specific request formats and to parse
34
+ API responses back into `LMSamplingResult` objects. They also need to
35
+ provide the `api_endpoint` and can override `headers` for authentication.
36
+ """
26
37
 
27
38
  api_endpoint: Annotated[
28
39
  str,
@@ -98,7 +109,9 @@ class REST(lf.LanguageModel):
98
109
  raise lf.TemporaryLMError(str(e)) from e
99
110
  except (
100
111
  requests.exceptions.ConnectionError,
112
+ requests.exceptions.ChunkedEncodingError,
101
113
  ConnectionError,
114
+ ConnectionResetError,
102
115
  ) as e:
103
116
  error_message = str(e)
104
117
  if 'REJECTED_CLIENT_THROTTLED' in error_message:
@@ -107,6 +120,8 @@ class REST(lf.LanguageModel):
107
120
  raise lf.TemporaryLMError(error_message) from e
108
121
  if 'UNREACHABLE_ERROR' in error_message:
109
122
  raise lf.TemporaryLMError(error_message) from e
123
+ if 'Connection reset by peer' in error_message:
124
+ raise lf.TemporaryLMError(error_message) from e
110
125
  raise lf.LMError(error_message) from e
111
126
 
112
127
  def _error(self, status_code: int, content: str) -> lf.LMError: