langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/llms/fake.py
CHANGED
@@ -21,9 +21,13 @@ import langfun.core as lf
|
|
21
21
|
class Fake(lf.LanguageModel):
|
22
22
|
"""The base class for all fake language models."""
|
23
23
|
|
24
|
-
def _score(self, prompt: lf.Message
|
24
|
+
def _score(self, prompt: lf.Message| list[lf.Message],
|
25
|
+
completions: list[lf.Message]):
|
25
26
|
return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
|
26
27
|
|
28
|
+
def _tokenize(self, prompt: lf.Message) -> list[tuple[str | bytes, int]]:
|
29
|
+
return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
|
30
|
+
|
27
31
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
28
32
|
results = []
|
29
33
|
for prompt in prompts:
|
@@ -57,12 +61,12 @@ class StaticResponse(Fake):
|
|
57
61
|
"""Language model that always gives the same canned response."""
|
58
62
|
|
59
63
|
response: Annotated[
|
60
|
-
str,
|
64
|
+
str | lf.Message,
|
61
65
|
'A canned response that will be returned regardless of the prompt.'
|
62
66
|
]
|
63
67
|
|
64
68
|
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
65
|
-
return lf.AIMessage(self.response)
|
69
|
+
return lf.AIMessage.from_value(self.response)
|
66
70
|
|
67
71
|
|
68
72
|
@lf.use_init_args(['mapping'])
|
@@ -70,12 +74,12 @@ class StaticMapping(Fake):
|
|
70
74
|
"""A static mapping from prompt to response."""
|
71
75
|
|
72
76
|
mapping: Annotated[
|
73
|
-
dict[str, str],
|
77
|
+
dict[str, str | lf.Message],
|
74
78
|
'A mapping from prompt to response.'
|
75
79
|
]
|
76
80
|
|
77
81
|
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
78
|
-
return lf.AIMessage(self.mapping[prompt])
|
82
|
+
return lf.AIMessage.from_value(self.mapping[prompt])
|
79
83
|
|
80
84
|
|
81
85
|
@lf.use_init_args(['sequence'])
|
@@ -83,7 +87,7 @@ class StaticSequence(Fake):
|
|
83
87
|
"""A static sequence of responses to use."""
|
84
88
|
|
85
89
|
sequence: Annotated[
|
86
|
-
list[str],
|
90
|
+
list[str | lf.Message],
|
87
91
|
'A sequence of strings as the response.'
|
88
92
|
]
|
89
93
|
|
@@ -92,6 +96,6 @@ class StaticSequence(Fake):
|
|
92
96
|
self._pos = 0
|
93
97
|
|
94
98
|
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
95
|
-
r = lf.AIMessage(self.sequence[self._pos])
|
99
|
+
r = lf.AIMessage.from_value(self.sequence[self._pos])
|
96
100
|
self._pos += 1
|
97
101
|
return r
|
langfun/core/llms/fake_test.py
CHANGED
@@ -34,6 +34,7 @@ class EchoTest(unittest.TestCase):
|
|
34
34
|
'hi',
|
35
35
|
score=1.0,
|
36
36
|
logprobs=None,
|
37
|
+
is_cached=False,
|
37
38
|
usage=lf.LMSamplingUsage(2, 2, 4),
|
38
39
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
39
40
|
),
|
@@ -62,6 +63,13 @@ class EchoTest(unittest.TestCase):
|
|
62
63
|
[lf.LMScoringResult(0.0), lf.LMScoringResult(-1.0)],
|
63
64
|
)
|
64
65
|
|
66
|
+
def test_tokenize(self):
|
67
|
+
lm = fakelm.Echo()
|
68
|
+
self.assertEqual(
|
69
|
+
lm.tokenize('hi'),
|
70
|
+
[('hi', 0)]
|
71
|
+
)
|
72
|
+
|
65
73
|
|
66
74
|
class StaticResponseTest(unittest.TestCase):
|
67
75
|
|
@@ -78,6 +86,7 @@ class StaticResponseTest(unittest.TestCase):
|
|
78
86
|
canned_response,
|
79
87
|
score=1.0,
|
80
88
|
logprobs=None,
|
89
|
+
is_cached=False,
|
81
90
|
usage=lf.LMSamplingUsage(2, 38, 40),
|
82
91
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
83
92
|
),
|
@@ -99,6 +108,7 @@ class StaticResponseTest(unittest.TestCase):
|
|
99
108
|
canned_response,
|
100
109
|
score=1.0,
|
101
110
|
logprobs=None,
|
111
|
+
is_cached=False,
|
102
112
|
usage=lf.LMSamplingUsage(15, 38, 53),
|
103
113
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
104
114
|
),
|
@@ -143,6 +153,7 @@ class StaticMappingTest(unittest.TestCase):
|
|
143
153
|
'Hello',
|
144
154
|
score=1.0,
|
145
155
|
logprobs=None,
|
156
|
+
is_cached=False,
|
146
157
|
usage=lf.LMSamplingUsage(2, 5, 7),
|
147
158
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
148
159
|
),
|
@@ -159,6 +170,7 @@ class StaticMappingTest(unittest.TestCase):
|
|
159
170
|
'I am fine, how about you?',
|
160
171
|
score=1.0,
|
161
172
|
logprobs=None,
|
173
|
+
is_cached=False,
|
162
174
|
usage=lf.LMSamplingUsage(12, 25, 37),
|
163
175
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
164
176
|
),
|
@@ -192,6 +204,7 @@ class StaticSequenceTest(unittest.TestCase):
|
|
192
204
|
'Hello',
|
193
205
|
score=1.0,
|
194
206
|
logprobs=None,
|
207
|
+
is_cached=False,
|
195
208
|
usage=lf.LMSamplingUsage(2, 5, 7),
|
196
209
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
197
210
|
),
|
@@ -208,6 +221,7 @@ class StaticSequenceTest(unittest.TestCase):
|
|
208
221
|
'I am fine, how about you?',
|
209
222
|
score=1.0,
|
210
223
|
logprobs=None,
|
224
|
+
is_cached=False,
|
211
225
|
usage=lf.LMSamplingUsage(12, 25, 37),
|
212
226
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
213
227
|
),
|
@@ -0,0 +1,507 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Gemini REST API (Shared by Google GenAI and Vertex AI)."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import rest
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
# Supported modalities.
|
25
|
+
|
26
|
+
IMAGE_TYPES = [
|
27
|
+
'image/png',
|
28
|
+
'image/jpeg',
|
29
|
+
'image/webp',
|
30
|
+
'image/heic',
|
31
|
+
'image/heif',
|
32
|
+
]
|
33
|
+
|
34
|
+
AUDIO_TYPES = [
|
35
|
+
'audio/aac',
|
36
|
+
'audio/flac',
|
37
|
+
'audio/mp3',
|
38
|
+
'audio/m4a',
|
39
|
+
'audio/mpeg',
|
40
|
+
'audio/mpga',
|
41
|
+
'audio/mp4',
|
42
|
+
'audio/opus',
|
43
|
+
'audio/pcm',
|
44
|
+
'audio/wav',
|
45
|
+
'audio/webm',
|
46
|
+
]
|
47
|
+
|
48
|
+
VIDEO_TYPES = [
|
49
|
+
'video/mov',
|
50
|
+
'video/mpeg',
|
51
|
+
'video/mpegps',
|
52
|
+
'video/mpg',
|
53
|
+
'video/mp4',
|
54
|
+
'video/webm',
|
55
|
+
'video/wmv',
|
56
|
+
'video/x-flv',
|
57
|
+
'video/3gpp',
|
58
|
+
'video/quicktime',
|
59
|
+
]
|
60
|
+
|
61
|
+
DOCUMENT_TYPES = [
|
62
|
+
'application/pdf',
|
63
|
+
'text/plain',
|
64
|
+
'text/csv',
|
65
|
+
'text/html',
|
66
|
+
'text/xml',
|
67
|
+
'text/x-script.python',
|
68
|
+
'application/json',
|
69
|
+
]
|
70
|
+
|
71
|
+
TEXT_ONLY = []
|
72
|
+
|
73
|
+
ALL_MODALITIES = (
|
74
|
+
IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
|
75
|
+
)
|
76
|
+
|
77
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
78
|
+
# For automatically rate control and cost estimation, we explicitly register
|
79
|
+
# supported models here. This may be inconvenient for new models, but it
|
80
|
+
# helps us to keep track of the models and their pricing.
|
81
|
+
# Models and RPM are from
|
82
|
+
# https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
|
83
|
+
# Pricing in US dollars, from https://ai.google.dev/pricing
|
84
|
+
# as of 2025-01-03.
|
85
|
+
# NOTE: Please update google_genai.py, vertexai.py, __init__.py when
|
86
|
+
# adding new models.
|
87
|
+
# !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
|
88
|
+
'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
|
89
|
+
latest_update='2024-12-19',
|
90
|
+
experimental=True,
|
91
|
+
in_service=True,
|
92
|
+
supported_modalities=ALL_MODALITIES,
|
93
|
+
rpm_free=10,
|
94
|
+
tpm_free=4_000_000,
|
95
|
+
rpm_paid=0,
|
96
|
+
tpm_paid=0,
|
97
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
98
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
99
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
100
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
101
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
102
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
103
|
+
),
|
104
|
+
'gemini-2.0-flash-exp': pg.Dict(
|
105
|
+
latest_update='2024-12-11',
|
106
|
+
experimental=True,
|
107
|
+
in_service=True,
|
108
|
+
supported_modalities=ALL_MODALITIES,
|
109
|
+
rpm_free=10,
|
110
|
+
tpm_free=4_000_000,
|
111
|
+
rpm_paid=0,
|
112
|
+
tpm_paid=0,
|
113
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
114
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
115
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
116
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
117
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
118
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
119
|
+
),
|
120
|
+
'gemini-exp-1206': pg.Dict(
|
121
|
+
latest_update='2024-12-06',
|
122
|
+
experimental=True,
|
123
|
+
in_service=True,
|
124
|
+
supported_modalities=ALL_MODALITIES,
|
125
|
+
rpm_free=10,
|
126
|
+
tpm_free=4_000_000,
|
127
|
+
rpm_paid=0,
|
128
|
+
tpm_paid=0,
|
129
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
130
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
131
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
132
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
133
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
134
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
135
|
+
),
|
136
|
+
'learnlm-1.5-pro-experimental': pg.Dict(
|
137
|
+
latest_update='2024-11-19',
|
138
|
+
experimental=True,
|
139
|
+
in_service=True,
|
140
|
+
supported_modalities=ALL_MODALITIES,
|
141
|
+
rpm_free=10,
|
142
|
+
tpm_free=4_000_000,
|
143
|
+
rpm_paid=0,
|
144
|
+
tpm_paid=0,
|
145
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
146
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
147
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
148
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
149
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
150
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
151
|
+
),
|
152
|
+
'gemini-exp-1114': pg.Dict(
|
153
|
+
latest_update='2024-11-14',
|
154
|
+
experimental=True,
|
155
|
+
in_service=True,
|
156
|
+
supported_modalities=ALL_MODALITIES,
|
157
|
+
rpm_free=10,
|
158
|
+
tpm_free=4_000_000,
|
159
|
+
rpm_paid=0,
|
160
|
+
tpm_paid=0,
|
161
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
162
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
163
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
164
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
165
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
166
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
167
|
+
),
|
168
|
+
'gemini-1.5-flash-latest': pg.Dict(
|
169
|
+
latest_update='2024-09-30',
|
170
|
+
in_service=True,
|
171
|
+
supported_modalities=ALL_MODALITIES,
|
172
|
+
rpm_free=15,
|
173
|
+
tpm_free=1_000_000,
|
174
|
+
rpm_paid=2000,
|
175
|
+
tpm_paid=4_000_000,
|
176
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
177
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
178
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
179
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
180
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
181
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
182
|
+
),
|
183
|
+
'gemini-1.5-flash': pg.Dict(
|
184
|
+
latest_update='2024-09-30',
|
185
|
+
in_service=True,
|
186
|
+
supported_modalities=ALL_MODALITIES,
|
187
|
+
rpm_free=15,
|
188
|
+
tpm_free=1_000_000,
|
189
|
+
rpm_paid=2000,
|
190
|
+
tpm_paid=4_000_000,
|
191
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
192
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
193
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
194
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
195
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
196
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
197
|
+
),
|
198
|
+
'gemini-1.5-flash-001': pg.Dict(
|
199
|
+
latest_update='2024-09-30',
|
200
|
+
in_service=True,
|
201
|
+
supported_modalities=ALL_MODALITIES,
|
202
|
+
rpm_free=15,
|
203
|
+
tpm_free=1_000_000,
|
204
|
+
rpm_paid=2000,
|
205
|
+
tpm_paid=4_000_000,
|
206
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
207
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
208
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
209
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
210
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
211
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
212
|
+
),
|
213
|
+
'gemini-1.5-flash-002': pg.Dict(
|
214
|
+
latest_update='2024-09-30',
|
215
|
+
in_service=True,
|
216
|
+
supported_modalities=ALL_MODALITIES,
|
217
|
+
rpm_free=15,
|
218
|
+
tpm_free=1_000_000,
|
219
|
+
rpm_paid=2000,
|
220
|
+
tpm_paid=4_000_000,
|
221
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
222
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
223
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
224
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
225
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
226
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
227
|
+
),
|
228
|
+
'gemini-1.5-flash-8b': pg.Dict(
|
229
|
+
latest_update='2024-10-30',
|
230
|
+
in_service=True,
|
231
|
+
supported_modalities=ALL_MODALITIES,
|
232
|
+
rpm_free=15,
|
233
|
+
tpm_free=1_000_000,
|
234
|
+
rpm_paid=4000,
|
235
|
+
tpm_paid=4_000_000,
|
236
|
+
cost_per_1m_input_tokens_up_to_128k=0.0375,
|
237
|
+
cost_per_1m_output_tokens_up_to_128k=0.15,
|
238
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01,
|
239
|
+
cost_per_1m_input_tokens_longer_than_128k=0.075,
|
240
|
+
cost_per_1m_output_tokens_longer_than_128k=0.3,
|
241
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.02,
|
242
|
+
),
|
243
|
+
'gemini-1.5-flash-8b-001': pg.Dict(
|
244
|
+
latest_update='2024-10-30',
|
245
|
+
in_service=True,
|
246
|
+
supported_modalities=ALL_MODALITIES,
|
247
|
+
rpm_free=15,
|
248
|
+
tpm_free=1_000_000,
|
249
|
+
rpm_paid=4000,
|
250
|
+
tpm_paid=4_000_000,
|
251
|
+
cost_per_1m_input_tokens_up_to_128k=0.0375,
|
252
|
+
cost_per_1m_output_tokens_up_to_128k=0.15,
|
253
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01,
|
254
|
+
cost_per_1m_input_tokens_longer_than_128k=0.075,
|
255
|
+
cost_per_1m_output_tokens_longer_than_128k=0.3,
|
256
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.02,
|
257
|
+
),
|
258
|
+
'gemini-1.5-pro-latest': pg.Dict(
|
259
|
+
latest_update='2024-09-30',
|
260
|
+
in_service=True,
|
261
|
+
supported_modalities=ALL_MODALITIES,
|
262
|
+
rpm_free=2,
|
263
|
+
tpm_free=32_000,
|
264
|
+
rpm_paid=1000,
|
265
|
+
tpm_paid=4_000_000,
|
266
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
267
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
268
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
269
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
270
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
271
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
272
|
+
),
|
273
|
+
'gemini-1.5-pro': pg.Dict(
|
274
|
+
latest_update='2024-09-30',
|
275
|
+
in_service=True,
|
276
|
+
supported_modalities=ALL_MODALITIES,
|
277
|
+
rpm_free=2,
|
278
|
+
tpm_free=32_000,
|
279
|
+
rpm_paid=1000,
|
280
|
+
tpm_paid=4_000_000,
|
281
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
282
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
283
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
284
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
285
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
286
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
287
|
+
),
|
288
|
+
'gemini-1.5-pro-001': pg.Dict(
|
289
|
+
latest_update='2024-09-30',
|
290
|
+
in_service=True,
|
291
|
+
supported_modalities=ALL_MODALITIES,
|
292
|
+
rpm_free=2,
|
293
|
+
tpm_free=32_000,
|
294
|
+
rpm_paid=1000,
|
295
|
+
tpm_paid=4_000_000,
|
296
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
297
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
298
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
299
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
300
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
301
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
302
|
+
),
|
303
|
+
'gemini-1.5-pro-002': pg.Dict(
|
304
|
+
latest_update='2024-09-30',
|
305
|
+
in_service=True,
|
306
|
+
supported_modalities=ALL_MODALITIES,
|
307
|
+
rpm_free=2,
|
308
|
+
tpm_free=32_000,
|
309
|
+
rpm_paid=1000,
|
310
|
+
tpm_paid=4_000_000,
|
311
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
312
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
313
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
314
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
315
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
316
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
317
|
+
),
|
318
|
+
'gemini-1.0-pro': pg.Dict(
|
319
|
+
in_service=False,
|
320
|
+
supported_modalities=TEXT_ONLY,
|
321
|
+
rpm_free=15,
|
322
|
+
tpm_free=32_000,
|
323
|
+
rpm_paid=360,
|
324
|
+
tpm_paid=120_000,
|
325
|
+
cost_per_1m_input_tokens_up_to_128k=0.5,
|
326
|
+
cost_per_1m_output_tokens_up_to_128k=1.5,
|
327
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
328
|
+
cost_per_1m_input_tokens_longer_than_128k=0.5,
|
329
|
+
cost_per_1m_output_tokens_longer_than_128k=1.5,
|
330
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
331
|
+
),
|
332
|
+
}
|
333
|
+
|
334
|
+
|
335
|
+
@pg.use_init_args(['model'])
|
336
|
+
class Gemini(rest.REST):
|
337
|
+
"""Language models provided by Google GenAI."""
|
338
|
+
|
339
|
+
model: pg.typing.Annotated[
|
340
|
+
pg.typing.Enum(
|
341
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
342
|
+
),
|
343
|
+
'The name of the model to use.',
|
344
|
+
]
|
345
|
+
|
346
|
+
@property
|
347
|
+
def supported_modalities(self) -> list[str]:
|
348
|
+
"""Returns the list of supported modalities."""
|
349
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
|
350
|
+
|
351
|
+
@property
|
352
|
+
def max_concurrency(self) -> int:
|
353
|
+
"""Returns the maximum number of concurrent requests."""
|
354
|
+
return self.rate_to_max_concurrency(
|
355
|
+
requests_per_min=max(
|
356
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
|
357
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
|
358
|
+
),
|
359
|
+
tokens_per_min=max(
|
360
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
|
361
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
|
362
|
+
),
|
363
|
+
)
|
364
|
+
|
365
|
+
def estimate_cost(
|
366
|
+
self,
|
367
|
+
num_input_tokens: int,
|
368
|
+
num_output_tokens: int
|
369
|
+
) -> float | None:
|
370
|
+
"""Estimate the cost based on usage."""
|
371
|
+
entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
|
372
|
+
if num_input_tokens < 128_000:
|
373
|
+
cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
|
374
|
+
cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
|
375
|
+
else:
|
376
|
+
cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
|
377
|
+
cost_per_1m_output_tokens = (
|
378
|
+
entry.cost_per_1m_output_tokens_longer_than_128k
|
379
|
+
)
|
380
|
+
return (
|
381
|
+
cost_per_1m_input_tokens * num_input_tokens
|
382
|
+
+ cost_per_1m_output_tokens * num_output_tokens
|
383
|
+
) / 1000_1000
|
384
|
+
|
385
|
+
@property
|
386
|
+
def model_id(self) -> str:
|
387
|
+
"""Returns a string to identify the model."""
|
388
|
+
return self.model
|
389
|
+
|
390
|
+
@classmethod
|
391
|
+
def dir(cls):
|
392
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
393
|
+
|
394
|
+
@property
|
395
|
+
def headers(self):
|
396
|
+
return {
|
397
|
+
'Content-Type': 'application/json; charset=utf-8',
|
398
|
+
}
|
399
|
+
|
400
|
+
def request(
|
401
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
402
|
+
) -> dict[str, Any]:
|
403
|
+
request = dict(
|
404
|
+
generationConfig=self._generation_config(prompt, sampling_options)
|
405
|
+
)
|
406
|
+
request['contents'] = [self._content_from_message(prompt)]
|
407
|
+
return request
|
408
|
+
|
409
|
+
def _generation_config(
|
410
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
411
|
+
) -> dict[str, Any]:
|
412
|
+
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
413
|
+
config = dict(
|
414
|
+
temperature=options.temperature,
|
415
|
+
maxOutputTokens=options.max_tokens,
|
416
|
+
candidateCount=options.n,
|
417
|
+
topK=options.top_k,
|
418
|
+
topP=options.top_p,
|
419
|
+
stopSequences=options.stop,
|
420
|
+
seed=options.random_seed,
|
421
|
+
responseLogprobs=options.logprobs,
|
422
|
+
logprobs=options.top_logprobs,
|
423
|
+
)
|
424
|
+
|
425
|
+
if json_schema := prompt.metadata.get('json_schema'):
|
426
|
+
if not isinstance(json_schema, dict):
|
427
|
+
raise ValueError(
|
428
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
429
|
+
)
|
430
|
+
json_schema = pg.to_json(json_schema)
|
431
|
+
config['responseSchema'] = json_schema
|
432
|
+
config['responseMimeType'] = 'application/json'
|
433
|
+
prompt.metadata.formatted_text = (
|
434
|
+
prompt.text
|
435
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
436
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
437
|
+
)
|
438
|
+
return config
|
439
|
+
|
440
|
+
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
441
|
+
"""Gets generation content from langfun message."""
|
442
|
+
parts = []
|
443
|
+
for lf_chunk in prompt.chunk():
|
444
|
+
if isinstance(lf_chunk, str):
|
445
|
+
parts.append({'text': lf_chunk})
|
446
|
+
elif isinstance(lf_chunk, lf_modalities.Mime):
|
447
|
+
try:
|
448
|
+
modalities = lf_chunk.make_compatible(
|
449
|
+
self.supported_modalities + ['text/plain']
|
450
|
+
)
|
451
|
+
if isinstance(modalities, lf_modalities.Mime):
|
452
|
+
modalities = [modalities]
|
453
|
+
for modality in modalities:
|
454
|
+
if modality.is_text:
|
455
|
+
parts.append({'text': modality.to_text()})
|
456
|
+
else:
|
457
|
+
parts.append({
|
458
|
+
'inlineData': {
|
459
|
+
'data': base64.b64encode(modality.to_bytes()).decode(),
|
460
|
+
'mimeType': modality.mime_type,
|
461
|
+
}
|
462
|
+
})
|
463
|
+
except lf.ModalityError as e:
|
464
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
465
|
+
else:
|
466
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
467
|
+
return dict(role='user', parts=parts)
|
468
|
+
|
469
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
470
|
+
messages = [
|
471
|
+
self._message_from_content_parts(candidate['content']['parts'])
|
472
|
+
for candidate in json['candidates']
|
473
|
+
]
|
474
|
+
usage = json['usageMetadata']
|
475
|
+
input_tokens = usage['promptTokenCount']
|
476
|
+
output_tokens = usage['candidatesTokenCount']
|
477
|
+
return lf.LMSamplingResult(
|
478
|
+
[lf.LMSample(message) for message in messages],
|
479
|
+
usage=lf.LMSamplingUsage(
|
480
|
+
prompt_tokens=input_tokens,
|
481
|
+
completion_tokens=output_tokens,
|
482
|
+
total_tokens=input_tokens + output_tokens,
|
483
|
+
estimated_cost=self.estimate_cost(
|
484
|
+
num_input_tokens=input_tokens,
|
485
|
+
num_output_tokens=output_tokens,
|
486
|
+
),
|
487
|
+
),
|
488
|
+
)
|
489
|
+
|
490
|
+
def _message_from_content_parts(
|
491
|
+
self, parts: list[dict[str, Any]]
|
492
|
+
) -> lf.Message:
|
493
|
+
"""Converts Vertex AI's content parts protocol to message."""
|
494
|
+
chunks = []
|
495
|
+
thought_chunks = []
|
496
|
+
for part in parts:
|
497
|
+
if text_part := part.get('text'):
|
498
|
+
if part.get('thought'):
|
499
|
+
thought_chunks.append(text_part)
|
500
|
+
else:
|
501
|
+
chunks.append(text_part)
|
502
|
+
else:
|
503
|
+
raise ValueError(f'Unsupported part: {part}')
|
504
|
+
message = lf.AIMessage.from_chunks(chunks)
|
505
|
+
if thought_chunks:
|
506
|
+
message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
|
507
|
+
return message
|