langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,195 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Gemini API."""
15
+
16
+ import base64
17
+ from typing import Any
18
+ import unittest
19
+ from unittest import mock
20
+
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import gemini
24
+ import pyglove as pg
25
+ import requests
26
+
27
+
28
+ example_image = (
29
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
30
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
31
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
32
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
33
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
34
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
35
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
36
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
37
+ )
38
+
39
+
40
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
41
+ del url, kwargs
42
+ c = pg.Dict(json['generationConfig'])
43
+ content = json['contents'][0]['parts'][0]['text']
44
+ response = requests.Response()
45
+ response.status_code = 200
46
+ response._content = pg.to_json_str({
47
+ 'candidates': [
48
+ {
49
+ 'content': {
50
+ 'role': 'model',
51
+ 'parts': [
52
+ {
53
+ 'text': (
54
+ f'This is a response to {content} with '
55
+ f'temperature={c.temperature}, '
56
+ f'top_p={c.topP}, '
57
+ f'top_k={c.topK}, '
58
+ f'max_tokens={c.maxOutputTokens}, '
59
+ f'stop={"".join(c.stopSequences)}.'
60
+ ),
61
+ },
62
+ {
63
+ 'text': 'This is the thought.',
64
+ 'thought': True,
65
+ }
66
+ ],
67
+ },
68
+ },
69
+ ],
70
+ 'usageMetadata': {
71
+ 'promptTokenCount': 3,
72
+ 'candidatesTokenCount': 4,
73
+ }
74
+ }).encode()
75
+ return response
76
+
77
+
78
+ class GeminiTest(unittest.TestCase):
79
+ """Tests for Vertex model with REST API."""
80
+
81
+ def test_content_from_message_text_only(self):
82
+ text = 'This is a beautiful day'
83
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
84
+ chunks = model._content_from_message(lf.UserMessage(text))
85
+ self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
86
+
87
+ def test_content_from_message_mm(self):
88
+ image = lf_modalities.Image.from_bytes(example_image)
89
+ message = lf.UserMessage(
90
+ 'This is an <<[[image]]>>, what is it?', image=image
91
+ )
92
+
93
+ # Non-multimodal model.
94
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
95
+ gemini.Gemini(
96
+ 'gemini-1.0-pro', api_endpoint=''
97
+ )._content_from_message(message)
98
+
99
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
100
+ content = model._content_from_message(message)
101
+ self.assertEqual(
102
+ content,
103
+ {
104
+ 'role': 'user',
105
+ 'parts': [
106
+ {'text': 'This is an'},
107
+ {
108
+ 'inlineData': {
109
+ 'data': base64.b64encode(example_image).decode(),
110
+ 'mimeType': 'image/png',
111
+ }
112
+ },
113
+ {'text': ', what is it?'},
114
+ ],
115
+ },
116
+ )
117
+
118
+ def test_generation_config(self):
119
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
120
+ json_schema = {
121
+ 'type': 'object',
122
+ 'properties': {
123
+ 'name': {'type': 'string'},
124
+ },
125
+ 'required': ['name'],
126
+ 'title': 'Person',
127
+ }
128
+ actual = model._generation_config(
129
+ lf.UserMessage('hi', json_schema=json_schema),
130
+ lf.LMSamplingOptions(
131
+ temperature=2.0,
132
+ top_p=1.0,
133
+ top_k=20,
134
+ max_tokens=1024,
135
+ stop=['\n'],
136
+ ),
137
+ )
138
+ self.assertEqual(
139
+ actual,
140
+ dict(
141
+ candidateCount=1,
142
+ temperature=2.0,
143
+ topP=1.0,
144
+ topK=20,
145
+ maxOutputTokens=1024,
146
+ stopSequences=['\n'],
147
+ responseLogprobs=False,
148
+ logprobs=None,
149
+ seed=None,
150
+ responseMimeType='application/json',
151
+ responseSchema={
152
+ 'type': 'object',
153
+ 'properties': {
154
+ 'name': {'type': 'string'}
155
+ },
156
+ 'required': ['name'],
157
+ 'title': 'Person',
158
+ }
159
+ ),
160
+ )
161
+ with self.assertRaisesRegex(
162
+ ValueError, '`json_schema` must be a dict, got'
163
+ ):
164
+ model._generation_config(
165
+ lf.UserMessage('hi', json_schema='not a dict'),
166
+ lf.LMSamplingOptions(),
167
+ )
168
+
169
+ def test_call_model(self):
170
+ with mock.patch('requests.Session.post') as mock_generate:
171
+ mock_generate.side_effect = mock_requests_post
172
+
173
+ lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
174
+ r = lm(
175
+ 'hello',
176
+ temperature=2.0,
177
+ top_p=1.0,
178
+ top_k=20,
179
+ max_tokens=1024,
180
+ stop='\n',
181
+ )
182
+ self.assertEqual(
183
+ r.text,
184
+ (
185
+ 'This is a response to hello with temperature=2.0, '
186
+ 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
187
+ ),
188
+ )
189
+ self.assertEqual(r.metadata.thought, 'This is the thought.')
190
+ self.assertEqual(r.metadata.usage.prompt_tokens, 3)
191
+ self.assertEqual(r.metadata.usage.completion_tokens, 4)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ unittest.main()
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,57 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Gemini models exposed through Google Generative AI APIs."""
14
+ """Language models from Google GenAI."""
15
15
 
16
- import abc
17
- import functools
18
16
  import os
19
- from typing import Annotated, Any, Literal
17
+ from typing import Annotated, Literal
20
18
 
21
19
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import vertexai
20
+ from langfun.core.llms import gemini
24
21
  import pyglove as pg
25
22
 
26
23
 
27
- try:
28
- import google.generativeai as genai # pylint: disable=g-import-not-at-top
29
- BlobDict = genai.types.BlobDict
30
- GenerativeModel = genai.GenerativeModel
31
- Completion = getattr(genai.types, 'Completion', Any)
32
- ChatResponse = getattr(genai.types, 'ChatResponse', Any)
33
- GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
34
- GenerationConfig = genai.GenerationConfig
35
- except ImportError:
36
- genai = None
37
- BlobDict = Any
38
- GenerativeModel = Any
39
- Completion = Any
40
- ChatResponse = Any
41
- GenerationConfig = Any
42
- GenerateContentResponse = Any
43
-
44
-
45
24
  @lf.use_init_args(['model'])
46
- class GenAI(lf.LanguageModel):
25
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
26
+ class GenAI(gemini.Gemini):
47
27
  """Language models provided by Google GenAI."""
48
28
 
49
- model: Annotated[
50
- Literal[
51
- 'gemini-2.0-flash-thinking-exp-1219',
52
- 'gemini-2.0-flash-exp',
53
- 'gemini-exp-1206',
54
- 'gemini-exp-1114',
55
- 'gemini-1.5-pro-latest',
56
- 'gemini-1.5-flash-latest',
57
- 'gemini-pro',
58
- 'gemini-pro-vision',
59
- 'text-bison-001',
60
- 'chat-bison-001',
61
- ],
62
- 'Model name.',
63
- ]
64
-
65
29
  api_key: Annotated[
66
30
  str | None,
67
31
  (
@@ -71,26 +35,18 @@ class GenAI(lf.LanguageModel):
71
35
  ),
72
36
  ] = None
73
37
 
74
- supported_modalities: Annotated[
75
- list[str],
76
- 'A list of MIME types for supported modalities'
77
- ] = []
38
+ api_version: Annotated[
39
+ Literal['v1beta', 'v1alpha'],
40
+ 'The API version to use.'
41
+ ] = 'v1beta'
78
42
 
79
- # Set the default max concurrency to 8 workers.
80
- max_concurrency = 8
81
-
82
- def _on_bound(self):
83
- super()._on_bound()
84
- if genai is None:
85
- raise RuntimeError(
86
- 'Please install "langfun[llm-google-genai]" to use '
87
- 'Google Generative AI models.'
88
- )
89
- self.__dict__.pop('_api_initialized', None)
43
+ @property
44
+ def model_id(self) -> str:
45
+ """Returns a string to identify the model."""
46
+ return f'GenAI({self.model})'
90
47
 
91
- @functools.cached_property
92
- def _api_initialized(self):
93
- assert genai is not None
48
+ @property
49
+ def api_endpoint(self) -> str:
94
50
  api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
95
51
  if not api_key:
96
52
  raise ValueError(
@@ -100,306 +56,76 @@ class GenAI(lf.LanguageModel):
100
56
  'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
101
57
  'for more details.'
102
58
  )
103
- genai.configure(api_key=api_key)
104
- return True
105
-
106
- @classmethod
107
- def dir(cls) -> list[str]:
108
- """Lists generative models."""
109
- assert genai is not None
110
- return [
111
- m.name.lstrip('models/')
112
- for m in genai.list_models()
113
- if (
114
- 'generateContent' in m.supported_generation_methods
115
- or 'generateText' in m.supported_generation_methods
116
- or 'generateMessage' in m.supported_generation_methods
117
- )
118
- ]
119
-
120
- @property
121
- def model_id(self) -> str:
122
- """Returns a string to identify the model."""
123
- return self.model
124
-
125
- @property
126
- def resource_id(self) -> str:
127
- """Returns a string to identify the resource for rate control."""
128
- return self.model_id
129
-
130
- def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
131
- """Creates generation config from langfun sampling options."""
132
- return GenerationConfig(
133
- candidate_count=options.n,
134
- temperature=options.temperature,
135
- top_p=options.top_p,
136
- top_k=options.top_k,
137
- max_output_tokens=options.max_tokens,
138
- stop_sequences=options.stop,
59
+ return (
60
+ f'https://generativelanguage.googleapis.com/{self.api_version}'
61
+ f'/models/{self.model}:generateContent?'
62
+ f'key={api_key}'
139
63
  )
140
64
 
141
- def _content_from_message(
142
- self, prompt: lf.Message
143
- ) -> list[str | BlobDict]:
144
- """Gets Evergreen formatted content from langfun message."""
145
- formatted = lf.UserMessage(prompt.text)
146
- formatted.source = prompt
147
-
148
- chunks = []
149
- for lf_chunk in formatted.chunk():
150
- if isinstance(lf_chunk, str):
151
- chunks.append(lf_chunk)
152
- elif isinstance(lf_chunk, lf_modalities.Mime):
153
- try:
154
- modalities = lf_chunk.make_compatible(
155
- self.supported_modalities + ['text/plain']
156
- )
157
- if isinstance(modalities, lf_modalities.Mime):
158
- modalities = [modalities]
159
- for modality in modalities:
160
- if modality.is_text:
161
- chunk = modality.to_text()
162
- else:
163
- chunk = BlobDict(
164
- data=modality.to_bytes(),
165
- mime_type=modality.mime_type
166
- )
167
- chunks.append(chunk)
168
- except lf.ModalityError as e:
169
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
170
- else:
171
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
172
- return chunks
173
-
174
- def _response_to_result(
175
- self, response: GenerateContentResponse | pg.Dict
176
- ) -> lf.LMSamplingResult:
177
- """Parses generative response into message."""
178
- samples = []
179
- for candidate in response.candidates:
180
- chunks = []
181
- for part in candidate.content.parts:
182
- # TODO(daiyip): support multi-modal parts when they are available via
183
- # Gemini API.
184
- if hasattr(part, 'text'):
185
- chunks.append(part.text)
186
- samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
187
- return lf.LMSamplingResult(samples)
188
-
189
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
- assert self._api_initialized, 'Vertex AI API is not initialized.'
191
- return self._parallel_execute_with_currency_control(
192
- self._sample_single,
193
- prompts,
194
- )
195
65
 
196
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
197
- """Samples a single prompt."""
198
- model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
199
- input_content = self._content_from_message(prompt)
200
- response = model.generate_content(
201
- input_content,
202
- generation_config=self._generation_config(self.sampling_options),
203
- )
204
- return self._response_to_result(response)
205
-
206
-
207
- class _LegacyGenerativeModel(pg.Object):
208
- """Base for legacy GenAI generative model."""
209
-
210
- model: str
211
-
212
- def generate_content(
213
- self,
214
- input_content: list[str | BlobDict],
215
- generation_config: GenerationConfig,
216
- ) -> pg.Dict:
217
- """Generate content."""
218
- segments = []
219
- for s in input_content:
220
- if not isinstance(s, str):
221
- raise ValueError(f'Unsupported modality: {s!r}')
222
- segments.append(s)
223
- return self.generate(' '.join(segments), generation_config)
224
-
225
- @abc.abstractmethod
226
- def generate(
227
- self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
228
- """Generate response based on prompt."""
229
-
230
-
231
- class _LegacyCompletionModel(_LegacyGenerativeModel):
232
- """Legacy GenAI completion model."""
233
-
234
- def generate(
235
- self, prompt: str, generation_config: GenerationConfig
236
- ) -> pg.Dict:
237
- assert genai is not None
238
- completion: Completion = genai.generate_text(
239
- model=f'models/{self.model}',
240
- prompt=prompt,
241
- temperature=generation_config.temperature,
242
- top_k=generation_config.top_k,
243
- top_p=generation_config.top_p,
244
- candidate_count=generation_config.candidate_count,
245
- max_output_tokens=generation_config.max_output_tokens,
246
- stop_sequences=generation_config.stop_sequences,
247
- )
248
- return pg.Dict(
249
- candidates=[
250
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
251
- for c in completion.candidates
252
- ]
253
- )
254
-
255
-
256
- class _LegacyChatModel(_LegacyGenerativeModel):
257
- """Legacy GenAI chat model."""
258
-
259
- def generate(
260
- self, prompt: str, generation_config: GenerationConfig
261
- ) -> pg.Dict:
262
- assert genai is not None
263
- response: ChatResponse = genai.chat(
264
- model=f'models/{self.model}',
265
- messages=prompt,
266
- temperature=generation_config.temperature,
267
- top_k=generation_config.top_k,
268
- top_p=generation_config.top_p,
269
- candidate_count=generation_config.candidate_count,
270
- )
271
- return pg.Dict(
272
- candidates=[
273
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
274
- for c in response.candidates
275
- ]
276
- )
277
-
278
-
279
- class _ModelHub:
280
- """Google Generative AI model hub."""
281
-
282
- def __init__(self):
283
- self._model_cache = {}
284
-
285
- def get(
286
- self, model_name: str
287
- ) -> GenerativeModel | _LegacyGenerativeModel:
288
- """Gets a generative model by model id."""
289
- assert genai is not None
290
- model = self._model_cache.get(model_name, None)
291
- if model is None:
292
- model_info = genai.get_model(f'models/{model_name}')
293
- if 'generateContent' in model_info.supported_generation_methods:
294
- model = genai.GenerativeModel(model_name)
295
- elif 'generateText' in model_info.supported_generation_methods:
296
- model = _LegacyCompletionModel(model_name)
297
- elif 'generateMessage' in model_info.supported_generation_methods:
298
- model = _LegacyChatModel(model_name)
299
- else:
300
- raise ValueError(f'Unsupported model: {model_name!r}')
301
- self._model_cache[model_name] = model
302
- return model
303
-
304
-
305
- _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
306
-
307
-
308
- #
309
- # Public Gemini models.
310
- #
311
- class GeminiFlash2_0ThinkingExp(GenAI): # pylint: disable=invalid-name
312
- """Gemini 2.0 Flash Thinking Experimental model."""
66
+ class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
67
+ """Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
313
68
 
69
+ api_version = 'v1alpha'
314
70
  model = 'gemini-2.0-flash-thinking-exp-1219'
315
- supported_modalities = (
316
- vertexai.DOCUMENT_TYPES
317
- + vertexai.IMAGE_TYPES
318
- + vertexai.AUDIO_TYPES
319
- + vertexai.VIDEO_TYPES
320
- )
71
+ timeout = None
321
72
 
322
73
 
323
74
  class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
324
- """Gemini Experimental model launched on 12/06/2024."""
75
+ """Gemini Flash 2.0 model launched on 12/11/2024."""
325
76
 
326
77
  model = 'gemini-2.0-flash-exp'
327
- supported_modalities = (
328
- vertexai.DOCUMENT_TYPES
329
- + vertexai.IMAGE_TYPES
330
- + vertexai.AUDIO_TYPES
331
- + vertexai.VIDEO_TYPES
332
- )
333
78
 
334
79
 
335
80
  class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
336
81
  """Gemini Experimental model launched on 12/06/2024."""
337
82
 
338
83
  model = 'gemini-exp-1206'
339
- supported_modalities = (
340
- vertexai.DOCUMENT_TYPES
341
- + vertexai.IMAGE_TYPES
342
- + vertexai.AUDIO_TYPES
343
- + vertexai.VIDEO_TYPES
344
- )
345
84
 
346
85
 
347
86
  class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
348
87
  """Gemini Experimental model launched on 11/14/2024."""
349
88
 
350
89
  model = 'gemini-exp-1114'
351
- supported_modalities = (
352
- vertexai.DOCUMENT_TYPES
353
- + vertexai.IMAGE_TYPES
354
- + vertexai.AUDIO_TYPES
355
- + vertexai.VIDEO_TYPES
356
- )
357
90
 
358
91
 
359
92
  class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
360
93
  """Gemini Pro latest model."""
361
94
 
362
95
  model = 'gemini-1.5-pro-latest'
363
- supported_modalities = (
364
- vertexai.DOCUMENT_TYPES
365
- + vertexai.IMAGE_TYPES
366
- + vertexai.AUDIO_TYPES
367
- + vertexai.VIDEO_TYPES
368
- )
369
96
 
370
97
 
371
- class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
372
- """Gemini Flash latest model."""
98
+ class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
99
+ """Gemini Pro latest model."""
100
+
101
+ model = 'gemini-1.5-pro-002'
373
102
 
374
- model = 'gemini-1.5-flash-latest'
375
- supported_modalities = (
376
- vertexai.DOCUMENT_TYPES
377
- + vertexai.IMAGE_TYPES
378
- + vertexai.AUDIO_TYPES
379
- + vertexai.VIDEO_TYPES
380
- )
381
103
 
104
+ class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
105
+ """Gemini Pro latest model."""
382
106
 
383
- class GeminiPro(GenAI):
384
- """Gemini Pro model."""
107
+ model = 'gemini-1.5-pro-001'
385
108
 
386
- model = 'gemini-pro'
109
+
110
+ class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
111
+ """Gemini Flash latest model."""
112
+
113
+ model = 'gemini-1.5-flash-latest'
387
114
 
388
115
 
389
- class GeminiProVision(GenAI):
390
- """Gemini Pro vision model."""
116
+ class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
117
+ """Gemini Flash 1.5 model stable version 002."""
391
118
 
392
- model = 'gemini-pro-vision'
393
- supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
119
+ model = 'gemini-1.5-flash-002'
394
120
 
395
121
 
396
- class Palm2(GenAI):
397
- """PaLM2 model."""
122
+ class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
123
+ """Gemini Flash 1.5 model stable version 001."""
398
124
 
399
- model = 'text-bison-001'
125
+ model = 'gemini-1.5-flash-001'
400
126
 
401
127
 
402
- class Palm2_IT(GenAI): # pylint: disable=invalid-name
403
- """PaLM2 instruction-tuned model."""
128
+ class GeminiPro1(GenAI): # pylint: disable=invalid-name
129
+ """Gemini 1.0 Pro model."""
404
130
 
405
- model = 'chat-bison-001'
131
+ model = 'gemini-1.0-pro'