google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +160 -59
- google/genai/_api_module.py +6 -1
- google/genai/_automatic_function_calling_util.py +12 -12
- google/genai/_common.py +14 -2
- google/genai/_extra_utils.py +14 -8
- google/genai/_replay_api_client.py +35 -3
- google/genai/_test_api_client.py +8 -8
- google/genai/_transformers.py +169 -48
- google/genai/batches.py +176 -127
- google/genai/caches.py +315 -214
- google/genai/chats.py +179 -35
- google/genai/client.py +16 -6
- google/genai/errors.py +19 -5
- google/genai/files.py +161 -115
- google/genai/live.py +137 -105
- google/genai/models.py +1553 -734
- google/genai/operations.py +635 -0
- google/genai/pagers.py +5 -5
- google/genai/tunings.py +166 -103
- google/genai/types.py +590 -142
- google/genai/version.py +1 -1
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/METADATA +94 -12
- google_genai-1.4.0.dist-info/RECORD +27 -0
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/WHEEL +1 -1
- google/genai/_operations.py +0 -365
- google_genai-1.2.0.dist-info/RECORD +0 -27
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/LICENSE +0 -0
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
@@ -20,7 +20,7 @@ import base64
|
|
20
20
|
import contextlib
|
21
21
|
import json
|
22
22
|
import logging
|
23
|
-
from typing import AsyncIterator, Optional, Sequence, Union
|
23
|
+
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union
|
24
24
|
|
25
25
|
import google.auth
|
26
26
|
from websockets import ConnectionClosed
|
@@ -31,7 +31,7 @@ from . import _transformers as t
|
|
31
31
|
from . import client
|
32
32
|
from . import errors
|
33
33
|
from . import types
|
34
|
-
from ._api_client import
|
34
|
+
from ._api_client import BaseApiClient
|
35
35
|
from ._common import experimental_warning
|
36
36
|
from ._common import get_value_by_path as getv
|
37
37
|
from ._common import set_value_by_path as setv
|
@@ -49,12 +49,14 @@ from .models import _Tool_to_mldev
|
|
49
49
|
from .models import _Tool_to_vertex
|
50
50
|
|
51
51
|
try:
|
52
|
-
from websockets.asyncio.client import ClientConnection
|
53
|
-
from websockets.asyncio.client import connect
|
52
|
+
from websockets.asyncio.client import ClientConnection # type: ignore
|
53
|
+
from websockets.asyncio.client import connect # type: ignore
|
54
54
|
except ModuleNotFoundError:
|
55
|
-
|
56
|
-
from websockets.client import
|
55
|
+
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
56
|
+
from websockets.client import ClientConnection # type: ignore
|
57
|
+
from websockets.client import connect # type: ignore
|
57
58
|
|
59
|
+
logger = logging.getLogger('google_genai.live')
|
58
60
|
|
59
61
|
_FUNCTION_RESPONSE_REQUIRES_ID = (
|
60
62
|
'FunctionResponse request must have an `id` field from the'
|
@@ -65,22 +67,26 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
|
|
65
67
|
class AsyncSession:
|
66
68
|
"""AsyncSession. The live module is experimental."""
|
67
69
|
|
68
|
-
def __init__(
|
70
|
+
def __init__(
|
71
|
+
self, api_client: client.BaseApiClient, websocket: ClientConnection
|
72
|
+
):
|
69
73
|
self._api_client = api_client
|
70
74
|
self._ws = websocket
|
71
75
|
|
72
76
|
async def send(
|
73
77
|
self,
|
74
78
|
*,
|
75
|
-
input:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
79
|
+
input: Optional[
|
80
|
+
Union[
|
81
|
+
types.ContentListUnion,
|
82
|
+
types.ContentListUnionDict,
|
83
|
+
types.LiveClientContentOrDict,
|
84
|
+
types.LiveClientRealtimeInputOrDict,
|
85
|
+
types.LiveClientToolResponseOrDict,
|
86
|
+
types.FunctionResponseOrDict,
|
87
|
+
Sequence[types.FunctionResponseOrDict],
|
88
|
+
]
|
89
|
+
] = None,
|
84
90
|
end_of_turn: Optional[bool] = False,
|
85
91
|
):
|
86
92
|
"""Send input to the model.
|
@@ -214,7 +220,7 @@ class AsyncSession:
|
|
214
220
|
response_dict = self._LiveServerMessage_from_mldev(response)
|
215
221
|
|
216
222
|
return types.LiveServerMessage._from_response(
|
217
|
-
response_dict, parameter_model
|
223
|
+
response=response_dict, kwargs=parameter_model
|
218
224
|
)
|
219
225
|
|
220
226
|
async def _send_loop(
|
@@ -234,8 +240,8 @@ class AsyncSession:
|
|
234
240
|
def _LiveServerContent_from_mldev(
|
235
241
|
self,
|
236
242
|
from_object: Union[dict, object],
|
237
|
-
) ->
|
238
|
-
to_object = {}
|
243
|
+
) -> Dict[str, Any]:
|
244
|
+
to_object: dict[str, Any] = {}
|
239
245
|
if getv(from_object, ['modelTurn']) is not None:
|
240
246
|
setv(
|
241
247
|
to_object,
|
@@ -254,8 +260,8 @@ class AsyncSession:
|
|
254
260
|
def _LiveToolCall_from_mldev(
|
255
261
|
self,
|
256
262
|
from_object: Union[dict, object],
|
257
|
-
) ->
|
258
|
-
to_object = {}
|
263
|
+
) -> Dict[str, Any]:
|
264
|
+
to_object: dict[str, Any] = {}
|
259
265
|
if getv(from_object, ['functionCalls']) is not None:
|
260
266
|
setv(
|
261
267
|
to_object,
|
@@ -267,8 +273,8 @@ class AsyncSession:
|
|
267
273
|
def _LiveToolCall_from_vertex(
|
268
274
|
self,
|
269
275
|
from_object: Union[dict, object],
|
270
|
-
) ->
|
271
|
-
to_object = {}
|
276
|
+
) -> Dict[str, Any]:
|
277
|
+
to_object: dict[str, Any] = {}
|
272
278
|
if getv(from_object, ['functionCalls']) is not None:
|
273
279
|
setv(
|
274
280
|
to_object,
|
@@ -280,8 +286,8 @@ class AsyncSession:
|
|
280
286
|
def _LiveServerMessage_from_mldev(
|
281
287
|
self,
|
282
288
|
from_object: Union[dict, object],
|
283
|
-
) ->
|
284
|
-
to_object = {}
|
289
|
+
) -> Dict[str, Any]:
|
290
|
+
to_object: dict[str, Any] = {}
|
285
291
|
if getv(from_object, ['serverContent']) is not None:
|
286
292
|
setv(
|
287
293
|
to_object,
|
@@ -307,8 +313,8 @@ class AsyncSession:
|
|
307
313
|
def _LiveServerContent_from_vertex(
|
308
314
|
self,
|
309
315
|
from_object: Union[dict, object],
|
310
|
-
) ->
|
311
|
-
to_object = {}
|
316
|
+
) -> Dict[str, Any]:
|
317
|
+
to_object: dict[str, Any] = {}
|
312
318
|
if getv(from_object, ['modelTurn']) is not None:
|
313
319
|
setv(
|
314
320
|
to_object,
|
@@ -327,8 +333,8 @@ class AsyncSession:
|
|
327
333
|
def _LiveServerMessage_from_vertex(
|
328
334
|
self,
|
329
335
|
from_object: Union[dict, object],
|
330
|
-
) ->
|
331
|
-
to_object = {}
|
336
|
+
) -> Dict[str, Any]:
|
337
|
+
to_object: dict[str, Any] = {}
|
332
338
|
if getv(from_object, ['serverContent']) is not None:
|
333
339
|
setv(
|
334
340
|
to_object,
|
@@ -354,18 +360,23 @@ class AsyncSession:
|
|
354
360
|
|
355
361
|
def _parse_client_message(
|
356
362
|
self,
|
357
|
-
input:
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
363
|
+
input: Optional[
|
364
|
+
Union[
|
365
|
+
types.ContentListUnion,
|
366
|
+
types.ContentListUnionDict,
|
367
|
+
types.LiveClientContentOrDict,
|
368
|
+
types.LiveClientRealtimeInputOrDict,
|
369
|
+
types.LiveClientToolResponseOrDict,
|
370
|
+
types.FunctionResponseOrDict,
|
371
|
+
Sequence[types.FunctionResponseOrDict],
|
372
|
+
]
|
373
|
+
] = None,
|
367
374
|
end_of_turn: Optional[bool] = False,
|
368
|
-
) ->
|
375
|
+
) -> Dict[str, Any]:
|
376
|
+
|
377
|
+
if not input:
|
378
|
+
logging.info('No input provided. Assume it is the end of turn.')
|
379
|
+
return {'client_content': {'turn_complete': True}}
|
369
380
|
if isinstance(input, str):
|
370
381
|
input = [input]
|
371
382
|
elif isinstance(input, dict) and 'data' in input:
|
@@ -374,7 +385,6 @@ class AsyncSession:
|
|
374
385
|
input['data'] = decoded_data
|
375
386
|
input = [input]
|
376
387
|
elif isinstance(input, types.Blob):
|
377
|
-
input.data = base64.b64encode(input.data).decode('utf-8')
|
378
388
|
input = [input]
|
379
389
|
elif isinstance(input, dict) and 'name' in input and 'response' in input:
|
380
390
|
# ToolResponse.FunctionResponse
|
@@ -392,7 +402,7 @@ class AsyncSession:
|
|
392
402
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
393
403
|
client_message = {'tool_response': {'function_responses': input}}
|
394
404
|
elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
|
395
|
-
to_object = {}
|
405
|
+
to_object: dict[str, Any] = {}
|
396
406
|
if self._api_client.vertexai:
|
397
407
|
contents = [
|
398
408
|
_Content_to_vertex(self._api_client, item, to_object)
|
@@ -411,7 +421,7 @@ class AsyncSession:
|
|
411
421
|
if any((isinstance(b, dict) and 'data' in b) for b in input):
|
412
422
|
pass
|
413
423
|
elif any(isinstance(b, types.Blob) for b in input):
|
414
|
-
input = [b.model_dump(exclude_none=True) for b in input]
|
424
|
+
input = [b.model_dump(exclude_none=True, mode='json') for b in input]
|
415
425
|
else:
|
416
426
|
raise ValueError(
|
417
427
|
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
@@ -419,11 +429,21 @@ class AsyncSession:
|
|
419
429
|
|
420
430
|
client_message = {'realtime_input': {'media_chunks': input}}
|
421
431
|
|
422
|
-
elif isinstance(input, dict)
|
423
|
-
|
424
|
-
|
432
|
+
elif isinstance(input, dict):
|
433
|
+
if 'content' in input or 'turns' in input:
|
434
|
+
# TODO(b/365983264) Add validation checks for content_update input_dict.
|
435
|
+
client_message = {'client_content': input}
|
436
|
+
elif 'media_chunks' in input:
|
437
|
+
client_message = {'realtime_input': input}
|
438
|
+
elif 'function_responses' in input:
|
439
|
+
client_message = {'tool_response': input}
|
440
|
+
else:
|
441
|
+
raise ValueError(
|
442
|
+
f'Unsupported input type "{type(input)}" or input content "{input}"')
|
425
443
|
elif isinstance(input, types.LiveClientRealtimeInput):
|
426
|
-
client_message = {
|
444
|
+
client_message = {
|
445
|
+
'realtime_input': input.model_dump(exclude_none=True, mode='json')
|
446
|
+
}
|
427
447
|
if isinstance(
|
428
448
|
client_message['realtime_input']['media_chunks'][0]['data'], bytes
|
429
449
|
):
|
@@ -436,20 +456,26 @@ class AsyncSession:
|
|
436
456
|
]
|
437
457
|
|
438
458
|
elif isinstance(input, types.LiveClientContent):
|
439
|
-
client_message = {
|
459
|
+
client_message = {
|
460
|
+
'client_content': input.model_dump(exclude_none=True, mode='json')
|
461
|
+
}
|
440
462
|
elif isinstance(input, types.LiveClientToolResponse):
|
441
463
|
# ToolResponse.FunctionResponse
|
442
464
|
if not (self._api_client.vertexai) and not (
|
443
465
|
input.function_responses[0].id
|
444
466
|
):
|
445
467
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
446
|
-
client_message = {
|
468
|
+
client_message = {
|
469
|
+
'tool_response': input.model_dump(exclude_none=True, mode='json')
|
470
|
+
}
|
447
471
|
elif isinstance(input, types.FunctionResponse):
|
448
472
|
if not (self._api_client.vertexai) and not (input.id):
|
449
473
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
450
474
|
client_message = {
|
451
475
|
'tool_response': {
|
452
|
-
'function_responses': [
|
476
|
+
'function_responses': [
|
477
|
+
input.model_dump(exclude_none=True, mode='json')
|
478
|
+
]
|
453
479
|
}
|
454
480
|
}
|
455
481
|
elif isinstance(input, Sequence) and isinstance(
|
@@ -460,7 +486,7 @@ class AsyncSession:
|
|
460
486
|
client_message = {
|
461
487
|
'tool_response': {
|
462
488
|
'function_responses': [
|
463
|
-
c.model_dump(exclude_none=True) for c in input
|
489
|
+
c.model_dump(exclude_none=True, mode='json') for c in input
|
464
490
|
]
|
465
491
|
}
|
466
492
|
}
|
@@ -480,39 +506,35 @@ class AsyncLive(_api_module.BaseModule):
|
|
480
506
|
"""AsyncLive. The live module is experimental."""
|
481
507
|
|
482
508
|
def _LiveSetup_to_mldev(
|
483
|
-
self, model: str, config: Optional[types.
|
509
|
+
self, model: str, config: Optional[types.LiveConnectConfig] = None
|
484
510
|
):
|
485
|
-
if isinstance(config, types.LiveConnectConfig):
|
486
|
-
from_object = config.model_dump(exclude_none=True)
|
487
|
-
else:
|
488
|
-
from_object = config
|
489
511
|
|
490
|
-
to_object = {}
|
491
|
-
if getv(
|
512
|
+
to_object: dict[str, Any] = {}
|
513
|
+
if getv(config, ['generation_config']) is not None:
|
492
514
|
setv(
|
493
515
|
to_object,
|
494
516
|
['generationConfig'],
|
495
517
|
_GenerateContentConfig_to_mldev(
|
496
518
|
self._api_client,
|
497
|
-
getv(
|
519
|
+
getv(config, ['generation_config']),
|
498
520
|
to_object,
|
499
521
|
),
|
500
522
|
)
|
501
|
-
if getv(
|
523
|
+
if getv(config, ['response_modalities']) is not None:
|
502
524
|
if getv(to_object, ['generationConfig']) is not None:
|
503
|
-
to_object['generationConfig']['responseModalities'] =
|
504
|
-
'response_modalities'
|
505
|
-
|
525
|
+
to_object['generationConfig']['responseModalities'] = getv(
|
526
|
+
config, ['response_modalities']
|
527
|
+
)
|
506
528
|
else:
|
507
529
|
to_object['generationConfig'] = {
|
508
|
-
'responseModalities':
|
530
|
+
'responseModalities': getv(config, ['response_modalities'])
|
509
531
|
}
|
510
|
-
if getv(
|
532
|
+
if getv(config, ['speech_config']) is not None:
|
511
533
|
if getv(to_object, ['generationConfig']) is not None:
|
512
534
|
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
|
513
535
|
self._api_client,
|
514
536
|
t.t_speech_config(
|
515
|
-
self._api_client, getv(
|
537
|
+
self._api_client, getv(config, ['speech_config'])
|
516
538
|
),
|
517
539
|
to_object,
|
518
540
|
)
|
@@ -521,31 +543,33 @@ class AsyncLive(_api_module.BaseModule):
|
|
521
543
|
'speechConfig': _SpeechConfig_to_mldev(
|
522
544
|
self._api_client,
|
523
545
|
t.t_speech_config(
|
524
|
-
self._api_client, getv(
|
546
|
+
self._api_client, getv(config, ['speech_config'])
|
525
547
|
),
|
526
548
|
to_object,
|
527
549
|
)
|
528
550
|
}
|
529
551
|
|
530
|
-
if getv(
|
552
|
+
if getv(config, ['system_instruction']) is not None:
|
531
553
|
setv(
|
532
554
|
to_object,
|
533
555
|
['systemInstruction'],
|
534
556
|
_Content_to_mldev(
|
535
557
|
self._api_client,
|
536
558
|
t.t_content(
|
537
|
-
self._api_client, getv(
|
559
|
+
self._api_client, getv(config, ['system_instruction'])
|
538
560
|
),
|
539
561
|
to_object,
|
540
562
|
),
|
541
563
|
)
|
542
|
-
if getv(
|
564
|
+
if getv(config, ['tools']) is not None:
|
543
565
|
setv(
|
544
566
|
to_object,
|
545
567
|
['tools'],
|
546
568
|
[
|
547
|
-
_Tool_to_mldev(
|
548
|
-
|
569
|
+
_Tool_to_mldev(
|
570
|
+
self._api_client, t.t_tool(self._api_client, item), to_object
|
571
|
+
)
|
572
|
+
for item in t.t_tools(self._api_client, getv(config, ['tools']))
|
549
573
|
],
|
550
574
|
)
|
551
575
|
|
@@ -554,33 +578,29 @@ class AsyncLive(_api_module.BaseModule):
|
|
554
578
|
return return_value
|
555
579
|
|
556
580
|
def _LiveSetup_to_vertex(
|
557
|
-
self, model: str, config: Optional[types.
|
581
|
+
self, model: str, config: Optional[types.LiveConnectConfig] = None
|
558
582
|
):
|
559
|
-
if isinstance(config, types.LiveConnectConfig):
|
560
|
-
from_object = config.model_dump(exclude_none=True)
|
561
|
-
else:
|
562
|
-
from_object = config
|
563
583
|
|
564
|
-
to_object = {}
|
584
|
+
to_object: dict[str, Any] = {}
|
565
585
|
|
566
|
-
if getv(
|
586
|
+
if getv(config, ['generation_config']) is not None:
|
567
587
|
setv(
|
568
588
|
to_object,
|
569
589
|
['generationConfig'],
|
570
590
|
_GenerateContentConfig_to_vertex(
|
571
591
|
self._api_client,
|
572
|
-
getv(
|
592
|
+
getv(config, ['generation_config']),
|
573
593
|
to_object,
|
574
594
|
),
|
575
595
|
)
|
576
|
-
if getv(
|
596
|
+
if getv(config, ['response_modalities']) is not None:
|
577
597
|
if getv(to_object, ['generationConfig']) is not None:
|
578
|
-
to_object['generationConfig']['responseModalities'] =
|
579
|
-
'response_modalities'
|
580
|
-
|
598
|
+
to_object['generationConfig']['responseModalities'] = getv(
|
599
|
+
config, ['response_modalities']
|
600
|
+
)
|
581
601
|
else:
|
582
602
|
to_object['generationConfig'] = {
|
583
|
-
'responseModalities':
|
603
|
+
'responseModalities': getv(config, ['response_modalities'])
|
584
604
|
}
|
585
605
|
else:
|
586
606
|
# Set default to AUDIO to align with MLDev API.
|
@@ -590,12 +610,12 @@ class AsyncLive(_api_module.BaseModule):
|
|
590
610
|
to_object.update(
|
591
611
|
{'generationConfig': {'responseModalities': ['AUDIO']}}
|
592
612
|
)
|
593
|
-
if getv(
|
613
|
+
if getv(config, ['speech_config']) is not None:
|
594
614
|
if getv(to_object, ['generationConfig']) is not None:
|
595
615
|
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
|
596
616
|
self._api_client,
|
597
617
|
t.t_speech_config(
|
598
|
-
self._api_client, getv(
|
618
|
+
self._api_client, getv(config, ['speech_config'])
|
599
619
|
),
|
600
620
|
to_object,
|
601
621
|
)
|
@@ -604,30 +624,32 @@ class AsyncLive(_api_module.BaseModule):
|
|
604
624
|
'speechConfig': _SpeechConfig_to_vertex(
|
605
625
|
self._api_client,
|
606
626
|
t.t_speech_config(
|
607
|
-
self._api_client, getv(
|
627
|
+
self._api_client, getv(config, ['speech_config'])
|
608
628
|
),
|
609
629
|
to_object,
|
610
630
|
)
|
611
631
|
}
|
612
|
-
if getv(
|
632
|
+
if getv(config, ['system_instruction']) is not None:
|
613
633
|
setv(
|
614
634
|
to_object,
|
615
635
|
['systemInstruction'],
|
616
636
|
_Content_to_vertex(
|
617
637
|
self._api_client,
|
618
638
|
t.t_content(
|
619
|
-
self._api_client, getv(
|
639
|
+
self._api_client, getv(config, ['system_instruction'])
|
620
640
|
),
|
621
641
|
to_object,
|
622
642
|
),
|
623
643
|
)
|
624
|
-
if getv(
|
644
|
+
if getv(config, ['tools']) is not None:
|
625
645
|
setv(
|
626
646
|
to_object,
|
627
647
|
['tools'],
|
628
648
|
[
|
629
|
-
_Tool_to_vertex(
|
630
|
-
|
649
|
+
_Tool_to_vertex(
|
650
|
+
self._api_client, t.t_tool(self._api_client, item), to_object
|
651
|
+
)
|
652
|
+
for item in t.t_tools(self._api_client, getv(config, ['tools']))
|
631
653
|
],
|
632
654
|
)
|
633
655
|
|
@@ -661,16 +683,24 @@ class AsyncLive(_api_module.BaseModule):
|
|
661
683
|
print(message)
|
662
684
|
"""
|
663
685
|
base_url = self._api_client._websocket_base_url()
|
686
|
+
transformed_model = t.t_model(self._api_client, model)
|
687
|
+
# Ensure the config is a LiveConnectConfig.
|
688
|
+
parameter_model = types.LiveConnectConfig(**config) if isinstance(
|
689
|
+
config, dict
|
690
|
+
) else config
|
691
|
+
|
664
692
|
if self._api_client.api_key:
|
665
693
|
api_key = self._api_client.api_key
|
666
694
|
version = self._api_client._http_options['api_version']
|
667
695
|
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
668
696
|
headers = self._api_client._http_options['headers']
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
697
|
+
request_dict = _common.convert_to_dict(
|
698
|
+
self._LiveSetup_to_mldev(
|
699
|
+
model=transformed_model,
|
700
|
+
config=parameter_model,
|
701
|
+
)
|
673
702
|
)
|
703
|
+
request = json.dumps(request_dict)
|
674
704
|
else:
|
675
705
|
# Get bearer token through Application Default Credentials.
|
676
706
|
creds, _ = google.auth.default(
|
@@ -682,26 +712,28 @@ class AsyncLive(_api_module.BaseModule):
|
|
682
712
|
auth_req = google.auth.transport.requests.Request()
|
683
713
|
creds.refresh(auth_req)
|
684
714
|
bearer_token = creds.token
|
685
|
-
headers =
|
686
|
-
|
715
|
+
headers = self._api_client._http_options['headers']
|
716
|
+
headers.update({
|
687
717
|
'Authorization': 'Bearer {}'.format(bearer_token),
|
688
|
-
}
|
718
|
+
})
|
689
719
|
version = self._api_client._http_options['api_version']
|
690
720
|
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
691
721
|
location = self._api_client.location
|
692
722
|
project = self._api_client.project
|
693
|
-
transformed_model = t.t_model(self._api_client, model)
|
694
723
|
if transformed_model.startswith('publishers/'):
|
695
724
|
transformed_model = (
|
696
725
|
f'projects/{project}/locations/{location}/' + transformed_model
|
697
726
|
)
|
698
|
-
|
699
|
-
|
700
|
-
|
727
|
+
request_dict = _common.convert_to_dict(
|
728
|
+
self._LiveSetup_to_vertex(
|
729
|
+
model=transformed_model,
|
730
|
+
config=parameter_model,
|
731
|
+
)
|
701
732
|
)
|
733
|
+
request = json.dumps(request_dict)
|
702
734
|
|
703
735
|
async with connect(uri, additional_headers=headers) as ws:
|
704
736
|
await ws.send(request)
|
705
|
-
|
737
|
+
logger.info(await ws.recv(decode=False))
|
706
738
|
|
707
739
|
yield AsyncSession(api_client=self._api_client, websocket=ws)
|