langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -5
- langfun/core/coding/python/correction.py +4 -3
- langfun/core/coding/python/errors.py +10 -9
- langfun/core/coding/python/execution.py +23 -12
- langfun/core/coding/python/execution_test.py +21 -2
- langfun/core/coding/python/generation.py +18 -9
- langfun/core/concurrent.py +2 -3
- langfun/core/console.py +8 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/reporting.py +15 -6
- langfun/core/language_model.py +7 -4
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +25 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/deepseek.py +261 -0
- langfun/core/llms/deepseek_test.py +438 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +5 -0
- langfun/core/llms/vertexai.py +31 -359
- langfun/core/llms/vertexai_test.py +6 -166
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,57 +11,21 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Language models from Google GenAI."""
|
15
15
|
|
16
|
-
import abc
|
17
|
-
import functools
|
18
16
|
import os
|
19
|
-
from typing import Annotated,
|
17
|
+
from typing import Annotated, Literal
|
20
18
|
|
21
19
|
import langfun.core as lf
|
22
|
-
from langfun.core import
|
23
|
-
from langfun.core.llms import vertexai
|
20
|
+
from langfun.core.llms import gemini
|
24
21
|
import pyglove as pg
|
25
22
|
|
26
23
|
|
27
|
-
try:
|
28
|
-
import google.generativeai as genai # pylint: disable=g-import-not-at-top
|
29
|
-
BlobDict = genai.types.BlobDict
|
30
|
-
GenerativeModel = genai.GenerativeModel
|
31
|
-
Completion = getattr(genai.types, 'Completion', Any)
|
32
|
-
ChatResponse = getattr(genai.types, 'ChatResponse', Any)
|
33
|
-
GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
|
34
|
-
GenerationConfig = genai.GenerationConfig
|
35
|
-
except ImportError:
|
36
|
-
genai = None
|
37
|
-
BlobDict = Any
|
38
|
-
GenerativeModel = Any
|
39
|
-
Completion = Any
|
40
|
-
ChatResponse = Any
|
41
|
-
GenerationConfig = Any
|
42
|
-
GenerateContentResponse = Any
|
43
|
-
|
44
|
-
|
45
24
|
@lf.use_init_args(['model'])
|
46
|
-
|
25
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
26
|
+
class GenAI(gemini.Gemini):
|
47
27
|
"""Language models provided by Google GenAI."""
|
48
28
|
|
49
|
-
model: Annotated[
|
50
|
-
Literal[
|
51
|
-
'gemini-2.0-flash-thinking-exp-1219',
|
52
|
-
'gemini-2.0-flash-exp',
|
53
|
-
'gemini-exp-1206',
|
54
|
-
'gemini-exp-1114',
|
55
|
-
'gemini-1.5-pro-latest',
|
56
|
-
'gemini-1.5-flash-latest',
|
57
|
-
'gemini-pro',
|
58
|
-
'gemini-pro-vision',
|
59
|
-
'text-bison-001',
|
60
|
-
'chat-bison-001',
|
61
|
-
],
|
62
|
-
'Model name.',
|
63
|
-
]
|
64
|
-
|
65
29
|
api_key: Annotated[
|
66
30
|
str | None,
|
67
31
|
(
|
@@ -71,26 +35,18 @@ class GenAI(lf.LanguageModel):
|
|
71
35
|
),
|
72
36
|
] = None
|
73
37
|
|
74
|
-
|
75
|
-
|
76
|
-
'
|
77
|
-
] =
|
38
|
+
api_version: Annotated[
|
39
|
+
Literal['v1beta', 'v1alpha'],
|
40
|
+
'The API version to use.'
|
41
|
+
] = 'v1beta'
|
78
42
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
super()._on_bound()
|
84
|
-
if genai is None:
|
85
|
-
raise RuntimeError(
|
86
|
-
'Please install "langfun[llm-google-genai]" to use '
|
87
|
-
'Google Generative AI models.'
|
88
|
-
)
|
89
|
-
self.__dict__.pop('_api_initialized', None)
|
43
|
+
@property
|
44
|
+
def model_id(self) -> str:
|
45
|
+
"""Returns a string to identify the model."""
|
46
|
+
return f'GenAI({self.model})'
|
90
47
|
|
91
|
-
@
|
92
|
-
def
|
93
|
-
assert genai is not None
|
48
|
+
@property
|
49
|
+
def api_endpoint(self) -> str:
|
94
50
|
api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
|
95
51
|
if not api_key:
|
96
52
|
raise ValueError(
|
@@ -100,306 +56,76 @@ class GenAI(lf.LanguageModel):
|
|
100
56
|
'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
|
101
57
|
'for more details.'
|
102
58
|
)
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
def dir(cls) -> list[str]:
|
108
|
-
"""Lists generative models."""
|
109
|
-
assert genai is not None
|
110
|
-
return [
|
111
|
-
m.name.lstrip('models/')
|
112
|
-
for m in genai.list_models()
|
113
|
-
if (
|
114
|
-
'generateContent' in m.supported_generation_methods
|
115
|
-
or 'generateText' in m.supported_generation_methods
|
116
|
-
or 'generateMessage' in m.supported_generation_methods
|
117
|
-
)
|
118
|
-
]
|
119
|
-
|
120
|
-
@property
|
121
|
-
def model_id(self) -> str:
|
122
|
-
"""Returns a string to identify the model."""
|
123
|
-
return self.model
|
124
|
-
|
125
|
-
@property
|
126
|
-
def resource_id(self) -> str:
|
127
|
-
"""Returns a string to identify the resource for rate control."""
|
128
|
-
return self.model_id
|
129
|
-
|
130
|
-
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
131
|
-
"""Creates generation config from langfun sampling options."""
|
132
|
-
return GenerationConfig(
|
133
|
-
candidate_count=options.n,
|
134
|
-
temperature=options.temperature,
|
135
|
-
top_p=options.top_p,
|
136
|
-
top_k=options.top_k,
|
137
|
-
max_output_tokens=options.max_tokens,
|
138
|
-
stop_sequences=options.stop,
|
59
|
+
return (
|
60
|
+
f'https://generativelanguage.googleapis.com/{self.api_version}'
|
61
|
+
f'/models/{self.model}:generateContent?'
|
62
|
+
f'key={api_key}'
|
139
63
|
)
|
140
64
|
|
141
|
-
def _content_from_message(
|
142
|
-
self, prompt: lf.Message
|
143
|
-
) -> list[str | BlobDict]:
|
144
|
-
"""Gets Evergreen formatted content from langfun message."""
|
145
|
-
formatted = lf.UserMessage(prompt.text)
|
146
|
-
formatted.source = prompt
|
147
|
-
|
148
|
-
chunks = []
|
149
|
-
for lf_chunk in formatted.chunk():
|
150
|
-
if isinstance(lf_chunk, str):
|
151
|
-
chunks.append(lf_chunk)
|
152
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
153
|
-
try:
|
154
|
-
modalities = lf_chunk.make_compatible(
|
155
|
-
self.supported_modalities + ['text/plain']
|
156
|
-
)
|
157
|
-
if isinstance(modalities, lf_modalities.Mime):
|
158
|
-
modalities = [modalities]
|
159
|
-
for modality in modalities:
|
160
|
-
if modality.is_text:
|
161
|
-
chunk = modality.to_text()
|
162
|
-
else:
|
163
|
-
chunk = BlobDict(
|
164
|
-
data=modality.to_bytes(),
|
165
|
-
mime_type=modality.mime_type
|
166
|
-
)
|
167
|
-
chunks.append(chunk)
|
168
|
-
except lf.ModalityError as e:
|
169
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
170
|
-
else:
|
171
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
172
|
-
return chunks
|
173
|
-
|
174
|
-
def _response_to_result(
|
175
|
-
self, response: GenerateContentResponse | pg.Dict
|
176
|
-
) -> lf.LMSamplingResult:
|
177
|
-
"""Parses generative response into message."""
|
178
|
-
samples = []
|
179
|
-
for candidate in response.candidates:
|
180
|
-
chunks = []
|
181
|
-
for part in candidate.content.parts:
|
182
|
-
# TODO(daiyip): support multi-modal parts when they are available via
|
183
|
-
# Gemini API.
|
184
|
-
if hasattr(part, 'text'):
|
185
|
-
chunks.append(part.text)
|
186
|
-
samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
|
187
|
-
return lf.LMSamplingResult(samples)
|
188
|
-
|
189
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
-
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
191
|
-
return self._parallel_execute_with_currency_control(
|
192
|
-
self._sample_single,
|
193
|
-
prompts,
|
194
|
-
)
|
195
65
|
|
196
|
-
|
197
|
-
|
198
|
-
model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
|
199
|
-
input_content = self._content_from_message(prompt)
|
200
|
-
response = model.generate_content(
|
201
|
-
input_content,
|
202
|
-
generation_config=self._generation_config(self.sampling_options),
|
203
|
-
)
|
204
|
-
return self._response_to_result(response)
|
205
|
-
|
206
|
-
|
207
|
-
class _LegacyGenerativeModel(pg.Object):
|
208
|
-
"""Base for legacy GenAI generative model."""
|
209
|
-
|
210
|
-
model: str
|
211
|
-
|
212
|
-
def generate_content(
|
213
|
-
self,
|
214
|
-
input_content: list[str | BlobDict],
|
215
|
-
generation_config: GenerationConfig,
|
216
|
-
) -> pg.Dict:
|
217
|
-
"""Generate content."""
|
218
|
-
segments = []
|
219
|
-
for s in input_content:
|
220
|
-
if not isinstance(s, str):
|
221
|
-
raise ValueError(f'Unsupported modality: {s!r}')
|
222
|
-
segments.append(s)
|
223
|
-
return self.generate(' '.join(segments), generation_config)
|
224
|
-
|
225
|
-
@abc.abstractmethod
|
226
|
-
def generate(
|
227
|
-
self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
|
228
|
-
"""Generate response based on prompt."""
|
229
|
-
|
230
|
-
|
231
|
-
class _LegacyCompletionModel(_LegacyGenerativeModel):
|
232
|
-
"""Legacy GenAI completion model."""
|
233
|
-
|
234
|
-
def generate(
|
235
|
-
self, prompt: str, generation_config: GenerationConfig
|
236
|
-
) -> pg.Dict:
|
237
|
-
assert genai is not None
|
238
|
-
completion: Completion = genai.generate_text(
|
239
|
-
model=f'models/{self.model}',
|
240
|
-
prompt=prompt,
|
241
|
-
temperature=generation_config.temperature,
|
242
|
-
top_k=generation_config.top_k,
|
243
|
-
top_p=generation_config.top_p,
|
244
|
-
candidate_count=generation_config.candidate_count,
|
245
|
-
max_output_tokens=generation_config.max_output_tokens,
|
246
|
-
stop_sequences=generation_config.stop_sequences,
|
247
|
-
)
|
248
|
-
return pg.Dict(
|
249
|
-
candidates=[
|
250
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
|
251
|
-
for c in completion.candidates
|
252
|
-
]
|
253
|
-
)
|
254
|
-
|
255
|
-
|
256
|
-
class _LegacyChatModel(_LegacyGenerativeModel):
|
257
|
-
"""Legacy GenAI chat model."""
|
258
|
-
|
259
|
-
def generate(
|
260
|
-
self, prompt: str, generation_config: GenerationConfig
|
261
|
-
) -> pg.Dict:
|
262
|
-
assert genai is not None
|
263
|
-
response: ChatResponse = genai.chat(
|
264
|
-
model=f'models/{self.model}',
|
265
|
-
messages=prompt,
|
266
|
-
temperature=generation_config.temperature,
|
267
|
-
top_k=generation_config.top_k,
|
268
|
-
top_p=generation_config.top_p,
|
269
|
-
candidate_count=generation_config.candidate_count,
|
270
|
-
)
|
271
|
-
return pg.Dict(
|
272
|
-
candidates=[
|
273
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
|
274
|
-
for c in response.candidates
|
275
|
-
]
|
276
|
-
)
|
277
|
-
|
278
|
-
|
279
|
-
class _ModelHub:
|
280
|
-
"""Google Generative AI model hub."""
|
281
|
-
|
282
|
-
def __init__(self):
|
283
|
-
self._model_cache = {}
|
284
|
-
|
285
|
-
def get(
|
286
|
-
self, model_name: str
|
287
|
-
) -> GenerativeModel | _LegacyGenerativeModel:
|
288
|
-
"""Gets a generative model by model id."""
|
289
|
-
assert genai is not None
|
290
|
-
model = self._model_cache.get(model_name, None)
|
291
|
-
if model is None:
|
292
|
-
model_info = genai.get_model(f'models/{model_name}')
|
293
|
-
if 'generateContent' in model_info.supported_generation_methods:
|
294
|
-
model = genai.GenerativeModel(model_name)
|
295
|
-
elif 'generateText' in model_info.supported_generation_methods:
|
296
|
-
model = _LegacyCompletionModel(model_name)
|
297
|
-
elif 'generateMessage' in model_info.supported_generation_methods:
|
298
|
-
model = _LegacyChatModel(model_name)
|
299
|
-
else:
|
300
|
-
raise ValueError(f'Unsupported model: {model_name!r}')
|
301
|
-
self._model_cache[model_name] = model
|
302
|
-
return model
|
303
|
-
|
304
|
-
|
305
|
-
_GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
306
|
-
|
307
|
-
|
308
|
-
#
|
309
|
-
# Public Gemini models.
|
310
|
-
#
|
311
|
-
class GeminiFlash2_0ThinkingExp(GenAI): # pylint: disable=invalid-name
|
312
|
-
"""Gemini 2.0 Flash Thinking Experimental model."""
|
66
|
+
class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
|
67
|
+
"""Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
313
68
|
|
69
|
+
api_version = 'v1alpha'
|
314
70
|
model = 'gemini-2.0-flash-thinking-exp-1219'
|
315
|
-
|
316
|
-
vertexai.DOCUMENT_TYPES
|
317
|
-
+ vertexai.IMAGE_TYPES
|
318
|
-
+ vertexai.AUDIO_TYPES
|
319
|
-
+ vertexai.VIDEO_TYPES
|
320
|
-
)
|
71
|
+
timeout = None
|
321
72
|
|
322
73
|
|
323
74
|
class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
|
324
|
-
"""Gemini
|
75
|
+
"""Gemini Flash 2.0 model launched on 12/11/2024."""
|
325
76
|
|
326
77
|
model = 'gemini-2.0-flash-exp'
|
327
|
-
supported_modalities = (
|
328
|
-
vertexai.DOCUMENT_TYPES
|
329
|
-
+ vertexai.IMAGE_TYPES
|
330
|
-
+ vertexai.AUDIO_TYPES
|
331
|
-
+ vertexai.VIDEO_TYPES
|
332
|
-
)
|
333
78
|
|
334
79
|
|
335
80
|
class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
|
336
81
|
"""Gemini Experimental model launched on 12/06/2024."""
|
337
82
|
|
338
83
|
model = 'gemini-exp-1206'
|
339
|
-
supported_modalities = (
|
340
|
-
vertexai.DOCUMENT_TYPES
|
341
|
-
+ vertexai.IMAGE_TYPES
|
342
|
-
+ vertexai.AUDIO_TYPES
|
343
|
-
+ vertexai.VIDEO_TYPES
|
344
|
-
)
|
345
84
|
|
346
85
|
|
347
86
|
class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
|
348
87
|
"""Gemini Experimental model launched on 11/14/2024."""
|
349
88
|
|
350
89
|
model = 'gemini-exp-1114'
|
351
|
-
supported_modalities = (
|
352
|
-
vertexai.DOCUMENT_TYPES
|
353
|
-
+ vertexai.IMAGE_TYPES
|
354
|
-
+ vertexai.AUDIO_TYPES
|
355
|
-
+ vertexai.VIDEO_TYPES
|
356
|
-
)
|
357
90
|
|
358
91
|
|
359
92
|
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
360
93
|
"""Gemini Pro latest model."""
|
361
94
|
|
362
95
|
model = 'gemini-1.5-pro-latest'
|
363
|
-
supported_modalities = (
|
364
|
-
vertexai.DOCUMENT_TYPES
|
365
|
-
+ vertexai.IMAGE_TYPES
|
366
|
-
+ vertexai.AUDIO_TYPES
|
367
|
-
+ vertexai.VIDEO_TYPES
|
368
|
-
)
|
369
96
|
|
370
97
|
|
371
|
-
class
|
372
|
-
"""Gemini
|
98
|
+
class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
|
99
|
+
"""Gemini Pro latest model."""
|
100
|
+
|
101
|
+
model = 'gemini-1.5-pro-002'
|
373
102
|
|
374
|
-
model = 'gemini-1.5-flash-latest'
|
375
|
-
supported_modalities = (
|
376
|
-
vertexai.DOCUMENT_TYPES
|
377
|
-
+ vertexai.IMAGE_TYPES
|
378
|
-
+ vertexai.AUDIO_TYPES
|
379
|
-
+ vertexai.VIDEO_TYPES
|
380
|
-
)
|
381
103
|
|
104
|
+
class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
|
105
|
+
"""Gemini Pro latest model."""
|
382
106
|
|
383
|
-
|
384
|
-
"""Gemini Pro model."""
|
107
|
+
model = 'gemini-1.5-pro-001'
|
385
108
|
|
386
|
-
|
109
|
+
|
110
|
+
class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
|
111
|
+
"""Gemini Flash latest model."""
|
112
|
+
|
113
|
+
model = 'gemini-1.5-flash-latest'
|
387
114
|
|
388
115
|
|
389
|
-
class
|
390
|
-
"""Gemini
|
116
|
+
class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
|
117
|
+
"""Gemini Flash 1.5 model stable version 002."""
|
391
118
|
|
392
|
-
model = 'gemini-
|
393
|
-
supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
|
119
|
+
model = 'gemini-1.5-flash-002'
|
394
120
|
|
395
121
|
|
396
|
-
class
|
397
|
-
"""
|
122
|
+
class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
|
123
|
+
"""Gemini Flash 1.5 model stable version 001."""
|
398
124
|
|
399
|
-
model = '
|
125
|
+
model = 'gemini-1.5-flash-001'
|
400
126
|
|
401
127
|
|
402
|
-
class
|
403
|
-
"""
|
128
|
+
class GeminiPro1(GenAI): # pylint: disable=invalid-name
|
129
|
+
"""Gemini 1.0 Pro model."""
|
404
130
|
|
405
|
-
model = '
|
131
|
+
model = 'gemini-1.0-pro'
|
@@ -11,223 +11,28 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for Google GenAI models."""
|
15
15
|
|
16
16
|
import os
|
17
17
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
|
-
from google import generativeai as genai
|
21
|
-
import langfun.core as lf
|
22
|
-
from langfun.core import modalities as lf_modalities
|
23
18
|
from langfun.core.llms import google_genai
|
24
|
-
import pyglove as pg
|
25
|
-
|
26
|
-
|
27
|
-
example_image = (
|
28
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
29
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
30
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
31
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
32
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
33
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
34
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
35
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
def mock_get_model(model_name, *args, **kwargs):
|
40
|
-
del args, kwargs
|
41
|
-
if 'gemini' in model_name:
|
42
|
-
method = 'generateContent'
|
43
|
-
elif 'chat' in model_name:
|
44
|
-
method = 'generateMessage'
|
45
|
-
else:
|
46
|
-
method = 'generateText'
|
47
|
-
return pg.Dict(supported_generation_methods=[method])
|
48
|
-
|
49
|
-
|
50
|
-
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
-
return pg.Dict(
|
52
|
-
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
-
)
|
54
|
-
|
55
|
-
|
56
|
-
def mock_chat(*, model, messages, **kwargs):
|
57
|
-
return pg.Dict(
|
58
|
-
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
def mock_generate_content(content, generation_config, **kwargs):
|
63
|
-
del kwargs
|
64
|
-
c = generation_config
|
65
|
-
return genai.types.GenerateContentResponse(
|
66
|
-
done=True,
|
67
|
-
iterator=None,
|
68
|
-
chunks=[],
|
69
|
-
result=pg.Dict(
|
70
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
71
|
-
candidates=[
|
72
|
-
pg.Dict(
|
73
|
-
content=pg.Dict(
|
74
|
-
parts=[
|
75
|
-
pg.Dict(
|
76
|
-
text=(
|
77
|
-
f'This is a response to {content[0]} with '
|
78
|
-
f'n={c.candidate_count}, '
|
79
|
-
f'temperature={c.temperature}, '
|
80
|
-
f'top_p={c.top_p}, '
|
81
|
-
f'top_k={c.top_k}, '
|
82
|
-
f'max_tokens={c.max_output_tokens}, '
|
83
|
-
f'stop={c.stop_sequences}.'
|
84
|
-
)
|
85
|
-
)
|
86
|
-
]
|
87
|
-
),
|
88
|
-
),
|
89
|
-
],
|
90
|
-
),
|
91
|
-
)
|
92
19
|
|
93
20
|
|
94
21
|
class GenAITest(unittest.TestCase):
|
95
|
-
"""Tests for
|
96
|
-
|
97
|
-
def test_content_from_message_text_only(self):
|
98
|
-
text = 'This is a beautiful day'
|
99
|
-
model = google_genai.GeminiPro()
|
100
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
101
|
-
self.assertEqual(chunks, [text])
|
102
|
-
|
103
|
-
def test_content_from_message_mm(self):
|
104
|
-
message = lf.UserMessage(
|
105
|
-
'This is an <<[[image]]>>, what is it?',
|
106
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
107
|
-
)
|
22
|
+
"""Tests for GenAI model."""
|
108
23
|
|
109
|
-
|
110
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
111
|
-
google_genai.GeminiPro()._content_from_message(message)
|
112
|
-
|
113
|
-
model = google_genai.GeminiProVision()
|
114
|
-
chunks = model._content_from_message(message)
|
115
|
-
self.maxDiff = None
|
116
|
-
self.assertEqual(
|
117
|
-
chunks,
|
118
|
-
[
|
119
|
-
'This is an',
|
120
|
-
genai.types.BlobDict(mime_type='image/png', data=example_image),
|
121
|
-
', what is it?',
|
122
|
-
],
|
123
|
-
)
|
124
|
-
|
125
|
-
def test_response_to_result_text_only(self):
|
126
|
-
response = genai.types.GenerateContentResponse(
|
127
|
-
done=True,
|
128
|
-
iterator=None,
|
129
|
-
chunks=[],
|
130
|
-
result=pg.Dict(
|
131
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
132
|
-
candidates=[
|
133
|
-
pg.Dict(
|
134
|
-
content=pg.Dict(
|
135
|
-
parts=[pg.Dict(text='This is response 1.')]
|
136
|
-
),
|
137
|
-
),
|
138
|
-
pg.Dict(
|
139
|
-
content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
|
140
|
-
),
|
141
|
-
],
|
142
|
-
),
|
143
|
-
)
|
144
|
-
model = google_genai.GeminiProVision()
|
145
|
-
result = model._response_to_result(response)
|
146
|
-
self.assertEqual(
|
147
|
-
result,
|
148
|
-
lf.LMSamplingResult([
|
149
|
-
lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
|
150
|
-
lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
|
151
|
-
]),
|
152
|
-
)
|
153
|
-
|
154
|
-
def test_model_hub(self):
|
155
|
-
orig_get_model = genai.get_model
|
156
|
-
genai.get_model = mock_get_model
|
157
|
-
|
158
|
-
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
159
|
-
self.assertIsNotNone(model)
|
160
|
-
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
161
|
-
|
162
|
-
genai.get_model = orig_get_model
|
163
|
-
|
164
|
-
def test_api_key_check(self):
|
24
|
+
def test_basics(self):
|
165
25
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
166
|
-
_ = google_genai.
|
26
|
+
_ = google_genai.GeminiPro1_5().api_endpoint
|
27
|
+
|
28
|
+
self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
|
167
29
|
|
168
|
-
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
169
30
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
170
|
-
|
31
|
+
lm = google_genai.GeminiPro1_5()
|
32
|
+
self.assertIsNotNone(lm.api_endpoint)
|
33
|
+
self.assertTrue(lm.model_id.startswith('GenAI('))
|
171
34
|
del os.environ['GOOGLE_API_KEY']
|
172
35
|
|
173
|
-
def test_call(self):
|
174
|
-
with mock.patch(
|
175
|
-
'google.generativeai.GenerativeModel.generate_content',
|
176
|
-
) as mock_generate:
|
177
|
-
orig_get_model = genai.get_model
|
178
|
-
genai.get_model = mock_get_model
|
179
|
-
mock_generate.side_effect = mock_generate_content
|
180
|
-
|
181
|
-
lm = google_genai.GeminiPro(api_key='test_key')
|
182
|
-
self.maxDiff = None
|
183
|
-
self.assertEqual(
|
184
|
-
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
185
|
-
(
|
186
|
-
'This is a response to hello with n=1, temperature=2.0, '
|
187
|
-
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
188
|
-
),
|
189
|
-
)
|
190
|
-
genai.get_model = orig_get_model
|
191
|
-
|
192
|
-
def test_call_with_legacy_completion_model(self):
|
193
|
-
orig_get_model = genai.get_model
|
194
|
-
genai.get_model = mock_get_model
|
195
|
-
orig_generate_text = getattr(genai, 'generate_text', None)
|
196
|
-
if orig_generate_text is not None:
|
197
|
-
genai.generate_text = mock_generate_text
|
198
|
-
|
199
|
-
lm = google_genai.Palm2(api_key='test_key')
|
200
|
-
self.maxDiff = None
|
201
|
-
self.assertEqual(
|
202
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
203
|
-
(
|
204
|
-
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
205
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
206
|
-
"'max_output_tokens': None, 'stop_sequences': None}"
|
207
|
-
),
|
208
|
-
)
|
209
|
-
genai.generate_text = orig_generate_text
|
210
|
-
genai.get_model = orig_get_model
|
211
|
-
|
212
|
-
def test_call_with_legacy_chat_model(self):
|
213
|
-
orig_get_model = genai.get_model
|
214
|
-
genai.get_model = mock_get_model
|
215
|
-
orig_chat = getattr(genai, 'chat', None)
|
216
|
-
if orig_chat is not None:
|
217
|
-
genai.chat = mock_chat
|
218
|
-
|
219
|
-
lm = google_genai.Palm2_IT(api_key='test_key')
|
220
|
-
self.maxDiff = None
|
221
|
-
self.assertEqual(
|
222
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
223
|
-
(
|
224
|
-
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
225
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
226
|
-
),
|
227
|
-
)
|
228
|
-
genai.chat = orig_chat
|
229
|
-
genai.get_model = orig_get_model
|
230
|
-
|
231
36
|
|
232
37
|
if __name__ == '__main__':
|
233
38
|
unittest.main()
|