google-genai 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +2 -1
- google/genai/_api_client.py +161 -52
- google/genai/_automatic_function_calling_util.py +14 -14
- google/genai/_common.py +14 -29
- google/genai/_replay_api_client.py +13 -54
- google/genai/_transformers.py +38 -0
- google/genai/batches.py +80 -78
- google/genai/caches.py +112 -98
- google/genai/chats.py +7 -10
- google/genai/client.py +6 -3
- google/genai/files.py +91 -90
- google/genai/live.py +65 -34
- google/genai/models.py +374 -297
- google/genai/tunings.py +87 -85
- google/genai/types.py +167 -82
- google/genai/version.py +16 -0
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/METADATA +57 -17
- google_genai-0.5.0.dist-info/RECORD +25 -0
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/WHEEL +1 -1
- google_genai-0.3.0.dist-info/RECORD +0 -24
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/LICENSE +0 -0
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
@@ -68,6 +68,7 @@ class AsyncSession:
|
|
68
68
|
|
69
69
|
async def send(
|
70
70
|
self,
|
71
|
+
*,
|
71
72
|
input: Union[
|
72
73
|
types.ContentListUnion,
|
73
74
|
types.ContentListUnionDict,
|
@@ -80,6 +81,25 @@ class AsyncSession:
|
|
80
81
|
],
|
81
82
|
end_of_turn: Optional[bool] = False,
|
82
83
|
):
|
84
|
+
"""Send input to the model.
|
85
|
+
|
86
|
+
The method will send the input request to the server.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
input: The input request to the model.
|
90
|
+
end_of_turn: Whether the input is the last message in a turn.
|
91
|
+
|
92
|
+
Example usage:
|
93
|
+
|
94
|
+
.. code-block:: python
|
95
|
+
|
96
|
+
client = genai.Client(api_key=API_KEY)
|
97
|
+
|
98
|
+
async with client.aio.live.connect(model='...') as session:
|
99
|
+
await session.send(input='Hello world!', end_of_turn=True)
|
100
|
+
async for message in session.receive():
|
101
|
+
print(message)
|
102
|
+
"""
|
83
103
|
client_message = self._parse_client_message(input, end_of_turn)
|
84
104
|
await self._ws.send(json.dumps(client_message))
|
85
105
|
|
@@ -113,7 +133,7 @@ class AsyncSession:
|
|
113
133
|
yield result
|
114
134
|
|
115
135
|
async def start_stream(
|
116
|
-
self, stream: AsyncIterator[bytes], mime_type: str
|
136
|
+
self, *, stream: AsyncIterator[bytes], mime_type: str
|
117
137
|
) -> AsyncIterator[types.LiveServerMessage]:
|
118
138
|
"""start a live session from a data stream.
|
119
139
|
|
@@ -199,7 +219,7 @@ class AsyncSession:
|
|
199
219
|
):
|
200
220
|
async for data in data_stream:
|
201
221
|
input = {'data': data, 'mimeType': mime_type}
|
202
|
-
await self.send(input)
|
222
|
+
await self.send(input=input)
|
203
223
|
# Give a chance for the receive loop to process responses.
|
204
224
|
await asyncio.sleep(10**-12)
|
205
225
|
# Give a chance for the receiver to process the last response.
|
@@ -221,6 +241,8 @@ class AsyncSession:
|
|
221
241
|
)
|
222
242
|
if getv(from_object, ['turnComplete']) is not None:
|
223
243
|
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
244
|
+
if getv(from_object, ['interrupted']) is not None:
|
245
|
+
setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
|
224
246
|
return to_object
|
225
247
|
|
226
248
|
def _LiveToolCall_from_mldev(
|
@@ -292,6 +314,8 @@ class AsyncSession:
|
|
292
314
|
)
|
293
315
|
if getv(from_object, ['turnComplete']) is not None:
|
294
316
|
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
317
|
+
if getv(from_object, ['interrupted']) is not None:
|
318
|
+
setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
|
295
319
|
return to_object
|
296
320
|
|
297
321
|
def _LiveServerMessage_from_vertex(
|
@@ -338,7 +362,7 @@ class AsyncSession:
|
|
338
362
|
) -> dict:
|
339
363
|
if isinstance(input, str):
|
340
364
|
input = [input]
|
341
|
-
elif
|
365
|
+
elif isinstance(input, dict) and 'data' in input:
|
342
366
|
if isinstance(input['data'], bytes):
|
343
367
|
decoded_data = base64.b64encode(input['data']).decode('utf-8')
|
344
368
|
input['data'] = decoded_data
|
@@ -405,7 +429,9 @@ class AsyncSession:
|
|
405
429
|
client_message = {'client_content': input.model_dump(exclude_none=True)}
|
406
430
|
elif isinstance(input, types.LiveClientToolResponse):
|
407
431
|
# ToolResponse.FunctionResponse
|
408
|
-
if not (self._api_client.vertexai) and not (
|
432
|
+
if not (self._api_client.vertexai) and not (
|
433
|
+
input.function_responses[0].id
|
434
|
+
):
|
409
435
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
410
436
|
client_message = {'tool_response': input.model_dump(exclude_none=True)}
|
411
437
|
elif isinstance(input, types.FunctionResponse):
|
@@ -457,7 +483,7 @@ class AsyncLive(_common.BaseModule):
|
|
457
483
|
to_object,
|
458
484
|
['generationConfig'],
|
459
485
|
_GenerateContentConfig_to_mldev(
|
460
|
-
self.
|
486
|
+
self._api_client,
|
461
487
|
getv(from_object, ['generation_config']),
|
462
488
|
to_object,
|
463
489
|
),
|
@@ -474,17 +500,18 @@ class AsyncLive(_common.BaseModule):
|
|
474
500
|
if getv(from_object, ['speech_config']) is not None:
|
475
501
|
if getv(to_object, ['generationConfig']) is not None:
|
476
502
|
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
|
477
|
-
self.
|
503
|
+
self._api_client,
|
478
504
|
t.t_speech_config(
|
479
|
-
self.
|
505
|
+
self._api_client, getv(from_object, ['speech_config'])
|
506
|
+
),
|
480
507
|
to_object,
|
481
508
|
)
|
482
509
|
else:
|
483
510
|
to_object['generationConfig'] = {
|
484
511
|
'speechConfig': _SpeechConfig_to_mldev(
|
485
|
-
self.
|
512
|
+
self._api_client,
|
486
513
|
t.t_speech_config(
|
487
|
-
self.
|
514
|
+
self._api_client, getv(from_object, ['speech_config'])
|
488
515
|
),
|
489
516
|
to_object,
|
490
517
|
)
|
@@ -495,9 +522,9 @@ class AsyncLive(_common.BaseModule):
|
|
495
522
|
to_object,
|
496
523
|
['systemInstruction'],
|
497
524
|
_Content_to_mldev(
|
498
|
-
self.
|
525
|
+
self._api_client,
|
499
526
|
t.t_content(
|
500
|
-
self.
|
527
|
+
self._api_client, getv(from_object, ['system_instruction'])
|
501
528
|
),
|
502
529
|
to_object,
|
503
530
|
),
|
@@ -507,7 +534,7 @@ class AsyncLive(_common.BaseModule):
|
|
507
534
|
to_object,
|
508
535
|
['tools'],
|
509
536
|
[
|
510
|
-
_Tool_to_mldev(self.
|
537
|
+
_Tool_to_mldev(self._api_client, item, to_object)
|
511
538
|
for item in getv(from_object, ['tools'])
|
512
539
|
],
|
513
540
|
)
|
@@ -531,7 +558,7 @@ class AsyncLive(_common.BaseModule):
|
|
531
558
|
to_object,
|
532
559
|
['generationConfig'],
|
533
560
|
_GenerateContentConfig_to_vertex(
|
534
|
-
self.
|
561
|
+
self._api_client,
|
535
562
|
getv(from_object, ['generation_config']),
|
536
563
|
to_object,
|
537
564
|
),
|
@@ -556,17 +583,18 @@ class AsyncLive(_common.BaseModule):
|
|
556
583
|
if getv(from_object, ['speech_config']) is not None:
|
557
584
|
if getv(to_object, ['generationConfig']) is not None:
|
558
585
|
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
|
559
|
-
self.
|
586
|
+
self._api_client,
|
560
587
|
t.t_speech_config(
|
561
|
-
self.
|
588
|
+
self._api_client, getv(from_object, ['speech_config'])
|
589
|
+
),
|
562
590
|
to_object,
|
563
591
|
)
|
564
592
|
else:
|
565
593
|
to_object['generationConfig'] = {
|
566
594
|
'speechConfig': _SpeechConfig_to_vertex(
|
567
|
-
self.
|
595
|
+
self._api_client,
|
568
596
|
t.t_speech_config(
|
569
|
-
self.
|
597
|
+
self._api_client, getv(from_object, ['speech_config'])
|
570
598
|
),
|
571
599
|
to_object,
|
572
600
|
)
|
@@ -576,9 +604,9 @@ class AsyncLive(_common.BaseModule):
|
|
576
604
|
to_object,
|
577
605
|
['systemInstruction'],
|
578
606
|
_Content_to_vertex(
|
579
|
-
self.
|
607
|
+
self._api_client,
|
580
608
|
t.t_content(
|
581
|
-
self.
|
609
|
+
self._api_client, getv(from_object, ['system_instruction'])
|
582
610
|
),
|
583
611
|
to_object,
|
584
612
|
),
|
@@ -588,7 +616,7 @@ class AsyncLive(_common.BaseModule):
|
|
588
616
|
to_object,
|
589
617
|
['tools'],
|
590
618
|
[
|
591
|
-
_Tool_to_vertex(self.
|
619
|
+
_Tool_to_vertex(self._api_client, item, to_object)
|
592
620
|
for item in getv(from_object, ['tools'])
|
593
621
|
],
|
594
622
|
)
|
@@ -599,7 +627,10 @@ class AsyncLive(_common.BaseModule):
|
|
599
627
|
|
600
628
|
@contextlib.asynccontextmanager
|
601
629
|
async def connect(
|
602
|
-
self,
|
630
|
+
self,
|
631
|
+
*,
|
632
|
+
model: str,
|
633
|
+
config: Optional[types.LiveConnectConfigOrDict] = None,
|
603
634
|
) -> AsyncSession:
|
604
635
|
"""Connect to the live server.
|
605
636
|
|
@@ -609,19 +640,19 @@ class AsyncLive(_common.BaseModule):
|
|
609
640
|
|
610
641
|
client = genai.Client(api_key=API_KEY)
|
611
642
|
config = {}
|
612
|
-
async with client.aio.live.connect(model='
|
643
|
+
async with client.aio.live.connect(model='...', config=config) as session:
|
613
644
|
await session.send(input='Hello world!', end_of_turn=True)
|
614
|
-
async for message in session:
|
645
|
+
async for message in session.receive():
|
615
646
|
print(message)
|
616
647
|
"""
|
617
|
-
base_url = self.
|
618
|
-
if self.
|
619
|
-
api_key = self.
|
620
|
-
version = self.
|
648
|
+
base_url = self._api_client._websocket_base_url()
|
649
|
+
if self._api_client.api_key:
|
650
|
+
api_key = self._api_client.api_key
|
651
|
+
version = self._api_client._http_options['api_version']
|
621
652
|
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
622
|
-
headers = self.
|
653
|
+
headers = self._api_client._http_options['headers']
|
623
654
|
|
624
|
-
transformed_model = t.t_model(self.
|
655
|
+
transformed_model = t.t_model(self._api_client, model)
|
625
656
|
request = json.dumps(
|
626
657
|
self._LiveSetup_to_mldev(model=transformed_model, config=config)
|
627
658
|
)
|
@@ -640,11 +671,11 @@ class AsyncLive(_common.BaseModule):
|
|
640
671
|
'Content-Type': 'application/json',
|
641
672
|
'Authorization': 'Bearer {}'.format(bearer_token),
|
642
673
|
}
|
643
|
-
version = self.
|
674
|
+
version = self._api_client._http_options['api_version']
|
644
675
|
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
645
|
-
location = self.
|
646
|
-
project = self.
|
647
|
-
transformed_model = t.t_model(self.
|
676
|
+
location = self._api_client.location
|
677
|
+
project = self._api_client.project
|
678
|
+
transformed_model = t.t_model(self._api_client, model)
|
648
679
|
if transformed_model.startswith('publishers/'):
|
649
680
|
transformed_model = (
|
650
681
|
f'projects/{project}/locations/{location}/' + transformed_model
|
@@ -658,4 +689,4 @@ class AsyncLive(_common.BaseModule):
|
|
658
689
|
await ws.send(request)
|
659
690
|
logging.info(await ws.recv(decode=False))
|
660
691
|
|
661
|
-
yield AsyncSession(api_client=self.
|
692
|
+
yield AsyncSession(api_client=self._api_client, websocket=ws)
|