langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +15 -6
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +25 -26
  14. langfun/core/llms/cache/in_memory.py +6 -0
  15. langfun/core/llms/cache/in_memory_test.py +5 -0
  16. langfun/core/llms/deepseek.py +261 -0
  17. langfun/core/llms/deepseek_test.py +438 -0
  18. langfun/core/llms/gemini.py +507 -0
  19. langfun/core/llms/gemini_test.py +195 -0
  20. langfun/core/llms/google_genai.py +46 -320
  21. langfun/core/llms/google_genai_test.py +9 -204
  22. langfun/core/llms/openai.py +5 -0
  23. langfun/core/llms/vertexai.py +31 -359
  24. langfun/core/llms/vertexai_test.py +6 -166
  25. langfun/core/structured/mapping.py +13 -13
  26. langfun/core/structured/mapping_test.py +2 -2
  27. langfun/core/structured/schema.py +16 -8
  28. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
  29. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
  30. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
  31. langfun/core/text_formatting.py +0 -168
  32. langfun/core/text_formatting_test.py +0 -65
  33. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
  34. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,57 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Gemini models exposed through Google Generative AI APIs."""
14
+ """Language models from Google GenAI."""
15
15
 
16
- import abc
17
- import functools
18
16
  import os
19
- from typing import Annotated, Any, Literal
17
+ from typing import Annotated, Literal
20
18
 
21
19
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import vertexai
20
+ from langfun.core.llms import gemini
24
21
  import pyglove as pg
25
22
 
26
23
 
27
- try:
28
- import google.generativeai as genai # pylint: disable=g-import-not-at-top
29
- BlobDict = genai.types.BlobDict
30
- GenerativeModel = genai.GenerativeModel
31
- Completion = getattr(genai.types, 'Completion', Any)
32
- ChatResponse = getattr(genai.types, 'ChatResponse', Any)
33
- GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
34
- GenerationConfig = genai.GenerationConfig
35
- except ImportError:
36
- genai = None
37
- BlobDict = Any
38
- GenerativeModel = Any
39
- Completion = Any
40
- ChatResponse = Any
41
- GenerationConfig = Any
42
- GenerateContentResponse = Any
43
-
44
-
45
24
  @lf.use_init_args(['model'])
46
- class GenAI(lf.LanguageModel):
25
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
26
+ class GenAI(gemini.Gemini):
47
27
  """Language models provided by Google GenAI."""
48
28
 
49
- model: Annotated[
50
- Literal[
51
- 'gemini-2.0-flash-thinking-exp-1219',
52
- 'gemini-2.0-flash-exp',
53
- 'gemini-exp-1206',
54
- 'gemini-exp-1114',
55
- 'gemini-1.5-pro-latest',
56
- 'gemini-1.5-flash-latest',
57
- 'gemini-pro',
58
- 'gemini-pro-vision',
59
- 'text-bison-001',
60
- 'chat-bison-001',
61
- ],
62
- 'Model name.',
63
- ]
64
-
65
29
  api_key: Annotated[
66
30
  str | None,
67
31
  (
@@ -71,26 +35,18 @@ class GenAI(lf.LanguageModel):
71
35
  ),
72
36
  ] = None
73
37
 
74
- supported_modalities: Annotated[
75
- list[str],
76
- 'A list of MIME types for supported modalities'
77
- ] = []
38
+ api_version: Annotated[
39
+ Literal['v1beta', 'v1alpha'],
40
+ 'The API version to use.'
41
+ ] = 'v1beta'
78
42
 
79
- # Set the default max concurrency to 8 workers.
80
- max_concurrency = 8
81
-
82
- def _on_bound(self):
83
- super()._on_bound()
84
- if genai is None:
85
- raise RuntimeError(
86
- 'Please install "langfun[llm-google-genai]" to use '
87
- 'Google Generative AI models.'
88
- )
89
- self.__dict__.pop('_api_initialized', None)
43
+ @property
44
+ def model_id(self) -> str:
45
+ """Returns a string to identify the model."""
46
+ return f'GenAI({self.model})'
90
47
 
91
- @functools.cached_property
92
- def _api_initialized(self):
93
- assert genai is not None
48
+ @property
49
+ def api_endpoint(self) -> str:
94
50
  api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
95
51
  if not api_key:
96
52
  raise ValueError(
@@ -100,306 +56,76 @@ class GenAI(lf.LanguageModel):
100
56
  'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
101
57
  'for more details.'
102
58
  )
103
- genai.configure(api_key=api_key)
104
- return True
105
-
106
- @classmethod
107
- def dir(cls) -> list[str]:
108
- """Lists generative models."""
109
- assert genai is not None
110
- return [
111
- m.name.lstrip('models/')
112
- for m in genai.list_models()
113
- if (
114
- 'generateContent' in m.supported_generation_methods
115
- or 'generateText' in m.supported_generation_methods
116
- or 'generateMessage' in m.supported_generation_methods
117
- )
118
- ]
119
-
120
- @property
121
- def model_id(self) -> str:
122
- """Returns a string to identify the model."""
123
- return self.model
124
-
125
- @property
126
- def resource_id(self) -> str:
127
- """Returns a string to identify the resource for rate control."""
128
- return self.model_id
129
-
130
- def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
131
- """Creates generation config from langfun sampling options."""
132
- return GenerationConfig(
133
- candidate_count=options.n,
134
- temperature=options.temperature,
135
- top_p=options.top_p,
136
- top_k=options.top_k,
137
- max_output_tokens=options.max_tokens,
138
- stop_sequences=options.stop,
59
+ return (
60
+ f'https://generativelanguage.googleapis.com/{self.api_version}'
61
+ f'/models/{self.model}:generateContent?'
62
+ f'key={api_key}'
139
63
  )
140
64
 
141
- def _content_from_message(
142
- self, prompt: lf.Message
143
- ) -> list[str | BlobDict]:
144
- """Gets Evergreen formatted content from langfun message."""
145
- formatted = lf.UserMessage(prompt.text)
146
- formatted.source = prompt
147
-
148
- chunks = []
149
- for lf_chunk in formatted.chunk():
150
- if isinstance(lf_chunk, str):
151
- chunks.append(lf_chunk)
152
- elif isinstance(lf_chunk, lf_modalities.Mime):
153
- try:
154
- modalities = lf_chunk.make_compatible(
155
- self.supported_modalities + ['text/plain']
156
- )
157
- if isinstance(modalities, lf_modalities.Mime):
158
- modalities = [modalities]
159
- for modality in modalities:
160
- if modality.is_text:
161
- chunk = modality.to_text()
162
- else:
163
- chunk = BlobDict(
164
- data=modality.to_bytes(),
165
- mime_type=modality.mime_type
166
- )
167
- chunks.append(chunk)
168
- except lf.ModalityError as e:
169
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
170
- else:
171
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
172
- return chunks
173
-
174
- def _response_to_result(
175
- self, response: GenerateContentResponse | pg.Dict
176
- ) -> lf.LMSamplingResult:
177
- """Parses generative response into message."""
178
- samples = []
179
- for candidate in response.candidates:
180
- chunks = []
181
- for part in candidate.content.parts:
182
- # TODO(daiyip): support multi-modal parts when they are available via
183
- # Gemini API.
184
- if hasattr(part, 'text'):
185
- chunks.append(part.text)
186
- samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
187
- return lf.LMSamplingResult(samples)
188
-
189
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
- assert self._api_initialized, 'Vertex AI API is not initialized.'
191
- return self._parallel_execute_with_currency_control(
192
- self._sample_single,
193
- prompts,
194
- )
195
65
 
196
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
197
- """Samples a single prompt."""
198
- model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
199
- input_content = self._content_from_message(prompt)
200
- response = model.generate_content(
201
- input_content,
202
- generation_config=self._generation_config(self.sampling_options),
203
- )
204
- return self._response_to_result(response)
205
-
206
-
207
- class _LegacyGenerativeModel(pg.Object):
208
- """Base for legacy GenAI generative model."""
209
-
210
- model: str
211
-
212
- def generate_content(
213
- self,
214
- input_content: list[str | BlobDict],
215
- generation_config: GenerationConfig,
216
- ) -> pg.Dict:
217
- """Generate content."""
218
- segments = []
219
- for s in input_content:
220
- if not isinstance(s, str):
221
- raise ValueError(f'Unsupported modality: {s!r}')
222
- segments.append(s)
223
- return self.generate(' '.join(segments), generation_config)
224
-
225
- @abc.abstractmethod
226
- def generate(
227
- self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
228
- """Generate response based on prompt."""
229
-
230
-
231
- class _LegacyCompletionModel(_LegacyGenerativeModel):
232
- """Legacy GenAI completion model."""
233
-
234
- def generate(
235
- self, prompt: str, generation_config: GenerationConfig
236
- ) -> pg.Dict:
237
- assert genai is not None
238
- completion: Completion = genai.generate_text(
239
- model=f'models/{self.model}',
240
- prompt=prompt,
241
- temperature=generation_config.temperature,
242
- top_k=generation_config.top_k,
243
- top_p=generation_config.top_p,
244
- candidate_count=generation_config.candidate_count,
245
- max_output_tokens=generation_config.max_output_tokens,
246
- stop_sequences=generation_config.stop_sequences,
247
- )
248
- return pg.Dict(
249
- candidates=[
250
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
251
- for c in completion.candidates
252
- ]
253
- )
254
-
255
-
256
- class _LegacyChatModel(_LegacyGenerativeModel):
257
- """Legacy GenAI chat model."""
258
-
259
- def generate(
260
- self, prompt: str, generation_config: GenerationConfig
261
- ) -> pg.Dict:
262
- assert genai is not None
263
- response: ChatResponse = genai.chat(
264
- model=f'models/{self.model}',
265
- messages=prompt,
266
- temperature=generation_config.temperature,
267
- top_k=generation_config.top_k,
268
- top_p=generation_config.top_p,
269
- candidate_count=generation_config.candidate_count,
270
- )
271
- return pg.Dict(
272
- candidates=[
273
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
274
- for c in response.candidates
275
- ]
276
- )
277
-
278
-
279
- class _ModelHub:
280
- """Google Generative AI model hub."""
281
-
282
- def __init__(self):
283
- self._model_cache = {}
284
-
285
- def get(
286
- self, model_name: str
287
- ) -> GenerativeModel | _LegacyGenerativeModel:
288
- """Gets a generative model by model id."""
289
- assert genai is not None
290
- model = self._model_cache.get(model_name, None)
291
- if model is None:
292
- model_info = genai.get_model(f'models/{model_name}')
293
- if 'generateContent' in model_info.supported_generation_methods:
294
- model = genai.GenerativeModel(model_name)
295
- elif 'generateText' in model_info.supported_generation_methods:
296
- model = _LegacyCompletionModel(model_name)
297
- elif 'generateMessage' in model_info.supported_generation_methods:
298
- model = _LegacyChatModel(model_name)
299
- else:
300
- raise ValueError(f'Unsupported model: {model_name!r}')
301
- self._model_cache[model_name] = model
302
- return model
303
-
304
-
305
- _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
306
-
307
-
308
- #
309
- # Public Gemini models.
310
- #
311
- class GeminiFlash2_0ThinkingExp(GenAI): # pylint: disable=invalid-name
312
- """Gemini 2.0 Flash Thinking Experimental model."""
66
+ class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
67
+ """Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
313
68
 
69
+ api_version = 'v1alpha'
314
70
  model = 'gemini-2.0-flash-thinking-exp-1219'
315
- supported_modalities = (
316
- vertexai.DOCUMENT_TYPES
317
- + vertexai.IMAGE_TYPES
318
- + vertexai.AUDIO_TYPES
319
- + vertexai.VIDEO_TYPES
320
- )
71
+ timeout = None
321
72
 
322
73
 
323
74
  class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
324
- """Gemini Experimental model launched on 12/06/2024."""
75
+ """Gemini Flash 2.0 model launched on 12/11/2024."""
325
76
 
326
77
  model = 'gemini-2.0-flash-exp'
327
- supported_modalities = (
328
- vertexai.DOCUMENT_TYPES
329
- + vertexai.IMAGE_TYPES
330
- + vertexai.AUDIO_TYPES
331
- + vertexai.VIDEO_TYPES
332
- )
333
78
 
334
79
 
335
80
  class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
336
81
  """Gemini Experimental model launched on 12/06/2024."""
337
82
 
338
83
  model = 'gemini-exp-1206'
339
- supported_modalities = (
340
- vertexai.DOCUMENT_TYPES
341
- + vertexai.IMAGE_TYPES
342
- + vertexai.AUDIO_TYPES
343
- + vertexai.VIDEO_TYPES
344
- )
345
84
 
346
85
 
347
86
  class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
348
87
  """Gemini Experimental model launched on 11/14/2024."""
349
88
 
350
89
  model = 'gemini-exp-1114'
351
- supported_modalities = (
352
- vertexai.DOCUMENT_TYPES
353
- + vertexai.IMAGE_TYPES
354
- + vertexai.AUDIO_TYPES
355
- + vertexai.VIDEO_TYPES
356
- )
357
90
 
358
91
 
359
92
  class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
360
93
  """Gemini Pro latest model."""
361
94
 
362
95
  model = 'gemini-1.5-pro-latest'
363
- supported_modalities = (
364
- vertexai.DOCUMENT_TYPES
365
- + vertexai.IMAGE_TYPES
366
- + vertexai.AUDIO_TYPES
367
- + vertexai.VIDEO_TYPES
368
- )
369
96
 
370
97
 
371
- class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
372
- """Gemini Flash latest model."""
98
+ class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
99
+ """Gemini Pro latest model."""
100
+
101
+ model = 'gemini-1.5-pro-002'
373
102
 
374
- model = 'gemini-1.5-flash-latest'
375
- supported_modalities = (
376
- vertexai.DOCUMENT_TYPES
377
- + vertexai.IMAGE_TYPES
378
- + vertexai.AUDIO_TYPES
379
- + vertexai.VIDEO_TYPES
380
- )
381
103
 
104
+ class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
105
+ """Gemini Pro latest model."""
382
106
 
383
- class GeminiPro(GenAI):
384
- """Gemini Pro model."""
107
+ model = 'gemini-1.5-pro-001'
385
108
 
386
- model = 'gemini-pro'
109
+
110
+ class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
111
+ """Gemini Flash latest model."""
112
+
113
+ model = 'gemini-1.5-flash-latest'
387
114
 
388
115
 
389
- class GeminiProVision(GenAI):
390
- """Gemini Pro vision model."""
116
+ class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
117
+ """Gemini Flash 1.5 model stable version 002."""
391
118
 
392
- model = 'gemini-pro-vision'
393
- supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
119
+ model = 'gemini-1.5-flash-002'
394
120
 
395
121
 
396
- class Palm2(GenAI):
397
- """PaLM2 model."""
122
+ class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
123
+ """Gemini Flash 1.5 model stable version 001."""
398
124
 
399
- model = 'text-bison-001'
125
+ model = 'gemini-1.5-flash-001'
400
126
 
401
127
 
402
- class Palm2_IT(GenAI): # pylint: disable=invalid-name
403
- """PaLM2 instruction-tuned model."""
128
+ class GeminiPro1(GenAI): # pylint: disable=invalid-name
129
+ """Gemini 1.0 Pro model."""
404
130
 
405
- model = 'chat-bison-001'
131
+ model = 'gemini-1.0-pro'
@@ -11,223 +11,28 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for Google GenAI models."""
15
15
 
16
16
  import os
17
17
  import unittest
18
- from unittest import mock
19
-
20
- from google import generativeai as genai
21
- import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
18
  from langfun.core.llms import google_genai
24
- import pyglove as pg
25
-
26
-
27
- example_image = (
28
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
- )
37
-
38
-
39
- def mock_get_model(model_name, *args, **kwargs):
40
- del args, kwargs
41
- if 'gemini' in model_name:
42
- method = 'generateContent'
43
- elif 'chat' in model_name:
44
- method = 'generateMessage'
45
- else:
46
- method = 'generateText'
47
- return pg.Dict(supported_generation_methods=[method])
48
-
49
-
50
- def mock_generate_text(*, model, prompt, **kwargs):
51
- return pg.Dict(
52
- candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
- )
54
-
55
-
56
- def mock_chat(*, model, messages, **kwargs):
57
- return pg.Dict(
58
- candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
- )
60
-
61
-
62
- def mock_generate_content(content, generation_config, **kwargs):
63
- del kwargs
64
- c = generation_config
65
- return genai.types.GenerateContentResponse(
66
- done=True,
67
- iterator=None,
68
- chunks=[],
69
- result=pg.Dict(
70
- prompt_feedback=pg.Dict(block_reason=None),
71
- candidates=[
72
- pg.Dict(
73
- content=pg.Dict(
74
- parts=[
75
- pg.Dict(
76
- text=(
77
- f'This is a response to {content[0]} with '
78
- f'n={c.candidate_count}, '
79
- f'temperature={c.temperature}, '
80
- f'top_p={c.top_p}, '
81
- f'top_k={c.top_k}, '
82
- f'max_tokens={c.max_output_tokens}, '
83
- f'stop={c.stop_sequences}.'
84
- )
85
- )
86
- ]
87
- ),
88
- ),
89
- ],
90
- ),
91
- )
92
19
 
93
20
 
94
21
  class GenAITest(unittest.TestCase):
95
- """Tests for Google GenAI model."""
96
-
97
- def test_content_from_message_text_only(self):
98
- text = 'This is a beautiful day'
99
- model = google_genai.GeminiPro()
100
- chunks = model._content_from_message(lf.UserMessage(text))
101
- self.assertEqual(chunks, [text])
102
-
103
- def test_content_from_message_mm(self):
104
- message = lf.UserMessage(
105
- 'This is an <<[[image]]>>, what is it?',
106
- image=lf_modalities.Image.from_bytes(example_image),
107
- )
22
+ """Tests for GenAI model."""
108
23
 
109
- # Non-multimodal model.
110
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
111
- google_genai.GeminiPro()._content_from_message(message)
112
-
113
- model = google_genai.GeminiProVision()
114
- chunks = model._content_from_message(message)
115
- self.maxDiff = None
116
- self.assertEqual(
117
- chunks,
118
- [
119
- 'This is an',
120
- genai.types.BlobDict(mime_type='image/png', data=example_image),
121
- ', what is it?',
122
- ],
123
- )
124
-
125
- def test_response_to_result_text_only(self):
126
- response = genai.types.GenerateContentResponse(
127
- done=True,
128
- iterator=None,
129
- chunks=[],
130
- result=pg.Dict(
131
- prompt_feedback=pg.Dict(block_reason=None),
132
- candidates=[
133
- pg.Dict(
134
- content=pg.Dict(
135
- parts=[pg.Dict(text='This is response 1.')]
136
- ),
137
- ),
138
- pg.Dict(
139
- content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
140
- ),
141
- ],
142
- ),
143
- )
144
- model = google_genai.GeminiProVision()
145
- result = model._response_to_result(response)
146
- self.assertEqual(
147
- result,
148
- lf.LMSamplingResult([
149
- lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
150
- lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
151
- ]),
152
- )
153
-
154
- def test_model_hub(self):
155
- orig_get_model = genai.get_model
156
- genai.get_model = mock_get_model
157
-
158
- model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
159
- self.assertIsNotNone(model)
160
- self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
161
-
162
- genai.get_model = orig_get_model
163
-
164
- def test_api_key_check(self):
24
+ def test_basics(self):
165
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
166
- _ = google_genai.GeminiPro()._api_initialized
26
+ _ = google_genai.GeminiPro1_5().api_endpoint
27
+
28
+ self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
167
29
 
168
- self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
169
30
  os.environ['GOOGLE_API_KEY'] = 'abc'
170
- self.assertTrue(google_genai.GeminiPro()._api_initialized)
31
+ lm = google_genai.GeminiPro1_5()
32
+ self.assertIsNotNone(lm.api_endpoint)
33
+ self.assertTrue(lm.model_id.startswith('GenAI('))
171
34
  del os.environ['GOOGLE_API_KEY']
172
35
 
173
- def test_call(self):
174
- with mock.patch(
175
- 'google.generativeai.GenerativeModel.generate_content',
176
- ) as mock_generate:
177
- orig_get_model = genai.get_model
178
- genai.get_model = mock_get_model
179
- mock_generate.side_effect = mock_generate_content
180
-
181
- lm = google_genai.GeminiPro(api_key='test_key')
182
- self.maxDiff = None
183
- self.assertEqual(
184
- lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
185
- (
186
- 'This is a response to hello with n=1, temperature=2.0, '
187
- 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
188
- ),
189
- )
190
- genai.get_model = orig_get_model
191
-
192
- def test_call_with_legacy_completion_model(self):
193
- orig_get_model = genai.get_model
194
- genai.get_model = mock_get_model
195
- orig_generate_text = getattr(genai, 'generate_text', None)
196
- if orig_generate_text is not None:
197
- genai.generate_text = mock_generate_text
198
-
199
- lm = google_genai.Palm2(api_key='test_key')
200
- self.maxDiff = None
201
- self.assertEqual(
202
- lm('hello', temperature=2.0, top_k=20).text,
203
- (
204
- "hello to models/text-bison-001 with {'temperature': 2.0, "
205
- "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
206
- "'max_output_tokens': None, 'stop_sequences': None}"
207
- ),
208
- )
209
- genai.generate_text = orig_generate_text
210
- genai.get_model = orig_get_model
211
-
212
- def test_call_with_legacy_chat_model(self):
213
- orig_get_model = genai.get_model
214
- genai.get_model = mock_get_model
215
- orig_chat = getattr(genai, 'chat', None)
216
- if orig_chat is not None:
217
- genai.chat = mock_chat
218
-
219
- lm = google_genai.Palm2_IT(api_key='test_key')
220
- self.maxDiff = None
221
- self.assertEqual(
222
- lm('hello', temperature=2.0, top_k=20).text,
223
- (
224
- "hello to models/chat-bison-001 with {'temperature': 2.0, "
225
- "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
226
- ),
227
- )
228
- genai.chat = orig_chat
229
- genai.get_model = orig_get_model
230
-
231
36
 
232
37
  if __name__ == '__main__':
233
38
  unittest.main()