langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/reporting.py +7 -2
- langfun/core/language_model.py +4 -1
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +21 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +5 -0
- langfun/core/llms/vertexai.py +26 -357
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/RECORD +18 -16
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
51
51
|
self._update_thread = None
|
52
52
|
self._stop_update = False
|
53
53
|
self._stop_update_experiment_ids = set()
|
54
|
+
self._experiment_index_lock = None
|
54
55
|
|
55
56
|
def on_run_start(
|
56
57
|
self,
|
@@ -61,6 +62,9 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
61
62
|
self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
|
62
63
|
self._stop_update = False
|
63
64
|
self._stop_update_experiment_ids = set()
|
65
|
+
self._experiment_index_lock = {
|
66
|
+
leaf.id: threading.Lock() for leaf in root.leaf_nodes
|
67
|
+
}
|
64
68
|
self._update_thread = threading.Thread(
|
65
69
|
target=self._update_thread_func, args=(runner,)
|
66
70
|
)
|
@@ -170,7 +174,8 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
170
174
|
card_view=False,
|
171
175
|
),
|
172
176
|
)
|
173
|
-
|
177
|
+
with self._experiment_index_lock[experiment.id]:
|
178
|
+
html.save(index_html_path)
|
174
179
|
experiment.info(
|
175
180
|
f'Updated {index_html_path!r} in {t.elapse:.2f} seconds.',
|
176
181
|
)
|
@@ -185,11 +190,11 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
185
190
|
time.time() - self._last_experiment_report_time[experiment.id]
|
186
191
|
> self.experiment_report_interval
|
187
192
|
):
|
193
|
+
self._last_experiment_report_time[experiment.id] = time.time()
|
188
194
|
if background:
|
189
195
|
runner.background_run(_save)
|
190
196
|
else:
|
191
197
|
_save()
|
192
|
-
self._last_experiment_report_time[experiment.id] = time.time()
|
193
198
|
|
194
199
|
def _save_example_html(
|
195
200
|
self, runner: Runner, experiment: Experiment, example: Example
|
langfun/core/language_model.py
CHANGED
@@ -434,7 +434,10 @@ class LanguageModel(component.Component):
|
|
434
434
|
def __init__(self, *args, **kwargs) -> None:
|
435
435
|
"""Overrides __init__ to pass through **kwargs to sampling options."""
|
436
436
|
|
437
|
-
sampling_options = kwargs.pop(
|
437
|
+
sampling_options = kwargs.pop(
|
438
|
+
'sampling_options',
|
439
|
+
pg.clone(self.__schema__.fields['sampling_options'].default_value)
|
440
|
+
)
|
438
441
|
sampling_options_delta = {}
|
439
442
|
|
440
443
|
for k, v in kwargs.items():
|
@@ -117,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
|
|
117
117
|
self.assertEqual(lm.sampling_options.top_k, 2)
|
118
118
|
self.assertEqual(lm.max_attempts, 2)
|
119
119
|
|
120
|
+
def test_subclassing(self):
|
121
|
+
|
122
|
+
class ChildModel(lm_lib.LanguageModel):
|
123
|
+
|
124
|
+
sampling_options = lm_lib.LMSamplingOptions(
|
125
|
+
temperature=0.5, top_k=20
|
126
|
+
)
|
127
|
+
|
128
|
+
def _sample(self, *args, **kwargs):
|
129
|
+
pass
|
130
|
+
|
131
|
+
lm = ChildModel(top_k=10)
|
132
|
+
self.assertEqual(lm.sampling_options.temperature, 0.5)
|
133
|
+
self.assertEqual(lm.sampling_options.top_k, 10)
|
134
|
+
|
120
135
|
def test_sample(self):
|
121
136
|
lm = MockModel(top_k=1)
|
122
137
|
self.assertEqual(
|
langfun/core/llms/__init__.py
CHANGED
@@ -32,16 +32,30 @@ from langfun.core.llms.rest import REST
|
|
32
32
|
|
33
33
|
# Gemini models.
|
34
34
|
from langfun.core.llms.google_genai import GenAI
|
35
|
-
from langfun.core.llms.google_genai import
|
35
|
+
from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219
|
36
36
|
from langfun.core.llms.google_genai import GeminiFlash2_0Exp
|
37
|
-
from langfun.core.llms.google_genai import GeminiExp_20241114
|
38
37
|
from langfun.core.llms.google_genai import GeminiExp_20241206
|
39
|
-
from langfun.core.llms.google_genai import
|
38
|
+
from langfun.core.llms.google_genai import GeminiExp_20241114
|
40
39
|
from langfun.core.llms.google_genai import GeminiPro1_5
|
41
|
-
from langfun.core.llms.google_genai import
|
42
|
-
from langfun.core.llms.google_genai import
|
43
|
-
from langfun.core.llms.google_genai import
|
44
|
-
from langfun.core.llms.google_genai import
|
40
|
+
from langfun.core.llms.google_genai import GeminiPro1_5_002
|
41
|
+
from langfun.core.llms.google_genai import GeminiPro1_5_001
|
42
|
+
from langfun.core.llms.google_genai import GeminiFlash1_5
|
43
|
+
from langfun.core.llms.google_genai import GeminiFlash1_5_002
|
44
|
+
from langfun.core.llms.google_genai import GeminiFlash1_5_001
|
45
|
+
from langfun.core.llms.google_genai import GeminiPro1
|
46
|
+
|
47
|
+
from langfun.core.llms.vertexai import VertexAI
|
48
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
|
49
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
|
50
|
+
from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
|
51
|
+
from langfun.core.llms.vertexai import VertexAIGeminiExp_20241114
|
52
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
53
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
|
54
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
|
55
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
|
56
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
57
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
58
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
45
59
|
|
46
60
|
# OpenAI models.
|
47
61
|
from langfun.core.llms.openai import OpenAI
|
@@ -124,25 +138,6 @@ from langfun.core.llms.groq import GroqGemma_7B_IT
|
|
124
138
|
from langfun.core.llms.groq import GroqWhisper_Large_v3
|
125
139
|
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
126
140
|
|
127
|
-
from langfun.core.llms.vertexai import VertexAI
|
128
|
-
from langfun.core.llms.vertexai import VertexAIGemini2_0
|
129
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
|
130
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp
|
131
|
-
from langfun.core.llms.vertexai import VertexAIGemini1_5
|
132
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
133
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
|
134
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
|
135
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
|
136
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
|
137
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
|
138
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
139
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
140
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
|
141
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
142
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
|
143
|
-
from langfun.core.llms.vertexai import VertexAIEndpoint
|
144
|
-
|
145
|
-
|
146
141
|
# LLaMA C++ models.
|
147
142
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
148
143
|
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import collections
|
17
17
|
import contextlib
|
18
|
+
import json
|
18
19
|
from typing import Annotated, Any, Iterator
|
19
20
|
import langfun.core as lf
|
20
21
|
from langfun.core.llms.cache import base
|
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
|
|
49
50
|
"Creating a new cache as cache file '%s' does not exist.",
|
50
51
|
self.filename,
|
51
52
|
)
|
53
|
+
except json.JSONDecodeError:
|
54
|
+
pg.logging.warning(
|
55
|
+
"Creating a new cache as cache file '%s' is corrupted.",
|
56
|
+
self.filename,
|
57
|
+
)
|
52
58
|
|
53
59
|
def model_ids(self) -> list[str]:
|
54
60
|
"""Returns the model ids of cached queires."""
|
@@ -295,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
295
295
|
self.assertEqual(cache2.stats.num_updates, 2)
|
296
296
|
cache2.save()
|
297
297
|
|
298
|
+
# Corrupted file.
|
299
|
+
pg.io.writefile(path, 'bad_content')
|
300
|
+
cache3 = in_memory.InMemory(path)
|
301
|
+
self.assertEqual(len(cache3), 0)
|
302
|
+
|
298
303
|
|
299
304
|
class LmCacheTest(unittest.TestCase):
|
300
305
|
|
@@ -0,0 +1,507 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Gemini REST API (Shared by Google GenAI and Vertex AI)."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import rest
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
# Supported modalities.
|
25
|
+
|
26
|
+
IMAGE_TYPES = [
|
27
|
+
'image/png',
|
28
|
+
'image/jpeg',
|
29
|
+
'image/webp',
|
30
|
+
'image/heic',
|
31
|
+
'image/heif',
|
32
|
+
]
|
33
|
+
|
34
|
+
AUDIO_TYPES = [
|
35
|
+
'audio/aac',
|
36
|
+
'audio/flac',
|
37
|
+
'audio/mp3',
|
38
|
+
'audio/m4a',
|
39
|
+
'audio/mpeg',
|
40
|
+
'audio/mpga',
|
41
|
+
'audio/mp4',
|
42
|
+
'audio/opus',
|
43
|
+
'audio/pcm',
|
44
|
+
'audio/wav',
|
45
|
+
'audio/webm',
|
46
|
+
]
|
47
|
+
|
48
|
+
VIDEO_TYPES = [
|
49
|
+
'video/mov',
|
50
|
+
'video/mpeg',
|
51
|
+
'video/mpegps',
|
52
|
+
'video/mpg',
|
53
|
+
'video/mp4',
|
54
|
+
'video/webm',
|
55
|
+
'video/wmv',
|
56
|
+
'video/x-flv',
|
57
|
+
'video/3gpp',
|
58
|
+
'video/quicktime',
|
59
|
+
]
|
60
|
+
|
61
|
+
DOCUMENT_TYPES = [
|
62
|
+
'application/pdf',
|
63
|
+
'text/plain',
|
64
|
+
'text/csv',
|
65
|
+
'text/html',
|
66
|
+
'text/xml',
|
67
|
+
'text/x-script.python',
|
68
|
+
'application/json',
|
69
|
+
]
|
70
|
+
|
71
|
+
TEXT_ONLY = []
|
72
|
+
|
73
|
+
ALL_MODALITIES = (
|
74
|
+
IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
|
75
|
+
)
|
76
|
+
|
77
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
78
|
+
# For automatically rate control and cost estimation, we explicitly register
|
79
|
+
# supported models here. This may be inconvenient for new models, but it
|
80
|
+
# helps us to keep track of the models and their pricing.
|
81
|
+
# Models and RPM are from
|
82
|
+
# https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
|
83
|
+
# Pricing in US dollars, from https://ai.google.dev/pricing
|
84
|
+
# as of 2025-01-03.
|
85
|
+
# NOTE: Please update google_genai.py, vertexai.py, __init__.py when
|
86
|
+
# adding new models.
|
87
|
+
# !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
|
88
|
+
'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
|
89
|
+
latest_update='2024-12-19',
|
90
|
+
experimental=True,
|
91
|
+
in_service=True,
|
92
|
+
supported_modalities=ALL_MODALITIES,
|
93
|
+
rpm_free=10,
|
94
|
+
tpm_free=4_000_000,
|
95
|
+
rpm_paid=0,
|
96
|
+
tpm_paid=0,
|
97
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
98
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
99
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
100
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
101
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
102
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
103
|
+
),
|
104
|
+
'gemini-2.0-flash-exp': pg.Dict(
|
105
|
+
latest_update='2024-12-11',
|
106
|
+
experimental=True,
|
107
|
+
in_service=True,
|
108
|
+
supported_modalities=ALL_MODALITIES,
|
109
|
+
rpm_free=10,
|
110
|
+
tpm_free=4_000_000,
|
111
|
+
rpm_paid=0,
|
112
|
+
tpm_paid=0,
|
113
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
114
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
115
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
116
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
117
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
118
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
119
|
+
),
|
120
|
+
'gemini-exp-1206': pg.Dict(
|
121
|
+
latest_update='2024-12-06',
|
122
|
+
experimental=True,
|
123
|
+
in_service=True,
|
124
|
+
supported_modalities=ALL_MODALITIES,
|
125
|
+
rpm_free=10,
|
126
|
+
tpm_free=4_000_000,
|
127
|
+
rpm_paid=0,
|
128
|
+
tpm_paid=0,
|
129
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
130
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
131
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
132
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
133
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
134
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
135
|
+
),
|
136
|
+
'learnlm-1.5-pro-experimental': pg.Dict(
|
137
|
+
latest_update='2024-11-19',
|
138
|
+
experimental=True,
|
139
|
+
in_service=True,
|
140
|
+
supported_modalities=ALL_MODALITIES,
|
141
|
+
rpm_free=10,
|
142
|
+
tpm_free=4_000_000,
|
143
|
+
rpm_paid=0,
|
144
|
+
tpm_paid=0,
|
145
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
146
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
147
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
148
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
149
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
150
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
151
|
+
),
|
152
|
+
'gemini-exp-1114': pg.Dict(
|
153
|
+
latest_update='2024-11-14',
|
154
|
+
experimental=True,
|
155
|
+
in_service=True,
|
156
|
+
supported_modalities=ALL_MODALITIES,
|
157
|
+
rpm_free=10,
|
158
|
+
tpm_free=4_000_000,
|
159
|
+
rpm_paid=0,
|
160
|
+
tpm_paid=0,
|
161
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
162
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
163
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
164
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
165
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
166
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
167
|
+
),
|
168
|
+
'gemini-1.5-flash-latest': pg.Dict(
|
169
|
+
latest_update='2024-09-30',
|
170
|
+
in_service=True,
|
171
|
+
supported_modalities=ALL_MODALITIES,
|
172
|
+
rpm_free=15,
|
173
|
+
tpm_free=1_000_000,
|
174
|
+
rpm_paid=2000,
|
175
|
+
tpm_paid=4_000_000,
|
176
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
177
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
178
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
179
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
180
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
181
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
182
|
+
),
|
183
|
+
'gemini-1.5-flash': pg.Dict(
|
184
|
+
latest_update='2024-09-30',
|
185
|
+
in_service=True,
|
186
|
+
supported_modalities=ALL_MODALITIES,
|
187
|
+
rpm_free=15,
|
188
|
+
tpm_free=1_000_000,
|
189
|
+
rpm_paid=2000,
|
190
|
+
tpm_paid=4_000_000,
|
191
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
192
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
193
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
194
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
195
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
196
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
197
|
+
),
|
198
|
+
'gemini-1.5-flash-001': pg.Dict(
|
199
|
+
latest_update='2024-09-30',
|
200
|
+
in_service=True,
|
201
|
+
supported_modalities=ALL_MODALITIES,
|
202
|
+
rpm_free=15,
|
203
|
+
tpm_free=1_000_000,
|
204
|
+
rpm_paid=2000,
|
205
|
+
tpm_paid=4_000_000,
|
206
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
207
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
208
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
209
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
210
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
211
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
212
|
+
),
|
213
|
+
'gemini-1.5-flash-002': pg.Dict(
|
214
|
+
latest_update='2024-09-30',
|
215
|
+
in_service=True,
|
216
|
+
supported_modalities=ALL_MODALITIES,
|
217
|
+
rpm_free=15,
|
218
|
+
tpm_free=1_000_000,
|
219
|
+
rpm_paid=2000,
|
220
|
+
tpm_paid=4_000_000,
|
221
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
222
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
223
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
224
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
225
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
226
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
227
|
+
),
|
228
|
+
'gemini-1.5-flash-8b': pg.Dict(
|
229
|
+
latest_update='2024-10-30',
|
230
|
+
in_service=True,
|
231
|
+
supported_modalities=ALL_MODALITIES,
|
232
|
+
rpm_free=15,
|
233
|
+
tpm_free=1_000_000,
|
234
|
+
rpm_paid=4000,
|
235
|
+
tpm_paid=4_000_000,
|
236
|
+
cost_per_1m_input_tokens_up_to_128k=0.0375,
|
237
|
+
cost_per_1m_output_tokens_up_to_128k=0.15,
|
238
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01,
|
239
|
+
cost_per_1m_input_tokens_longer_than_128k=0.075,
|
240
|
+
cost_per_1m_output_tokens_longer_than_128k=0.3,
|
241
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.02,
|
242
|
+
),
|
243
|
+
'gemini-1.5-flash-8b-001': pg.Dict(
|
244
|
+
latest_update='2024-10-30',
|
245
|
+
in_service=True,
|
246
|
+
supported_modalities=ALL_MODALITIES,
|
247
|
+
rpm_free=15,
|
248
|
+
tpm_free=1_000_000,
|
249
|
+
rpm_paid=4000,
|
250
|
+
tpm_paid=4_000_000,
|
251
|
+
cost_per_1m_input_tokens_up_to_128k=0.0375,
|
252
|
+
cost_per_1m_output_tokens_up_to_128k=0.15,
|
253
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01,
|
254
|
+
cost_per_1m_input_tokens_longer_than_128k=0.075,
|
255
|
+
cost_per_1m_output_tokens_longer_than_128k=0.3,
|
256
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.02,
|
257
|
+
),
|
258
|
+
'gemini-1.5-pro-latest': pg.Dict(
|
259
|
+
latest_update='2024-09-30',
|
260
|
+
in_service=True,
|
261
|
+
supported_modalities=ALL_MODALITIES,
|
262
|
+
rpm_free=2,
|
263
|
+
tpm_free=32_000,
|
264
|
+
rpm_paid=1000,
|
265
|
+
tpm_paid=4_000_000,
|
266
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
267
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
268
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
269
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
270
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
271
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
272
|
+
),
|
273
|
+
'gemini-1.5-pro': pg.Dict(
|
274
|
+
latest_update='2024-09-30',
|
275
|
+
in_service=True,
|
276
|
+
supported_modalities=ALL_MODALITIES,
|
277
|
+
rpm_free=2,
|
278
|
+
tpm_free=32_000,
|
279
|
+
rpm_paid=1000,
|
280
|
+
tpm_paid=4_000_000,
|
281
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
282
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
283
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
284
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
285
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
286
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
287
|
+
),
|
288
|
+
'gemini-1.5-pro-001': pg.Dict(
|
289
|
+
latest_update='2024-09-30',
|
290
|
+
in_service=True,
|
291
|
+
supported_modalities=ALL_MODALITIES,
|
292
|
+
rpm_free=2,
|
293
|
+
tpm_free=32_000,
|
294
|
+
rpm_paid=1000,
|
295
|
+
tpm_paid=4_000_000,
|
296
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
297
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
298
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
299
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
300
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
301
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
302
|
+
),
|
303
|
+
'gemini-1.5-pro-002': pg.Dict(
|
304
|
+
latest_update='2024-09-30',
|
305
|
+
in_service=True,
|
306
|
+
supported_modalities=ALL_MODALITIES,
|
307
|
+
rpm_free=2,
|
308
|
+
tpm_free=32_000,
|
309
|
+
rpm_paid=1000,
|
310
|
+
tpm_paid=4_000_000,
|
311
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
312
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
313
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
314
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
315
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
316
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
317
|
+
),
|
318
|
+
'gemini-1.0-pro': pg.Dict(
|
319
|
+
in_service=False,
|
320
|
+
supported_modalities=TEXT_ONLY,
|
321
|
+
rpm_free=15,
|
322
|
+
tpm_free=32_000,
|
323
|
+
rpm_paid=360,
|
324
|
+
tpm_paid=120_000,
|
325
|
+
cost_per_1m_input_tokens_up_to_128k=0.5,
|
326
|
+
cost_per_1m_output_tokens_up_to_128k=1.5,
|
327
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
328
|
+
cost_per_1m_input_tokens_longer_than_128k=0.5,
|
329
|
+
cost_per_1m_output_tokens_longer_than_128k=1.5,
|
330
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
331
|
+
),
|
332
|
+
}
|
333
|
+
|
334
|
+
|
335
|
+
@pg.use_init_args(['model'])
|
336
|
+
class Gemini(rest.REST):
|
337
|
+
"""Language models provided by Google GenAI."""
|
338
|
+
|
339
|
+
model: pg.typing.Annotated[
|
340
|
+
pg.typing.Enum(
|
341
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
342
|
+
),
|
343
|
+
'The name of the model to use.',
|
344
|
+
]
|
345
|
+
|
346
|
+
@property
|
347
|
+
def supported_modalities(self) -> list[str]:
|
348
|
+
"""Returns the list of supported modalities."""
|
349
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
|
350
|
+
|
351
|
+
@property
|
352
|
+
def max_concurrency(self) -> int:
|
353
|
+
"""Returns the maximum number of concurrent requests."""
|
354
|
+
return self.rate_to_max_concurrency(
|
355
|
+
requests_per_min=max(
|
356
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
|
357
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
|
358
|
+
),
|
359
|
+
tokens_per_min=max(
|
360
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
|
361
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
|
362
|
+
),
|
363
|
+
)
|
364
|
+
|
365
|
+
def estimate_cost(
|
366
|
+
self,
|
367
|
+
num_input_tokens: int,
|
368
|
+
num_output_tokens: int
|
369
|
+
) -> float | None:
|
370
|
+
"""Estimate the cost based on usage."""
|
371
|
+
entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
|
372
|
+
if num_input_tokens < 128_000:
|
373
|
+
cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
|
374
|
+
cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
|
375
|
+
else:
|
376
|
+
cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
|
377
|
+
cost_per_1m_output_tokens = (
|
378
|
+
entry.cost_per_1m_output_tokens_longer_than_128k
|
379
|
+
)
|
380
|
+
return (
|
381
|
+
cost_per_1m_input_tokens * num_input_tokens
|
382
|
+
+ cost_per_1m_output_tokens * num_output_tokens
|
383
|
+
) / 1000_1000
|
384
|
+
|
385
|
+
@property
|
386
|
+
def model_id(self) -> str:
|
387
|
+
"""Returns a string to identify the model."""
|
388
|
+
return self.model
|
389
|
+
|
390
|
+
@classmethod
|
391
|
+
def dir(cls):
|
392
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
393
|
+
|
394
|
+
@property
|
395
|
+
def headers(self):
|
396
|
+
return {
|
397
|
+
'Content-Type': 'application/json; charset=utf-8',
|
398
|
+
}
|
399
|
+
|
400
|
+
def request(
|
401
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
402
|
+
) -> dict[str, Any]:
|
403
|
+
request = dict(
|
404
|
+
generationConfig=self._generation_config(prompt, sampling_options)
|
405
|
+
)
|
406
|
+
request['contents'] = [self._content_from_message(prompt)]
|
407
|
+
return request
|
408
|
+
|
409
|
+
def _generation_config(
|
410
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
411
|
+
) -> dict[str, Any]:
|
412
|
+
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
413
|
+
config = dict(
|
414
|
+
temperature=options.temperature,
|
415
|
+
maxOutputTokens=options.max_tokens,
|
416
|
+
candidateCount=options.n,
|
417
|
+
topK=options.top_k,
|
418
|
+
topP=options.top_p,
|
419
|
+
stopSequences=options.stop,
|
420
|
+
seed=options.random_seed,
|
421
|
+
responseLogprobs=options.logprobs,
|
422
|
+
logprobs=options.top_logprobs,
|
423
|
+
)
|
424
|
+
|
425
|
+
if json_schema := prompt.metadata.get('json_schema'):
|
426
|
+
if not isinstance(json_schema, dict):
|
427
|
+
raise ValueError(
|
428
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
429
|
+
)
|
430
|
+
json_schema = pg.to_json(json_schema)
|
431
|
+
config['responseSchema'] = json_schema
|
432
|
+
config['responseMimeType'] = 'application/json'
|
433
|
+
prompt.metadata.formatted_text = (
|
434
|
+
prompt.text
|
435
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
436
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
437
|
+
)
|
438
|
+
return config
|
439
|
+
|
440
|
+
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
441
|
+
"""Gets generation content from langfun message."""
|
442
|
+
parts = []
|
443
|
+
for lf_chunk in prompt.chunk():
|
444
|
+
if isinstance(lf_chunk, str):
|
445
|
+
parts.append({'text': lf_chunk})
|
446
|
+
elif isinstance(lf_chunk, lf_modalities.Mime):
|
447
|
+
try:
|
448
|
+
modalities = lf_chunk.make_compatible(
|
449
|
+
self.supported_modalities + ['text/plain']
|
450
|
+
)
|
451
|
+
if isinstance(modalities, lf_modalities.Mime):
|
452
|
+
modalities = [modalities]
|
453
|
+
for modality in modalities:
|
454
|
+
if modality.is_text:
|
455
|
+
parts.append({'text': modality.to_text()})
|
456
|
+
else:
|
457
|
+
parts.append({
|
458
|
+
'inlineData': {
|
459
|
+
'data': base64.b64encode(modality.to_bytes()).decode(),
|
460
|
+
'mimeType': modality.mime_type,
|
461
|
+
}
|
462
|
+
})
|
463
|
+
except lf.ModalityError as e:
|
464
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
465
|
+
else:
|
466
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
467
|
+
return dict(role='user', parts=parts)
|
468
|
+
|
469
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
470
|
+
messages = [
|
471
|
+
self._message_from_content_parts(candidate['content']['parts'])
|
472
|
+
for candidate in json['candidates']
|
473
|
+
]
|
474
|
+
usage = json['usageMetadata']
|
475
|
+
input_tokens = usage['promptTokenCount']
|
476
|
+
output_tokens = usage['candidatesTokenCount']
|
477
|
+
return lf.LMSamplingResult(
|
478
|
+
[lf.LMSample(message) for message in messages],
|
479
|
+
usage=lf.LMSamplingUsage(
|
480
|
+
prompt_tokens=input_tokens,
|
481
|
+
completion_tokens=output_tokens,
|
482
|
+
total_tokens=input_tokens + output_tokens,
|
483
|
+
estimated_cost=self.estimate_cost(
|
484
|
+
num_input_tokens=input_tokens,
|
485
|
+
num_output_tokens=output_tokens,
|
486
|
+
),
|
487
|
+
),
|
488
|
+
)
|
489
|
+
|
490
|
+
def _message_from_content_parts(
|
491
|
+
self, parts: list[dict[str, Any]]
|
492
|
+
) -> lf.Message:
|
493
|
+
"""Converts Vertex AI's content parts protocol to message."""
|
494
|
+
chunks = []
|
495
|
+
thought_chunks = []
|
496
|
+
for part in parts:
|
497
|
+
if text_part := part.get('text'):
|
498
|
+
if part.get('thought'):
|
499
|
+
thought_chunks.append(text_part)
|
500
|
+
else:
|
501
|
+
chunks.append(text_part)
|
502
|
+
else:
|
503
|
+
raise ValueError(f'Unsupported part: {part}')
|
504
|
+
message = lf.AIMessage.from_chunks(chunks)
|
505
|
+
if thought_chunks:
|
506
|
+
message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
|
507
|
+
return message
|