langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -4
- langfun/core/eval/matching.py +2 -2
- langfun/core/eval/scoring.py +6 -2
- langfun/core/eval/v2/checkpointing.py +106 -72
- langfun/core/eval/v2/checkpointing_test.py +108 -3
- langfun/core/eval/v2/eval_test_helper.py +56 -0
- langfun/core/eval/v2/evaluation.py +25 -4
- langfun/core/eval/v2/evaluation_test.py +11 -0
- langfun/core/eval/v2/example.py +11 -1
- langfun/core/eval/v2/example_test.py +16 -2
- langfun/core/eval/v2/experiment.py +83 -19
- langfun/core/eval/v2/experiment_test.py +121 -3
- langfun/core/eval/v2/reporting.py +67 -20
- langfun/core/eval/v2/reporting_test.py +119 -2
- langfun/core/eval/v2/runners.py +7 -4
- langfun/core/llms/__init__.py +23 -24
- langfun/core/llms/anthropic.py +12 -0
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -310
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +23 -37
- langfun/core/llms/vertexai.py +28 -348
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
- langfun/core/repr_utils.py +0 -204
- langfun/core/repr_utils_test.py +0 -90
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Gemini API."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
import unittest
|
19
|
+
from unittest import mock
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import gemini
|
24
|
+
import pyglove as pg
|
25
|
+
import requests
|
26
|
+
|
27
|
+
|
28
|
+
example_image = (
|
29
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
30
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
31
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
32
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
33
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
34
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
35
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
36
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
41
|
+
del url, kwargs
|
42
|
+
c = pg.Dict(json['generationConfig'])
|
43
|
+
content = json['contents'][0]['parts'][0]['text']
|
44
|
+
response = requests.Response()
|
45
|
+
response.status_code = 200
|
46
|
+
response._content = pg.to_json_str({
|
47
|
+
'candidates': [
|
48
|
+
{
|
49
|
+
'content': {
|
50
|
+
'role': 'model',
|
51
|
+
'parts': [
|
52
|
+
{
|
53
|
+
'text': (
|
54
|
+
f'This is a response to {content} with '
|
55
|
+
f'temperature={c.temperature}, '
|
56
|
+
f'top_p={c.topP}, '
|
57
|
+
f'top_k={c.topK}, '
|
58
|
+
f'max_tokens={c.maxOutputTokens}, '
|
59
|
+
f'stop={"".join(c.stopSequences)}.'
|
60
|
+
),
|
61
|
+
},
|
62
|
+
{
|
63
|
+
'text': 'This is the thought.',
|
64
|
+
'thought': True,
|
65
|
+
}
|
66
|
+
],
|
67
|
+
},
|
68
|
+
},
|
69
|
+
],
|
70
|
+
'usageMetadata': {
|
71
|
+
'promptTokenCount': 3,
|
72
|
+
'candidatesTokenCount': 4,
|
73
|
+
}
|
74
|
+
}).encode()
|
75
|
+
return response
|
76
|
+
|
77
|
+
|
78
|
+
class GeminiTest(unittest.TestCase):
|
79
|
+
"""Tests for Vertex model with REST API."""
|
80
|
+
|
81
|
+
def test_content_from_message_text_only(self):
|
82
|
+
text = 'This is a beautiful day'
|
83
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
84
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
85
|
+
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
86
|
+
|
87
|
+
def test_content_from_message_mm(self):
|
88
|
+
image = lf_modalities.Image.from_bytes(example_image)
|
89
|
+
message = lf.UserMessage(
|
90
|
+
'This is an <<[[image]]>>, what is it?', image=image
|
91
|
+
)
|
92
|
+
|
93
|
+
# Non-multimodal model.
|
94
|
+
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
95
|
+
gemini.Gemini(
|
96
|
+
'gemini-1.0-pro', api_endpoint=''
|
97
|
+
)._content_from_message(message)
|
98
|
+
|
99
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
100
|
+
content = model._content_from_message(message)
|
101
|
+
self.assertEqual(
|
102
|
+
content,
|
103
|
+
{
|
104
|
+
'role': 'user',
|
105
|
+
'parts': [
|
106
|
+
{'text': 'This is an'},
|
107
|
+
{
|
108
|
+
'inlineData': {
|
109
|
+
'data': base64.b64encode(example_image).decode(),
|
110
|
+
'mimeType': 'image/png',
|
111
|
+
}
|
112
|
+
},
|
113
|
+
{'text': ', what is it?'},
|
114
|
+
],
|
115
|
+
},
|
116
|
+
)
|
117
|
+
|
118
|
+
def test_generation_config(self):
|
119
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
120
|
+
json_schema = {
|
121
|
+
'type': 'object',
|
122
|
+
'properties': {
|
123
|
+
'name': {'type': 'string'},
|
124
|
+
},
|
125
|
+
'required': ['name'],
|
126
|
+
'title': 'Person',
|
127
|
+
}
|
128
|
+
actual = model._generation_config(
|
129
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
130
|
+
lf.LMSamplingOptions(
|
131
|
+
temperature=2.0,
|
132
|
+
top_p=1.0,
|
133
|
+
top_k=20,
|
134
|
+
max_tokens=1024,
|
135
|
+
stop=['\n'],
|
136
|
+
),
|
137
|
+
)
|
138
|
+
self.assertEqual(
|
139
|
+
actual,
|
140
|
+
dict(
|
141
|
+
candidateCount=1,
|
142
|
+
temperature=2.0,
|
143
|
+
topP=1.0,
|
144
|
+
topK=20,
|
145
|
+
maxOutputTokens=1024,
|
146
|
+
stopSequences=['\n'],
|
147
|
+
responseLogprobs=False,
|
148
|
+
logprobs=None,
|
149
|
+
seed=None,
|
150
|
+
responseMimeType='application/json',
|
151
|
+
responseSchema={
|
152
|
+
'type': 'object',
|
153
|
+
'properties': {
|
154
|
+
'name': {'type': 'string'}
|
155
|
+
},
|
156
|
+
'required': ['name'],
|
157
|
+
'title': 'Person',
|
158
|
+
}
|
159
|
+
),
|
160
|
+
)
|
161
|
+
with self.assertRaisesRegex(
|
162
|
+
ValueError, '`json_schema` must be a dict, got'
|
163
|
+
):
|
164
|
+
model._generation_config(
|
165
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
166
|
+
lf.LMSamplingOptions(),
|
167
|
+
)
|
168
|
+
|
169
|
+
def test_call_model(self):
|
170
|
+
with mock.patch('requests.Session.post') as mock_generate:
|
171
|
+
mock_generate.side_effect = mock_requests_post
|
172
|
+
|
173
|
+
lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
174
|
+
r = lm(
|
175
|
+
'hello',
|
176
|
+
temperature=2.0,
|
177
|
+
top_p=1.0,
|
178
|
+
top_k=20,
|
179
|
+
max_tokens=1024,
|
180
|
+
stop='\n',
|
181
|
+
)
|
182
|
+
self.assertEqual(
|
183
|
+
r.text,
|
184
|
+
(
|
185
|
+
'This is a response to hello with temperature=2.0, '
|
186
|
+
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
187
|
+
),
|
188
|
+
)
|
189
|
+
self.assertEqual(r.metadata.thought, 'This is the thought.')
|
190
|
+
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
191
|
+
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
192
|
+
|
193
|
+
|
194
|
+
if __name__ == '__main__':
|
195
|
+
unittest.main()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,56 +11,21 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Language models from Google GenAI."""
|
15
15
|
|
16
|
-
import abc
|
17
|
-
import functools
|
18
16
|
import os
|
19
|
-
from typing import Annotated,
|
17
|
+
from typing import Annotated, Literal
|
20
18
|
|
21
19
|
import langfun.core as lf
|
22
|
-
from langfun.core import
|
23
|
-
from langfun.core.llms import vertexai
|
20
|
+
from langfun.core.llms import gemini
|
24
21
|
import pyglove as pg
|
25
22
|
|
26
23
|
|
27
|
-
try:
|
28
|
-
import google.generativeai as genai # pylint: disable=g-import-not-at-top
|
29
|
-
BlobDict = genai.types.BlobDict
|
30
|
-
GenerativeModel = genai.GenerativeModel
|
31
|
-
Completion = getattr(genai.types, 'Completion', Any)
|
32
|
-
ChatResponse = getattr(genai.types, 'ChatResponse', Any)
|
33
|
-
GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
|
34
|
-
GenerationConfig = genai.GenerationConfig
|
35
|
-
except ImportError:
|
36
|
-
genai = None
|
37
|
-
BlobDict = Any
|
38
|
-
GenerativeModel = Any
|
39
|
-
Completion = Any
|
40
|
-
ChatResponse = Any
|
41
|
-
GenerationConfig = Any
|
42
|
-
GenerateContentResponse = Any
|
43
|
-
|
44
|
-
|
45
24
|
@lf.use_init_args(['model'])
|
46
|
-
|
25
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
26
|
+
class GenAI(gemini.Gemini):
|
47
27
|
"""Language models provided by Google GenAI."""
|
48
28
|
|
49
|
-
model: Annotated[
|
50
|
-
Literal[
|
51
|
-
'gemini-2.0-flash-exp',
|
52
|
-
'gemini-exp-1206',
|
53
|
-
'gemini-exp-1114',
|
54
|
-
'gemini-1.5-pro-latest',
|
55
|
-
'gemini-1.5-flash-latest',
|
56
|
-
'gemini-pro',
|
57
|
-
'gemini-pro-vision',
|
58
|
-
'text-bison-001',
|
59
|
-
'chat-bison-001',
|
60
|
-
],
|
61
|
-
'Model name.',
|
62
|
-
]
|
63
|
-
|
64
29
|
api_key: Annotated[
|
65
30
|
str | None,
|
66
31
|
(
|
@@ -70,26 +35,18 @@ class GenAI(lf.LanguageModel):
|
|
70
35
|
),
|
71
36
|
] = None
|
72
37
|
|
73
|
-
|
74
|
-
|
75
|
-
'
|
76
|
-
] =
|
38
|
+
api_version: Annotated[
|
39
|
+
Literal['v1beta', 'v1alpha'],
|
40
|
+
'The API version to use.'
|
41
|
+
] = 'v1beta'
|
77
42
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
super()._on_bound()
|
83
|
-
if genai is None:
|
84
|
-
raise RuntimeError(
|
85
|
-
'Please install "langfun[llm-google-genai]" to use '
|
86
|
-
'Google Generative AI models.'
|
87
|
-
)
|
88
|
-
self.__dict__.pop('_api_initialized', None)
|
43
|
+
@property
|
44
|
+
def model_id(self) -> str:
|
45
|
+
"""Returns a string to identify the model."""
|
46
|
+
return f'GenAI({self.model})'
|
89
47
|
|
90
|
-
@
|
91
|
-
def
|
92
|
-
assert genai is not None
|
48
|
+
@property
|
49
|
+
def api_endpoint(self) -> str:
|
93
50
|
api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
|
94
51
|
if not api_key:
|
95
52
|
raise ValueError(
|
@@ -99,296 +56,75 @@ class GenAI(lf.LanguageModel):
|
|
99
56
|
'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
|
100
57
|
'for more details.'
|
101
58
|
)
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
def dir(cls) -> list[str]:
|
107
|
-
"""Lists generative models."""
|
108
|
-
assert genai is not None
|
109
|
-
return [
|
110
|
-
m.name.lstrip('models/')
|
111
|
-
for m in genai.list_models()
|
112
|
-
if (
|
113
|
-
'generateContent' in m.supported_generation_methods
|
114
|
-
or 'generateText' in m.supported_generation_methods
|
115
|
-
or 'generateMessage' in m.supported_generation_methods
|
116
|
-
)
|
117
|
-
]
|
118
|
-
|
119
|
-
@property
|
120
|
-
def model_id(self) -> str:
|
121
|
-
"""Returns a string to identify the model."""
|
122
|
-
return self.model
|
123
|
-
|
124
|
-
@property
|
125
|
-
def resource_id(self) -> str:
|
126
|
-
"""Returns a string to identify the resource for rate control."""
|
127
|
-
return self.model_id
|
128
|
-
|
129
|
-
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
130
|
-
"""Creates generation config from langfun sampling options."""
|
131
|
-
return GenerationConfig(
|
132
|
-
candidate_count=options.n,
|
133
|
-
temperature=options.temperature,
|
134
|
-
top_p=options.top_p,
|
135
|
-
top_k=options.top_k,
|
136
|
-
max_output_tokens=options.max_tokens,
|
137
|
-
stop_sequences=options.stop,
|
138
|
-
)
|
139
|
-
|
140
|
-
def _content_from_message(
|
141
|
-
self, prompt: lf.Message
|
142
|
-
) -> list[str | BlobDict]:
|
143
|
-
"""Gets Evergreen formatted content from langfun message."""
|
144
|
-
formatted = lf.UserMessage(prompt.text)
|
145
|
-
formatted.source = prompt
|
146
|
-
|
147
|
-
chunks = []
|
148
|
-
for lf_chunk in formatted.chunk():
|
149
|
-
if isinstance(lf_chunk, str):
|
150
|
-
chunks.append(lf_chunk)
|
151
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
152
|
-
try:
|
153
|
-
modalities = lf_chunk.make_compatible(
|
154
|
-
self.supported_modalities + ['text/plain']
|
155
|
-
)
|
156
|
-
if isinstance(modalities, lf_modalities.Mime):
|
157
|
-
modalities = [modalities]
|
158
|
-
for modality in modalities:
|
159
|
-
if modality.is_text:
|
160
|
-
chunk = modality.to_text()
|
161
|
-
else:
|
162
|
-
chunk = BlobDict(
|
163
|
-
data=modality.to_bytes(),
|
164
|
-
mime_type=modality.mime_type
|
165
|
-
)
|
166
|
-
chunks.append(chunk)
|
167
|
-
except lf.ModalityError as e:
|
168
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
169
|
-
else:
|
170
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
171
|
-
return chunks
|
172
|
-
|
173
|
-
def _response_to_result(
|
174
|
-
self, response: GenerateContentResponse | pg.Dict
|
175
|
-
) -> lf.LMSamplingResult:
|
176
|
-
"""Parses generative response into message."""
|
177
|
-
samples = []
|
178
|
-
for candidate in response.candidates:
|
179
|
-
chunks = []
|
180
|
-
for part in candidate.content.parts:
|
181
|
-
# TODO(daiyip): support multi-modal parts when they are available via
|
182
|
-
# Gemini API.
|
183
|
-
if hasattr(part, 'text'):
|
184
|
-
chunks.append(part.text)
|
185
|
-
samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
|
186
|
-
return lf.LMSamplingResult(samples)
|
187
|
-
|
188
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
189
|
-
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
190
|
-
return self._parallel_execute_with_currency_control(
|
191
|
-
self._sample_single,
|
192
|
-
prompts,
|
193
|
-
)
|
194
|
-
|
195
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
196
|
-
"""Samples a single prompt."""
|
197
|
-
model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
|
198
|
-
input_content = self._content_from_message(prompt)
|
199
|
-
response = model.generate_content(
|
200
|
-
input_content,
|
201
|
-
generation_config=self._generation_config(self.sampling_options),
|
202
|
-
)
|
203
|
-
return self._response_to_result(response)
|
204
|
-
|
205
|
-
|
206
|
-
class _LegacyGenerativeModel(pg.Object):
|
207
|
-
"""Base for legacy GenAI generative model."""
|
208
|
-
|
209
|
-
model: str
|
210
|
-
|
211
|
-
def generate_content(
|
212
|
-
self,
|
213
|
-
input_content: list[str | BlobDict],
|
214
|
-
generation_config: GenerationConfig,
|
215
|
-
) -> pg.Dict:
|
216
|
-
"""Generate content."""
|
217
|
-
segments = []
|
218
|
-
for s in input_content:
|
219
|
-
if not isinstance(s, str):
|
220
|
-
raise ValueError(f'Unsupported modality: {s!r}')
|
221
|
-
segments.append(s)
|
222
|
-
return self.generate(' '.join(segments), generation_config)
|
223
|
-
|
224
|
-
@abc.abstractmethod
|
225
|
-
def generate(
|
226
|
-
self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
|
227
|
-
"""Generate response based on prompt."""
|
228
|
-
|
229
|
-
|
230
|
-
class _LegacyCompletionModel(_LegacyGenerativeModel):
|
231
|
-
"""Legacy GenAI completion model."""
|
232
|
-
|
233
|
-
def generate(
|
234
|
-
self, prompt: str, generation_config: GenerationConfig
|
235
|
-
) -> pg.Dict:
|
236
|
-
assert genai is not None
|
237
|
-
completion: Completion = genai.generate_text(
|
238
|
-
model=f'models/{self.model}',
|
239
|
-
prompt=prompt,
|
240
|
-
temperature=generation_config.temperature,
|
241
|
-
top_k=generation_config.top_k,
|
242
|
-
top_p=generation_config.top_p,
|
243
|
-
candidate_count=generation_config.candidate_count,
|
244
|
-
max_output_tokens=generation_config.max_output_tokens,
|
245
|
-
stop_sequences=generation_config.stop_sequences,
|
246
|
-
)
|
247
|
-
return pg.Dict(
|
248
|
-
candidates=[
|
249
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
|
250
|
-
for c in completion.candidates
|
251
|
-
]
|
252
|
-
)
|
253
|
-
|
254
|
-
|
255
|
-
class _LegacyChatModel(_LegacyGenerativeModel):
|
256
|
-
"""Legacy GenAI chat model."""
|
257
|
-
|
258
|
-
def generate(
|
259
|
-
self, prompt: str, generation_config: GenerationConfig
|
260
|
-
) -> pg.Dict:
|
261
|
-
assert genai is not None
|
262
|
-
response: ChatResponse = genai.chat(
|
263
|
-
model=f'models/{self.model}',
|
264
|
-
messages=prompt,
|
265
|
-
temperature=generation_config.temperature,
|
266
|
-
top_k=generation_config.top_k,
|
267
|
-
top_p=generation_config.top_p,
|
268
|
-
candidate_count=generation_config.candidate_count,
|
269
|
-
)
|
270
|
-
return pg.Dict(
|
271
|
-
candidates=[
|
272
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
|
273
|
-
for c in response.candidates
|
274
|
-
]
|
59
|
+
return (
|
60
|
+
f'https://generativelanguage.googleapis.com/{self.api_version}'
|
61
|
+
f'/models/{self.model}:generateContent?'
|
62
|
+
f'key={api_key}'
|
275
63
|
)
|
276
64
|
|
277
65
|
|
278
|
-
class
|
279
|
-
"""
|
280
|
-
|
281
|
-
def __init__(self):
|
282
|
-
self._model_cache = {}
|
283
|
-
|
284
|
-
def get(
|
285
|
-
self, model_name: str
|
286
|
-
) -> GenerativeModel | _LegacyGenerativeModel:
|
287
|
-
"""Gets a generative model by model id."""
|
288
|
-
assert genai is not None
|
289
|
-
model = self._model_cache.get(model_name, None)
|
290
|
-
if model is None:
|
291
|
-
model_info = genai.get_model(f'models/{model_name}')
|
292
|
-
if 'generateContent' in model_info.supported_generation_methods:
|
293
|
-
model = genai.GenerativeModel(model_name)
|
294
|
-
elif 'generateText' in model_info.supported_generation_methods:
|
295
|
-
model = _LegacyCompletionModel(model_name)
|
296
|
-
elif 'generateMessage' in model_info.supported_generation_methods:
|
297
|
-
model = _LegacyChatModel(model_name)
|
298
|
-
else:
|
299
|
-
raise ValueError(f'Unsupported model: {model_name!r}')
|
300
|
-
self._model_cache[model_name] = model
|
301
|
-
return model
|
302
|
-
|
66
|
+
class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
|
67
|
+
"""Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
303
68
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
#
|
308
|
-
# Public Gemini models.
|
309
|
-
#
|
69
|
+
api_version = 'v1alpha'
|
70
|
+
model = 'gemini-2.0-flash-thinking-exp-1219'
|
310
71
|
|
311
72
|
|
312
73
|
class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
|
313
|
-
"""Gemini
|
74
|
+
"""Gemini Flash 2.0 model launched on 12/11/2024."""
|
314
75
|
|
315
76
|
model = 'gemini-2.0-flash-exp'
|
316
|
-
supported_modalities = (
|
317
|
-
vertexai.DOCUMENT_TYPES
|
318
|
-
+ vertexai.IMAGE_TYPES
|
319
|
-
+ vertexai.AUDIO_TYPES
|
320
|
-
+ vertexai.VIDEO_TYPES
|
321
|
-
)
|
322
77
|
|
323
78
|
|
324
79
|
class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
|
325
80
|
"""Gemini Experimental model launched on 12/06/2024."""
|
326
81
|
|
327
82
|
model = 'gemini-exp-1206'
|
328
|
-
supported_modalities = (
|
329
|
-
vertexai.DOCUMENT_TYPES
|
330
|
-
+ vertexai.IMAGE_TYPES
|
331
|
-
+ vertexai.AUDIO_TYPES
|
332
|
-
+ vertexai.VIDEO_TYPES
|
333
|
-
)
|
334
83
|
|
335
84
|
|
336
85
|
class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
|
337
86
|
"""Gemini Experimental model launched on 11/14/2024."""
|
338
87
|
|
339
88
|
model = 'gemini-exp-1114'
|
340
|
-
supported_modalities = (
|
341
|
-
vertexai.DOCUMENT_TYPES
|
342
|
-
+ vertexai.IMAGE_TYPES
|
343
|
-
+ vertexai.AUDIO_TYPES
|
344
|
-
+ vertexai.VIDEO_TYPES
|
345
|
-
)
|
346
89
|
|
347
90
|
|
348
91
|
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
349
92
|
"""Gemini Pro latest model."""
|
350
93
|
|
351
94
|
model = 'gemini-1.5-pro-latest'
|
352
|
-
supported_modalities = (
|
353
|
-
vertexai.DOCUMENT_TYPES
|
354
|
-
+ vertexai.IMAGE_TYPES
|
355
|
-
+ vertexai.AUDIO_TYPES
|
356
|
-
+ vertexai.VIDEO_TYPES
|
357
|
-
)
|
358
95
|
|
359
96
|
|
360
|
-
class
|
361
|
-
"""Gemini
|
97
|
+
class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
|
98
|
+
"""Gemini Pro latest model."""
|
99
|
+
|
100
|
+
model = 'gemini-1.5-pro-002'
|
362
101
|
|
363
|
-
model = 'gemini-1.5-flash-latest'
|
364
|
-
supported_modalities = (
|
365
|
-
vertexai.DOCUMENT_TYPES
|
366
|
-
+ vertexai.IMAGE_TYPES
|
367
|
-
+ vertexai.AUDIO_TYPES
|
368
|
-
+ vertexai.VIDEO_TYPES
|
369
|
-
)
|
370
102
|
|
103
|
+
class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
|
104
|
+
"""Gemini Pro latest model."""
|
371
105
|
|
372
|
-
|
373
|
-
"""Gemini Pro model."""
|
106
|
+
model = 'gemini-1.5-pro-001'
|
374
107
|
|
375
|
-
|
108
|
+
|
109
|
+
class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
|
110
|
+
"""Gemini Flash latest model."""
|
111
|
+
|
112
|
+
model = 'gemini-1.5-flash-latest'
|
376
113
|
|
377
114
|
|
378
|
-
class
|
379
|
-
"""Gemini
|
115
|
+
class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
|
116
|
+
"""Gemini Flash 1.5 model stable version 002."""
|
380
117
|
|
381
|
-
model = 'gemini-
|
382
|
-
supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
|
118
|
+
model = 'gemini-1.5-flash-002'
|
383
119
|
|
384
120
|
|
385
|
-
class
|
386
|
-
"""
|
121
|
+
class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
|
122
|
+
"""Gemini Flash 1.5 model stable version 001."""
|
387
123
|
|
388
|
-
model = '
|
124
|
+
model = 'gemini-1.5-flash-001'
|
389
125
|
|
390
126
|
|
391
|
-
class
|
392
|
-
"""
|
127
|
+
class GeminiPro1(GenAI): # pylint: disable=invalid-name
|
128
|
+
"""Gemini 1.0 Pro model."""
|
393
129
|
|
394
|
-
model = '
|
130
|
+
model = 'gemini-1.0-pro'
|