langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/concurrent_test.py +1 -0
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +134 -30
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +15 -6
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +7 -1
- langfun/core/llms/anthropic.py +130 -0
- langfun/core/llms/cache/base.py +3 -1
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/deepseek.py +1 -1
- langfun/core/llms/gemini.py +2 -5
- langfun/core/llms/groq.py +1 -1
- langfun/core/llms/llama_cpp.py +1 -1
- langfun/core/llms/openai.py +7 -2
- langfun/core/llms/openai_compatible.py +136 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/vertexai.py +12 -2
- langfun/core/message.py +78 -44
- langfun/core/message_test.py +56 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/mime.py +9 -0
- langfun/core/modality.py +104 -27
- langfun/core/modality_test.py +42 -12
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/completion.py +2 -7
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/mapping.py +4 -13
- langfun/core/structured/querying.py +13 -11
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/template.py +39 -13
- langfun/core/template_test.py +83 -17
- langfun/env/event_handlers/metric_writer_test.py +3 -3
- langfun/env/load_balancers_test.py +2 -2
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +44 -44
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
|
@@ -656,11 +656,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
|
656
656
|
|
|
657
657
|
string_io = io.StringIO()
|
|
658
658
|
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
|
659
|
+
image = Image()
|
|
659
660
|
with contextlib.redirect_stdout(string_io):
|
|
660
661
|
self.assertEqual(
|
|
661
|
-
lm(
|
|
662
|
-
|
|
663
|
-
|
|
662
|
+
lm(
|
|
663
|
+
message_lib.UserMessage(
|
|
664
|
+
f'hi <<[[{image.id}]]>>',
|
|
665
|
+
referred_modalities=[image],
|
|
666
|
+
),
|
|
667
|
+
debug=True
|
|
668
|
+
),
|
|
669
|
+
f'hi <<[[{image.id}]]>>'
|
|
664
670
|
)
|
|
665
671
|
|
|
666
672
|
debug_info = string_io.getvalue()
|
langfun/core/llms/__init__.py
CHANGED
|
@@ -30,7 +30,8 @@ from langfun.core.llms.compositional import RandomChoice
|
|
|
30
30
|
|
|
31
31
|
# Base models by request/response protocol.
|
|
32
32
|
from langfun.core.llms.rest import REST
|
|
33
|
-
from langfun.core.llms.openai_compatible import
|
|
33
|
+
from langfun.core.llms.openai_compatible import OpenAIChatCompletionAPI
|
|
34
|
+
from langfun.core.llms.openai_compatible import OpenAIResponsesAPI
|
|
34
35
|
from langfun.core.llms.gemini import Gemini
|
|
35
36
|
from langfun.core.llms.anthropic import Anthropic
|
|
36
37
|
|
|
@@ -151,6 +152,9 @@ from langfun.core.llms.openai import Gpt35
|
|
|
151
152
|
|
|
152
153
|
# Anthropic models.
|
|
153
154
|
|
|
155
|
+
from langfun.core.llms.anthropic import Claude45
|
|
156
|
+
from langfun.core.llms.anthropic import Claude45Haiku_20251001
|
|
157
|
+
from langfun.core.llms.anthropic import Claude45Sonnet_20250929
|
|
154
158
|
from langfun.core.llms.anthropic import Claude4
|
|
155
159
|
from langfun.core.llms.anthropic import Claude4Sonnet_20250514
|
|
156
160
|
from langfun.core.llms.anthropic import Claude4Opus_20250514
|
|
@@ -168,6 +172,8 @@ from langfun.core.llms.anthropic import Claude3Haiku
|
|
|
168
172
|
from langfun.core.llms.anthropic import Claude3Haiku_20240307
|
|
169
173
|
|
|
170
174
|
from langfun.core.llms.vertexai import VertexAIAnthropic
|
|
175
|
+
from langfun.core.llms.vertexai import VertexAIClaude45Haiku_20251001
|
|
176
|
+
from langfun.core.llms.vertexai import VertexAIClaude45Sonnet_20250929
|
|
171
177
|
from langfun.core.llms.vertexai import VertexAIClaude4Opus_20250514
|
|
172
178
|
from langfun.core.llms.vertexai import VertexAIClaude4Sonnet_20250514
|
|
173
179
|
from langfun.core.llms.vertexai import VertexAIClaude37Sonnet_20250219
|
langfun/core/llms/anthropic.py
CHANGED
|
@@ -59,6 +59,60 @@ class AnthropicModelInfo(lf.ModelInfo):
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
SUPPORTED_MODELS = [
|
|
62
|
+
AnthropicModelInfo(
|
|
63
|
+
model_id='claude-haiku-4-5-20251001',
|
|
64
|
+
provider='Anthropic',
|
|
65
|
+
in_service=True,
|
|
66
|
+
description='Claude 4.5 Haiku model (10/15/2025).',
|
|
67
|
+
release_date=datetime.datetime(2025, 10, 15),
|
|
68
|
+
input_modalities=(
|
|
69
|
+
AnthropicModelInfo.INPUT_IMAGE_TYPES
|
|
70
|
+
+ AnthropicModelInfo.INPUT_DOC_TYPES
|
|
71
|
+
),
|
|
72
|
+
context_length=lf.ModelInfo.ContextLength(
|
|
73
|
+
max_input_tokens=200_000,
|
|
74
|
+
max_output_tokens=64_000,
|
|
75
|
+
),
|
|
76
|
+
pricing=lf.ModelInfo.Pricing(
|
|
77
|
+
cost_per_1m_cached_input_tokens=0.1,
|
|
78
|
+
cost_per_1m_input_tokens=1,
|
|
79
|
+
cost_per_1m_output_tokens=5,
|
|
80
|
+
),
|
|
81
|
+
rate_limits=AnthropicModelInfo.RateLimits(
|
|
82
|
+
# Tier 4 rate limits
|
|
83
|
+
max_requests_per_minute=4000,
|
|
84
|
+
max_input_tokens_per_minute=4_000_000,
|
|
85
|
+
max_output_tokens_per_minute=800_000,
|
|
86
|
+
),
|
|
87
|
+
),
|
|
88
|
+
AnthropicModelInfo(
|
|
89
|
+
model_id='claude-sonnet-4-5-20250929',
|
|
90
|
+
provider='Anthropic',
|
|
91
|
+
in_service=True,
|
|
92
|
+
description='Claude 4.5 Sonnet model (9/29/2025).',
|
|
93
|
+
release_date=datetime.datetime(2025, 9, 29),
|
|
94
|
+
input_modalities=(
|
|
95
|
+
AnthropicModelInfo.INPUT_IMAGE_TYPES
|
|
96
|
+
+ AnthropicModelInfo.INPUT_DOC_TYPES
|
|
97
|
+
),
|
|
98
|
+
context_length=lf.ModelInfo.ContextLength(
|
|
99
|
+
max_input_tokens=200_000,
|
|
100
|
+
max_output_tokens=64_000,
|
|
101
|
+
),
|
|
102
|
+
pricing=lf.ModelInfo.Pricing(
|
|
103
|
+
cost_per_1m_cached_input_tokens=0.3,
|
|
104
|
+
cost_per_1m_input_tokens=3,
|
|
105
|
+
cost_per_1m_output_tokens=15,
|
|
106
|
+
),
|
|
107
|
+
rate_limits=AnthropicModelInfo.RateLimits(
|
|
108
|
+
# Tier 4 rate limits
|
|
109
|
+
# This rate limit is a total limit that applies to combined traffic
|
|
110
|
+
# across both Sonnet 4 and Sonnet 4.5.
|
|
111
|
+
max_requests_per_minute=4000,
|
|
112
|
+
max_input_tokens_per_minute=2_000_000,
|
|
113
|
+
max_output_tokens_per_minute=400_000,
|
|
114
|
+
),
|
|
115
|
+
),
|
|
62
116
|
AnthropicModelInfo(
|
|
63
117
|
model_id='claude-4-opus-20250514',
|
|
64
118
|
provider='Anthropic',
|
|
@@ -190,6 +244,62 @@ SUPPORTED_MODELS = [
|
|
|
190
244
|
max_output_tokens_per_minute=80_000,
|
|
191
245
|
),
|
|
192
246
|
),
|
|
247
|
+
AnthropicModelInfo(
|
|
248
|
+
model_id='claude-haiku-4-5@20251001',
|
|
249
|
+
alias_for='claude-haiku-4-5-20251001',
|
|
250
|
+
provider='VertexAI',
|
|
251
|
+
in_service=True,
|
|
252
|
+
description='Claude 4.5 Haiku model served on VertexAI (10/15/2025).',
|
|
253
|
+
release_date=datetime.datetime(2025, 10, 15),
|
|
254
|
+
input_modalities=(
|
|
255
|
+
AnthropicModelInfo.INPUT_IMAGE_TYPES
|
|
256
|
+
+ AnthropicModelInfo.INPUT_DOC_TYPES
|
|
257
|
+
),
|
|
258
|
+
context_length=lf.ModelInfo.ContextLength(
|
|
259
|
+
max_input_tokens=200_000,
|
|
260
|
+
max_output_tokens=64_000,
|
|
261
|
+
),
|
|
262
|
+
pricing=lf.ModelInfo.Pricing(
|
|
263
|
+
# For global endpoint
|
|
264
|
+
cost_per_1m_cached_input_tokens=0.1,
|
|
265
|
+
cost_per_1m_input_tokens=1,
|
|
266
|
+
cost_per_1m_output_tokens=5,
|
|
267
|
+
),
|
|
268
|
+
rate_limits=AnthropicModelInfo.RateLimits(
|
|
269
|
+
# For global endpoint
|
|
270
|
+
max_requests_per_minute=2500,
|
|
271
|
+
max_input_tokens_per_minute=200_000,
|
|
272
|
+
max_output_tokens_per_minute=0,
|
|
273
|
+
),
|
|
274
|
+
),
|
|
275
|
+
AnthropicModelInfo(
|
|
276
|
+
model_id='claude-sonnet-4-5@20250929',
|
|
277
|
+
alias_for='claude-sonnet-4-5-20250929',
|
|
278
|
+
provider='VertexAI',
|
|
279
|
+
in_service=True,
|
|
280
|
+
description='Claude 4.5 Sonnet model (9/29/2025).',
|
|
281
|
+
release_date=datetime.datetime(2025, 9, 29),
|
|
282
|
+
input_modalities=(
|
|
283
|
+
AnthropicModelInfo.INPUT_IMAGE_TYPES
|
|
284
|
+
+ AnthropicModelInfo.INPUT_DOC_TYPES
|
|
285
|
+
),
|
|
286
|
+
context_length=lf.ModelInfo.ContextLength(
|
|
287
|
+
max_input_tokens=200_000,
|
|
288
|
+
max_output_tokens=64_000,
|
|
289
|
+
),
|
|
290
|
+
pricing=lf.ModelInfo.Pricing(
|
|
291
|
+
# For global endpoint
|
|
292
|
+
cost_per_1m_cached_input_tokens=0.3,
|
|
293
|
+
cost_per_1m_input_tokens=3,
|
|
294
|
+
cost_per_1m_output_tokens=15,
|
|
295
|
+
),
|
|
296
|
+
rate_limits=AnthropicModelInfo.RateLimits(
|
|
297
|
+
# For global endpoint
|
|
298
|
+
max_requests_per_minute=1500,
|
|
299
|
+
max_input_tokens_per_minute=200_000,
|
|
300
|
+
max_output_tokens_per_minute=0,
|
|
301
|
+
),
|
|
302
|
+
),
|
|
193
303
|
AnthropicModelInfo(
|
|
194
304
|
model_id='claude-opus-4@20250514',
|
|
195
305
|
alias_for='claude-opus-4-20250514',
|
|
@@ -658,6 +768,8 @@ class Anthropic(rest.REST):
|
|
|
658
768
|
args.pop('temperature', None)
|
|
659
769
|
args.pop('top_k', None)
|
|
660
770
|
args.pop('top_p', None)
|
|
771
|
+
if options.extras:
|
|
772
|
+
args.update(options.extras)
|
|
661
773
|
return args
|
|
662
774
|
|
|
663
775
|
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
|
@@ -679,6 +791,24 @@ class Anthropic(rest.REST):
|
|
|
679
791
|
return super()._error(status_code, content)
|
|
680
792
|
|
|
681
793
|
|
|
794
|
+
class Claude45(Anthropic):
|
|
795
|
+
"""Base class for Claude 4.5 models."""
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
# pylint: disable=invalid-name
|
|
799
|
+
class Claude45Haiku_20251001(Claude45):
|
|
800
|
+
"""Claude 4.5 Haiku model 20251001."""
|
|
801
|
+
|
|
802
|
+
model = 'claude-haiku-4-5-20251001'
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
# pylint: disable=invalid-name
|
|
806
|
+
class Claude45Sonnet_20250929(Claude45):
|
|
807
|
+
"""Claude 4.5 Sonnet model 20250929."""
|
|
808
|
+
|
|
809
|
+
model = 'claude-sonnet-4-5-20250929'
|
|
810
|
+
|
|
811
|
+
|
|
682
812
|
class Claude4(Anthropic):
|
|
683
813
|
"""Base class for Claude 4 models."""
|
|
684
814
|
|
langfun/core/llms/cache/base.py
CHANGED
|
@@ -121,4 +121,6 @@ class LMCacheBase(lf.LMCache):
|
|
|
121
121
|
|
|
122
122
|
def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
|
|
123
123
|
"""Default key for LM cache."""
|
|
124
|
-
|
|
124
|
+
# prompt text already contains the modality id for referenced modality
|
|
125
|
+
# objects, so no need to include them in the key.
|
|
126
|
+
return (prompt.text, lm.sampling_options.cache_key(), seed)
|
|
@@ -175,18 +175,28 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
|
175
175
|
|
|
176
176
|
cache = in_memory.InMemory()
|
|
177
177
|
lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
|
|
178
|
-
|
|
179
|
-
|
|
178
|
+
image_foo = CustomModality('foo')
|
|
179
|
+
image_bar = CustomModality('bar')
|
|
180
|
+
lm(
|
|
181
|
+
lf.UserMessage(
|
|
182
|
+
f'hi <<[[{image_foo.id}]]>>', referred_modalities=[image_foo]
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
lm(
|
|
186
|
+
lf.UserMessage(
|
|
187
|
+
f'hi <<[[{image_bar.id}]]>>', referred_modalities=[image_bar]
|
|
188
|
+
)
|
|
189
|
+
)
|
|
180
190
|
self.assertEqual(
|
|
181
191
|
list(cache.keys()),
|
|
182
192
|
[
|
|
183
193
|
(
|
|
184
|
-
'hi <<[[
|
|
194
|
+
f'hi <<[[{image_foo.id}]]>>',
|
|
185
195
|
(None, None, 1, 40, None, None),
|
|
186
196
|
0,
|
|
187
197
|
),
|
|
188
198
|
(
|
|
189
|
-
'hi <<[[
|
|
199
|
+
f'hi <<[[{image_bar.id}]]>>',
|
|
190
200
|
(None, None, 1, 40, None, None),
|
|
191
201
|
0,
|
|
192
202
|
),
|
langfun/core/llms/deepseek.py
CHANGED
|
@@ -93,7 +93,7 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
93
93
|
# DeepSeek API uses an API format compatible with OpenAI.
|
|
94
94
|
# Reference: https://api-docs.deepseek.com/
|
|
95
95
|
@lf.use_init_args(['model'])
|
|
96
|
-
class DeepSeek(openai_compatible.
|
|
96
|
+
class DeepSeek(openai_compatible.OpenAIChatCompletionAPI):
|
|
97
97
|
"""DeepSeek model."""
|
|
98
98
|
|
|
99
99
|
model: pg.typing.Annotated[
|
langfun/core/llms/gemini.py
CHANGED
|
@@ -752,11 +752,8 @@ class Gemini(rest.REST):
|
|
|
752
752
|
prompt.as_format('gemini', chunk_preprocessor=modality_conversion)
|
|
753
753
|
)
|
|
754
754
|
request['contents'] = contents
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
# metadata_gemini_tools=[{'google_search': {}}]
|
|
758
|
-
if tools := prompt.metadata.get('gemini_tools'):
|
|
759
|
-
request['tools'] = tools
|
|
755
|
+
if sampling_options.extras:
|
|
756
|
+
request.update(sampling_options.extras)
|
|
760
757
|
return request
|
|
761
758
|
|
|
762
759
|
def _generation_config(
|
langfun/core/llms/groq.py
CHANGED
|
@@ -259,7 +259,7 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
@lf.use_init_args(['model'])
|
|
262
|
-
class Groq(openai_compatible.
|
|
262
|
+
class Groq(openai_compatible.OpenAIChatCompletionAPI):
|
|
263
263
|
"""Groq LLMs through REST APIs (OpenAI compatible).
|
|
264
264
|
|
|
265
265
|
See https://platform.openai.com/docs/api-reference/chat
|
langfun/core/llms/llama_cpp.py
CHANGED
|
@@ -20,7 +20,7 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
@pg.use_init_args(['url', 'model'])
|
|
22
22
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
23
|
-
class LlamaCppRemote(openai_compatible.
|
|
23
|
+
class LlamaCppRemote(openai_compatible.OpenAIChatCompletionAPI):
|
|
24
24
|
"""The remote LLaMA C++ model.
|
|
25
25
|
|
|
26
26
|
The Remote LLaMA C++ models can be launched via
|
langfun/core/llms/openai.py
CHANGED
|
@@ -1031,7 +1031,7 @@ _SUPPORTED_MODELS_BY_MODEL_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
1031
1031
|
|
|
1032
1032
|
|
|
1033
1033
|
@lf.use_init_args(['model'])
|
|
1034
|
-
class OpenAI(openai_compatible.
|
|
1034
|
+
class OpenAI(openai_compatible.OpenAIResponsesAPI):
|
|
1035
1035
|
"""OpenAI model."""
|
|
1036
1036
|
|
|
1037
1037
|
model: pg.typing.Annotated[
|
|
@@ -1041,7 +1041,12 @@ class OpenAI(openai_compatible.OpenAICompatible):
|
|
|
1041
1041
|
'The name of the model to use.',
|
|
1042
1042
|
]
|
|
1043
1043
|
|
|
1044
|
-
|
|
1044
|
+
# Disable message storage by default.
|
|
1045
|
+
sampling_options = lf.LMSamplingOptions(
|
|
1046
|
+
extras={'store': False}
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
api_endpoint: str = 'https://api.openai.com/v1/responses'
|
|
1045
1050
|
|
|
1046
1051
|
api_key: Annotated[
|
|
1047
1052
|
str | None,
|
|
@@ -23,8 +23,13 @@ import pyglove as pg
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@lf.use_init_args(['api_endpoint', 'model'])
|
|
26
|
-
class
|
|
27
|
-
"""Base for OpenAI compatible models.
|
|
26
|
+
class OpenAIChatCompletionAPI(rest.REST):
|
|
27
|
+
"""Base for OpenAI compatible models based on ChatCompletion API.
|
|
28
|
+
|
|
29
|
+
See https://platform.openai.com/docs/api-reference/chat
|
|
30
|
+
As of 2025-10-23, OpenAI is migrating from ChatCompletion API to Responses
|
|
31
|
+
API.
|
|
32
|
+
"""
|
|
28
33
|
|
|
29
34
|
model: Annotated[
|
|
30
35
|
str, 'The name of the model to use.',
|
|
@@ -42,12 +47,14 @@ class OpenAICompatible(rest.REST):
|
|
|
42
47
|
# Reference:
|
|
43
48
|
# https://platform.openai.com/docs/api-reference/completions/create
|
|
44
49
|
# NOTE(daiyip): options.top_k is not applicable.
|
|
45
|
-
args =
|
|
46
|
-
|
|
47
|
-
top_logprobs=options.top_logprobs,
|
|
48
|
-
)
|
|
50
|
+
args = {}
|
|
51
|
+
|
|
49
52
|
if self.model:
|
|
50
53
|
args['model'] = self.model
|
|
54
|
+
if options.n != 1:
|
|
55
|
+
args['n'] = options.n
|
|
56
|
+
if options.top_logprobs is not None:
|
|
57
|
+
args['top_logprobs'] = options.top_logprobs
|
|
51
58
|
if options.logprobs:
|
|
52
59
|
args['logprobs'] = options.logprobs
|
|
53
60
|
if options.temperature is not None:
|
|
@@ -62,6 +69,8 @@ class OpenAICompatible(rest.REST):
|
|
|
62
69
|
args['seed'] = options.random_seed
|
|
63
70
|
if options.reasoning_effort is not None:
|
|
64
71
|
args['reasoning_effort'] = options.reasoning_effort
|
|
72
|
+
if options.extras:
|
|
73
|
+
args.update(options.extras)
|
|
65
74
|
return args
|
|
66
75
|
|
|
67
76
|
def request(
|
|
@@ -72,27 +81,13 @@ class OpenAICompatible(rest.REST):
|
|
|
72
81
|
"""Returns the JSON input for a message."""
|
|
73
82
|
request_args = self._request_args(sampling_options)
|
|
74
83
|
|
|
75
|
-
#
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if json_schema is not None:
|
|
79
|
-
if not isinstance(json_schema, dict):
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
82
|
-
)
|
|
83
|
-
if 'title' not in json_schema:
|
|
84
|
-
raise ValueError(
|
|
85
|
-
f'The root of `json_schema` must have a `title` field, '
|
|
86
|
-
f'got {json_schema!r}.'
|
|
87
|
-
)
|
|
84
|
+
# Handle structured output.
|
|
85
|
+
output_schema = self._structure_output_schema(prompt)
|
|
86
|
+
if output_schema is not None:
|
|
88
87
|
request_args.update(
|
|
89
88
|
response_format=dict(
|
|
90
89
|
type='json_schema',
|
|
91
|
-
json_schema=
|
|
92
|
-
schema=json_schema,
|
|
93
|
-
name=json_schema['title'],
|
|
94
|
-
strict=True,
|
|
95
|
-
)
|
|
90
|
+
json_schema=output_schema,
|
|
96
91
|
)
|
|
97
92
|
)
|
|
98
93
|
prompt.metadata.formatted_text = (
|
|
@@ -118,17 +113,43 @@ class OpenAICompatible(rest.REST):
|
|
|
118
113
|
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
119
114
|
messages.append(
|
|
120
115
|
system_message.as_format(
|
|
121
|
-
'
|
|
116
|
+
'openai_chat_completion_api', chunk_preprocessor=modality_check
|
|
122
117
|
)
|
|
123
118
|
)
|
|
124
119
|
messages.append(
|
|
125
|
-
prompt.as_format(
|
|
120
|
+
prompt.as_format(
|
|
121
|
+
'openai_chat_completion_api',
|
|
122
|
+
chunk_preprocessor=modality_check
|
|
123
|
+
)
|
|
126
124
|
)
|
|
127
125
|
request = dict()
|
|
128
126
|
request.update(request_args)
|
|
129
127
|
request['messages'] = messages
|
|
130
128
|
return request
|
|
131
129
|
|
|
130
|
+
def _structure_output_schema(
|
|
131
|
+
self, prompt: lf.Message
|
|
132
|
+
) -> dict[str, Any] | None:
|
|
133
|
+
# Users could use `metadata_json_schema` to pass additional
|
|
134
|
+
# request arguments.
|
|
135
|
+
json_schema = prompt.metadata.get('json_schema')
|
|
136
|
+
if json_schema is not None:
|
|
137
|
+
if not isinstance(json_schema, dict):
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
140
|
+
)
|
|
141
|
+
if 'title' not in json_schema:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f'The root of `json_schema` must have a `title` field, '
|
|
144
|
+
f'got {json_schema!r}.'
|
|
145
|
+
)
|
|
146
|
+
return dict(
|
|
147
|
+
schema=json_schema,
|
|
148
|
+
name=json_schema['title'],
|
|
149
|
+
strict=True,
|
|
150
|
+
)
|
|
151
|
+
return None
|
|
152
|
+
|
|
132
153
|
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
|
133
154
|
# Reference:
|
|
134
155
|
# https://platform.openai.com/docs/api-reference/chat/object
|
|
@@ -144,7 +165,10 @@ class OpenAICompatible(rest.REST):
|
|
|
144
165
|
for t in choice_logprobs['content']
|
|
145
166
|
]
|
|
146
167
|
return lf.LMSample(
|
|
147
|
-
lf.Message.from_value(
|
|
168
|
+
lf.Message.from_value(
|
|
169
|
+
choice['message'],
|
|
170
|
+
format='openai_chat_completion_api'
|
|
171
|
+
),
|
|
148
172
|
score=0.0,
|
|
149
173
|
logprobs=logprobs,
|
|
150
174
|
)
|
|
@@ -169,3 +193,88 @@ class OpenAICompatible(rest.REST):
|
|
|
169
193
|
or (status_code == 400 and b'string_above_max_length' in content)):
|
|
170
194
|
return lf.ContextLimitError(f'{status_code}: {content}')
|
|
171
195
|
return super()._error(status_code, content)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
|
|
199
|
+
"""Base for OpenAI compatible models based on Responses API.
|
|
200
|
+
|
|
201
|
+
https://platform.openai.com/docs/api-reference/responses/create
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def _request_args(
|
|
205
|
+
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
|
206
|
+
"""Returns a dict as request arguments."""
|
|
207
|
+
if options.logprobs:
|
|
208
|
+
raise ValueError('logprobs is not supported on Responses API.')
|
|
209
|
+
if options.n != 1:
|
|
210
|
+
raise ValueError('n must be 1 for Responses API.')
|
|
211
|
+
return super()._request_args(options)
|
|
212
|
+
|
|
213
|
+
def request(
|
|
214
|
+
self,
|
|
215
|
+
prompt: lf.Message,
|
|
216
|
+
sampling_options: lf.LMSamplingOptions
|
|
217
|
+
) -> dict[str, Any]:
|
|
218
|
+
"""Returns the JSON input for a message."""
|
|
219
|
+
request_args = self._request_args(sampling_options)
|
|
220
|
+
|
|
221
|
+
# Handle structured output.
|
|
222
|
+
output_schema = self._structure_output_schema(prompt)
|
|
223
|
+
if output_schema is not None:
|
|
224
|
+
output_schema['type'] = 'json_schema'
|
|
225
|
+
request_args.update(text=dict(format=output_schema))
|
|
226
|
+
prompt.metadata.formatted_text = (
|
|
227
|
+
prompt.text
|
|
228
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
|
229
|
+
+ pg.to_json_str(request_args['text'], json_indent=2)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
request = dict()
|
|
233
|
+
request.update(request_args)
|
|
234
|
+
|
|
235
|
+
# Users could use `metadata_system_message` to pass system message.
|
|
236
|
+
system_message = prompt.metadata.get('system_message')
|
|
237
|
+
if system_message:
|
|
238
|
+
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
239
|
+
request['instructions'] = system_message.text
|
|
240
|
+
|
|
241
|
+
# Prepare input.
|
|
242
|
+
def modality_check(chunk: str | lf.Modality) -> Any:
|
|
243
|
+
if (isinstance(chunk, lf_modalities.Mime)
|
|
244
|
+
and not self.supports_input(chunk.mime_type)):
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f'Unsupported modality: {chunk!r}.'
|
|
247
|
+
)
|
|
248
|
+
return chunk
|
|
249
|
+
|
|
250
|
+
request['input'] = [
|
|
251
|
+
prompt.as_format(
|
|
252
|
+
'openai_responses_api',
|
|
253
|
+
chunk_preprocessor=modality_check
|
|
254
|
+
)
|
|
255
|
+
]
|
|
256
|
+
return request
|
|
257
|
+
|
|
258
|
+
def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
|
|
259
|
+
for item in output:
|
|
260
|
+
if isinstance(item, dict) and item.get('type') == 'message':
|
|
261
|
+
return lf.LMSample(
|
|
262
|
+
lf.Message.from_value(item, format='openai_responses_api'),
|
|
263
|
+
score=0.0,
|
|
264
|
+
)
|
|
265
|
+
raise ValueError('No message found in output.')
|
|
266
|
+
|
|
267
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
|
268
|
+
"""Returns a LMSamplingResult from a JSON response."""
|
|
269
|
+
usage = json['usage']
|
|
270
|
+
return lf.LMSamplingResult(
|
|
271
|
+
samples=[self._parse_output(json['output'])],
|
|
272
|
+
usage=lf.LMSamplingUsage(
|
|
273
|
+
prompt_tokens=usage['input_tokens'],
|
|
274
|
+
completion_tokens=usage['output_tokens'],
|
|
275
|
+
total_tokens=usage['total_tokens'],
|
|
276
|
+
completion_tokens_details=usage.get(
|
|
277
|
+
'output_tokens_details', None
|
|
278
|
+
),
|
|
279
|
+
),
|
|
280
|
+
)
|