langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. langfun/core/__init__.py +0 -4
  2. langfun/core/eval/matching.py +2 -2
  3. langfun/core/eval/scoring.py +6 -2
  4. langfun/core/eval/v2/checkpointing.py +106 -72
  5. langfun/core/eval/v2/checkpointing_test.py +108 -3
  6. langfun/core/eval/v2/eval_test_helper.py +56 -0
  7. langfun/core/eval/v2/evaluation.py +25 -4
  8. langfun/core/eval/v2/evaluation_test.py +11 -0
  9. langfun/core/eval/v2/example.py +11 -1
  10. langfun/core/eval/v2/example_test.py +16 -2
  11. langfun/core/eval/v2/experiment.py +83 -19
  12. langfun/core/eval/v2/experiment_test.py +121 -3
  13. langfun/core/eval/v2/reporting.py +67 -20
  14. langfun/core/eval/v2/reporting_test.py +119 -2
  15. langfun/core/eval/v2/runners.py +7 -4
  16. langfun/core/llms/__init__.py +23 -24
  17. langfun/core/llms/anthropic.py +12 -0
  18. langfun/core/llms/cache/in_memory.py +6 -0
  19. langfun/core/llms/cache/in_memory_test.py +5 -0
  20. langfun/core/llms/gemini.py +507 -0
  21. langfun/core/llms/gemini_test.py +195 -0
  22. langfun/core/llms/google_genai.py +46 -310
  23. langfun/core/llms/google_genai_test.py +9 -204
  24. langfun/core/llms/openai.py +23 -37
  25. langfun/core/llms/vertexai.py +28 -348
  26. langfun/core/llms/vertexai_test.py +6 -166
  27. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
  28. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
  29. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
  30. langfun/core/repr_utils.py +0 -204
  31. langfun/core/repr_utils_test.py +0 -90
  32. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
  33. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,12 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
- import base64
17
16
  import functools
18
17
  import os
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import rest
21
+ from langfun.core.llms import gemini
24
22
  import pyglove as pg
25
23
 
26
24
  try:
@@ -38,108 +36,11 @@ except ImportError:
38
36
  Credentials = Any
39
37
 
40
38
 
41
- # https://cloud.google.com/vertex-ai/generative-ai/pricing
42
- # describes that the average number of characters per token is about 4.
43
- AVGERAGE_CHARS_PER_TOKEN = 4
44
-
45
-
46
- # Price in US dollars,
47
- # from https://cloud.google.com/vertex-ai/generative-ai/pricing
48
- # as of 2024-10-10.
49
- SUPPORTED_MODELS_AND_SETTINGS = {
50
- 'gemini-1.5-pro-001': pg.Dict(
51
- rpm=100,
52
- cost_per_1k_input_chars=0.0003125,
53
- cost_per_1k_output_chars=0.00125,
54
- ),
55
- 'gemini-1.5-pro-002': pg.Dict(
56
- rpm=100,
57
- cost_per_1k_input_chars=0.0003125,
58
- cost_per_1k_output_chars=0.00125,
59
- ),
60
- 'gemini-1.5-flash-002': pg.Dict(
61
- rpm=500,
62
- cost_per_1k_input_chars=0.00001875,
63
- cost_per_1k_output_chars=0.000075,
64
- ),
65
- 'gemini-1.5-flash-001': pg.Dict(
66
- rpm=500,
67
- cost_per_1k_input_chars=0.00001875,
68
- cost_per_1k_output_chars=0.000075,
69
- ),
70
- 'gemini-1.5-pro': pg.Dict(
71
- rpm=100,
72
- cost_per_1k_input_chars=0.0003125,
73
- cost_per_1k_output_chars=0.00125,
74
- ),
75
- 'gemini-1.5-flash': pg.Dict(
76
- rpm=500,
77
- cost_per_1k_input_chars=0.00001875,
78
- cost_per_1k_output_chars=0.000075,
79
- ),
80
- 'gemini-1.5-pro-preview-0514': pg.Dict(
81
- rpm=50,
82
- cost_per_1k_input_chars=0.0003125,
83
- cost_per_1k_output_chars=0.00125,
84
- ),
85
- 'gemini-1.5-pro-preview-0409': pg.Dict(
86
- rpm=50,
87
- cost_per_1k_input_chars=0.0003125,
88
- cost_per_1k_output_chars=0.00125,
89
- ),
90
- 'gemini-1.5-flash-preview-0514': pg.Dict(
91
- rpm=200,
92
- cost_per_1k_input_chars=0.00001875,
93
- cost_per_1k_output_chars=0.000075,
94
- ),
95
- 'gemini-1.0-pro': pg.Dict(
96
- rpm=300,
97
- cost_per_1k_input_chars=0.000125,
98
- cost_per_1k_output_chars=0.000375,
99
- ),
100
- 'gemini-1.0-pro-vision': pg.Dict(
101
- rpm=100,
102
- cost_per_1k_input_chars=0.000125,
103
- cost_per_1k_output_chars=0.000375,
104
- ),
105
- # TODO(sharatsharat): Update costs when published
106
- 'gemini-exp-1206': pg.Dict(
107
- rpm=20,
108
- cost_per_1k_input_chars=0.000,
109
- cost_per_1k_output_chars=0.000,
110
- ),
111
- # TODO(sharatsharat): Update costs when published
112
- 'gemini-2.0-flash-exp': pg.Dict(
113
- rpm=20,
114
- cost_per_1k_input_chars=0.000,
115
- cost_per_1k_output_chars=0.000,
116
- ),
117
- # TODO(chengrun): Set a more appropriate rpm for endpoint.
118
- 'vertexai-endpoint': pg.Dict(
119
- rpm=20,
120
- cost_per_1k_input_chars=0.0000125,
121
- cost_per_1k_output_chars=0.0000375,
122
- ),
123
- }
124
-
125
-
126
39
  @lf.use_init_args(['model'])
127
40
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
128
- class VertexAI(rest.REST):
41
+ class VertexAI(gemini.Gemini):
129
42
  """Language model served on VertexAI with REST API."""
130
43
 
131
- model: pg.typing.Annotated[
132
- pg.typing.Enum(
133
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
134
- ),
135
- (
136
- 'Vertex AI model name with REST API support. See '
137
- 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
138
- 'model-reference/inference#supported-models'
139
- ' for details.'
140
- ),
141
- ]
142
-
143
44
  project: Annotated[
144
45
  str | None,
145
46
  (
@@ -164,11 +65,6 @@ class VertexAI(rest.REST):
164
65
  ),
165
66
  ] = None
166
67
 
167
- supported_modalities: Annotated[
168
- list[str],
169
- 'A list of MIME types for supported modalities'
170
- ] = []
171
-
172
68
  def _on_bound(self):
173
69
  super()._on_bound()
174
70
  if google_auth is None:
@@ -203,31 +99,9 @@ class VertexAI(rest.REST):
203
99
  self._credentials = credentials
204
100
 
205
101
  @property
206
- def max_concurrency(self) -> int:
207
- """Returns the maximum number of concurrent requests."""
208
- return self.rate_to_max_concurrency(
209
- requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
210
- tokens_per_min=0,
211
- )
212
-
213
- def estimate_cost(
214
- self,
215
- num_input_tokens: int,
216
- num_output_tokens: int
217
- ) -> float | None:
218
- """Estimate the cost based on usage."""
219
- cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
220
- 'cost_per_1k_input_chars', None
221
- )
222
- cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
223
- 'cost_per_1k_output_chars', None
224
- )
225
- if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
226
- return None
227
- return (
228
- cost_per_1k_input_chars * num_input_tokens
229
- + cost_per_1k_output_chars * num_output_tokens
230
- ) * AVGERAGE_CHARS_PER_TOKEN / 1000
102
+ def model_id(self) -> str:
103
+ """Returns a string to identify the model."""
104
+ return f'VertexAI({self.model})'
231
105
 
232
106
  @functools.cached_property
233
107
  def _session(self):
@@ -238,12 +112,6 @@ class VertexAI(rest.REST):
238
112
  s.headers.update(self.headers or {})
239
113
  return s
240
114
 
241
- @property
242
- def headers(self):
243
- return {
244
- 'Content-Type': 'application/json; charset=utf-8',
245
- }
246
-
247
115
  @property
248
116
  def api_endpoint(self) -> str:
249
117
  return (
@@ -252,257 +120,69 @@ class VertexAI(rest.REST):
252
120
  f'models/{self.model}:generateContent'
253
121
  )
254
122
 
255
- def request(
256
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
257
- ) -> dict[str, Any]:
258
- request = dict(
259
- generationConfig=self._generation_config(prompt, sampling_options)
260
- )
261
- request['contents'] = [self._content_from_message(prompt)]
262
- return request
263
-
264
- def _generation_config(
265
- self, prompt: lf.Message, options: lf.LMSamplingOptions
266
- ) -> dict[str, Any]:
267
- """Returns a dict as generation config for prompt and LMSamplingOptions."""
268
- config = dict(
269
- temperature=options.temperature,
270
- maxOutputTokens=options.max_tokens,
271
- candidateCount=options.n,
272
- topK=options.top_k,
273
- topP=options.top_p,
274
- stopSequences=options.stop,
275
- seed=options.random_seed,
276
- responseLogprobs=options.logprobs,
277
- logprobs=options.top_logprobs,
278
- )
279
123
 
280
- if json_schema := prompt.metadata.get('json_schema'):
281
- if not isinstance(json_schema, dict):
282
- raise ValueError(
283
- f'`json_schema` must be a dict, got {json_schema!r}.'
284
- )
285
- json_schema = pg.to_json(json_schema)
286
- config['responseSchema'] = json_schema
287
- config['responseMimeType'] = 'application/json'
288
- prompt.metadata.formatted_text = (
289
- prompt.text
290
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
291
- + pg.to_json_str(json_schema, json_indent=2)
292
- )
293
- return config
294
-
295
- def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
296
- """Gets generation content from langfun message."""
297
- parts = []
298
- for lf_chunk in prompt.chunk():
299
- if isinstance(lf_chunk, str):
300
- parts.append({'text': lf_chunk})
301
- elif isinstance(lf_chunk, lf_modalities.Mime):
302
- try:
303
- modalities = lf_chunk.make_compatible(
304
- self.supported_modalities + ['text/plain']
305
- )
306
- if isinstance(modalities, lf_modalities.Mime):
307
- modalities = [modalities]
308
- for modality in modalities:
309
- if modality.is_text:
310
- parts.append({'text': modality.to_text()})
311
- else:
312
- parts.append({
313
- 'inlineData': {
314
- 'data': base64.b64encode(modality.to_bytes()).decode(),
315
- 'mimeType': modality.mime_type,
316
- }
317
- })
318
- except lf.ModalityError as e:
319
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
320
- else:
321
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
322
- return dict(role='user', parts=parts)
323
-
324
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
325
- messages = [
326
- self._message_from_content_parts(candidate['content']['parts'])
327
- for candidate in json['candidates']
328
- ]
329
- usage = json['usageMetadata']
330
- input_tokens = usage['promptTokenCount']
331
- output_tokens = usage['candidatesTokenCount']
332
- return lf.LMSamplingResult(
333
- [lf.LMSample(message) for message in messages],
334
- usage=lf.LMSamplingUsage(
335
- prompt_tokens=input_tokens,
336
- completion_tokens=output_tokens,
337
- total_tokens=input_tokens + output_tokens,
338
- estimated_cost=self.estimate_cost(
339
- num_input_tokens=input_tokens,
340
- num_output_tokens=output_tokens,
341
- ),
342
- ),
343
- )
124
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
125
+ """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
344
126
 
345
- def _message_from_content_parts(
346
- self, parts: list[dict[str, Any]]
347
- ) -> lf.Message:
348
- """Converts Vertex AI's content parts protocol to message."""
349
- chunks = []
350
- for part in parts:
351
- if text_part := part.get('text'):
352
- chunks.append(text_part)
353
- else:
354
- raise ValueError(f'Unsupported part: {part}')
355
- return lf.AIMessage.from_chunks(chunks)
356
-
357
-
358
- IMAGE_TYPES = [
359
- 'image/png',
360
- 'image/jpeg',
361
- 'image/webp',
362
- 'image/heic',
363
- 'image/heif',
364
- ]
365
-
366
- AUDIO_TYPES = [
367
- 'audio/aac',
368
- 'audio/flac',
369
- 'audio/mp3',
370
- 'audio/m4a',
371
- 'audio/mpeg',
372
- 'audio/mpga',
373
- 'audio/mp4',
374
- 'audio/opus',
375
- 'audio/pcm',
376
- 'audio/wav',
377
- 'audio/webm',
378
- ]
379
-
380
- VIDEO_TYPES = [
381
- 'video/mov',
382
- 'video/mpeg',
383
- 'video/mpegps',
384
- 'video/mpg',
385
- 'video/mp4',
386
- 'video/webm',
387
- 'video/wmv',
388
- 'video/x-flv',
389
- 'video/3gpp',
390
- 'video/quicktime',
391
- ]
392
-
393
- DOCUMENT_TYPES = [
394
- 'application/pdf',
395
- 'text/plain',
396
- 'text/csv',
397
- 'text/html',
398
- 'text/xml',
399
- 'text/x-script.python',
400
- 'application/json',
401
- ]
402
-
403
-
404
- class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
405
- """Vertex AI Gemini 2.0 model."""
406
-
407
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
408
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
409
- )
410
-
411
-
412
- class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
127
+ api_version = 'v1alpha'
128
+ model = 'gemini-2.0-flash-thinking-exp-1219'
129
+
130
+
131
+ class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
413
132
  """Vertex AI Gemini 2.0 Flash model."""
414
133
 
415
134
  model = 'gemini-2.0-flash-exp'
416
135
 
417
136
 
418
- class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
419
- """Vertex AI Gemini 1.5 model."""
137
+ class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
138
+ """Vertex AI Gemini Experimental model launched on 12/06/2024."""
420
139
 
421
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
422
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
423
- )
140
+ model = 'gemini-exp-1206'
424
141
 
425
142
 
426
- class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
427
- """Vertex AI Gemini 1.5 Pro model."""
143
+ class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
144
+ """Vertex AI Gemini Experimental model launched on 11/14/2024."""
428
145
 
429
- model = 'gemini-1.5-pro'
146
+ model = 'gemini-exp-1114'
430
147
 
431
148
 
432
- class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
149
+ class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
433
150
  """Vertex AI Gemini 1.5 Pro model."""
434
151
 
435
- model = 'gemini-1.5-pro-002'
152
+ model = 'gemini-1.5-pro-latest'
436
153
 
437
154
 
438
- class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
155
+ class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
439
156
  """Vertex AI Gemini 1.5 Pro model."""
440
157
 
441
- model = 'gemini-1.5-pro-001'
442
-
443
-
444
- class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
445
- """Vertex AI Gemini 1.5 Pro preview model."""
446
-
447
- model = 'gemini-1.5-pro-preview-0514'
158
+ model = 'gemini-1.5-pro-002'
448
159
 
449
160
 
450
- class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name
451
- """Vertex AI Gemini 1.5 Pro preview model."""
161
+ class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
162
+ """Vertex AI Gemini 1.5 Pro model."""
452
163
 
453
- model = 'gemini-1.5-pro-preview-0409'
164
+ model = 'gemini-1.5-pro-001'
454
165
 
455
166
 
456
- class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
167
+ class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
457
168
  """Vertex AI Gemini 1.5 Flash model."""
458
169
 
459
170
  model = 'gemini-1.5-flash'
460
171
 
461
172
 
462
- class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
173
+ class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
463
174
  """Vertex AI Gemini 1.5 Flash model."""
464
175
 
465
176
  model = 'gemini-1.5-flash-002'
466
177
 
467
178
 
468
- class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
179
+ class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
469
180
  """Vertex AI Gemini 1.5 Flash model."""
470
181
 
471
182
  model = 'gemini-1.5-flash-001'
472
183
 
473
184
 
474
- class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
475
- """Vertex AI Gemini 1.5 Flash preview model."""
476
-
477
- model = 'gemini-1.5-flash-preview-0514'
478
-
479
-
480
185
  class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
481
186
  """Vertex AI Gemini 1.0 Pro model."""
482
187
 
483
188
  model = 'gemini-1.0-pro'
484
-
485
-
486
- class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
487
- """Vertex AI Gemini 1.0 Pro Vision model."""
488
-
489
- model = 'gemini-1.0-pro-vision'
490
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
491
- IMAGE_TYPES + VIDEO_TYPES
492
- )
493
-
494
-
495
- class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
496
- """Vertex AI Endpoint model."""
497
-
498
- model = 'vertexai-endpoint'
499
-
500
- endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
501
-
502
- @property
503
- def api_endpoint(self) -> str:
504
- return (
505
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
506
- f'{self.project}/locations/{self.location}/'
507
- f'endpoints/{self.endpoint}:generateContent'
508
- )
@@ -11,105 +11,18 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for VertexAI models."""
15
15
 
16
- import base64
17
16
  import os
18
- from typing import Any
19
17
  import unittest
20
18
  from unittest import mock
21
19
 
22
- import langfun.core as lf
23
- from langfun.core import modalities as lf_modalities
24
20
  from langfun.core.llms import vertexai
25
- import pyglove as pg
26
- import requests
27
-
28
-
29
- example_image = (
30
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
31
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
32
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
33
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
34
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
35
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
36
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
37
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
38
- )
39
-
40
-
41
- def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
42
- del url, kwargs
43
- c = pg.Dict(json['generationConfig'])
44
- content = json['contents'][0]['parts'][0]['text']
45
- response = requests.Response()
46
- response.status_code = 200
47
- response._content = pg.to_json_str({
48
- 'candidates': [
49
- {
50
- 'content': {
51
- 'role': 'model',
52
- 'parts': [
53
- {
54
- 'text': (
55
- f'This is a response to {content} with '
56
- f'temperature={c.temperature}, '
57
- f'top_p={c.topP}, '
58
- f'top_k={c.topK}, '
59
- f'max_tokens={c.maxOutputTokens}, '
60
- f'stop={"".join(c.stopSequences)}.'
61
- )
62
- },
63
- ],
64
- },
65
- },
66
- ],
67
- 'usageMetadata': {
68
- 'promptTokenCount': 3,
69
- 'candidatesTokenCount': 4,
70
- }
71
- }).encode()
72
- return response
73
21
 
74
22
 
75
23
  class VertexAITest(unittest.TestCase):
76
24
  """Tests for Vertex model with REST API."""
77
25
 
78
- def test_content_from_message_text_only(self):
79
- text = 'This is a beautiful day'
80
- model = vertexai.VertexAIGeminiPro1_5_002()
81
- chunks = model._content_from_message(lf.UserMessage(text))
82
- self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
83
-
84
- def test_content_from_message_mm(self):
85
- image = lf_modalities.Image.from_bytes(example_image)
86
- message = lf.UserMessage(
87
- 'This is an <<[[image]]>>, what is it?', image=image
88
- )
89
-
90
- # Non-multimodal model.
91
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
92
- vertexai.VertexAIGeminiPro1()._content_from_message(message)
93
-
94
- model = vertexai.VertexAIGeminiPro1Vision()
95
- content = model._content_from_message(message)
96
- self.assertEqual(
97
- content,
98
- {
99
- 'role': 'user',
100
- 'parts': [
101
- {'text': 'This is an'},
102
- {
103
- 'inlineData': {
104
- 'data': base64.b64encode(example_image).decode(),
105
- 'mimeType': 'image/png',
106
- }
107
- },
108
- {'text': ', what is it?'},
109
- ],
110
- },
111
- )
112
-
113
26
  @mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
114
27
  def test_project_and_location_check(self):
115
28
  with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
@@ -126,87 +39,14 @@ class VertexAITest(unittest.TestCase):
126
39
 
127
40
  os.environ['VERTEXAI_PROJECT'] = 'abc'
128
41
  os.environ['VERTEXAI_LOCATION'] = 'us-central1'
129
- self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
42
+ model = vertexai.VertexAIGeminiPro1()
43
+ self.assertTrue(model.model_id.startswith('VertexAI('))
44
+ self.assertIsNotNone(model.api_endpoint)
45
+ self.assertTrue(model._api_initialized)
46
+ self.assertIsNotNone(model._session)
130
47
  del os.environ['VERTEXAI_PROJECT']
131
48
  del os.environ['VERTEXAI_LOCATION']
132
49
 
133
- def test_generation_config(self):
134
- model = vertexai.VertexAIGeminiPro1()
135
- json_schema = {
136
- 'type': 'object',
137
- 'properties': {
138
- 'name': {'type': 'string'},
139
- },
140
- 'required': ['name'],
141
- 'title': 'Person',
142
- }
143
- actual = model._generation_config(
144
- lf.UserMessage('hi', json_schema=json_schema),
145
- lf.LMSamplingOptions(
146
- temperature=2.0,
147
- top_p=1.0,
148
- top_k=20,
149
- max_tokens=1024,
150
- stop=['\n'],
151
- ),
152
- )
153
- self.assertEqual(
154
- actual,
155
- dict(
156
- candidateCount=1,
157
- temperature=2.0,
158
- topP=1.0,
159
- topK=20,
160
- maxOutputTokens=1024,
161
- stopSequences=['\n'],
162
- responseLogprobs=False,
163
- logprobs=None,
164
- seed=None,
165
- responseMimeType='application/json',
166
- responseSchema={
167
- 'type': 'object',
168
- 'properties': {
169
- 'name': {'type': 'string'}
170
- },
171
- 'required': ['name'],
172
- 'title': 'Person',
173
- }
174
- ),
175
- )
176
- with self.assertRaisesRegex(
177
- ValueError, '`json_schema` must be a dict, got'
178
- ):
179
- model._generation_config(
180
- lf.UserMessage('hi', json_schema='not a dict'),
181
- lf.LMSamplingOptions(),
182
- )
183
-
184
- @mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
185
- def test_call_model(self):
186
- with mock.patch('requests.Session.post') as mock_generate:
187
- mock_generate.side_effect = mock_requests_post
188
-
189
- lm = vertexai.VertexAIGeminiPro1_5_002(
190
- project='abc', location='us-central1'
191
- )
192
- r = lm(
193
- 'hello',
194
- temperature=2.0,
195
- top_p=1.0,
196
- top_k=20,
197
- max_tokens=1024,
198
- stop='\n',
199
- )
200
- self.assertEqual(
201
- r.text,
202
- (
203
- 'This is a response to hello with temperature=2.0, '
204
- 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
205
- ),
206
- )
207
- self.assertEqual(r.metadata.usage.prompt_tokens, 3)
208
- self.assertEqual(r.metadata.usage.completion_tokens, 4)
209
-
210
50
 
211
51
  if __name__ == '__main__':
212
52
  unittest.main()