google-genai 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +2 -0
- google/genai/_api_client.py +1 -1
- google/genai/_automatic_function_calling_util.py +0 -44
- google/genai/_extra_utils.py +15 -0
- google/genai/_transformers.py +3 -2
- google/genai/batches.py +148 -0
- google/genai/caches.py +10 -0
- google/genai/chats.py +12 -2
- google/genai/live.py +74 -42
- google/genai/models.py +98 -11
- google/genai/tunings.py +241 -4
- google/genai/types.py +379 -85
- {google_genai-0.0.1.dist-info → google_genai-0.1.0.dist-info}/METADATA +52 -44
- google_genai-0.1.0.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/RECORD +0 -24
- {google_genai-0.0.1.dist-info → google_genai-0.1.0.dist-info}/LICENSE +0 -0
- {google_genai-0.0.1.dist-info → google_genai-0.1.0.dist-info}/WHEEL +0 -0
- {google_genai-0.0.1.dist-info → google_genai-0.1.0.dist-info}/top_level.txt +0 -0
google/genai/__init__.py
CHANGED
google/genai/_api_client.py
CHANGED
@@ -51,7 +51,7 @@ class HttpOptions(TypedDict):
|
|
51
51
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
52
52
|
"""Appends the telemetry header to the headers dict."""
|
53
53
|
# TODO: Automate revisions to the SDK library version.
|
54
|
-
library_label = 'google-genai-sdk/0.1.0'
|
54
|
+
library_label = f'google-genai-sdk/0.1.0'
|
55
55
|
language_label = 'gl-python/' + sys.version.split()[0]
|
56
56
|
version_header_value = f'{library_label} {language_label}'
|
57
57
|
if (
|
@@ -295,47 +295,3 @@ def _get_required_fields(schema: types.Schema) -> list[str]:
|
|
295
295
|
if not field_schema.nullable and field_schema.default is None
|
296
296
|
]
|
297
297
|
|
298
|
-
|
299
|
-
def function_to_declaration(
|
300
|
-
client, func: Callable
|
301
|
-
) -> types.FunctionDeclaration:
|
302
|
-
"""Converts a function to a FunctionDeclaration."""
|
303
|
-
parameters_properties = {}
|
304
|
-
for name, param in inspect.signature(func).parameters.items():
|
305
|
-
if param.kind in (
|
306
|
-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
307
|
-
inspect.Parameter.KEYWORD_ONLY,
|
308
|
-
inspect.Parameter.POSITIONAL_ONLY,
|
309
|
-
):
|
310
|
-
schema = _parse_schema_from_parameter(client, param, func.__name__)
|
311
|
-
parameters_properties[name] = schema
|
312
|
-
declaration = types.FunctionDeclaration(
|
313
|
-
name=func.__name__,
|
314
|
-
description=func.__doc__,
|
315
|
-
)
|
316
|
-
if parameters_properties:
|
317
|
-
declaration.parameters = types.Schema(
|
318
|
-
type='OBJECT',
|
319
|
-
properties=parameters_properties,
|
320
|
-
)
|
321
|
-
if client.vertexai:
|
322
|
-
declaration.parameters.required = _get_required_fields(
|
323
|
-
declaration.parameters
|
324
|
-
)
|
325
|
-
if not client.vertexai:
|
326
|
-
return declaration
|
327
|
-
|
328
|
-
return_annotation = inspect.signature(func).return_annotation
|
329
|
-
if return_annotation is inspect._empty:
|
330
|
-
return declaration
|
331
|
-
|
332
|
-
declaration.response = _parse_schema_from_parameter(
|
333
|
-
client,
|
334
|
-
inspect.Parameter(
|
335
|
-
'return_value',
|
336
|
-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
337
|
-
annotation=return_annotation,
|
338
|
-
),
|
339
|
-
func.__name__,
|
340
|
-
)
|
341
|
-
return declaration
|
google/genai/_extra_utils.py
CHANGED
@@ -293,3 +293,18 @@ def get_max_remote_calls_afc(
|
|
293
293
|
):
|
294
294
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
295
295
|
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
296
|
+
|
297
|
+
def should_append_afc_history(
|
298
|
+
config: Optional[types.GenerateContentConfigOrDict] = None,
|
299
|
+
) -> bool:
|
300
|
+
config_model = (
|
301
|
+
types.GenerateContentConfig(**config)
|
302
|
+
if config and isinstance(config, dict)
|
303
|
+
else config
|
304
|
+
)
|
305
|
+
if (
|
306
|
+
not config_model
|
307
|
+
or not config_model.automatic_function_calling
|
308
|
+
):
|
309
|
+
return True
|
310
|
+
return not config_model.automatic_function_calling.ignore_call_history
|
google/genai/_transformers.py
CHANGED
@@ -27,7 +27,6 @@ import PIL.Image
|
|
27
27
|
|
28
28
|
from . import _api_client
|
29
29
|
from . import types
|
30
|
-
from ._automatic_function_calling_util import function_to_declaration
|
31
30
|
|
32
31
|
|
33
32
|
def _resource_name(
|
@@ -307,7 +306,9 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
|
|
307
306
|
return None
|
308
307
|
if inspect.isfunction(origin):
|
309
308
|
return types.Tool(
|
310
|
-
function_declarations=[
|
309
|
+
function_declarations=[
|
310
|
+
types.FunctionDeclaration.from_function(client, origin)
|
311
|
+
]
|
311
312
|
)
|
312
313
|
else:
|
313
314
|
return origin
|
google/genai/batches.py
CHANGED
@@ -661,6 +661,24 @@ class Batches(_common.BaseModule):
|
|
661
661
|
return return_value
|
662
662
|
|
663
663
|
def get(self, *, name: str) -> types.BatchJob:
|
664
|
+
"""Gets a batch job.
|
665
|
+
|
666
|
+
Args:
|
667
|
+
name (str): A fully-qualified BatchJob resource name or ID.
|
668
|
+
Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
|
669
|
+
when project and location are initialized in the client.
|
670
|
+
|
671
|
+
Returns:
|
672
|
+
A BatchJob object that contains details about the batch job.
|
673
|
+
|
674
|
+
Usage:
|
675
|
+
|
676
|
+
.. code-block:: python
|
677
|
+
|
678
|
+
batch_job = client.batches.get(name='123456789')
|
679
|
+
print(f"Batch job: {batch_job.name}, state {batch_job.state}")
|
680
|
+
"""
|
681
|
+
|
664
682
|
parameter_model = types._GetBatchJobParameters(
|
665
683
|
name=name,
|
666
684
|
)
|
@@ -767,6 +785,23 @@ class Batches(_common.BaseModule):
|
|
767
785
|
return return_value
|
768
786
|
|
769
787
|
def delete(self, *, name: str) -> types.DeleteResourceJob:
|
788
|
+
"""Deletes a batch job.
|
789
|
+
|
790
|
+
Args:
|
791
|
+
name (str): A fully-qualified BatchJob resource name or ID.
|
792
|
+
Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
|
793
|
+
when project and location are initialized in the client.
|
794
|
+
|
795
|
+
Returns:
|
796
|
+
A DeleteResourceJob object that shows the status of the deletion.
|
797
|
+
|
798
|
+
Usage:
|
799
|
+
|
800
|
+
.. code-block:: python
|
801
|
+
|
802
|
+
client.batches.delete(name='123456789')
|
803
|
+
"""
|
804
|
+
|
770
805
|
parameter_model = types._DeleteBatchJobParameters(
|
771
806
|
name=name,
|
772
807
|
)
|
@@ -814,12 +849,51 @@ class Batches(_common.BaseModule):
|
|
814
849
|
src: str,
|
815
850
|
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
816
851
|
) -> types.BatchJob:
|
852
|
+
"""Creates a batch job.
|
853
|
+
|
854
|
+
Args:
|
855
|
+
model (str): The model to use for the batch job.
|
856
|
+
src (str): The source of the batch job. Currently supports GCS URI(-s) or
|
857
|
+
Bigquery URI. Example: "gs://path/to/input/data" or
|
858
|
+
"bq://projectId.bqDatasetId.bqTableId".
|
859
|
+
config (CreateBatchJobConfig): Optional configuration for the batch job.
|
860
|
+
|
861
|
+
Returns:
|
862
|
+
A BatchJob object that contains details about the batch job.
|
863
|
+
|
864
|
+
Usage:
|
865
|
+
|
866
|
+
.. code-block:: python
|
867
|
+
|
868
|
+
batch_job = client.batches.create(
|
869
|
+
model="gemini-1.5-flash",
|
870
|
+
src="gs://path/to/input/data",
|
871
|
+
)
|
872
|
+
print(batch_job.state)
|
873
|
+
"""
|
817
874
|
config = _extra_utils.format_destination(src, config)
|
818
875
|
return self._create(model=model, src=src, config=config)
|
819
876
|
|
820
877
|
def list(
|
821
878
|
self, *, config: Optional[types.ListBatchJobConfigOrDict] = None
|
822
879
|
) -> Pager[types.BatchJob]:
|
880
|
+
"""Lists batch jobs.
|
881
|
+
|
882
|
+
Args:
|
883
|
+
config (ListBatchJobConfig): Optional configuration for the list request.
|
884
|
+
|
885
|
+
Returns:
|
886
|
+
A Pager object that contains one page of batch jobs. When iterating over
|
887
|
+
the pager, it automatically fetches the next page if there are more.
|
888
|
+
|
889
|
+
Usage:
|
890
|
+
|
891
|
+
.. code-block:: python
|
892
|
+
|
893
|
+
batch_jobs = client.batches.list(config={"page_size": 10})
|
894
|
+
for batch_job in batch_jobs:
|
895
|
+
print(f"Batch job: {batch_job.name}, state {batch_job.state}")
|
896
|
+
"""
|
823
897
|
return Pager(
|
824
898
|
'batch_jobs',
|
825
899
|
self._list,
|
@@ -874,6 +948,24 @@ class AsyncBatches(_common.BaseModule):
|
|
874
948
|
return return_value
|
875
949
|
|
876
950
|
async def get(self, *, name: str) -> types.BatchJob:
|
951
|
+
"""Gets a batch job.
|
952
|
+
|
953
|
+
Args:
|
954
|
+
name (str): A fully-qualified BatchJob resource name or ID.
|
955
|
+
Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
|
956
|
+
when project and location are initialized in the client.
|
957
|
+
|
958
|
+
Returns:
|
959
|
+
A BatchJob object that contains details about the batch job.
|
960
|
+
|
961
|
+
Usage:
|
962
|
+
|
963
|
+
.. code-block:: python
|
964
|
+
|
965
|
+
batch_job = client.batches.get(name='123456789')
|
966
|
+
print(f"Batch job: {batch_job.name}, state {batch_job.state}")
|
967
|
+
"""
|
968
|
+
|
877
969
|
parameter_model = types._GetBatchJobParameters(
|
878
970
|
name=name,
|
879
971
|
)
|
@@ -980,6 +1072,23 @@ class AsyncBatches(_common.BaseModule):
|
|
980
1072
|
return return_value
|
981
1073
|
|
982
1074
|
async def delete(self, *, name: str) -> types.DeleteResourceJob:
|
1075
|
+
"""Deletes a batch job.
|
1076
|
+
|
1077
|
+
Args:
|
1078
|
+
name (str): A fully-qualified BatchJob resource name or ID.
|
1079
|
+
Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
|
1080
|
+
when project and location are initialized in the client.
|
1081
|
+
|
1082
|
+
Returns:
|
1083
|
+
A DeleteResourceJob object that shows the status of the deletion.
|
1084
|
+
|
1085
|
+
Usage:
|
1086
|
+
|
1087
|
+
.. code-block:: python
|
1088
|
+
|
1089
|
+
client.batches.delete(name='123456789')
|
1090
|
+
"""
|
1091
|
+
|
983
1092
|
parameter_model = types._DeleteBatchJobParameters(
|
984
1093
|
name=name,
|
985
1094
|
)
|
@@ -1027,12 +1136,51 @@ class AsyncBatches(_common.BaseModule):
|
|
1027
1136
|
src: str,
|
1028
1137
|
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
1029
1138
|
) -> types.BatchJob:
|
1139
|
+
"""Creates a batch job asynchronously.
|
1140
|
+
|
1141
|
+
Args:
|
1142
|
+
model (str): The model to use for the batch job.
|
1143
|
+
src (str): The source of the batch job. Currently supports GCS URI(-s) or
|
1144
|
+
Bigquery URI. Example: "gs://path/to/input/data" or
|
1145
|
+
"bq://projectId.bqDatasetId.bqTableId".
|
1146
|
+
config (CreateBatchJobConfig): Optional configuration for the batch job.
|
1147
|
+
|
1148
|
+
Returns:
|
1149
|
+
A BatchJob object that contains details about the batch job.
|
1150
|
+
|
1151
|
+
Usage:
|
1152
|
+
|
1153
|
+
.. code-block:: python
|
1154
|
+
|
1155
|
+
batch_job = await client.aio.batches.create(
|
1156
|
+
model="gemini-1.5-flash",
|
1157
|
+
src="gs://path/to/input/data",
|
1158
|
+
)
|
1159
|
+
"""
|
1030
1160
|
config = _extra_utils.format_destination(src, config)
|
1031
1161
|
return await self._create(model=model, src=src, config=config)
|
1032
1162
|
|
1033
1163
|
async def list(
|
1034
1164
|
self, *, config: Optional[types.ListBatchJobConfigOrDict] = None
|
1035
1165
|
) -> AsyncPager[types.BatchJob]:
|
1166
|
+
"""Lists batch jobs asynchronously.
|
1167
|
+
|
1168
|
+
Args:
|
1169
|
+
config (ListBatchJobConfig): Optional configuration for the list request.
|
1170
|
+
|
1171
|
+
Returns:
|
1172
|
+
A Pager object that contains one page of batch jobs. When iterating over
|
1173
|
+
the pager, it automatically fetches the next page if there are more.
|
1174
|
+
|
1175
|
+
Usage:
|
1176
|
+
|
1177
|
+
.. code-block:: python
|
1178
|
+
|
1179
|
+
batch_jobs = await client.aio.batches.list(config={'page_size': 5})
|
1180
|
+
print(f"current page: {batch_jobs.page}")
|
1181
|
+
await batch_jobs_pager.next_page()
|
1182
|
+
print(f"next page: {batch_jobs_pager.page}")
|
1183
|
+
"""
|
1036
1184
|
return AsyncPager(
|
1037
1185
|
'batch_jobs',
|
1038
1186
|
self._list,
|
google/genai/caches.py
CHANGED
@@ -1249,6 +1249,7 @@ class Caches(_common.BaseModule):
|
|
1249
1249
|
Usage:
|
1250
1250
|
|
1251
1251
|
.. code-block:: python
|
1252
|
+
|
1252
1253
|
contents = ... // Initialize the content to cache.
|
1253
1254
|
response = await client.aio.caches.create(
|
1254
1255
|
model= ... // The publisher model id
|
@@ -1310,6 +1311,7 @@ class Caches(_common.BaseModule):
|
|
1310
1311
|
"""Gets cached content configurations.
|
1311
1312
|
|
1312
1313
|
.. code-block:: python
|
1314
|
+
|
1313
1315
|
await client.aio.caches.get(name= ... ) // The server-generated resource
|
1314
1316
|
name.
|
1315
1317
|
"""
|
@@ -1364,6 +1366,7 @@ class Caches(_common.BaseModule):
|
|
1364
1366
|
Usage:
|
1365
1367
|
|
1366
1368
|
.. code-block:: python
|
1369
|
+
|
1367
1370
|
await client.aio.caches.delete(name= ... ) // The server-generated
|
1368
1371
|
resource name.
|
1369
1372
|
"""
|
@@ -1420,6 +1423,7 @@ class Caches(_common.BaseModule):
|
|
1420
1423
|
"""Updates cached content configurations.
|
1421
1424
|
|
1422
1425
|
.. code-block:: python
|
1426
|
+
|
1423
1427
|
response = await client.aio.caches.update(
|
1424
1428
|
name= ... // The server-generated resource name.
|
1425
1429
|
config={
|
@@ -1473,6 +1477,7 @@ class Caches(_common.BaseModule):
|
|
1473
1477
|
"""Lists cached content configurations.
|
1474
1478
|
|
1475
1479
|
.. code-block:: python
|
1480
|
+
|
1476
1481
|
cached_contents = await client.aio.caches.list(config={'page_size': 2})
|
1477
1482
|
async for cached_content in cached_contents:
|
1478
1483
|
print(cached_content)
|
@@ -1548,6 +1553,7 @@ class AsyncCaches(_common.BaseModule):
|
|
1548
1553
|
Usage:
|
1549
1554
|
|
1550
1555
|
.. code-block:: python
|
1556
|
+
|
1551
1557
|
contents = ... // Initialize the content to cache.
|
1552
1558
|
response = await client.aio.caches.create(
|
1553
1559
|
model= ... // The publisher model id
|
@@ -1609,6 +1615,7 @@ class AsyncCaches(_common.BaseModule):
|
|
1609
1615
|
"""Gets cached content configurations.
|
1610
1616
|
|
1611
1617
|
.. code-block:: python
|
1618
|
+
|
1612
1619
|
await client.aio.caches.get(name= ... ) // The server-generated resource
|
1613
1620
|
name.
|
1614
1621
|
"""
|
@@ -1663,6 +1670,7 @@ class AsyncCaches(_common.BaseModule):
|
|
1663
1670
|
Usage:
|
1664
1671
|
|
1665
1672
|
.. code-block:: python
|
1673
|
+
|
1666
1674
|
await client.aio.caches.delete(name= ... ) // The server-generated
|
1667
1675
|
resource name.
|
1668
1676
|
"""
|
@@ -1719,6 +1727,7 @@ class AsyncCaches(_common.BaseModule):
|
|
1719
1727
|
"""Updates cached content configurations.
|
1720
1728
|
|
1721
1729
|
.. code-block:: python
|
1730
|
+
|
1722
1731
|
response = await client.aio.caches.update(
|
1723
1732
|
name= ... // The server-generated resource name.
|
1724
1733
|
config={
|
@@ -1772,6 +1781,7 @@ class AsyncCaches(_common.BaseModule):
|
|
1772
1781
|
"""Lists cached content configurations.
|
1773
1782
|
|
1774
1783
|
.. code-block:: python
|
1784
|
+
|
1775
1785
|
cached_contents = await client.aio.caches.list(config={'page_size': 2})
|
1776
1786
|
async for cached_content in cached_contents:
|
1777
1787
|
print(cached_content)
|
google/genai/chats.py
CHANGED
@@ -65,7 +65,12 @@ class Chat(_BaseChat):
|
|
65
65
|
config=self._config,
|
66
66
|
)
|
67
67
|
if response.candidates and response.candidates[0].content:
|
68
|
-
|
68
|
+
if response.automatic_function_calling_history:
|
69
|
+
self._curated_history.extend(
|
70
|
+
response.automatic_function_calling_history
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
self._curated_history.append(input_content)
|
69
74
|
self._curated_history.append(response.candidates[0].content)
|
70
75
|
return response
|
71
76
|
|
@@ -138,7 +143,12 @@ class AsyncChat(_BaseChat):
|
|
138
143
|
config=self._config,
|
139
144
|
)
|
140
145
|
if response.candidates and response.candidates[0].content:
|
141
|
-
|
146
|
+
if response.automatic_function_calling_history:
|
147
|
+
self._curated_history.extend(
|
148
|
+
response.automatic_function_calling_history
|
149
|
+
)
|
150
|
+
else:
|
151
|
+
self._curated_history.append(input_content)
|
142
152
|
self._curated_history.append(response.candidates[0].content)
|
143
153
|
return response
|
144
154
|
|
google/genai/live.py
CHANGED
@@ -53,6 +53,12 @@ except ModuleNotFoundError:
|
|
53
53
|
from websockets.client import connect
|
54
54
|
|
55
55
|
|
56
|
+
_FUNCTION_RESPONSE_REQUIRES_ID = (
|
57
|
+
'FunctionResponse request must have an `id` field from the'
|
58
|
+
' response of a ToolCall.FunctionalCalls in Google AI.'
|
59
|
+
)
|
60
|
+
|
61
|
+
|
56
62
|
class AsyncSession:
|
57
63
|
"""AsyncSession."""
|
58
64
|
|
@@ -81,20 +87,23 @@ class AsyncSession:
|
|
81
87
|
"""Receive model responses from the server.
|
82
88
|
|
83
89
|
The method will yield the model responses from the server. The returned
|
84
|
-
responses will represent a complete model turn.
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
```
|
89
|
-
client = genai.Client(api_key=API_KEY)
|
90
|
-
|
91
|
-
async with client.aio.live.connect(model='...') as session:
|
92
|
-
await session.send(input='Hello world!', end_of_turn=True)
|
93
|
-
async for message in session.receive():
|
94
|
-
print(message)
|
95
|
-
```
|
90
|
+
responses will represent a complete model turn. When the returned message
|
91
|
+
is fuction call, user must call `send` with the function response to
|
92
|
+
continue the turn.
|
93
|
+
|
96
94
|
Yields:
|
97
|
-
|
95
|
+
The model responses from the server.
|
96
|
+
|
97
|
+
Example usage:
|
98
|
+
|
99
|
+
.. code-block:: python
|
100
|
+
|
101
|
+
client = genai.Client(api_key=API_KEY)
|
102
|
+
|
103
|
+
async with client.aio.live.connect(model='...') as session:
|
104
|
+
await session.send(input='Hello world!', end_of_turn=True)
|
105
|
+
async for message in session.receive():
|
106
|
+
print(message)
|
98
107
|
"""
|
99
108
|
# TODO(b/365983264) Handle intermittent issues for the user.
|
100
109
|
while result := await self._receive():
|
@@ -113,28 +122,27 @@ class AsyncSession:
|
|
113
122
|
input stream to the model and the other task will be used to receive the
|
114
123
|
responses from the model.
|
115
124
|
|
116
|
-
Example usage:
|
117
|
-
```
|
118
|
-
client = genai.Client(api_key=API_KEY)
|
119
|
-
config = {'response_modalities': ['AUDIO']}
|
120
|
-
|
121
|
-
async def audio_stream():
|
122
|
-
stream = read_audio()
|
123
|
-
for data in stream:
|
124
|
-
yield data
|
125
|
-
|
126
|
-
async with client.aio.live.connect(model='...') as session:
|
127
|
-
for audio in session.start_stream(stream = audio_stream(),
|
128
|
-
mime_type = 'audio/pcm'):
|
129
|
-
play_audio_chunk(audio.data)
|
130
|
-
```
|
131
|
-
|
132
125
|
Args:
|
133
|
-
|
134
|
-
|
126
|
+
stream: An iterator that yields the model response.
|
127
|
+
mime_type: The MIME type of the data in the stream.
|
135
128
|
|
136
129
|
Yields:
|
137
|
-
|
130
|
+
The audio bytes received from the model and server response messages.
|
131
|
+
|
132
|
+
Example usage:
|
133
|
+
|
134
|
+
.. code-block:: python
|
135
|
+
|
136
|
+
client = genai.Client(api_key=API_KEY)
|
137
|
+
config = {'response_modalities': ['AUDIO']}
|
138
|
+
async def audio_stream():
|
139
|
+
stream = read_audio()
|
140
|
+
for data in stream:
|
141
|
+
yield data
|
142
|
+
async with client.aio.live.connect(model='...') as session:
|
143
|
+
for audio in session.start_stream(stream = audio_stream(),
|
144
|
+
mime_type = 'audio/pcm'):
|
145
|
+
play_audio_chunk(audio.data)
|
138
146
|
"""
|
139
147
|
stop_event = asyncio.Event()
|
140
148
|
# Start the send loop. When stream is complete stop_event is set.
|
@@ -340,6 +348,8 @@ class AsyncSession:
|
|
340
348
|
input = [input]
|
341
349
|
elif isinstance(input, dict) and 'name' in input and 'response' in input:
|
342
350
|
# ToolResponse.FunctionResponse
|
351
|
+
if not (self._api_client.vertexai) and 'id' not in input:
|
352
|
+
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
343
353
|
input = [input]
|
344
354
|
|
345
355
|
if isinstance(input, Sequence) and any(
|
@@ -395,7 +405,29 @@ class AsyncSession:
|
|
395
405
|
client_message = {'client_content': input.model_dump(exclude_none=True)}
|
396
406
|
elif isinstance(input, types.LiveClientToolResponse):
|
397
407
|
# ToolResponse.FunctionResponse
|
408
|
+
if not (self._api_client.vertexai) and not (input.function_responses[0].id):
|
409
|
+
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
398
410
|
client_message = {'tool_response': input.model_dump(exclude_none=True)}
|
411
|
+
elif isinstance(input, types.FunctionResponse):
|
412
|
+
if not (self._api_client.vertexai) and not (input.id):
|
413
|
+
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
414
|
+
client_message = {
|
415
|
+
'tool_response': {
|
416
|
+
'function_responses': [input.model_dump(exclude_none=True)]
|
417
|
+
}
|
418
|
+
}
|
419
|
+
elif isinstance(input, Sequence) and isinstance(
|
420
|
+
input[0], types.FunctionResponse
|
421
|
+
):
|
422
|
+
if not (self._api_client.vertexai) and not (input[0].id):
|
423
|
+
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
424
|
+
client_message = {
|
425
|
+
'tool_response': {
|
426
|
+
'function_responses': [
|
427
|
+
c.model_dump(exclude_none=True) for c in input
|
428
|
+
]
|
429
|
+
}
|
430
|
+
}
|
399
431
|
else:
|
400
432
|
raise ValueError(
|
401
433
|
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
@@ -571,16 +603,16 @@ class AsyncLive(_common.BaseModule):
|
|
571
603
|
) -> AsyncSession:
|
572
604
|
"""Connect to the live server.
|
573
605
|
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
606
|
+
Usage:
|
607
|
+
|
608
|
+
.. code-block:: python
|
609
|
+
|
610
|
+
client = genai.Client(api_key=API_KEY)
|
611
|
+
config = {}
|
612
|
+
async with client.aio.live.connect(model='gemini-1.0-pro-002', config=config) as session:
|
613
|
+
await session.send(input='Hello world!', end_of_turn=True)
|
614
|
+
async for message in session:
|
615
|
+
print(message)
|
584
616
|
"""
|
585
617
|
base_url = self.api_client._websocket_base_url()
|
586
618
|
if self.api_client.api_key:
|