google-genai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/caches.py CHANGED
@@ -17,6 +17,7 @@
17
17
 
18
18
  from typing import Optional, Union
19
19
  from urllib.parse import urlencode
20
+ from . import _api_module
20
21
  from . import _common
21
22
  from . import _transformers as t
22
23
  from . import types
@@ -33,7 +34,7 @@ def _Part_to_mldev(
33
34
  ) -> dict:
34
35
  to_object = {}
35
36
  if getv(from_object, ['video_metadata']) is not None:
36
- raise ValueError('video_metadata parameter is not supported in Google AI.')
37
+ raise ValueError('video_metadata parameter is not supported in Gemini API.')
37
38
 
38
39
  if getv(from_object, ['thought']) is not None:
39
40
  setv(to_object, ['thought'], getv(from_object, ['thought']))
@@ -165,51 +166,51 @@ def _Schema_to_mldev(
165
166
  ) -> dict:
166
167
  to_object = {}
167
168
  if getv(from_object, ['min_items']) is not None:
168
- raise ValueError('min_items parameter is not supported in Google AI.')
169
+ raise ValueError('min_items parameter is not supported in Gemini API.')
169
170
 
170
171
  if getv(from_object, ['example']) is not None:
171
- raise ValueError('example parameter is not supported in Google AI.')
172
+ raise ValueError('example parameter is not supported in Gemini API.')
172
173
 
173
174
  if getv(from_object, ['property_ordering']) is not None:
174
175
  raise ValueError(
175
- 'property_ordering parameter is not supported in Google AI.'
176
+ 'property_ordering parameter is not supported in Gemini API.'
176
177
  )
177
178
 
178
179
  if getv(from_object, ['pattern']) is not None:
179
- raise ValueError('pattern parameter is not supported in Google AI.')
180
+ raise ValueError('pattern parameter is not supported in Gemini API.')
180
181
 
181
182
  if getv(from_object, ['minimum']) is not None:
182
- raise ValueError('minimum parameter is not supported in Google AI.')
183
+ raise ValueError('minimum parameter is not supported in Gemini API.')
183
184
 
184
185
  if getv(from_object, ['default']) is not None:
185
- raise ValueError('default parameter is not supported in Google AI.')
186
+ raise ValueError('default parameter is not supported in Gemini API.')
186
187
 
187
188
  if getv(from_object, ['any_of']) is not None:
188
- raise ValueError('any_of parameter is not supported in Google AI.')
189
+ raise ValueError('any_of parameter is not supported in Gemini API.')
189
190
 
190
191
  if getv(from_object, ['max_length']) is not None:
191
- raise ValueError('max_length parameter is not supported in Google AI.')
192
+ raise ValueError('max_length parameter is not supported in Gemini API.')
192
193
 
193
194
  if getv(from_object, ['title']) is not None:
194
- raise ValueError('title parameter is not supported in Google AI.')
195
+ raise ValueError('title parameter is not supported in Gemini API.')
195
196
 
196
197
  if getv(from_object, ['min_length']) is not None:
197
- raise ValueError('min_length parameter is not supported in Google AI.')
198
+ raise ValueError('min_length parameter is not supported in Gemini API.')
198
199
 
199
200
  if getv(from_object, ['min_properties']) is not None:
200
- raise ValueError('min_properties parameter is not supported in Google AI.')
201
+ raise ValueError('min_properties parameter is not supported in Gemini API.')
201
202
 
202
203
  if getv(from_object, ['max_items']) is not None:
203
- raise ValueError('max_items parameter is not supported in Google AI.')
204
+ raise ValueError('max_items parameter is not supported in Gemini API.')
204
205
 
205
206
  if getv(from_object, ['maximum']) is not None:
206
- raise ValueError('maximum parameter is not supported in Google AI.')
207
+ raise ValueError('maximum parameter is not supported in Gemini API.')
207
208
 
208
209
  if getv(from_object, ['nullable']) is not None:
209
- raise ValueError('nullable parameter is not supported in Google AI.')
210
+ raise ValueError('nullable parameter is not supported in Gemini API.')
210
211
 
211
212
  if getv(from_object, ['max_properties']) is not None:
212
- raise ValueError('max_properties parameter is not supported in Google AI.')
213
+ raise ValueError('max_properties parameter is not supported in Gemini API.')
213
214
 
214
215
  if getv(from_object, ['type']) is not None:
215
216
  setv(to_object, ['type'], getv(from_object, ['type']))
@@ -321,7 +322,7 @@ def _FunctionDeclaration_to_mldev(
321
322
  ) -> dict:
322
323
  to_object = {}
323
324
  if getv(from_object, ['response']) is not None:
324
- raise ValueError('response parameter is not supported in Google AI.')
325
+ raise ValueError('response parameter is not supported in Gemini API.')
325
326
 
326
327
  if getv(from_object, ['description']) is not None:
327
328
  setv(to_object, ['description'], getv(from_object, ['description']))
@@ -477,7 +478,7 @@ def _Tool_to_mldev(
477
478
  )
478
479
 
479
480
  if getv(from_object, ['retrieval']) is not None:
480
- raise ValueError('retrieval parameter is not supported in Google AI.')
481
+ raise ValueError('retrieval parameter is not supported in Gemini API.')
481
482
 
482
483
  if getv(from_object, ['google_search']) is not None:
483
484
  setv(
@@ -634,8 +635,6 @@ def _CreateCachedContentConfig_to_mldev(
634
635
  parent_object: dict = None,
635
636
  ) -> dict:
636
637
  to_object = {}
637
- if getv(from_object, ['http_options']) is not None:
638
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
639
638
 
640
639
  if getv(from_object, ['ttl']) is not None:
641
640
  setv(parent_object, ['ttl'], getv(from_object, ['ttl']))
@@ -697,8 +696,6 @@ def _CreateCachedContentConfig_to_vertex(
697
696
  parent_object: dict = None,
698
697
  ) -> dict:
699
698
  to_object = {}
700
- if getv(from_object, ['http_options']) is not None:
701
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
702
699
 
703
700
  if getv(from_object, ['ttl']) is not None:
704
701
  setv(parent_object, ['ttl'], getv(from_object, ['ttl']))
@@ -804,30 +801,6 @@ def _CreateCachedContentParameters_to_vertex(
804
801
  return to_object
805
802
 
806
803
 
807
- def _GetCachedContentConfig_to_mldev(
808
- api_client: ApiClient,
809
- from_object: Union[dict, object],
810
- parent_object: dict = None,
811
- ) -> dict:
812
- to_object = {}
813
- if getv(from_object, ['http_options']) is not None:
814
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
815
-
816
- return to_object
817
-
818
-
819
- def _GetCachedContentConfig_to_vertex(
820
- api_client: ApiClient,
821
- from_object: Union[dict, object],
822
- parent_object: dict = None,
823
- ) -> dict:
824
- to_object = {}
825
- if getv(from_object, ['http_options']) is not None:
826
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
827
-
828
- return to_object
829
-
830
-
831
804
  def _GetCachedContentParameters_to_mldev(
832
805
  api_client: ApiClient,
833
806
  from_object: Union[dict, object],
@@ -842,13 +815,7 @@ def _GetCachedContentParameters_to_mldev(
842
815
  )
843
816
 
844
817
  if getv(from_object, ['config']) is not None:
845
- setv(
846
- to_object,
847
- ['config'],
848
- _GetCachedContentConfig_to_mldev(
849
- api_client, getv(from_object, ['config']), to_object
850
- ),
851
- )
818
+ setv(to_object, ['config'], getv(from_object, ['config']))
852
819
 
853
820
  return to_object
854
821
 
@@ -867,37 +834,7 @@ def _GetCachedContentParameters_to_vertex(
867
834
  )
868
835
 
869
836
  if getv(from_object, ['config']) is not None:
870
- setv(
871
- to_object,
872
- ['config'],
873
- _GetCachedContentConfig_to_vertex(
874
- api_client, getv(from_object, ['config']), to_object
875
- ),
876
- )
877
-
878
- return to_object
879
-
880
-
881
- def _DeleteCachedContentConfig_to_mldev(
882
- api_client: ApiClient,
883
- from_object: Union[dict, object],
884
- parent_object: dict = None,
885
- ) -> dict:
886
- to_object = {}
887
- if getv(from_object, ['http_options']) is not None:
888
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
889
-
890
- return to_object
891
-
892
-
893
- def _DeleteCachedContentConfig_to_vertex(
894
- api_client: ApiClient,
895
- from_object: Union[dict, object],
896
- parent_object: dict = None,
897
- ) -> dict:
898
- to_object = {}
899
- if getv(from_object, ['http_options']) is not None:
900
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
837
+ setv(to_object, ['config'], getv(from_object, ['config']))
901
838
 
902
839
  return to_object
903
840
 
@@ -916,13 +853,7 @@ def _DeleteCachedContentParameters_to_mldev(
916
853
  )
917
854
 
918
855
  if getv(from_object, ['config']) is not None:
919
- setv(
920
- to_object,
921
- ['config'],
922
- _DeleteCachedContentConfig_to_mldev(
923
- api_client, getv(from_object, ['config']), to_object
924
- ),
925
- )
856
+ setv(to_object, ['config'], getv(from_object, ['config']))
926
857
 
927
858
  return to_object
928
859
 
@@ -941,13 +872,7 @@ def _DeleteCachedContentParameters_to_vertex(
941
872
  )
942
873
 
943
874
  if getv(from_object, ['config']) is not None:
944
- setv(
945
- to_object,
946
- ['config'],
947
- _DeleteCachedContentConfig_to_vertex(
948
- api_client, getv(from_object, ['config']), to_object
949
- ),
950
- )
875
+ setv(to_object, ['config'], getv(from_object, ['config']))
951
876
 
952
877
  return to_object
953
878
 
@@ -958,8 +883,6 @@ def _UpdateCachedContentConfig_to_mldev(
958
883
  parent_object: dict = None,
959
884
  ) -> dict:
960
885
  to_object = {}
961
- if getv(from_object, ['http_options']) is not None:
962
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
963
886
 
964
887
  if getv(from_object, ['ttl']) is not None:
965
888
  setv(parent_object, ['ttl'], getv(from_object, ['ttl']))
@@ -976,8 +899,6 @@ def _UpdateCachedContentConfig_to_vertex(
976
899
  parent_object: dict = None,
977
900
  ) -> dict:
978
901
  to_object = {}
979
- if getv(from_object, ['http_options']) is not None:
980
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
981
902
 
982
903
  if getv(from_object, ['ttl']) is not None:
983
904
  setv(parent_object, ['ttl'], getv(from_object, ['ttl']))
@@ -1044,6 +965,7 @@ def _ListCachedContentsConfig_to_mldev(
1044
965
  parent_object: dict = None,
1045
966
  ) -> dict:
1046
967
  to_object = {}
968
+
1047
969
  if getv(from_object, ['page_size']) is not None:
1048
970
  setv(
1049
971
  parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
@@ -1065,6 +987,7 @@ def _ListCachedContentsConfig_to_vertex(
1065
987
  parent_object: dict = None,
1066
988
  ) -> dict:
1067
989
  to_object = {}
990
+
1068
991
  if getv(from_object, ['page_size']) is not None:
1069
992
  setv(
1070
993
  parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
@@ -1240,7 +1163,7 @@ def _ListCachedContentsResponse_from_vertex(
1240
1163
  return to_object
1241
1164
 
1242
1165
 
1243
- class Caches(_common.BaseModule):
1166
+ class Caches(_api_module.BaseModule):
1244
1167
 
1245
1168
  def create(
1246
1169
  self,
@@ -1288,8 +1211,14 @@ class Caches(_common.BaseModule):
1288
1211
  if query_params:
1289
1212
  path = f'{path}?{urlencode(query_params)}'
1290
1213
  # TODO: remove the hack that pops config.
1291
- config = request_dict.pop('config', None)
1292
- http_options = config.pop('httpOptions', None) if config else None
1214
+ request_dict.pop('config', None)
1215
+
1216
+ http_options = None
1217
+ if isinstance(config, dict):
1218
+ http_options = config.get('http_options', None)
1219
+ elif hasattr(config, 'http_options'):
1220
+ http_options = config.http_options
1221
+
1293
1222
  request_dict = _common.convert_to_dict(request_dict)
1294
1223
  request_dict = _common.encode_unserializable_types(request_dict)
1295
1224
 
@@ -1305,7 +1234,7 @@ class Caches(_common.BaseModule):
1305
1234
  response_dict = _CachedContent_from_mldev(self._api_client, response_dict)
1306
1235
 
1307
1236
  return_value = types.CachedContent._from_response(
1308
- response_dict, parameter_model
1237
+ response=response_dict, kwargs=parameter_model
1309
1238
  )
1310
1239
  self._api_client._verify_response(return_value)
1311
1240
  return return_value
@@ -1343,8 +1272,14 @@ class Caches(_common.BaseModule):
1343
1272
  if query_params:
1344
1273
  path = f'{path}?{urlencode(query_params)}'
1345
1274
  # TODO: remove the hack that pops config.
1346
- config = request_dict.pop('config', None)
1347
- http_options = config.pop('httpOptions', None) if config else None
1275
+ request_dict.pop('config', None)
1276
+
1277
+ http_options = None
1278
+ if isinstance(config, dict):
1279
+ http_options = config.get('http_options', None)
1280
+ elif hasattr(config, 'http_options'):
1281
+ http_options = config.http_options
1282
+
1348
1283
  request_dict = _common.convert_to_dict(request_dict)
1349
1284
  request_dict = _common.encode_unserializable_types(request_dict)
1350
1285
 
@@ -1360,7 +1295,7 @@ class Caches(_common.BaseModule):
1360
1295
  response_dict = _CachedContent_from_mldev(self._api_client, response_dict)
1361
1296
 
1362
1297
  return_value = types.CachedContent._from_response(
1363
- response_dict, parameter_model
1298
+ response=response_dict, kwargs=parameter_model
1364
1299
  )
1365
1300
  self._api_client._verify_response(return_value)
1366
1301
  return return_value
@@ -1400,8 +1335,14 @@ class Caches(_common.BaseModule):
1400
1335
  if query_params:
1401
1336
  path = f'{path}?{urlencode(query_params)}'
1402
1337
  # TODO: remove the hack that pops config.
1403
- config = request_dict.pop('config', None)
1404
- http_options = config.pop('httpOptions', None) if config else None
1338
+ request_dict.pop('config', None)
1339
+
1340
+ http_options = None
1341
+ if isinstance(config, dict):
1342
+ http_options = config.get('http_options', None)
1343
+ elif hasattr(config, 'http_options'):
1344
+ http_options = config.http_options
1345
+
1405
1346
  request_dict = _common.convert_to_dict(request_dict)
1406
1347
  request_dict = _common.encode_unserializable_types(request_dict)
1407
1348
 
@@ -1419,7 +1360,7 @@ class Caches(_common.BaseModule):
1419
1360
  )
1420
1361
 
1421
1362
  return_value = types.DeleteCachedContentResponse._from_response(
1422
- response_dict, parameter_model
1363
+ response=response_dict, kwargs=parameter_model
1423
1364
  )
1424
1365
  self._api_client._verify_response(return_value)
1425
1366
  return return_value
@@ -1461,8 +1402,14 @@ class Caches(_common.BaseModule):
1461
1402
  if query_params:
1462
1403
  path = f'{path}?{urlencode(query_params)}'
1463
1404
  # TODO: remove the hack that pops config.
1464
- config = request_dict.pop('config', None)
1465
- http_options = config.pop('httpOptions', None) if config else None
1405
+ request_dict.pop('config', None)
1406
+
1407
+ http_options = None
1408
+ if isinstance(config, dict):
1409
+ http_options = config.get('http_options', None)
1410
+ elif hasattr(config, 'http_options'):
1411
+ http_options = config.http_options
1412
+
1466
1413
  request_dict = _common.convert_to_dict(request_dict)
1467
1414
  request_dict = _common.encode_unserializable_types(request_dict)
1468
1415
 
@@ -1478,7 +1425,7 @@ class Caches(_common.BaseModule):
1478
1425
  response_dict = _CachedContent_from_mldev(self._api_client, response_dict)
1479
1426
 
1480
1427
  return_value = types.CachedContent._from_response(
1481
- response_dict, parameter_model
1428
+ response=response_dict, kwargs=parameter_model
1482
1429
  )
1483
1430
  self._api_client._verify_response(return_value)
1484
1431
  return return_value
@@ -1513,8 +1460,14 @@ class Caches(_common.BaseModule):
1513
1460
  if query_params:
1514
1461
  path = f'{path}?{urlencode(query_params)}'
1515
1462
  # TODO: remove the hack that pops config.
1516
- config = request_dict.pop('config', None)
1517
- http_options = config.pop('httpOptions', None) if config else None
1463
+ request_dict.pop('config', None)
1464
+
1465
+ http_options = None
1466
+ if isinstance(config, dict):
1467
+ http_options = config.get('http_options', None)
1468
+ elif hasattr(config, 'http_options'):
1469
+ http_options = config.http_options
1470
+
1518
1471
  request_dict = _common.convert_to_dict(request_dict)
1519
1472
  request_dict = _common.encode_unserializable_types(request_dict)
1520
1473
 
@@ -1532,7 +1485,7 @@ class Caches(_common.BaseModule):
1532
1485
  )
1533
1486
 
1534
1487
  return_value = types.ListCachedContentsResponse._from_response(
1535
- response_dict, parameter_model
1488
+ response=response_dict, kwargs=parameter_model
1536
1489
  )
1537
1490
  self._api_client._verify_response(return_value)
1538
1491
  return return_value
@@ -1548,7 +1501,7 @@ class Caches(_common.BaseModule):
1548
1501
  )
1549
1502
 
1550
1503
 
1551
- class AsyncCaches(_common.BaseModule):
1504
+ class AsyncCaches(_api_module.BaseModule):
1552
1505
 
1553
1506
  async def create(
1554
1507
  self,
@@ -1596,8 +1549,14 @@ class AsyncCaches(_common.BaseModule):
1596
1549
  if query_params:
1597
1550
  path = f'{path}?{urlencode(query_params)}'
1598
1551
  # TODO: remove the hack that pops config.
1599
- config = request_dict.pop('config', None)
1600
- http_options = config.pop('httpOptions', None) if config else None
1552
+ request_dict.pop('config', None)
1553
+
1554
+ http_options = None
1555
+ if isinstance(config, dict):
1556
+ http_options = config.get('http_options', None)
1557
+ elif hasattr(config, 'http_options'):
1558
+ http_options = config.http_options
1559
+
1601
1560
  request_dict = _common.convert_to_dict(request_dict)
1602
1561
  request_dict = _common.encode_unserializable_types(request_dict)
1603
1562
 
@@ -1613,7 +1572,7 @@ class AsyncCaches(_common.BaseModule):
1613
1572
  response_dict = _CachedContent_from_mldev(self._api_client, response_dict)
1614
1573
 
1615
1574
  return_value = types.CachedContent._from_response(
1616
- response_dict, parameter_model
1575
+ response=response_dict, kwargs=parameter_model
1617
1576
  )
1618
1577
  self._api_client._verify_response(return_value)
1619
1578
  return return_value
@@ -1651,8 +1610,14 @@ class AsyncCaches(_common.BaseModule):
1651
1610
  if query_params:
1652
1611
  path = f'{path}?{urlencode(query_params)}'
1653
1612
  # TODO: remove the hack that pops config.
1654
- config = request_dict.pop('config', None)
1655
- http_options = config.pop('httpOptions', None) if config else None
1613
+ request_dict.pop('config', None)
1614
+
1615
+ http_options = None
1616
+ if isinstance(config, dict):
1617
+ http_options = config.get('http_options', None)
1618
+ elif hasattr(config, 'http_options'):
1619
+ http_options = config.http_options
1620
+
1656
1621
  request_dict = _common.convert_to_dict(request_dict)
1657
1622
  request_dict = _common.encode_unserializable_types(request_dict)
1658
1623
 
@@ -1668,7 +1633,7 @@ class AsyncCaches(_common.BaseModule):
1668
1633
  response_dict = _CachedContent_from_mldev(self._api_client, response_dict)
1669
1634
 
1670
1635
  return_value = types.CachedContent._from_response(
1671
- response_dict, parameter_model
1636
+ response=response_dict, kwargs=parameter_model
1672
1637
  )
1673
1638
  self._api_client._verify_response(return_value)
1674
1639
  return return_value
@@ -1708,8 +1673,14 @@ class AsyncCaches(_common.BaseModule):
1708
1673
  if query_params:
1709
1674
  path = f'{path}?{urlencode(query_params)}'
1710
1675
  # TODO: remove the hack that pops config.
1711
- config = request_dict.pop('config', None)
1712
- http_options = config.pop('httpOptions', None) if config else None
1676
+ request_dict.pop('config', None)
1677
+
1678
+ http_options = None
1679
+ if isinstance(config, dict):
1680
+ http_options = config.get('http_options', None)
1681
+ elif hasattr(config, 'http_options'):
1682
+ http_options = config.http_options
1683
+
1713
1684
  request_dict = _common.convert_to_dict(request_dict)
1714
1685
  request_dict = _common.encode_unserializable_types(request_dict)
1715
1686
 
@@ -1727,7 +1698,7 @@ class AsyncCaches(_common.BaseModule):
1727
1698
  )
1728
1699
 
1729
1700
  return_value = types.DeleteCachedContentResponse._from_response(
1730
- response_dict, parameter_model
1701
+ response=response_dict, kwargs=parameter_model
1731
1702
  )
1732
1703
  self._api_client._verify_response(return_value)
1733
1704
  return return_value
@@ -1769,8 +1740,14 @@ class AsyncCaches(_common.BaseModule):
1769
1740
  if query_params:
1770
1741
  path = f'{path}?{urlencode(query_params)}'
1771
1742
  # TODO: remove the hack that pops config.
1772
- config = request_dict.pop('config', None)
1773
- http_options = config.pop('httpOptions', None) if config else None
1743
+ request_dict.pop('config', None)
1744
+
1745
+ http_options = None
1746
+ if isinstance(config, dict):
1747
+ http_options = config.get('http_options', None)
1748
+ elif hasattr(config, 'http_options'):
1749
+ http_options = config.http_options
1750
+
1774
1751
  request_dict = _common.convert_to_dict(request_dict)
1775
1752
  request_dict = _common.encode_unserializable_types(request_dict)
1776
1753
 
@@ -1786,7 +1763,7 @@ class AsyncCaches(_common.BaseModule):
1786
1763
  response_dict = _CachedContent_from_mldev(self._api_client, response_dict)
1787
1764
 
1788
1765
  return_value = types.CachedContent._from_response(
1789
- response_dict, parameter_model
1766
+ response=response_dict, kwargs=parameter_model
1790
1767
  )
1791
1768
  self._api_client._verify_response(return_value)
1792
1769
  return return_value
@@ -1821,8 +1798,14 @@ class AsyncCaches(_common.BaseModule):
1821
1798
  if query_params:
1822
1799
  path = f'{path}?{urlencode(query_params)}'
1823
1800
  # TODO: remove the hack that pops config.
1824
- config = request_dict.pop('config', None)
1825
- http_options = config.pop('httpOptions', None) if config else None
1801
+ request_dict.pop('config', None)
1802
+
1803
+ http_options = None
1804
+ if isinstance(config, dict):
1805
+ http_options = config.get('http_options', None)
1806
+ elif hasattr(config, 'http_options'):
1807
+ http_options = config.http_options
1808
+
1826
1809
  request_dict = _common.convert_to_dict(request_dict)
1827
1810
  request_dict = _common.encode_unserializable_types(request_dict)
1828
1811
 
@@ -1840,7 +1823,7 @@ class AsyncCaches(_common.BaseModule):
1840
1823
  )
1841
1824
 
1842
1825
  return_value = types.ListCachedContentsResponse._from_response(
1843
- response_dict, parameter_model
1826
+ response=response_dict, kwargs=parameter_model
1844
1827
  )
1845
1828
  self._api_client._verify_response(return_value)
1846
1829
  return return_value
google/genai/chats.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- from typing import Optional
16
+ from typing import AsyncIterator, Awaitable, Optional
17
17
  from typing import Union
18
18
 
19
19
  from . import _transformers as t
@@ -200,7 +200,7 @@ class AsyncChat(_BaseChat):
200
200
 
201
201
  async def send_message_stream(
202
202
  self, message: Union[list[PartUnionDict], PartUnionDict]
203
- ):
203
+ ) -> Awaitable[AsyncIterator[GenerateContentResponse]]:
204
204
  """Sends the conversation history with the additional message and yields the model's response in chunks.
205
205
 
206
206
  Args:
@@ -213,26 +213,30 @@ class AsyncChat(_BaseChat):
213
213
 
214
214
  .. code-block:: python
215
215
  chat = client.aio.chats.create(model='gemini-1.5-flash')
216
- async for chunk in chat.send_message_stream('tell me a story'):
216
+ async for chunk in await chat.send_message_stream('tell me a story'):
217
217
  print(chunk.text)
218
218
  """
219
219
 
220
220
  input_content = t.t_content(self._modules._api_client, message)
221
- output_contents = []
222
- finish_reason = None
223
- async for chunk in self._modules.generate_content_stream(
224
- model=self._model,
225
- contents=self._curated_history + [input_content],
226
- config=self._config,
227
- ):
228
- if _validate_response(chunk):
229
- output_contents.append(chunk.candidates[0].content)
230
- if chunk.candidates and chunk.candidates[0].finish_reason:
231
- finish_reason = chunk.candidates[0].finish_reason
232
- yield chunk
233
- if output_contents and finish_reason:
234
- self._curated_history.append(input_content)
235
- self._curated_history.extend(output_contents)
221
+
222
+ async def async_generator():
223
+ output_contents = []
224
+ finish_reason = None
225
+ async for chunk in await self._modules.generate_content_stream(
226
+ model=self._model,
227
+ contents=self._curated_history + [input_content],
228
+ config=self._config,
229
+ ):
230
+ if _validate_response(chunk):
231
+ output_contents.append(chunk.candidates[0].content)
232
+ if chunk.candidates and chunk.candidates[0].finish_reason:
233
+ finish_reason = chunk.candidates[0].finish_reason
234
+ yield chunk
235
+
236
+ if output_contents and finish_reason:
237
+ self._curated_history.append(input_content)
238
+ self._curated_history.extend(output_contents)
239
+ return async_generator()
236
240
 
237
241
 
238
242
  class AsyncChats: