langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,14 +13,12 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
- import base64
17
16
  import functools
18
17
  import os
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import rest
21
+ from langfun.core.llms import gemini
24
22
  import pyglove as pg
25
23
 
26
24
  try:
@@ -38,114 +36,11 @@ except ImportError:
38
36
  Credentials = Any
39
37
 
40
38
 
41
- # https://cloud.google.com/vertex-ai/generative-ai/pricing
42
- # describes that the average number of characters per token is about 4.
43
- AVGERAGE_CHARS_PER_TOKEN = 4
44
-
45
-
46
- # Price in US dollars,
47
- # from https://cloud.google.com/vertex-ai/generative-ai/pricing
48
- # as of 2024-10-10.
49
- SUPPORTED_MODELS_AND_SETTINGS = {
50
- 'gemini-1.5-pro-001': pg.Dict(
51
- rpm=100,
52
- cost_per_1k_input_chars=0.0003125,
53
- cost_per_1k_output_chars=0.00125,
54
- ),
55
- 'gemini-1.5-pro-002': pg.Dict(
56
- rpm=100,
57
- cost_per_1k_input_chars=0.0003125,
58
- cost_per_1k_output_chars=0.00125,
59
- ),
60
- 'gemini-1.5-flash-002': pg.Dict(
61
- rpm=500,
62
- cost_per_1k_input_chars=0.00001875,
63
- cost_per_1k_output_chars=0.000075,
64
- ),
65
- 'gemini-1.5-flash-001': pg.Dict(
66
- rpm=500,
67
- cost_per_1k_input_chars=0.00001875,
68
- cost_per_1k_output_chars=0.000075,
69
- ),
70
- 'gemini-1.5-pro': pg.Dict(
71
- rpm=100,
72
- cost_per_1k_input_chars=0.0003125,
73
- cost_per_1k_output_chars=0.00125,
74
- ),
75
- 'gemini-1.5-flash': pg.Dict(
76
- rpm=500,
77
- cost_per_1k_input_chars=0.00001875,
78
- cost_per_1k_output_chars=0.000075,
79
- ),
80
- 'gemini-1.5-pro-preview-0514': pg.Dict(
81
- rpm=50,
82
- cost_per_1k_input_chars=0.0003125,
83
- cost_per_1k_output_chars=0.00125,
84
- ),
85
- 'gemini-1.5-pro-preview-0409': pg.Dict(
86
- rpm=50,
87
- cost_per_1k_input_chars=0.0003125,
88
- cost_per_1k_output_chars=0.00125,
89
- ),
90
- 'gemini-1.5-flash-preview-0514': pg.Dict(
91
- rpm=200,
92
- cost_per_1k_input_chars=0.00001875,
93
- cost_per_1k_output_chars=0.000075,
94
- ),
95
- 'gemini-1.0-pro': pg.Dict(
96
- rpm=300,
97
- cost_per_1k_input_chars=0.000125,
98
- cost_per_1k_output_chars=0.000375,
99
- ),
100
- 'gemini-1.0-pro-vision': pg.Dict(
101
- rpm=100,
102
- cost_per_1k_input_chars=0.000125,
103
- cost_per_1k_output_chars=0.000375,
104
- ),
105
- # TODO(sharatsharat): Update costs when published
106
- 'gemini-exp-1206': pg.Dict(
107
- rpm=20,
108
- cost_per_1k_input_chars=0.000,
109
- cost_per_1k_output_chars=0.000,
110
- ),
111
- # TODO(sharatsharat): Update costs when published
112
- 'gemini-2.0-flash-exp': pg.Dict(
113
- rpm=10,
114
- cost_per_1k_input_chars=0.000,
115
- cost_per_1k_output_chars=0.000,
116
- ),
117
- # TODO(yifenglu): Update costs when published
118
- 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
119
- rpm=10,
120
- cost_per_1k_input_chars=0.000,
121
- cost_per_1k_output_chars=0.000,
122
- ),
123
- # TODO(chengrun): Set a more appropriate rpm for endpoint.
124
- 'vertexai-endpoint': pg.Dict(
125
- rpm=20,
126
- cost_per_1k_input_chars=0.0000125,
127
- cost_per_1k_output_chars=0.0000375,
128
- ),
129
- }
130
-
131
-
132
39
  @lf.use_init_args(['model'])
133
40
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
134
- class VertexAI(rest.REST):
41
+ class VertexAI(gemini.Gemini):
135
42
  """Language model served on VertexAI with REST API."""
136
43
 
137
- model: pg.typing.Annotated[
138
- pg.typing.Enum(
139
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
140
- ),
141
- (
142
- 'Vertex AI model name with REST API support. See '
143
- 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
144
- 'model-reference/inference#supported-models'
145
- ' for details.'
146
- ),
147
- ]
148
-
149
44
  project: Annotated[
150
45
  str | None,
151
46
  (
@@ -170,11 +65,6 @@ class VertexAI(rest.REST):
170
65
  ),
171
66
  ] = None
172
67
 
173
- supported_modalities: Annotated[
174
- list[str],
175
- 'A list of MIME types for supported modalities'
176
- ] = []
177
-
178
68
  def _on_bound(self):
179
69
  super()._on_bound()
180
70
  if google_auth is None:
@@ -209,31 +99,9 @@ class VertexAI(rest.REST):
209
99
  self._credentials = credentials
210
100
 
211
101
  @property
212
- def max_concurrency(self) -> int:
213
- """Returns the maximum number of concurrent requests."""
214
- return self.rate_to_max_concurrency(
215
- requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
216
- tokens_per_min=0,
217
- )
218
-
219
- def estimate_cost(
220
- self,
221
- num_input_tokens: int,
222
- num_output_tokens: int
223
- ) -> float | None:
224
- """Estimate the cost based on usage."""
225
- cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
226
- 'cost_per_1k_input_chars', None
227
- )
228
- cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
229
- 'cost_per_1k_output_chars', None
230
- )
231
- if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
232
- return None
233
- return (
234
- cost_per_1k_input_chars * num_input_tokens
235
- + cost_per_1k_output_chars * num_output_tokens
236
- ) * AVGERAGE_CHARS_PER_TOKEN / 1000
102
+ def model_id(self) -> str:
103
+ """Returns a string to identify the model."""
104
+ return f'VertexAI({self.model})'
237
105
 
238
106
  @functools.cached_property
239
107
  def _session(self):
@@ -244,12 +112,6 @@ class VertexAI(rest.REST):
244
112
  s.headers.update(self.headers or {})
245
113
  return s
246
114
 
247
- @property
248
- def headers(self):
249
- return {
250
- 'Content-Type': 'application/json; charset=utf-8',
251
- }
252
-
253
115
  @property
254
116
  def api_endpoint(self) -> str:
255
117
  return (
@@ -258,263 +120,70 @@ class VertexAI(rest.REST):
258
120
  f'models/{self.model}:generateContent'
259
121
  )
260
122
 
261
- def request(
262
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
263
- ) -> dict[str, Any]:
264
- request = dict(
265
- generationConfig=self._generation_config(prompt, sampling_options)
266
- )
267
- request['contents'] = [self._content_from_message(prompt)]
268
- return request
269
-
270
- def _generation_config(
271
- self, prompt: lf.Message, options: lf.LMSamplingOptions
272
- ) -> dict[str, Any]:
273
- """Returns a dict as generation config for prompt and LMSamplingOptions."""
274
- config = dict(
275
- temperature=options.temperature,
276
- maxOutputTokens=options.max_tokens,
277
- candidateCount=options.n,
278
- topK=options.top_k,
279
- topP=options.top_p,
280
- stopSequences=options.stop,
281
- seed=options.random_seed,
282
- responseLogprobs=options.logprobs,
283
- logprobs=options.top_logprobs,
284
- )
285
123
 
286
- if json_schema := prompt.metadata.get('json_schema'):
287
- if not isinstance(json_schema, dict):
288
- raise ValueError(
289
- f'`json_schema` must be a dict, got {json_schema!r}.'
290
- )
291
- json_schema = pg.to_json(json_schema)
292
- config['responseSchema'] = json_schema
293
- config['responseMimeType'] = 'application/json'
294
- prompt.metadata.formatted_text = (
295
- prompt.text
296
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
297
- + pg.to_json_str(json_schema, json_indent=2)
298
- )
299
- return config
300
-
301
- def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
302
- """Gets generation content from langfun message."""
303
- parts = []
304
- for lf_chunk in prompt.chunk():
305
- if isinstance(lf_chunk, str):
306
- parts.append({'text': lf_chunk})
307
- elif isinstance(lf_chunk, lf_modalities.Mime):
308
- try:
309
- modalities = lf_chunk.make_compatible(
310
- self.supported_modalities + ['text/plain']
311
- )
312
- if isinstance(modalities, lf_modalities.Mime):
313
- modalities = [modalities]
314
- for modality in modalities:
315
- if modality.is_text:
316
- parts.append({'text': modality.to_text()})
317
- else:
318
- parts.append({
319
- 'inlineData': {
320
- 'data': base64.b64encode(modality.to_bytes()).decode(),
321
- 'mimeType': modality.mime_type,
322
- }
323
- })
324
- except lf.ModalityError as e:
325
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
326
- else:
327
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
328
- return dict(role='user', parts=parts)
329
-
330
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
331
- messages = [
332
- self._message_from_content_parts(candidate['content']['parts'])
333
- for candidate in json['candidates']
334
- ]
335
- usage = json['usageMetadata']
336
- input_tokens = usage['promptTokenCount']
337
- output_tokens = usage['candidatesTokenCount']
338
- return lf.LMSamplingResult(
339
- [lf.LMSample(message) for message in messages],
340
- usage=lf.LMSamplingUsage(
341
- prompt_tokens=input_tokens,
342
- completion_tokens=output_tokens,
343
- total_tokens=input_tokens + output_tokens,
344
- estimated_cost=self.estimate_cost(
345
- num_input_tokens=input_tokens,
346
- num_output_tokens=output_tokens,
347
- ),
348
- ),
349
- )
124
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
125
+ """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
350
126
 
351
- def _message_from_content_parts(
352
- self, parts: list[dict[str, Any]]
353
- ) -> lf.Message:
354
- """Converts Vertex AI's content parts protocol to message."""
355
- chunks = []
356
- for part in parts:
357
- if text_part := part.get('text'):
358
- chunks.append(text_part)
359
- else:
360
- raise ValueError(f'Unsupported part: {part}')
361
- return lf.AIMessage.from_chunks(chunks)
362
-
363
-
364
- IMAGE_TYPES = [
365
- 'image/png',
366
- 'image/jpeg',
367
- 'image/webp',
368
- 'image/heic',
369
- 'image/heif',
370
- ]
371
-
372
- AUDIO_TYPES = [
373
- 'audio/aac',
374
- 'audio/flac',
375
- 'audio/mp3',
376
- 'audio/m4a',
377
- 'audio/mpeg',
378
- 'audio/mpga',
379
- 'audio/mp4',
380
- 'audio/opus',
381
- 'audio/pcm',
382
- 'audio/wav',
383
- 'audio/webm',
384
- ]
385
-
386
- VIDEO_TYPES = [
387
- 'video/mov',
388
- 'video/mpeg',
389
- 'video/mpegps',
390
- 'video/mpg',
391
- 'video/mp4',
392
- 'video/webm',
393
- 'video/wmv',
394
- 'video/x-flv',
395
- 'video/3gpp',
396
- 'video/quicktime',
397
- ]
398
-
399
- DOCUMENT_TYPES = [
400
- 'application/pdf',
401
- 'text/plain',
402
- 'text/csv',
403
- 'text/html',
404
- 'text/xml',
405
- 'text/x-script.python',
406
- 'application/json',
407
- ]
408
-
409
-
410
- class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
411
- """Vertex AI Gemini 2.0 model."""
412
-
413
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
414
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
415
- )
416
-
417
-
418
- class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
127
+ api_version = 'v1alpha'
128
+ model = 'gemini-2.0-flash-thinking-exp-1219'
129
+ timeout = None
130
+
131
+
132
+ class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
419
133
  """Vertex AI Gemini 2.0 Flash model."""
420
134
 
421
135
  model = 'gemini-2.0-flash-exp'
422
136
 
423
137
 
424
- class VertexAIGeminiFlash2_0ThinkingExp(VertexAIGemini2_0): # pylint: disable=invalid-name
425
- """Vertex AI Gemini 2.0 Flash model."""
138
+ class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
139
+ """Vertex AI Gemini Experimental model launched on 12/06/2024."""
426
140
 
427
- model = 'gemini-2.0-flash-thinking-exp-1219'
141
+ model = 'gemini-exp-1206'
428
142
 
429
143
 
430
- class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
431
- """Vertex AI Gemini 1.5 model."""
144
+ class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
145
+ """Vertex AI Gemini Experimental model launched on 11/14/2024."""
432
146
 
433
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
434
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
435
- )
147
+ model = 'gemini-exp-1114'
436
148
 
437
149
 
438
- class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
150
+ class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
439
151
  """Vertex AI Gemini 1.5 Pro model."""
440
152
 
441
- model = 'gemini-1.5-pro'
153
+ model = 'gemini-1.5-pro-latest'
442
154
 
443
155
 
444
- class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
156
+ class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
445
157
  """Vertex AI Gemini 1.5 Pro model."""
446
158
 
447
159
  model = 'gemini-1.5-pro-002'
448
160
 
449
161
 
450
- class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
162
+ class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
451
163
  """Vertex AI Gemini 1.5 Pro model."""
452
164
 
453
165
  model = 'gemini-1.5-pro-001'
454
166
 
455
167
 
456
- class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
457
- """Vertex AI Gemini 1.5 Pro preview model."""
458
-
459
- model = 'gemini-1.5-pro-preview-0514'
460
-
461
-
462
- class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name
463
- """Vertex AI Gemini 1.5 Pro preview model."""
464
-
465
- model = 'gemini-1.5-pro-preview-0409'
466
-
467
-
468
- class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
168
+ class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
469
169
  """Vertex AI Gemini 1.5 Flash model."""
470
170
 
471
171
  model = 'gemini-1.5-flash'
472
172
 
473
173
 
474
- class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
174
+ class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
475
175
  """Vertex AI Gemini 1.5 Flash model."""
476
176
 
477
177
  model = 'gemini-1.5-flash-002'
478
178
 
479
179
 
480
- class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
180
+ class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
481
181
  """Vertex AI Gemini 1.5 Flash model."""
482
182
 
483
183
  model = 'gemini-1.5-flash-001'
484
184
 
485
185
 
486
- class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
487
- """Vertex AI Gemini 1.5 Flash preview model."""
488
-
489
- model = 'gemini-1.5-flash-preview-0514'
490
-
491
-
492
186
  class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
493
187
  """Vertex AI Gemini 1.0 Pro model."""
494
188
 
495
189
  model = 'gemini-1.0-pro'
496
-
497
-
498
- class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
499
- """Vertex AI Gemini 1.0 Pro Vision model."""
500
-
501
- model = 'gemini-1.0-pro-vision'
502
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
503
- IMAGE_TYPES + VIDEO_TYPES
504
- )
505
-
506
-
507
- class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
508
- """Vertex AI Endpoint model."""
509
-
510
- model = 'vertexai-endpoint'
511
-
512
- endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
513
-
514
- @property
515
- def api_endpoint(self) -> str:
516
- return (
517
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
518
- f'{self.project}/locations/{self.location}/'
519
- f'endpoints/{self.endpoint}:generateContent'
520
- )
@@ -11,105 +11,18 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for VertexAI models."""
15
15
 
16
- import base64
17
16
  import os
18
- from typing import Any
19
17
  import unittest
20
18
  from unittest import mock
21
19
 
22
- import langfun.core as lf
23
- from langfun.core import modalities as lf_modalities
24
20
  from langfun.core.llms import vertexai
25
- import pyglove as pg
26
- import requests
27
-
28
-
29
- example_image = (
30
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
31
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
32
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
33
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
34
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
35
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
36
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
37
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
38
- )
39
-
40
-
41
- def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
42
- del url, kwargs
43
- c = pg.Dict(json['generationConfig'])
44
- content = json['contents'][0]['parts'][0]['text']
45
- response = requests.Response()
46
- response.status_code = 200
47
- response._content = pg.to_json_str({
48
- 'candidates': [
49
- {
50
- 'content': {
51
- 'role': 'model',
52
- 'parts': [
53
- {
54
- 'text': (
55
- f'This is a response to {content} with '
56
- f'temperature={c.temperature}, '
57
- f'top_p={c.topP}, '
58
- f'top_k={c.topK}, '
59
- f'max_tokens={c.maxOutputTokens}, '
60
- f'stop={"".join(c.stopSequences)}.'
61
- )
62
- },
63
- ],
64
- },
65
- },
66
- ],
67
- 'usageMetadata': {
68
- 'promptTokenCount': 3,
69
- 'candidatesTokenCount': 4,
70
- }
71
- }).encode()
72
- return response
73
21
 
74
22
 
75
23
  class VertexAITest(unittest.TestCase):
76
24
  """Tests for Vertex model with REST API."""
77
25
 
78
- def test_content_from_message_text_only(self):
79
- text = 'This is a beautiful day'
80
- model = vertexai.VertexAIGeminiPro1_5_002()
81
- chunks = model._content_from_message(lf.UserMessage(text))
82
- self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
83
-
84
- def test_content_from_message_mm(self):
85
- image = lf_modalities.Image.from_bytes(example_image)
86
- message = lf.UserMessage(
87
- 'This is an <<[[image]]>>, what is it?', image=image
88
- )
89
-
90
- # Non-multimodal model.
91
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
92
- vertexai.VertexAIGeminiPro1()._content_from_message(message)
93
-
94
- model = vertexai.VertexAIGeminiPro1Vision()
95
- content = model._content_from_message(message)
96
- self.assertEqual(
97
- content,
98
- {
99
- 'role': 'user',
100
- 'parts': [
101
- {'text': 'This is an'},
102
- {
103
- 'inlineData': {
104
- 'data': base64.b64encode(example_image).decode(),
105
- 'mimeType': 'image/png',
106
- }
107
- },
108
- {'text': ', what is it?'},
109
- ],
110
- },
111
- )
112
-
113
26
  @mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
114
27
  def test_project_and_location_check(self):
115
28
  with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
@@ -126,87 +39,14 @@ class VertexAITest(unittest.TestCase):
126
39
 
127
40
  os.environ['VERTEXAI_PROJECT'] = 'abc'
128
41
  os.environ['VERTEXAI_LOCATION'] = 'us-central1'
129
- self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
42
+ model = vertexai.VertexAIGeminiPro1()
43
+ self.assertTrue(model.model_id.startswith('VertexAI('))
44
+ self.assertIsNotNone(model.api_endpoint)
45
+ self.assertTrue(model._api_initialized)
46
+ self.assertIsNotNone(model._session)
130
47
  del os.environ['VERTEXAI_PROJECT']
131
48
  del os.environ['VERTEXAI_LOCATION']
132
49
 
133
- def test_generation_config(self):
134
- model = vertexai.VertexAIGeminiPro1()
135
- json_schema = {
136
- 'type': 'object',
137
- 'properties': {
138
- 'name': {'type': 'string'},
139
- },
140
- 'required': ['name'],
141
- 'title': 'Person',
142
- }
143
- actual = model._generation_config(
144
- lf.UserMessage('hi', json_schema=json_schema),
145
- lf.LMSamplingOptions(
146
- temperature=2.0,
147
- top_p=1.0,
148
- top_k=20,
149
- max_tokens=1024,
150
- stop=['\n'],
151
- ),
152
- )
153
- self.assertEqual(
154
- actual,
155
- dict(
156
- candidateCount=1,
157
- temperature=2.0,
158
- topP=1.0,
159
- topK=20,
160
- maxOutputTokens=1024,
161
- stopSequences=['\n'],
162
- responseLogprobs=False,
163
- logprobs=None,
164
- seed=None,
165
- responseMimeType='application/json',
166
- responseSchema={
167
- 'type': 'object',
168
- 'properties': {
169
- 'name': {'type': 'string'}
170
- },
171
- 'required': ['name'],
172
- 'title': 'Person',
173
- }
174
- ),
175
- )
176
- with self.assertRaisesRegex(
177
- ValueError, '`json_schema` must be a dict, got'
178
- ):
179
- model._generation_config(
180
- lf.UserMessage('hi', json_schema='not a dict'),
181
- lf.LMSamplingOptions(),
182
- )
183
-
184
- @mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
185
- def test_call_model(self):
186
- with mock.patch('requests.Session.post') as mock_generate:
187
- mock_generate.side_effect = mock_requests_post
188
-
189
- lm = vertexai.VertexAIGeminiPro1_5_002(
190
- project='abc', location='us-central1'
191
- )
192
- r = lm(
193
- 'hello',
194
- temperature=2.0,
195
- top_p=1.0,
196
- top_k=20,
197
- max_tokens=1024,
198
- stop='\n',
199
- )
200
- self.assertEqual(
201
- r.text,
202
- (
203
- 'This is a response to hello with temperature=2.0, '
204
- 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
205
- ),
206
- )
207
- self.assertEqual(r.metadata.usage.prompt_tokens, 3)
208
- self.assertEqual(r.metadata.usage.completion_tokens, 4)
209
-
210
50
 
211
51
  if __name__ == '__main__':
212
52
  unittest.main()