langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -177,6 +177,58 @@ class GeminiTest(unittest.TestCase):
177
177
  ),
178
178
  )
179
179
 
180
+ # Add test for thinkingConfig with thinking_level.
181
+ actual = model._generation_config(
182
+ lf.UserMessage('hi'),
183
+ lf.LMSamplingOptions(
184
+ thinking_level='high',
185
+ ),
186
+ )
187
+ self.assertEqual(
188
+ actual,
189
+ dict(
190
+ candidateCount=1,
191
+ temperature=None,
192
+ topP=None,
193
+ topK=40,
194
+ maxOutputTokens=None,
195
+ stopSequences=None,
196
+ responseLogprobs=False,
197
+ logprobs=None,
198
+ seed=None,
199
+ thinkingConfig={'thinkingLevel': 'high'},
200
+ ),
201
+ )
202
+
203
+ # Add test for thinkingConfig with both max_thinking_tokens and
204
+ # thinking_level.
205
+ actual = model._generation_config(
206
+ lf.UserMessage('hi'),
207
+ lf.LMSamplingOptions(
208
+ max_thinking_tokens=100,
209
+ thinking_level='low',
210
+ ),
211
+ )
212
+ self.assertEqual(
213
+ actual,
214
+ dict(
215
+ candidateCount=1,
216
+ temperature=None,
217
+ topP=None,
218
+ topK=40,
219
+ maxOutputTokens=None,
220
+ stopSequences=None,
221
+ responseLogprobs=False,
222
+ logprobs=None,
223
+ seed=None,
224
+ thinkingConfig={
225
+ 'includeThoughts': True,
226
+ 'thinkingBudget': 100,
227
+ 'thinkingLevel': 'low',
228
+ },
229
+ ),
230
+ )
231
+
180
232
  with self.assertRaisesRegex(
181
233
  ValueError, '`json_schema` must be a dict, got'
182
234
  ):
@@ -185,6 +237,32 @@ class GeminiTest(unittest.TestCase):
185
237
  lf.LMSamplingOptions(),
186
238
  )
187
239
 
240
+ def test_media_resolution_for_gemini3(self):
241
+ model = gemini.Gemini('gemini-3-pro-preview', api_endpoint='')
242
+ config = model._generation_config(
243
+ lf.UserMessage('hi'),
244
+ lf.LMSamplingOptions(),
245
+ )
246
+ self.assertEqual(config.get('mediaResolution'), 'MEDIA_RESOLUTION_HIGH')
247
+
248
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
249
+ config = model._generation_config(
250
+ lf.UserMessage('hi'),
251
+ lf.LMSamplingOptions(),
252
+ )
253
+ self.assertIsNone(config.get('mediaResolution'))
254
+
255
+ def test_request_tool_config(self):
256
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
257
+ request = model.request(
258
+ lf.UserMessage('hi'),
259
+ lf.LMSamplingOptions(),
260
+ )
261
+ self.assertEqual(
262
+ request.get('toolConfig'),
263
+ {'functionCallingConfig': {'mode': 'NONE'}},
264
+ )
265
+
188
266
  def test_call_model(self):
189
267
  with mock.patch('requests.Session.post') as mock_generate:
190
268
  mock_generate.side_effect = mock_requests_post
@@ -225,6 +303,38 @@ class GeminiTest(unittest.TestCase):
225
303
  ):
226
304
  lm('hello')
227
305
 
306
+ def test_call_model_with_max_tokens_error(self):
307
+ def mock_requests_post_error(*args, **kwargs):
308
+ del args, kwargs
309
+ response = requests.Response()
310
+ response.status_code = 200
311
+ response._content = pg.to_json_str({
312
+ 'candidates': [
313
+ {
314
+ 'finishReason': 'MAX_TOKENS',
315
+ 'content': {
316
+ 'parts': [
317
+ {
318
+ 'text': 'This is'
319
+ }
320
+ ]
321
+ }
322
+ },
323
+ ],
324
+ 'usageMetadata': {
325
+ 'promptTokenCount': 3,
326
+ 'candidatesTokenCount': 4,
327
+ }
328
+ }).encode()
329
+ return response
330
+
331
+ with mock.patch('requests.Session.post') as mock_generate:
332
+ mock_generate.side_effect = mock_requests_post_error
333
+ lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
334
+ m = lm('hello')
335
+ self.assertEqual(m.metadata.finish_reason, 'MAX_TOKENS')
336
+ self.assertEqual(m.text, 'This is')
337
+
228
338
  def test_call_model_with_system_message(self):
229
339
  with mock.patch('requests.Session.post') as mock_generate:
230
340
  mock_generate.side_effect = mock_requests_post
@@ -25,7 +25,35 @@ import pyglove as pg
25
25
  @lf.use_init_args(['model'])
26
26
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
27
27
  class GenAI(gemini.Gemini):
28
- """Language models provided by Google GenAI."""
28
+ """Google GenAI models.
29
+
30
+ **Quick Start:**
31
+
32
+ ```python
33
+ import langfun as lf
34
+
35
+ # Call Gemini 1.5 Flash using API key from environment variable
36
+ # 'GOOGLE_API_KEY'.
37
+ lm = lf.llms.Gemini15Flash()
38
+ r = lm('Who are you?')
39
+ print(r)
40
+ ```
41
+
42
+ **Setting up API key:**
43
+
44
+ The Google API key can be specified in following ways:
45
+
46
+ 1. At model instantiation:
47
+
48
+ ```python
49
+ lm = lf.llms.Gemini15Flash(api_key='MY_API_KEY')
50
+ ```
51
+ 2. via environment variable `GOOGLE_API_KEY`.
52
+
53
+ **References:**
54
+
55
+ * https://ai.google.dev/docs
56
+ """
29
57
 
30
58
  model: pg.typing.Annotated[
31
59
  pg.typing.Enum(
@@ -87,9 +115,14 @@ class GenAI(gemini.Gemini):
87
115
 
88
116
  # pylint: disable=invalid-name
89
117
 
118
+
90
119
  #
91
120
  # Experimental models.
92
121
  #
122
+ class Gemini3ProPreview(GenAI):
123
+ """Gemini 3 Pro Preview model."""
124
+
125
+ model = 'gemini-3-pro-preview'
93
126
 
94
127
 
95
128
  class Gemini25FlashImagePreview(GenAI):
langfun/core/llms/groq.py CHANGED
@@ -259,10 +259,35 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
259
259
 
260
260
 
261
261
  @lf.use_init_args(['model'])
262
- class Groq(openai_compatible.OpenAICompatible):
263
- """Groq LLMs through REST APIs (OpenAI compatible).
262
+ class Groq(openai_compatible.OpenAIChatCompletionAPI):
263
+ """Groq models.
264
264
 
265
- See https://platform.openai.com/docs/api-reference/chat
265
+ **Quick Start:**
266
+
267
+ ```python
268
+ import langfun as lf
269
+
270
+ # Call Llama 3.3 70B on Groq using API key from environment variable
271
+ # 'GROQ_API_KEY'.
272
+ lm = lf.llms.GroqLlama33_70B_Versatile()
273
+ r = lm('Who are you?')
274
+ print(r)
275
+ ```
276
+
277
+ **Setting up API key:**
278
+
279
+ The Groq API key can be specified in following ways:
280
+
281
+ 1. At model instantiation:
282
+
283
+ ```python
284
+ lm = lf.llms.GroqLlama33_70B_Versatile(api_key='MY_API_KEY')
285
+ ```
286
+ 2. via environment variable `GROQ_API_KEY`.
287
+
288
+ **References:**
289
+
290
+ * https://console.groq.com/docs
266
291
  """
267
292
 
268
293
  model: pg.typing.Annotated[
@@ -20,11 +20,30 @@ import pyglove as pg
20
20
 
21
21
  @pg.use_init_args(['url', 'model'])
22
22
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
23
- class LlamaCppRemote(openai_compatible.OpenAICompatible):
24
- """The remote LLaMA C++ model.
23
+ class LlamaCppRemote(openai_compatible.OpenAIChatCompletionAPI):
24
+ """LLaMA C++ models served via a remote server.
25
25
 
26
- The Remote LLaMA C++ models can be launched via
27
- https://github.com/ggerganov/llama.cpp/tree/master/examples/server
26
+ This class provides an interface to interact with language models
27
+ hosted on a LLaMA C++ server, which is compatible with the OpenAI
28
+ Chat Completions API format.
29
+
30
+ **Quick Start:**
31
+
32
+ Assuming a LLaMA C++ server is running at `http://localhost:8080`,
33
+ you can interact with it as follows:
34
+
35
+ ```python
36
+ import langfun as lf
37
+
38
+ # If model name is not specified, it will use server's default.
39
+ lm = lf.llms.LlamaCppRemote(url='http://localhost:8080')
40
+ r = lm('Who are you?')
41
+ print(r)
42
+ ```
43
+
44
+ **References:**
45
+
46
+ * https://github.com/ggerganov/llama.cpp/tree/master/examples/server
28
47
  """
29
48
  url: Annotated[
30
49
  str,
@@ -49,6 +49,75 @@ class OpenAIModelInfo(lf.ModelInfo):
49
49
  #
50
50
 
51
51
  SUPPORTED_MODELS = [
52
+ # GPT-5 models
53
+ OpenAIModelInfo(
54
+ model_id='gpt-5.1',
55
+ in_service=True,
56
+ model_type='instruction-tuned',
57
+ description='GPT 5.1 model (latest stable).',
58
+ url='https://platform.openai.com/docs/models/gpt-5.1',
59
+ input_modalities=OpenAIModelInfo.INPUT_IMAGE_TYPES,
60
+ context_length=lf.ModelInfo.ContextLength(
61
+ max_input_tokens=400_000,
62
+ max_output_tokens=128_000,
63
+ ),
64
+ pricing=lf.ModelInfo.Pricing(
65
+ cost_per_1m_cached_input_tokens=0.13,
66
+ cost_per_1m_input_tokens=1.25,
67
+ cost_per_1m_output_tokens=10.0,
68
+ ),
69
+ # Tier 5 rate limits.
70
+ rate_limits=lf.ModelInfo.RateLimits(
71
+ max_requests_per_minute=15_000,
72
+ max_tokens_per_minute=40_000_000,
73
+ ),
74
+ ),
75
+ OpenAIModelInfo(
76
+ model_id='gpt-5',
77
+ alias_for='gpt-5-2025-08-07',
78
+ in_service=True,
79
+ model_type='instruction-tuned',
80
+ description='GPT 5 model (latest stable).',
81
+ url='https://platform.openai.com/docs/models/gpt-5',
82
+ input_modalities=OpenAIModelInfo.INPUT_IMAGE_TYPES,
83
+ context_length=lf.ModelInfo.ContextLength(
84
+ max_input_tokens=400_000,
85
+ max_output_tokens=128_000,
86
+ ),
87
+ pricing=lf.ModelInfo.Pricing(
88
+ cost_per_1m_cached_input_tokens=0.125,
89
+ cost_per_1m_input_tokens=1.25,
90
+ cost_per_1m_output_tokens=10.0,
91
+ ),
92
+ # Tier 5 rate limits.
93
+ rate_limits=lf.ModelInfo.RateLimits(
94
+ max_requests_per_minute=15_000,
95
+ max_tokens_per_minute=40_000_000,
96
+ ),
97
+ ),
98
+ OpenAIModelInfo(
99
+ model_id='gpt-5-mini',
100
+ alias_for='gpt-5-mini-2025-08-07',
101
+ in_service=True,
102
+ model_type='instruction-tuned',
103
+ description='GPT 5 mini model (latest stable).',
104
+ url='https://platform.openai.com/docs/models/gpt-5-mini',
105
+ input_modalities=OpenAIModelInfo.INPUT_IMAGE_TYPES,
106
+ context_length=lf.ModelInfo.ContextLength(
107
+ max_input_tokens=400_000,
108
+ max_output_tokens=128_000,
109
+ ),
110
+ pricing=lf.ModelInfo.Pricing(
111
+ cost_per_1m_cached_input_tokens=0.025,
112
+ cost_per_1m_input_tokens=0.25,
113
+ cost_per_1m_output_tokens=2.0,
114
+ ),
115
+ # Tier 5 rate limits.
116
+ rate_limits=lf.ModelInfo.RateLimits(
117
+ max_requests_per_minute=180_000_000,
118
+ max_tokens_per_minute=30_000_000,
119
+ ),
120
+ ),
52
121
  # GPT-4.1 models
53
122
  OpenAIModelInfo(
54
123
  model_id='gpt-4.1',
@@ -984,8 +1053,36 @@ _SUPPORTED_MODELS_BY_MODEL_ID = {m.model_id: m for m in SUPPORTED_MODELS}
984
1053
 
985
1054
 
986
1055
  @lf.use_init_args(['model'])
987
- class OpenAI(openai_compatible.OpenAICompatible):
988
- """OpenAI model."""
1056
+ class OpenAI(openai_compatible.OpenAIResponsesAPI):
1057
+ """OpenAI models.
1058
+
1059
+ **Quick Start:**
1060
+
1061
+ ```python
1062
+ import langfun as lf
1063
+
1064
+ # Call GPT-4o using API key from environment variable 'OPENAI_API_KEY'.
1065
+ lm = lf.llms.Gpt4o()
1066
+ r = lm('Who are you?')
1067
+ print(r)
1068
+ ```
1069
+
1070
+ **Setting up API key:**
1071
+
1072
+ The OpenAI API key can be specified in following ways:
1073
+
1074
+ 1. At model instantiation:
1075
+
1076
+ ```python
1077
+ lm = lf.llms.Gpt4o(api_key='MY_API_KEY')
1078
+ ```
1079
+ 2. via environment variable `OPENAI_API_KEY`.
1080
+
1081
+ **References:**
1082
+
1083
+ * https://platform.openai.com/docs/models
1084
+ * https://platform.openai.com/docs/api-reference
1085
+ """
989
1086
 
990
1087
  model: pg.typing.Annotated[
991
1088
  pg.typing.Enum(
@@ -994,7 +1091,12 @@ class OpenAI(openai_compatible.OpenAICompatible):
994
1091
  'The name of the model to use.',
995
1092
  ]
996
1093
 
997
- api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
1094
+ # Disable message storage by default.
1095
+ sampling_options = lf.LMSamplingOptions(
1096
+ extras={'store': False}
1097
+ )
1098
+
1099
+ api_endpoint: str = 'https://api.openai.com/v1/responses'
998
1100
 
999
1101
  api_key: Annotated[
1000
1102
  str | None,
@@ -1069,6 +1171,21 @@ class OpenAI(openai_compatible.OpenAICompatible):
1069
1171
  return super()._request_args(options)
1070
1172
 
1071
1173
 
1174
+ class Gpt51(OpenAI):
1175
+ """GPT-5.1."""
1176
+ model = 'gpt-5.1'
1177
+
1178
+
1179
+ class Gpt5(OpenAI):
1180
+ """GPT-5."""
1181
+ model = 'gpt-5'
1182
+
1183
+
1184
+ class Gpt5Mini(OpenAI):
1185
+ """GPT-5 mini."""
1186
+ model = 'gpt-5-mini'
1187
+
1188
+
1072
1189
  class Gpt41(OpenAI):
1073
1190
  """GPT-4.1."""
1074
1191
  model = 'gpt-4.1'
@@ -23,8 +23,18 @@ import pyglove as pg
23
23
 
24
24
 
25
25
  @lf.use_init_args(['api_endpoint', 'model'])
26
- class OpenAICompatible(rest.REST):
27
- """Base for OpenAI compatible models."""
26
+ class OpenAIChatCompletionAPI(rest.REST):
27
+ """Base class for models compatible with OpenAI's Chat Completion API.
28
+
29
+ This class provides a common interface for language models that adhere to
30
+ the OpenAI Chat Completion API format, which is used by providers like
31
+ Groq, DeepSeek, and others. It standardizes request formatting and
32
+ response parsing for these models.
33
+
34
+ **References:**
35
+
36
+ * https://platform.openai.com/docs/api-reference/chat
37
+ """
28
38
 
29
39
  model: Annotated[
30
40
  str, 'The name of the model to use.',
@@ -42,12 +52,14 @@ class OpenAICompatible(rest.REST):
42
52
  # Reference:
43
53
  # https://platform.openai.com/docs/api-reference/completions/create
44
54
  # NOTE(daiyip): options.top_k is not applicable.
45
- args = dict(
46
- n=options.n,
47
- top_logprobs=options.top_logprobs,
48
- )
55
+ args = {}
56
+
49
57
  if self.model:
50
58
  args['model'] = self.model
59
+ if options.n != 1:
60
+ args['n'] = options.n
61
+ if options.top_logprobs is not None:
62
+ args['top_logprobs'] = options.top_logprobs
51
63
  if options.logprobs:
52
64
  args['logprobs'] = options.logprobs
53
65
  if options.temperature is not None:
@@ -62,6 +74,8 @@ class OpenAICompatible(rest.REST):
62
74
  args['seed'] = options.random_seed
63
75
  if options.reasoning_effort is not None:
64
76
  args['reasoning_effort'] = options.reasoning_effort
77
+ if options.extras:
78
+ args.update(options.extras)
65
79
  return args
66
80
 
67
81
  def request(
@@ -72,27 +86,13 @@ class OpenAICompatible(rest.REST):
72
86
  """Returns the JSON input for a message."""
73
87
  request_args = self._request_args(sampling_options)
74
88
 
75
- # Users could use `metadata_json_schema` to pass additional
76
- # request arguments.
77
- json_schema = prompt.metadata.get('json_schema')
78
- if json_schema is not None:
79
- if not isinstance(json_schema, dict):
80
- raise ValueError(
81
- f'`json_schema` must be a dict, got {json_schema!r}.'
82
- )
83
- if 'title' not in json_schema:
84
- raise ValueError(
85
- f'The root of `json_schema` must have a `title` field, '
86
- f'got {json_schema!r}.'
87
- )
89
+ # Handle structured output.
90
+ output_schema = self._structure_output_schema(prompt)
91
+ if output_schema is not None:
88
92
  request_args.update(
89
93
  response_format=dict(
90
94
  type='json_schema',
91
- json_schema=dict(
92
- schema=json_schema,
93
- name=json_schema['title'],
94
- strict=True,
95
- )
95
+ json_schema=output_schema,
96
96
  )
97
97
  )
98
98
  prompt.metadata.formatted_text = (
@@ -118,17 +118,43 @@ class OpenAICompatible(rest.REST):
118
118
  assert isinstance(system_message, lf.SystemMessage), type(system_message)
119
119
  messages.append(
120
120
  system_message.as_format(
121
- 'openai', chunk_preprocessor=modality_check
121
+ 'openai_chat_completion_api', chunk_preprocessor=modality_check
122
122
  )
123
123
  )
124
124
  messages.append(
125
- prompt.as_format('openai', chunk_preprocessor=modality_check)
125
+ prompt.as_format(
126
+ 'openai_chat_completion_api',
127
+ chunk_preprocessor=modality_check
128
+ )
126
129
  )
127
130
  request = dict()
128
131
  request.update(request_args)
129
132
  request['messages'] = messages
130
133
  return request
131
134
 
135
+ def _structure_output_schema(
136
+ self, prompt: lf.Message
137
+ ) -> dict[str, Any] | None:
138
+ # Users could use `metadata_json_schema` to pass additional
139
+ # request arguments.
140
+ json_schema = prompt.metadata.get('json_schema')
141
+ if json_schema is not None:
142
+ if not isinstance(json_schema, dict):
143
+ raise ValueError(
144
+ f'`json_schema` must be a dict, got {json_schema!r}.'
145
+ )
146
+ if 'title' not in json_schema:
147
+ raise ValueError(
148
+ f'The root of `json_schema` must have a `title` field, '
149
+ f'got {json_schema!r}.'
150
+ )
151
+ return dict(
152
+ schema=json_schema,
153
+ name=json_schema['title'],
154
+ strict=True,
155
+ )
156
+ return None
157
+
132
158
  def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
133
159
  # Reference:
134
160
  # https://platform.openai.com/docs/api-reference/chat/object
@@ -144,7 +170,10 @@ class OpenAICompatible(rest.REST):
144
170
  for t in choice_logprobs['content']
145
171
  ]
146
172
  return lf.LMSample(
147
- lf.Message.from_value(choice['message'], format='openai'),
173
+ lf.Message.from_value(
174
+ choice['message'],
175
+ format='openai_chat_completion_api'
176
+ ),
148
177
  score=0.0,
149
178
  logprobs=logprobs,
150
179
  )
@@ -169,3 +198,95 @@ class OpenAICompatible(rest.REST):
169
198
  or (status_code == 400 and b'string_above_max_length' in content)):
170
199
  return lf.ContextLimitError(f'{status_code}: {content}')
171
200
  return super()._error(status_code, content)
201
+
202
+
203
+ class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
204
+ """Base class for models compatible with OpenAI's Responses API.
205
+
206
+ This class provides a common interface for language models that adhere to
207
+ the new OpenAI Responses API format. It standardizes request formatting
208
+ and response parsing for these models, including handling instructions
209
+ (system messages) and structured outputs.
210
+
211
+ **References:**
212
+
213
+ * https://platform.openai.com/docs/api-reference/responses
214
+ """
215
+
216
+ def _request_args(
217
+ self, options: lf.LMSamplingOptions) -> dict[str, Any]:
218
+ """Returns a dict as request arguments."""
219
+ if options.logprobs:
220
+ raise ValueError('logprobs is not supported on Responses API.')
221
+ if options.n != 1:
222
+ raise ValueError('n must be 1 for Responses API.')
223
+ return super()._request_args(options)
224
+
225
+ def request(
226
+ self,
227
+ prompt: lf.Message,
228
+ sampling_options: lf.LMSamplingOptions
229
+ ) -> dict[str, Any]:
230
+ """Returns the JSON input for a message."""
231
+ request_args = self._request_args(sampling_options)
232
+
233
+ # Handle structured output.
234
+ output_schema = self._structure_output_schema(prompt)
235
+ if output_schema is not None:
236
+ output_schema['type'] = 'json_schema'
237
+ request_args.update(text=dict(format=output_schema))
238
+ prompt.metadata.formatted_text = (
239
+ prompt.text
240
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
241
+ + pg.to_json_str(request_args['text'], json_indent=2)
242
+ )
243
+
244
+ request = dict()
245
+ request.update(request_args)
246
+
247
+ # Users could use `metadata_system_message` to pass system message.
248
+ system_message = prompt.metadata.get('system_message')
249
+ if system_message:
250
+ assert isinstance(system_message, lf.SystemMessage), type(system_message)
251
+ request['instructions'] = system_message.text
252
+
253
+ # Prepare input.
254
+ def modality_check(chunk: str | lf.Modality) -> Any:
255
+ if (isinstance(chunk, lf_modalities.Mime)
256
+ and not self.supports_input(chunk.mime_type)):
257
+ raise ValueError(
258
+ f'Unsupported modality: {chunk!r}.'
259
+ )
260
+ return chunk
261
+
262
+ request['input'] = [
263
+ prompt.as_format(
264
+ 'openai_responses_api',
265
+ chunk_preprocessor=modality_check
266
+ )
267
+ ]
268
+ return request
269
+
270
+ def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
271
+ for item in output:
272
+ if isinstance(item, dict) and item.get('type') == 'message':
273
+ return lf.LMSample(
274
+ lf.Message.from_value(item, format='openai_responses_api'),
275
+ score=0.0,
276
+ )
277
+ raise ValueError('No message found in output.')
278
+
279
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
280
+ """Returns a LMSamplingResult from a JSON response."""
281
+ usage = json['usage']
282
+ return lf.LMSamplingResult(
283
+ samples=[self._parse_output(json['output'])],
284
+ usage=lf.LMSamplingUsage(
285
+ prompt_tokens=usage['input_tokens'],
286
+ completion_tokens=usage['output_tokens'],
287
+ total_tokens=usage['total_tokens'],
288
+ completion_tokens_details=usage.get(
289
+ 'output_tokens_details', None
290
+ ),
291
+ ),
292
+ )