langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,33 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Gemini models exposed through Google Generative AI APIs."""
14
+ """Language models from Google GenAI."""
15
15
 
16
- import abc
17
- import functools
18
16
  import os
19
- from typing import Annotated, Any, Literal
17
+ from typing import Annotated, Literal
20
18
 
21
- import google.generativeai as genai
22
19
  import langfun.core as lf
23
- from langfun.core import modalities as lf_modalities
20
+ from langfun.core.llms import gemini
24
21
  import pyglove as pg
25
22
 
26
23
 
27
24
  @lf.use_init_args(['model'])
28
- class GenAI(lf.LanguageModel):
25
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
26
+ class GenAI(gemini.Gemini):
29
27
  """Language models provided by Google GenAI."""
30
28
 
31
- model: Annotated[
32
- Literal[
33
- 'gemini-pro',
34
- 'gemini-pro-vision',
35
- 'text-bison-001',
36
- 'chat-bison-001',
37
- ],
38
- 'Model name.',
39
- ]
40
-
41
29
  api_key: Annotated[
42
30
  str | None,
43
31
  (
@@ -47,19 +35,18 @@ class GenAI(lf.LanguageModel):
47
35
  ),
48
36
  ] = None
49
37
 
50
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
51
- False
52
- )
53
-
54
- # Set the default max concurrency to 8 workers.
55
- max_concurrency = 8
38
+ api_version: Annotated[
39
+ Literal['v1beta', 'v1alpha'],
40
+ 'The API version to use.'
41
+ ] = 'v1beta'
56
42
 
57
- def _on_bound(self):
58
- super()._on_bound()
59
- self.__dict__.pop('_api_initialized', None)
43
+ @property
44
+ def model_id(self) -> str:
45
+ """Returns a string to identify the model."""
46
+ return f'GenAI({self.model})'
60
47
 
61
- @functools.cached_property
62
- def _api_initialized(self):
48
+ @property
49
+ def api_endpoint(self) -> str:
63
50
  api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
64
51
  if not api_key:
65
52
  raise ValueError(
@@ -69,219 +56,76 @@ class GenAI(lf.LanguageModel):
69
56
  'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
70
57
  'for more details.'
71
58
  )
72
- genai.configure(api_key=api_key)
73
- return True
74
-
75
- @classmethod
76
- def dir(cls) -> list[str]:
77
- """Lists generative models."""
78
- return [
79
- m.name.lstrip('models/')
80
- for m in genai.list_models()
81
- if (
82
- 'generateContent' in m.supported_generation_methods
83
- or 'generateText' in m.supported_generation_methods
84
- or 'generateMessage' in m.supported_generation_methods
85
- )
86
- ]
59
+ return (
60
+ f'https://generativelanguage.googleapis.com/{self.api_version}'
61
+ f'/models/{self.model}:generateContent?'
62
+ f'key={api_key}'
63
+ )
87
64
 
88
- @property
89
- def model_id(self) -> str:
90
- """Returns a string to identify the model."""
91
- return self.model
92
65
 
93
- @property
94
- def resource_id(self) -> str:
95
- """Returns a string to identify the resource for rate control."""
96
- return self.model_id
97
-
98
- def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
99
- """Creates generation config from langfun sampling options."""
100
- return genai.GenerationConfig(
101
- candidate_count=options.n,
102
- temperature=options.temperature,
103
- top_p=options.top_p,
104
- top_k=options.top_k,
105
- max_output_tokens=options.max_tokens,
106
- stop_sequences=options.stop,
107
- )
66
+ class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
67
+ """Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
108
68
 
109
- def _content_from_message(
110
- self, prompt: lf.Message
111
- ) -> list[str | genai.types.BlobDict]:
112
- """Gets Evergreen formatted content from langfun message."""
113
- formatted = lf.UserMessage(prompt.text)
114
- formatted.source = prompt
115
-
116
- chunks = []
117
- for lf_chunk in formatted.chunk():
118
- if isinstance(lf_chunk, str):
119
- chunk = lf_chunk
120
- elif self.multimodal and isinstance(lf_chunk, lf_modalities.MimeType):
121
- chunk = genai.types.BlobDict(
122
- data=lf_chunk.to_bytes(), mime_type=lf_chunk.mime_type
123
- )
124
- else:
125
- raise ValueError(f'Unsupported modality: {lf_chunk!r}')
126
- chunks.append(chunk)
127
- return chunks
128
-
129
- def _response_to_result(
130
- self, response: genai.types.GenerateContentResponse | pg.Dict
131
- ) -> lf.LMSamplingResult:
132
- """Parses generative response into message."""
133
- samples = []
134
- for candidate in response.candidates:
135
- chunks = []
136
- for part in candidate.content.parts:
137
- # TODO(daiyip): support multi-modal parts when they are available via
138
- # Gemini API.
139
- if hasattr(part, 'text'):
140
- chunks.append(part.text)
141
- samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
142
- return lf.LMSamplingResult(samples)
143
-
144
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
145
- assert self._api_initialized, 'Vertex AI API is not initialized.'
146
- return self._parallel_execute_with_currency_control(
147
- self._sample_single,
148
- prompts,
149
- )
69
+ api_version = 'v1alpha'
70
+ model = 'gemini-2.0-flash-thinking-exp-1219'
71
+ timeout = None
150
72
 
151
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
152
- """Samples a single prompt."""
153
- model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
154
- input_content = self._content_from_message(prompt)
155
- response = model.generate_content(
156
- input_content,
157
- generation_config=self._generation_config(self.sampling_options),
158
- )
159
- return self._response_to_result(response)
160
-
161
-
162
- class _LegacyGenerativeModel(pg.Object):
163
- """Base for legacy GenAI generative model."""
164
-
165
- model: str
166
-
167
- def generate_content(
168
- self,
169
- input_content: list[str | genai.types.BlobDict],
170
- generation_config: genai.GenerationConfig,
171
- ) -> pg.Dict:
172
- """Generate content."""
173
- segments = []
174
- for s in input_content:
175
- if not isinstance(s, str):
176
- raise ValueError(f'Unsupported modality: {s!r}')
177
- segments.append(s)
178
- return self.generate(' '.join(segments), generation_config)
179
-
180
- @abc.abstractmethod
181
- def generate(
182
- self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
183
- """Generate response based on prompt."""
184
-
185
-
186
- class _LegacyCompletionModel(_LegacyGenerativeModel):
187
- """Legacy GenAI completion model."""
188
-
189
- def generate(
190
- self, prompt: str, generation_config: genai.GenerationConfig
191
- ) -> pg.Dict:
192
- completion: genai.types.Completion = genai.generate_text(
193
- model=f'models/{self.model}',
194
- prompt=prompt,
195
- temperature=generation_config.temperature,
196
- top_k=generation_config.top_k,
197
- top_p=generation_config.top_p,
198
- candidate_count=generation_config.candidate_count,
199
- max_output_tokens=generation_config.max_output_tokens,
200
- stop_sequences=generation_config.stop_sequences,
201
- )
202
- return pg.Dict(
203
- candidates=[
204
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
205
- for c in completion.candidates
206
- ]
207
- )
208
73
 
74
+ class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
75
+ """Gemini Flash 2.0 model launched on 12/11/2024."""
209
76
 
210
- class _LegacyChatModel(_LegacyGenerativeModel):
211
- """Legacy GenAI chat model."""
77
+ model = 'gemini-2.0-flash-exp'
212
78
 
213
- def generate(
214
- self, prompt: str, generation_config: genai.GenerationConfig
215
- ) -> pg.Dict:
216
- response: genai.types.ChatResponse = genai.chat(
217
- model=f'models/{self.model}',
218
- messages=prompt,
219
- temperature=generation_config.temperature,
220
- top_k=generation_config.top_k,
221
- top_p=generation_config.top_p,
222
- candidate_count=generation_config.candidate_count,
223
- )
224
- return pg.Dict(
225
- candidates=[
226
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
227
- for c in response.candidates
228
- ]
229
- )
230
79
 
80
+ class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
81
+ """Gemini Experimental model launched on 12/06/2024."""
231
82
 
232
- class _ModelHub:
233
- """Google Generative AI model hub."""
83
+ model = 'gemini-exp-1206'
234
84
 
235
- def __init__(self):
236
- self._model_cache = {}
237
85
 
238
- def get(
239
- self, model_name: str
240
- ) -> genai.GenerativeModel | _LegacyGenerativeModel:
241
- """Gets a generative model by model id."""
242
- model = self._model_cache.get(model_name, None)
243
- if model is None:
244
- model_info = genai.get_model(f'models/{model_name}')
245
- if 'generateContent' in model_info.supported_generation_methods:
246
- model = genai.GenerativeModel(model_name)
247
- elif 'generateText' in model_info.supported_generation_methods:
248
- model = _LegacyCompletionModel(model_name)
249
- elif 'generateMessage' in model_info.supported_generation_methods:
250
- model = _LegacyChatModel(model_name)
251
- else:
252
- raise ValueError(f'Unsupported model: {model_name!r}')
253
- self._model_cache[model_name] = model
254
- return model
86
+ class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
87
+ """Gemini Experimental model launched on 11/14/2024."""
255
88
 
89
+ model = 'gemini-exp-1114'
256
90
 
257
- _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
258
91
 
92
+ class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
93
+ """Gemini Pro latest model."""
259
94
 
260
- #
261
- # Public Gemini models.
262
- #
95
+ model = 'gemini-1.5-pro-latest'
96
+
97
+
98
+ class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
99
+ """Gemini Pro latest model."""
100
+
101
+ model = 'gemini-1.5-pro-002'
102
+
103
+
104
+ class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
105
+ """Gemini Pro latest model."""
106
+
107
+ model = 'gemini-1.5-pro-001'
263
108
 
264
109
 
265
- class GeminiPro(GenAI):
266
- """Gemini Pro model."""
110
+ class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
111
+ """Gemini Flash latest model."""
267
112
 
268
- model = 'gemini-pro'
113
+ model = 'gemini-1.5-flash-latest'
269
114
 
270
115
 
271
- class GeminiProVision(GenAI):
272
- """Gemini Pro vision model."""
116
+ class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
117
+ """Gemini Flash 1.5 model stable version 002."""
273
118
 
274
- model = 'gemini-pro-vision'
275
- multimodal = True
119
+ model = 'gemini-1.5-flash-002'
276
120
 
277
121
 
278
- class Palm2(GenAI):
279
- """PaLM2 model."""
122
+ class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
123
+ """Gemini Flash 1.5 model stable version 001."""
280
124
 
281
- model = 'text-bison-001'
125
+ model = 'gemini-1.5-flash-001'
282
126
 
283
127
 
284
- class Palm2_IT(GenAI): # pylint: disable=invalid-name
285
- """PaLM2 instruction-tuned model."""
128
+ class GeminiPro1(GenAI): # pylint: disable=invalid-name
129
+ """Gemini 1.0 Pro model."""
286
130
 
287
- model = 'chat-bison-001'
131
+ model = 'gemini-1.0-pro'
@@ -11,216 +11,28 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for Google GenAI models."""
15
15
 
16
16
  import os
17
17
  import unittest
18
- from unittest import mock
19
-
20
- from google import generativeai as genai
21
- import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
18
  from langfun.core.llms import google_genai
24
- import pyglove as pg
25
-
26
-
27
- example_image = (
28
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
- )
37
-
38
-
39
- def mock_get_model(model_name, *args, **kwargs):
40
- del args, kwargs
41
- if 'gemini' in model_name:
42
- method = 'generateContent'
43
- elif 'chat' in model_name:
44
- method = 'generateMessage'
45
- else:
46
- method = 'generateText'
47
- return pg.Dict(supported_generation_methods=[method])
48
-
49
-
50
- def mock_generate_text(*, model, prompt, **kwargs):
51
- return pg.Dict(
52
- candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
- )
54
-
55
-
56
- def mock_chat(*, model, messages, **kwargs):
57
- return pg.Dict(
58
- candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
- )
60
-
61
-
62
- def mock_generate_content(content, generation_config, **kwargs):
63
- del kwargs
64
- c = generation_config
65
- return genai.types.GenerateContentResponse(
66
- done=True,
67
- iterator=None,
68
- chunks=[],
69
- result=pg.Dict(
70
- prompt_feedback=pg.Dict(block_reason=None),
71
- candidates=[
72
- pg.Dict(
73
- content=pg.Dict(
74
- parts=[
75
- pg.Dict(
76
- text=(
77
- f'This is a response to {content[0]} with '
78
- f'n={c.candidate_count}, '
79
- f'temperature={c.temperature}, '
80
- f'top_p={c.top_p}, '
81
- f'top_k={c.top_k}, '
82
- f'max_tokens={c.max_output_tokens}, '
83
- f'stop={c.stop_sequences}.'
84
- )
85
- )
86
- ]
87
- ),
88
- ),
89
- ],
90
- ),
91
- )
92
19
 
93
20
 
94
21
  class GenAITest(unittest.TestCase):
95
- """Tests for Google GenAI model."""
96
-
97
- def test_content_from_message_text_only(self):
98
- text = 'This is a beautiful day'
99
- model = google_genai.GeminiPro()
100
- chunks = model._content_from_message(lf.UserMessage(text))
101
- self.assertEqual(chunks, [text])
22
+ """Tests for GenAI model."""
102
23
 
103
- def test_content_from_message_mm(self):
104
- message = lf.UserMessage(
105
- 'This is an {{image}}, what is it?',
106
- image=lf_modalities.Image.from_bytes(example_image),
107
- )
108
-
109
- # Non-multimodal model.
110
- with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
111
- google_genai.GeminiPro()._content_from_message(message)
112
-
113
- model = google_genai.GeminiProVision()
114
- chunks = model._content_from_message(message)
115
- self.maxDiff = None
116
- self.assertEqual(
117
- chunks,
118
- [
119
- 'This is an',
120
- genai.types.BlobDict(mime_type='image/png', data=example_image),
121
- ', what is it?',
122
- ],
123
- )
124
-
125
- def test_response_to_result_text_only(self):
126
- response = genai.types.GenerateContentResponse(
127
- done=True,
128
- iterator=None,
129
- chunks=[],
130
- result=pg.Dict(
131
- prompt_feedback=pg.Dict(block_reason=None),
132
- candidates=[
133
- pg.Dict(
134
- content=pg.Dict(
135
- parts=[pg.Dict(text='This is response 1.')]
136
- ),
137
- ),
138
- pg.Dict(
139
- content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
140
- ),
141
- ],
142
- ),
143
- )
144
- model = google_genai.GeminiProVision()
145
- result = model._response_to_result(response)
146
- self.assertEqual(
147
- result,
148
- lf.LMSamplingResult([
149
- lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
150
- lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
151
- ]),
152
- )
153
-
154
- def test_model_hub(self):
155
- model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
156
- self.assertIsNotNone(model)
157
- self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
158
-
159
- def test_api_key_check(self):
24
+ def test_basics(self):
160
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
161
- _ = google_genai.GeminiPro()._api_initialized
26
+ _ = google_genai.GeminiPro1_5().api_endpoint
27
+
28
+ self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
162
29
 
163
- self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
164
30
  os.environ['GOOGLE_API_KEY'] = 'abc'
165
- self.assertTrue(google_genai.GeminiPro()._api_initialized)
31
+ lm = google_genai.GeminiPro1_5()
32
+ self.assertIsNotNone(lm.api_endpoint)
33
+ self.assertTrue(lm.model_id.startswith('GenAI('))
166
34
  del os.environ['GOOGLE_API_KEY']
167
35
 
168
- def test_call(self):
169
- with mock.patch(
170
- 'google.generativeai.generative_models.GenerativeModel.generate_content'
171
- ) as mock_generate:
172
- orig_get_model = genai.get_model
173
- genai.get_model = mock_get_model
174
- mock_generate.side_effect = mock_generate_content
175
-
176
- lm = google_genai.GeminiPro(api_key='test_key')
177
- self.maxDiff = None
178
- self.assertEqual(
179
- lm('hello', temperature=2.0, top_k=20).text,
180
- (
181
- 'This is a response to hello with n=1, temperature=2.0, '
182
- 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
183
- ),
184
- )
185
- genai.get_model = orig_get_model
186
-
187
- def test_call_with_legacy_completion_model(self):
188
- orig_get_model = genai.get_model
189
- genai.get_model = mock_get_model
190
- orig_generate_text = genai.generate_text
191
- genai.generate_text = mock_generate_text
192
-
193
- lm = google_genai.Palm2(api_key='test_key')
194
- self.maxDiff = None
195
- self.assertEqual(
196
- lm('hello', temperature=2.0, top_k=20).text,
197
- (
198
- "hello to models/text-bison-001 with {'temperature': 2.0, "
199
- "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
200
- "'max_output_tokens': 1024, 'stop_sequences': None}"
201
- ),
202
- )
203
- genai.get_model = orig_get_model
204
- genai.generate_text = orig_generate_text
205
-
206
- def test_call_with_legacy_chat_model(self):
207
- orig_get_model = genai.get_model
208
- genai.get_model = mock_get_model
209
- orig_chat = genai.chat
210
- genai.chat = mock_chat
211
-
212
- lm = google_genai.Palm2_IT(api_key='test_key')
213
- self.maxDiff = None
214
- self.assertEqual(
215
- lm('hello', temperature=2.0, top_k=20).text,
216
- (
217
- "hello to models/chat-bison-001 with {'temperature': 2.0, "
218
- "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
219
- ),
220
- )
221
- genai.get_model = orig_get_model
222
- genai.chat = orig_chat
223
-
224
36
 
225
37
  if __name__ == '__main__':
226
38
  unittest.main()