langfun 0.1.2.dev202510200805__py3-none-any.whl → 0.1.2.dev202511160804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +1 -0
- langfun/core/agentic/action.py +107 -12
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +25 -0
- langfun/core/async_support.py +32 -3
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +1 -0
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +9 -2
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +1 -0
- langfun/core/eval/v2/checkpointing.py +39 -5
- langfun/core/eval/v2/checkpointing_test.py +1 -1
- langfun/core/eval/v2/eval_test_helper.py +97 -1
- langfun/core/eval/v2/evaluation.py +88 -16
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +45 -39
- langfun/core/eval/v2/example_test.py +3 -3
- langfun/core/eval/v2/experiment.py +51 -8
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +30 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +20 -6
- langfun/core/eval/v2/runners/__init__.py +26 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +22 -124
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +79 -0
- langfun/core/eval/v2/runners/parallel.py +100 -0
- langfun/core/eval/v2/runners/parallel_test.py +98 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +175 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +103 -16
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +7 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +14 -9
- langfun/core/llms/google_genai.py +29 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +36 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +12 -1
- langfun/core/llms/vertexai.py +51 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/client.py +77 -22
- langfun/core/mcp/client_test.py +8 -35
- langfun/core/mcp/session.py +94 -29
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/tool.py +151 -22
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +19 -1
- langfun/core/modalities/mime.py +62 -3
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +215 -142
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/structured/schema/__init__.py +48 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +8 -2
- langfun/env/base_environment.py +320 -128
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +92 -15
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +84 -361
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +1 -1
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +95 -98
- langfun/env/event_handlers/event_logger_test.py +21 -21
- langfun/env/event_handlers/metric_writer.py +225 -140
- langfun/env/event_handlers/metric_writer_test.py +23 -6
- langfun/env/interface.py +854 -40
- langfun/env/interface_test.py +112 -2
- langfun/env/load_balancers_test.py +23 -2
- langfun/env/test_utils.py +126 -84
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/METADATA +1 -1
- langfun-0.1.2.dev202511160804.dist-info/RECORD +211 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun/env/base_test.py +0 -1481
- langfun/env/event_handlers/base.py +0 -350
- langfun-0.1.2.dev202510200805.dist-info/RECORD +0 -195
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510200805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/top_level.txt +0 -0
langfun/core/llms/deepseek.py
CHANGED
|
@@ -93,8 +93,36 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
93
93
|
# DeepSeek API uses an API format compatible with OpenAI.
|
|
94
94
|
# Reference: https://api-docs.deepseek.com/
|
|
95
95
|
@lf.use_init_args(['model'])
|
|
96
|
-
class DeepSeek(openai_compatible.
|
|
97
|
-
"""DeepSeek
|
|
96
|
+
class DeepSeek(openai_compatible.OpenAIChatCompletionAPI):
|
|
97
|
+
"""DeepSeek models.
|
|
98
|
+
|
|
99
|
+
**Quick Start:**
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
import langfun as lf
|
|
103
|
+
|
|
104
|
+
# Call DeepSeek-V3 using API key from environment variable
|
|
105
|
+
# 'DEEPSEEK_API_KEY'.
|
|
106
|
+
lm = lf.llms.DeepSeekV3()
|
|
107
|
+
r = lm('Who are you?')
|
|
108
|
+
print(r)
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
**Setting up API key:**
|
|
112
|
+
|
|
113
|
+
The DeepSeek API key can be specified in following ways:
|
|
114
|
+
|
|
115
|
+
1. At model instantiation:
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
lm = lf.llms.DeepSeekV3(api_key='MY_API_KEY')
|
|
119
|
+
```
|
|
120
|
+
2. via environment variable `DEEPSEEK_API_KEY`.
|
|
121
|
+
|
|
122
|
+
**References:**
|
|
123
|
+
|
|
124
|
+
* https://api-docs.deepseek.com/
|
|
125
|
+
"""
|
|
98
126
|
|
|
99
127
|
model: pg.typing.Annotated[
|
|
100
128
|
pg.typing.Enum(
|
langfun/core/llms/fake.py
CHANGED
|
@@ -20,7 +20,38 @@ import langfun.core as lf
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class Fake(lf.LanguageModel):
|
|
23
|
-
"""
|
|
23
|
+
"""Base class for fake language models, used for testing.
|
|
24
|
+
|
|
25
|
+
Fake models simulate the behavior of real language models but return
|
|
26
|
+
pre-defined responses, making them useful for testing prompts,
|
|
27
|
+
data processing logic, and agent behavior without incurring API costs
|
|
28
|
+
or relying on external services.
|
|
29
|
+
|
|
30
|
+
Langfun provides several fake models:
|
|
31
|
+
* `lf.llms.Echo`: Echoes the prompt back as the response.
|
|
32
|
+
* `lf.llms.StaticResponse`: Returns a fixed, pre-defined response for
|
|
33
|
+
any prompt.
|
|
34
|
+
* `lf.llms.StaticMapping`: Returns responses based on a prompt-to-response
|
|
35
|
+
dictionary.
|
|
36
|
+
* `lf.llms.StaticSequence`: Returns responses from a pre-defined sequence
|
|
37
|
+
in order.
|
|
38
|
+
|
|
39
|
+
**Example:**
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
import langfun as lf
|
|
43
|
+
|
|
44
|
+
# Use Echo model for testing
|
|
45
|
+
lm = lf.llms.Echo()
|
|
46
|
+
response = lm('hello')
|
|
47
|
+
assert response.text == 'hello'
|
|
48
|
+
|
|
49
|
+
# Use StaticResponse model
|
|
50
|
+
lm = lf.llms.StaticResponse('world')
|
|
51
|
+
response = lm('hello')
|
|
52
|
+
assert response.text == 'world'
|
|
53
|
+
```
|
|
54
|
+
"""
|
|
24
55
|
|
|
25
56
|
def _score(self, prompt: lf.Message| list[lf.Message],
|
|
26
57
|
completions: list[lf.Message]):
|
langfun/core/llms/gemini.py
CHANGED
|
@@ -696,7 +696,15 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
696
696
|
|
|
697
697
|
@pg.use_init_args(['model'])
|
|
698
698
|
class Gemini(rest.REST):
|
|
699
|
-
"""
|
|
699
|
+
"""Base class for Gemini models served on Google GenAI and Vertex AI.
|
|
700
|
+
|
|
701
|
+
This class implements the Gemini API protocol, shared by
|
|
702
|
+
`lf.llms.GoogleGenAI` and `lf.llms.VertexAI`, providing common request
|
|
703
|
+
formatting and response parsing for Gemini models.
|
|
704
|
+
|
|
705
|
+
It is not intended to be used directly. Please use `lf.llms.GoogleGenAI` or
|
|
706
|
+
`lf.llms.VertexAI` instead.
|
|
707
|
+
"""
|
|
700
708
|
|
|
701
709
|
model: pg.typing.Annotated[
|
|
702
710
|
pg.typing.Enum(
|
|
@@ -752,11 +760,8 @@ class Gemini(rest.REST):
|
|
|
752
760
|
prompt.as_format('gemini', chunk_preprocessor=modality_conversion)
|
|
753
761
|
)
|
|
754
762
|
request['contents'] = contents
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
# metadata_gemini_tools=[{'google_search': {}}]
|
|
758
|
-
if tools := prompt.metadata.get('gemini_tools'):
|
|
759
|
-
request['tools'] = tools
|
|
763
|
+
if sampling_options.extras:
|
|
764
|
+
request.update(sampling_options.extras)
|
|
760
765
|
return request
|
|
761
766
|
|
|
762
767
|
def _generation_config(
|
|
@@ -833,9 +838,9 @@ class Gemini(rest.REST):
|
|
|
833
838
|
)
|
|
834
839
|
|
|
835
840
|
def _error(self, status_code: int, content: str) -> lf.LMError:
|
|
836
|
-
if (
|
|
837
|
-
|
|
838
|
-
|
|
841
|
+
if status_code == 400 and (
|
|
842
|
+
b'exceeds the maximum number of tokens' in content
|
|
843
|
+
or b'Reduce the input token count and try again.' in content
|
|
839
844
|
):
|
|
840
845
|
return lf.ContextLimitError(f'{status_code}: {content}')
|
|
841
846
|
return super()._error(status_code, content)
|
|
@@ -25,7 +25,35 @@ import pyglove as pg
|
|
|
25
25
|
@lf.use_init_args(['model'])
|
|
26
26
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
27
27
|
class GenAI(gemini.Gemini):
|
|
28
|
-
"""
|
|
28
|
+
"""Google GenAI models.
|
|
29
|
+
|
|
30
|
+
**Quick Start:**
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import langfun as lf
|
|
34
|
+
|
|
35
|
+
# Call Gemini 1.5 Flash using API key from environment variable
|
|
36
|
+
# 'GOOGLE_API_KEY'.
|
|
37
|
+
lm = lf.llms.Gemini15Flash()
|
|
38
|
+
r = lm('Who are you?')
|
|
39
|
+
print(r)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
**Setting up API key:**
|
|
43
|
+
|
|
44
|
+
The Google API key can be specified in following ways:
|
|
45
|
+
|
|
46
|
+
1. At model instantiation:
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
lm = lf.llms.Gemini15Flash(api_key='MY_API_KEY')
|
|
50
|
+
```
|
|
51
|
+
2. via environment variable `GOOGLE_API_KEY`.
|
|
52
|
+
|
|
53
|
+
**References:**
|
|
54
|
+
|
|
55
|
+
* https://ai.google.dev/docs
|
|
56
|
+
"""
|
|
29
57
|
|
|
30
58
|
model: pg.typing.Annotated[
|
|
31
59
|
pg.typing.Enum(
|
langfun/core/llms/groq.py
CHANGED
|
@@ -259,10 +259,35 @@ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
@lf.use_init_args(['model'])
|
|
262
|
-
class Groq(openai_compatible.
|
|
263
|
-
"""Groq
|
|
262
|
+
class Groq(openai_compatible.OpenAIChatCompletionAPI):
|
|
263
|
+
"""Groq models.
|
|
264
264
|
|
|
265
|
-
|
|
265
|
+
**Quick Start:**
|
|
266
|
+
|
|
267
|
+
```python
|
|
268
|
+
import langfun as lf
|
|
269
|
+
|
|
270
|
+
# Call Llama 3.3 70B on Groq using API key from environment variable
|
|
271
|
+
# 'GROQ_API_KEY'.
|
|
272
|
+
lm = lf.llms.GroqLlama33_70B_Versatile()
|
|
273
|
+
r = lm('Who are you?')
|
|
274
|
+
print(r)
|
|
275
|
+
```
|
|
276
|
+
|
|
277
|
+
**Setting up API key:**
|
|
278
|
+
|
|
279
|
+
The Groq API key can be specified in following ways:
|
|
280
|
+
|
|
281
|
+
1. At model instantiation:
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
lm = lf.llms.GroqLlama33_70B_Versatile(api_key='MY_API_KEY')
|
|
285
|
+
```
|
|
286
|
+
2. via environment variable `GROQ_API_KEY`.
|
|
287
|
+
|
|
288
|
+
**References:**
|
|
289
|
+
|
|
290
|
+
* https://console.groq.com/docs
|
|
266
291
|
"""
|
|
267
292
|
|
|
268
293
|
model: pg.typing.Annotated[
|
langfun/core/llms/llama_cpp.py
CHANGED
|
@@ -20,11 +20,30 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
@pg.use_init_args(['url', 'model'])
|
|
22
22
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
23
|
-
class LlamaCppRemote(openai_compatible.
|
|
24
|
-
"""
|
|
23
|
+
class LlamaCppRemote(openai_compatible.OpenAIChatCompletionAPI):
|
|
24
|
+
"""LLaMA C++ models served via a remote server.
|
|
25
25
|
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
This class provides an interface to interact with language models
|
|
27
|
+
hosted on a LLaMA C++ server, which is compatible with the OpenAI
|
|
28
|
+
Chat Completions API format.
|
|
29
|
+
|
|
30
|
+
**Quick Start:**
|
|
31
|
+
|
|
32
|
+
Assuming a LLaMA C++ server is running at `http://localhost:8080`,
|
|
33
|
+
you can interact with it as follows:
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
import langfun as lf
|
|
37
|
+
|
|
38
|
+
# If model name is not specified, it will use server's default.
|
|
39
|
+
lm = lf.llms.LlamaCppRemote(url='http://localhost:8080')
|
|
40
|
+
r = lm('Who are you?')
|
|
41
|
+
print(r)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
**References:**
|
|
45
|
+
|
|
46
|
+
* https://github.com/ggerganov/llama.cpp/tree/master/examples/server
|
|
28
47
|
"""
|
|
29
48
|
url: Annotated[
|
|
30
49
|
str,
|
langfun/core/llms/openai.py
CHANGED
|
@@ -1031,8 +1031,36 @@ _SUPPORTED_MODELS_BY_MODEL_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
|
|
1031
1031
|
|
|
1032
1032
|
|
|
1033
1033
|
@lf.use_init_args(['model'])
|
|
1034
|
-
class OpenAI(openai_compatible.
|
|
1035
|
-
"""OpenAI
|
|
1034
|
+
class OpenAI(openai_compatible.OpenAIResponsesAPI):
|
|
1035
|
+
"""OpenAI models.
|
|
1036
|
+
|
|
1037
|
+
**Quick Start:**
|
|
1038
|
+
|
|
1039
|
+
```python
|
|
1040
|
+
import langfun as lf
|
|
1041
|
+
|
|
1042
|
+
# Call GPT-4o using API key from environment variable 'OPENAI_API_KEY'.
|
|
1043
|
+
lm = lf.llms.Gpt4o()
|
|
1044
|
+
r = lm('Who are you?')
|
|
1045
|
+
print(r)
|
|
1046
|
+
```
|
|
1047
|
+
|
|
1048
|
+
**Setting up API key:**
|
|
1049
|
+
|
|
1050
|
+
The OpenAI API key can be specified in following ways:
|
|
1051
|
+
|
|
1052
|
+
1. At model instantiation:
|
|
1053
|
+
|
|
1054
|
+
```python
|
|
1055
|
+
lm = lf.llms.Gpt4o(api_key='MY_API_KEY')
|
|
1056
|
+
```
|
|
1057
|
+
2. via environment variable `OPENAI_API_KEY`.
|
|
1058
|
+
|
|
1059
|
+
**References:**
|
|
1060
|
+
|
|
1061
|
+
* https://platform.openai.com/docs/models
|
|
1062
|
+
* https://platform.openai.com/docs/api-reference
|
|
1063
|
+
"""
|
|
1036
1064
|
|
|
1037
1065
|
model: pg.typing.Annotated[
|
|
1038
1066
|
pg.typing.Enum(
|
|
@@ -1041,7 +1069,12 @@ class OpenAI(openai_compatible.OpenAICompatible):
|
|
|
1041
1069
|
'The name of the model to use.',
|
|
1042
1070
|
]
|
|
1043
1071
|
|
|
1044
|
-
|
|
1072
|
+
# Disable message storage by default.
|
|
1073
|
+
sampling_options = lf.LMSamplingOptions(
|
|
1074
|
+
extras={'store': False}
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
api_endpoint: str = 'https://api.openai.com/v1/responses'
|
|
1045
1078
|
|
|
1046
1079
|
api_key: Annotated[
|
|
1047
1080
|
str | None,
|
|
@@ -23,8 +23,18 @@ import pyglove as pg
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@lf.use_init_args(['api_endpoint', 'model'])
|
|
26
|
-
class
|
|
27
|
-
"""Base for
|
|
26
|
+
class OpenAIChatCompletionAPI(rest.REST):
|
|
27
|
+
"""Base class for models compatible with OpenAI's Chat Completion API.
|
|
28
|
+
|
|
29
|
+
This class provides a common interface for language models that adhere to
|
|
30
|
+
the OpenAI Chat Completion API format, which is used by providers like
|
|
31
|
+
Groq, DeepSeek, and others. It standardizes request formatting and
|
|
32
|
+
response parsing for these models.
|
|
33
|
+
|
|
34
|
+
**References:**
|
|
35
|
+
|
|
36
|
+
* https://platform.openai.com/docs/api-reference/chat
|
|
37
|
+
"""
|
|
28
38
|
|
|
29
39
|
model: Annotated[
|
|
30
40
|
str, 'The name of the model to use.',
|
|
@@ -42,12 +52,14 @@ class OpenAICompatible(rest.REST):
|
|
|
42
52
|
# Reference:
|
|
43
53
|
# https://platform.openai.com/docs/api-reference/completions/create
|
|
44
54
|
# NOTE(daiyip): options.top_k is not applicable.
|
|
45
|
-
args =
|
|
46
|
-
|
|
47
|
-
top_logprobs=options.top_logprobs,
|
|
48
|
-
)
|
|
55
|
+
args = {}
|
|
56
|
+
|
|
49
57
|
if self.model:
|
|
50
58
|
args['model'] = self.model
|
|
59
|
+
if options.n != 1:
|
|
60
|
+
args['n'] = options.n
|
|
61
|
+
if options.top_logprobs is not None:
|
|
62
|
+
args['top_logprobs'] = options.top_logprobs
|
|
51
63
|
if options.logprobs:
|
|
52
64
|
args['logprobs'] = options.logprobs
|
|
53
65
|
if options.temperature is not None:
|
|
@@ -62,6 +74,8 @@ class OpenAICompatible(rest.REST):
|
|
|
62
74
|
args['seed'] = options.random_seed
|
|
63
75
|
if options.reasoning_effort is not None:
|
|
64
76
|
args['reasoning_effort'] = options.reasoning_effort
|
|
77
|
+
if options.extras:
|
|
78
|
+
args.update(options.extras)
|
|
65
79
|
return args
|
|
66
80
|
|
|
67
81
|
def request(
|
|
@@ -72,27 +86,13 @@ class OpenAICompatible(rest.REST):
|
|
|
72
86
|
"""Returns the JSON input for a message."""
|
|
73
87
|
request_args = self._request_args(sampling_options)
|
|
74
88
|
|
|
75
|
-
#
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if json_schema is not None:
|
|
79
|
-
if not isinstance(json_schema, dict):
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
82
|
-
)
|
|
83
|
-
if 'title' not in json_schema:
|
|
84
|
-
raise ValueError(
|
|
85
|
-
f'The root of `json_schema` must have a `title` field, '
|
|
86
|
-
f'got {json_schema!r}.'
|
|
87
|
-
)
|
|
89
|
+
# Handle structured output.
|
|
90
|
+
output_schema = self._structure_output_schema(prompt)
|
|
91
|
+
if output_schema is not None:
|
|
88
92
|
request_args.update(
|
|
89
93
|
response_format=dict(
|
|
90
94
|
type='json_schema',
|
|
91
|
-
json_schema=
|
|
92
|
-
schema=json_schema,
|
|
93
|
-
name=json_schema['title'],
|
|
94
|
-
strict=True,
|
|
95
|
-
)
|
|
95
|
+
json_schema=output_schema,
|
|
96
96
|
)
|
|
97
97
|
)
|
|
98
98
|
prompt.metadata.formatted_text = (
|
|
@@ -118,17 +118,43 @@ class OpenAICompatible(rest.REST):
|
|
|
118
118
|
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
119
119
|
messages.append(
|
|
120
120
|
system_message.as_format(
|
|
121
|
-
'
|
|
121
|
+
'openai_chat_completion_api', chunk_preprocessor=modality_check
|
|
122
122
|
)
|
|
123
123
|
)
|
|
124
124
|
messages.append(
|
|
125
|
-
prompt.as_format(
|
|
125
|
+
prompt.as_format(
|
|
126
|
+
'openai_chat_completion_api',
|
|
127
|
+
chunk_preprocessor=modality_check
|
|
128
|
+
)
|
|
126
129
|
)
|
|
127
130
|
request = dict()
|
|
128
131
|
request.update(request_args)
|
|
129
132
|
request['messages'] = messages
|
|
130
133
|
return request
|
|
131
134
|
|
|
135
|
+
def _structure_output_schema(
|
|
136
|
+
self, prompt: lf.Message
|
|
137
|
+
) -> dict[str, Any] | None:
|
|
138
|
+
# Users could use `metadata_json_schema` to pass additional
|
|
139
|
+
# request arguments.
|
|
140
|
+
json_schema = prompt.metadata.get('json_schema')
|
|
141
|
+
if json_schema is not None:
|
|
142
|
+
if not isinstance(json_schema, dict):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
|
145
|
+
)
|
|
146
|
+
if 'title' not in json_schema:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f'The root of `json_schema` must have a `title` field, '
|
|
149
|
+
f'got {json_schema!r}.'
|
|
150
|
+
)
|
|
151
|
+
return dict(
|
|
152
|
+
schema=json_schema,
|
|
153
|
+
name=json_schema['title'],
|
|
154
|
+
strict=True,
|
|
155
|
+
)
|
|
156
|
+
return None
|
|
157
|
+
|
|
132
158
|
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
|
133
159
|
# Reference:
|
|
134
160
|
# https://platform.openai.com/docs/api-reference/chat/object
|
|
@@ -144,7 +170,10 @@ class OpenAICompatible(rest.REST):
|
|
|
144
170
|
for t in choice_logprobs['content']
|
|
145
171
|
]
|
|
146
172
|
return lf.LMSample(
|
|
147
|
-
lf.Message.from_value(
|
|
173
|
+
lf.Message.from_value(
|
|
174
|
+
choice['message'],
|
|
175
|
+
format='openai_chat_completion_api'
|
|
176
|
+
),
|
|
148
177
|
score=0.0,
|
|
149
178
|
logprobs=logprobs,
|
|
150
179
|
)
|
|
@@ -169,3 +198,95 @@ class OpenAICompatible(rest.REST):
|
|
|
169
198
|
or (status_code == 400 and b'string_above_max_length' in content)):
|
|
170
199
|
return lf.ContextLimitError(f'{status_code}: {content}')
|
|
171
200
|
return super()._error(status_code, content)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class OpenAIResponsesAPI(OpenAIChatCompletionAPI):
|
|
204
|
+
"""Base class for models compatible with OpenAI's Responses API.
|
|
205
|
+
|
|
206
|
+
This class provides a common interface for language models that adhere to
|
|
207
|
+
the new OpenAI Responses API format. It standardizes request formatting
|
|
208
|
+
and response parsing for these models, including handling instructions
|
|
209
|
+
(system messages) and structured outputs.
|
|
210
|
+
|
|
211
|
+
**References:**
|
|
212
|
+
|
|
213
|
+
* https://platform.openai.com/docs/api-reference/responses
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def _request_args(
|
|
217
|
+
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
|
218
|
+
"""Returns a dict as request arguments."""
|
|
219
|
+
if options.logprobs:
|
|
220
|
+
raise ValueError('logprobs is not supported on Responses API.')
|
|
221
|
+
if options.n != 1:
|
|
222
|
+
raise ValueError('n must be 1 for Responses API.')
|
|
223
|
+
return super()._request_args(options)
|
|
224
|
+
|
|
225
|
+
def request(
|
|
226
|
+
self,
|
|
227
|
+
prompt: lf.Message,
|
|
228
|
+
sampling_options: lf.LMSamplingOptions
|
|
229
|
+
) -> dict[str, Any]:
|
|
230
|
+
"""Returns the JSON input for a message."""
|
|
231
|
+
request_args = self._request_args(sampling_options)
|
|
232
|
+
|
|
233
|
+
# Handle structured output.
|
|
234
|
+
output_schema = self._structure_output_schema(prompt)
|
|
235
|
+
if output_schema is not None:
|
|
236
|
+
output_schema['type'] = 'json_schema'
|
|
237
|
+
request_args.update(text=dict(format=output_schema))
|
|
238
|
+
prompt.metadata.formatted_text = (
|
|
239
|
+
prompt.text
|
|
240
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
|
241
|
+
+ pg.to_json_str(request_args['text'], json_indent=2)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
request = dict()
|
|
245
|
+
request.update(request_args)
|
|
246
|
+
|
|
247
|
+
# Users could use `metadata_system_message` to pass system message.
|
|
248
|
+
system_message = prompt.metadata.get('system_message')
|
|
249
|
+
if system_message:
|
|
250
|
+
assert isinstance(system_message, lf.SystemMessage), type(system_message)
|
|
251
|
+
request['instructions'] = system_message.text
|
|
252
|
+
|
|
253
|
+
# Prepare input.
|
|
254
|
+
def modality_check(chunk: str | lf.Modality) -> Any:
|
|
255
|
+
if (isinstance(chunk, lf_modalities.Mime)
|
|
256
|
+
and not self.supports_input(chunk.mime_type)):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f'Unsupported modality: {chunk!r}.'
|
|
259
|
+
)
|
|
260
|
+
return chunk
|
|
261
|
+
|
|
262
|
+
request['input'] = [
|
|
263
|
+
prompt.as_format(
|
|
264
|
+
'openai_responses_api',
|
|
265
|
+
chunk_preprocessor=modality_check
|
|
266
|
+
)
|
|
267
|
+
]
|
|
268
|
+
return request
|
|
269
|
+
|
|
270
|
+
def _parse_output(self, output: dict[str, Any]) -> lf.LMSample:
|
|
271
|
+
for item in output:
|
|
272
|
+
if isinstance(item, dict) and item.get('type') == 'message':
|
|
273
|
+
return lf.LMSample(
|
|
274
|
+
lf.Message.from_value(item, format='openai_responses_api'),
|
|
275
|
+
score=0.0,
|
|
276
|
+
)
|
|
277
|
+
raise ValueError('No message found in output.')
|
|
278
|
+
|
|
279
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
|
280
|
+
"""Returns a LMSamplingResult from a JSON response."""
|
|
281
|
+
usage = json['usage']
|
|
282
|
+
return lf.LMSamplingResult(
|
|
283
|
+
samples=[self._parse_output(json['output'])],
|
|
284
|
+
usage=lf.LMSamplingUsage(
|
|
285
|
+
prompt_tokens=usage['input_tokens'],
|
|
286
|
+
completion_tokens=usage['output_tokens'],
|
|
287
|
+
total_tokens=usage['total_tokens'],
|
|
288
|
+
completion_tokens_details=usage.get(
|
|
289
|
+
'output_tokens_details', None
|
|
290
|
+
),
|
|
291
|
+
),
|
|
292
|
+
)
|