google-genai 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/__init__.py CHANGED
@@ -17,4 +17,6 @@
17
17
 
18
18
  from .client import Client
19
19
 
20
+ __version__ = '0.1.0'
21
+
20
22
  __all__ = ['Client']
@@ -51,7 +51,7 @@ class HttpOptions(TypedDict):
51
51
  def _append_library_version_headers(headers: dict[str, str]) -> None:
52
52
  """Appends the telemetry header to the headers dict."""
53
53
  # TODO: Automate revisions to the SDK library version.
54
- library_label = 'google-genai-sdk/0.1.0'
54
+ library_label = f'google-genai-sdk/0.1.0'
55
55
  language_label = 'gl-python/' + sys.version.split()[0]
56
56
  version_header_value = f'{library_label} {language_label}'
57
57
  if (
@@ -295,47 +295,3 @@ def _get_required_fields(schema: types.Schema) -> list[str]:
295
295
  if not field_schema.nullable and field_schema.default is None
296
296
  ]
297
297
 
298
-
299
- def function_to_declaration(
300
- client, func: Callable
301
- ) -> types.FunctionDeclaration:
302
- """Converts a function to a FunctionDeclaration."""
303
- parameters_properties = {}
304
- for name, param in inspect.signature(func).parameters.items():
305
- if param.kind in (
306
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
307
- inspect.Parameter.KEYWORD_ONLY,
308
- inspect.Parameter.POSITIONAL_ONLY,
309
- ):
310
- schema = _parse_schema_from_parameter(client, param, func.__name__)
311
- parameters_properties[name] = schema
312
- declaration = types.FunctionDeclaration(
313
- name=func.__name__,
314
- description=func.__doc__,
315
- )
316
- if parameters_properties:
317
- declaration.parameters = types.Schema(
318
- type='OBJECT',
319
- properties=parameters_properties,
320
- )
321
- if client.vertexai:
322
- declaration.parameters.required = _get_required_fields(
323
- declaration.parameters
324
- )
325
- if not client.vertexai:
326
- return declaration
327
-
328
- return_annotation = inspect.signature(func).return_annotation
329
- if return_annotation is inspect._empty:
330
- return declaration
331
-
332
- declaration.response = _parse_schema_from_parameter(
333
- client,
334
- inspect.Parameter(
335
- 'return_value',
336
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
337
- annotation=return_annotation,
338
- ),
339
- func.__name__,
340
- )
341
- return declaration
@@ -293,3 +293,18 @@ def get_max_remote_calls_afc(
293
293
  ):
294
294
  return _DEFAULT_MAX_REMOTE_CALLS_AFC
295
295
  return int(config_model.automatic_function_calling.maximum_remote_calls)
296
+
297
+ def should_append_afc_history(
298
+ config: Optional[types.GenerateContentConfigOrDict] = None,
299
+ ) -> bool:
300
+ config_model = (
301
+ types.GenerateContentConfig(**config)
302
+ if config and isinstance(config, dict)
303
+ else config
304
+ )
305
+ if (
306
+ not config_model
307
+ or not config_model.automatic_function_calling
308
+ ):
309
+ return True
310
+ return not config_model.automatic_function_calling.ignore_call_history
@@ -27,7 +27,6 @@ import PIL.Image
27
27
 
28
28
  from . import _api_client
29
29
  from . import types
30
- from ._automatic_function_calling_util import function_to_declaration
31
30
 
32
31
 
33
32
  def _resource_name(
@@ -307,7 +306,9 @@ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
307
306
  return None
308
307
  if inspect.isfunction(origin):
309
308
  return types.Tool(
310
- function_declarations=[function_to_declaration(client, origin)]
309
+ function_declarations=[
310
+ types.FunctionDeclaration.from_function(client, origin)
311
+ ]
311
312
  )
312
313
  else:
313
314
  return origin
google/genai/batches.py CHANGED
@@ -661,6 +661,24 @@ class Batches(_common.BaseModule):
661
661
  return return_value
662
662
 
663
663
  def get(self, *, name: str) -> types.BatchJob:
664
+ """Gets a batch job.
665
+
666
+ Args:
667
+ name (str): A fully-qualified BatchJob resource name or ID.
668
+ Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
669
+ when project and location are initialized in the client.
670
+
671
+ Returns:
672
+ A BatchJob object that contains details about the batch job.
673
+
674
+ Usage:
675
+
676
+ .. code-block:: python
677
+
678
+ batch_job = client.batches.get(name='123456789')
679
+ print(f"Batch job: {batch_job.name}, state {batch_job.state}")
680
+ """
681
+
664
682
  parameter_model = types._GetBatchJobParameters(
665
683
  name=name,
666
684
  )
@@ -767,6 +785,23 @@ class Batches(_common.BaseModule):
767
785
  return return_value
768
786
 
769
787
  def delete(self, *, name: str) -> types.DeleteResourceJob:
788
+ """Deletes a batch job.
789
+
790
+ Args:
791
+ name (str): A fully-qualified BatchJob resource name or ID.
792
+ Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
793
+ when project and location are initialized in the client.
794
+
795
+ Returns:
796
+ A DeleteResourceJob object that shows the status of the deletion.
797
+
798
+ Usage:
799
+
800
+ .. code-block:: python
801
+
802
+ client.batches.delete(name='123456789')
803
+ """
804
+
770
805
  parameter_model = types._DeleteBatchJobParameters(
771
806
  name=name,
772
807
  )
@@ -814,12 +849,51 @@ class Batches(_common.BaseModule):
814
849
  src: str,
815
850
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
816
851
  ) -> types.BatchJob:
852
+ """Creates a batch job.
853
+
854
+ Args:
855
+ model (str): The model to use for the batch job.
856
+ src (str): The source of the batch job. Currently supports GCS URI(-s) or
857
+ Bigquery URI. Example: "gs://path/to/input/data" or
858
+ "bq://projectId.bqDatasetId.bqTableId".
859
+ config (CreateBatchJobConfig): Optional configuration for the batch job.
860
+
861
+ Returns:
862
+ A BatchJob object that contains details about the batch job.
863
+
864
+ Usage:
865
+
866
+ .. code-block:: python
867
+
868
+ batch_job = client.batches.create(
869
+ model="gemini-1.5-flash",
870
+ src="gs://path/to/input/data",
871
+ )
872
+ print(batch_job.state)
873
+ """
817
874
  config = _extra_utils.format_destination(src, config)
818
875
  return self._create(model=model, src=src, config=config)
819
876
 
820
877
  def list(
821
878
  self, *, config: Optional[types.ListBatchJobConfigOrDict] = None
822
879
  ) -> Pager[types.BatchJob]:
880
+ """Lists batch jobs.
881
+
882
+ Args:
883
+ config (ListBatchJobConfig): Optional configuration for the list request.
884
+
885
+ Returns:
886
+ A Pager object that contains one page of batch jobs. When iterating over
887
+ the pager, it automatically fetches the next page if there are more.
888
+
889
+ Usage:
890
+
891
+ .. code-block:: python
892
+
893
+ batch_jobs = client.batches.list(config={"page_size": 10})
894
+ for batch_job in batch_jobs:
895
+ print(f"Batch job: {batch_job.name}, state {batch_job.state}")
896
+ """
823
897
  return Pager(
824
898
  'batch_jobs',
825
899
  self._list,
@@ -874,6 +948,24 @@ class AsyncBatches(_common.BaseModule):
874
948
  return return_value
875
949
 
876
950
  async def get(self, *, name: str) -> types.BatchJob:
951
+ """Gets a batch job.
952
+
953
+ Args:
954
+ name (str): A fully-qualified BatchJob resource name or ID.
955
+ Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
956
+ when project and location are initialized in the client.
957
+
958
+ Returns:
959
+ A BatchJob object that contains details about the batch job.
960
+
961
+ Usage:
962
+
963
+ .. code-block:: python
964
+
965
+ batch_job = client.batches.get(name='123456789')
966
+ print(f"Batch job: {batch_job.name}, state {batch_job.state}")
967
+ """
968
+
877
969
  parameter_model = types._GetBatchJobParameters(
878
970
  name=name,
879
971
  )
@@ -980,6 +1072,23 @@ class AsyncBatches(_common.BaseModule):
980
1072
  return return_value
981
1073
 
982
1074
  async def delete(self, *, name: str) -> types.DeleteResourceJob:
1075
+ """Deletes a batch job.
1076
+
1077
+ Args:
1078
+ name (str): A fully-qualified BatchJob resource name or ID.
1079
+ Example: "projects/.../locations/.../batchPredictionJobs/456" or "456"
1080
+ when project and location are initialized in the client.
1081
+
1082
+ Returns:
1083
+ A DeleteResourceJob object that shows the status of the deletion.
1084
+
1085
+ Usage:
1086
+
1087
+ .. code-block:: python
1088
+
1089
+ client.batches.delete(name='123456789')
1090
+ """
1091
+
983
1092
  parameter_model = types._DeleteBatchJobParameters(
984
1093
  name=name,
985
1094
  )
@@ -1027,12 +1136,51 @@ class AsyncBatches(_common.BaseModule):
1027
1136
  src: str,
1028
1137
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
1029
1138
  ) -> types.BatchJob:
1139
+ """Creates a batch job asynchronously.
1140
+
1141
+ Args:
1142
+ model (str): The model to use for the batch job.
1143
+ src (str): The source of the batch job. Currently supports GCS URI(-s) or
1144
+ Bigquery URI. Example: "gs://path/to/input/data" or
1145
+ "bq://projectId.bqDatasetId.bqTableId".
1146
+ config (CreateBatchJobConfig): Optional configuration for the batch job.
1147
+
1148
+ Returns:
1149
+ A BatchJob object that contains details about the batch job.
1150
+
1151
+ Usage:
1152
+
1153
+ .. code-block:: python
1154
+
1155
+ batch_job = await client.aio.batches.create(
1156
+ model="gemini-1.5-flash",
1157
+ src="gs://path/to/input/data",
1158
+ )
1159
+ """
1030
1160
  config = _extra_utils.format_destination(src, config)
1031
1161
  return await self._create(model=model, src=src, config=config)
1032
1162
 
1033
1163
  async def list(
1034
1164
  self, *, config: Optional[types.ListBatchJobConfigOrDict] = None
1035
1165
  ) -> AsyncPager[types.BatchJob]:
1166
+ """Lists batch jobs asynchronously.
1167
+
1168
+ Args:
1169
+ config (ListBatchJobConfig): Optional configuration for the list request.
1170
+
1171
+ Returns:
1172
+ A Pager object that contains one page of batch jobs. When iterating over
1173
+ the pager, it automatically fetches the next page if there are more.
1174
+
1175
+ Usage:
1176
+
1177
+ .. code-block:: python
1178
+
1179
+ batch_jobs = await client.aio.batches.list(config={'page_size': 5})
1180
+ print(f"current page: {batch_jobs.page}")
1181
+ await batch_jobs_pager.next_page()
1182
+ print(f"next page: {batch_jobs_pager.page}")
1183
+ """
1036
1184
  return AsyncPager(
1037
1185
  'batch_jobs',
1038
1186
  self._list,
google/genai/caches.py CHANGED
@@ -1249,6 +1249,7 @@ class Caches(_common.BaseModule):
1249
1249
  Usage:
1250
1250
 
1251
1251
  .. code-block:: python
1252
+
1252
1253
  contents = ... // Initialize the content to cache.
1253
1254
  response = await client.aio.caches.create(
1254
1255
  model= ... // The publisher model id
@@ -1310,6 +1311,7 @@ class Caches(_common.BaseModule):
1310
1311
  """Gets cached content configurations.
1311
1312
 
1312
1313
  .. code-block:: python
1314
+
1313
1315
  await client.aio.caches.get(name= ... ) // The server-generated resource
1314
1316
  name.
1315
1317
  """
@@ -1364,6 +1366,7 @@ class Caches(_common.BaseModule):
1364
1366
  Usage:
1365
1367
 
1366
1368
  .. code-block:: python
1369
+
1367
1370
  await client.aio.caches.delete(name= ... ) // The server-generated
1368
1371
  resource name.
1369
1372
  """
@@ -1420,6 +1423,7 @@ class Caches(_common.BaseModule):
1420
1423
  """Updates cached content configurations.
1421
1424
 
1422
1425
  .. code-block:: python
1426
+
1423
1427
  response = await client.aio.caches.update(
1424
1428
  name= ... // The server-generated resource name.
1425
1429
  config={
@@ -1473,6 +1477,7 @@ class Caches(_common.BaseModule):
1473
1477
  """Lists cached content configurations.
1474
1478
 
1475
1479
  .. code-block:: python
1480
+
1476
1481
  cached_contents = await client.aio.caches.list(config={'page_size': 2})
1477
1482
  async for cached_content in cached_contents:
1478
1483
  print(cached_content)
@@ -1548,6 +1553,7 @@ class AsyncCaches(_common.BaseModule):
1548
1553
  Usage:
1549
1554
 
1550
1555
  .. code-block:: python
1556
+
1551
1557
  contents = ... // Initialize the content to cache.
1552
1558
  response = await client.aio.caches.create(
1553
1559
  model= ... // The publisher model id
@@ -1609,6 +1615,7 @@ class AsyncCaches(_common.BaseModule):
1609
1615
  """Gets cached content configurations.
1610
1616
 
1611
1617
  .. code-block:: python
1618
+
1612
1619
  await client.aio.caches.get(name= ... ) // The server-generated resource
1613
1620
  name.
1614
1621
  """
@@ -1663,6 +1670,7 @@ class AsyncCaches(_common.BaseModule):
1663
1670
  Usage:
1664
1671
 
1665
1672
  .. code-block:: python
1673
+
1666
1674
  await client.aio.caches.delete(name= ... ) // The server-generated
1667
1675
  resource name.
1668
1676
  """
@@ -1719,6 +1727,7 @@ class AsyncCaches(_common.BaseModule):
1719
1727
  """Updates cached content configurations.
1720
1728
 
1721
1729
  .. code-block:: python
1730
+
1722
1731
  response = await client.aio.caches.update(
1723
1732
  name= ... // The server-generated resource name.
1724
1733
  config={
@@ -1772,6 +1781,7 @@ class AsyncCaches(_common.BaseModule):
1772
1781
  """Lists cached content configurations.
1773
1782
 
1774
1783
  .. code-block:: python
1784
+
1775
1785
  cached_contents = await client.aio.caches.list(config={'page_size': 2})
1776
1786
  async for cached_content in cached_contents:
1777
1787
  print(cached_content)
google/genai/chats.py CHANGED
@@ -65,7 +65,12 @@ class Chat(_BaseChat):
65
65
  config=self._config,
66
66
  )
67
67
  if response.candidates and response.candidates[0].content:
68
- self._curated_history.append(input_content)
68
+ if response.automatic_function_calling_history:
69
+ self._curated_history.extend(
70
+ response.automatic_function_calling_history
71
+ )
72
+ else:
73
+ self._curated_history.append(input_content)
69
74
  self._curated_history.append(response.candidates[0].content)
70
75
  return response
71
76
 
@@ -138,7 +143,12 @@ class AsyncChat(_BaseChat):
138
143
  config=self._config,
139
144
  )
140
145
  if response.candidates and response.candidates[0].content:
141
- self._curated_history.append(input_content)
146
+ if response.automatic_function_calling_history:
147
+ self._curated_history.extend(
148
+ response.automatic_function_calling_history
149
+ )
150
+ else:
151
+ self._curated_history.append(input_content)
142
152
  self._curated_history.append(response.candidates[0].content)
143
153
  return response
144
154
 
google/genai/live.py CHANGED
@@ -53,6 +53,12 @@ except ModuleNotFoundError:
53
53
  from websockets.client import connect
54
54
 
55
55
 
56
+ _FUNCTION_RESPONSE_REQUIRES_ID = (
57
+ 'FunctionResponse request must have an `id` field from the'
58
+ ' response of a ToolCall.FunctionalCalls in Google AI.'
59
+ )
60
+
61
+
56
62
  class AsyncSession:
57
63
  """AsyncSession."""
58
64
 
@@ -81,20 +87,23 @@ class AsyncSession:
81
87
  """Receive model responses from the server.
82
88
 
83
89
  The method will yield the model responses from the server. The returned
84
- responses will represent a complete model turn.
85
- when the returned message is fuction call, user must call `send` with the
86
- function response to continue the turn.
87
- Example usage:
88
- ```
89
- client = genai.Client(api_key=API_KEY)
90
-
91
- async with client.aio.live.connect(model='...') as session:
92
- await session.send(input='Hello world!', end_of_turn=True)
93
- async for message in session.receive():
94
- print(message)
95
- ```
90
+ responses will represent a complete model turn. When the returned message
91
+ is fuction call, user must call `send` with the function response to
92
+ continue the turn.
93
+
96
94
  Yields:
97
- The model responses from the server.
95
+ The model responses from the server.
96
+
97
+ Example usage:
98
+
99
+ .. code-block:: python
100
+
101
+ client = genai.Client(api_key=API_KEY)
102
+
103
+ async with client.aio.live.connect(model='...') as session:
104
+ await session.send(input='Hello world!', end_of_turn=True)
105
+ async for message in session.receive():
106
+ print(message)
98
107
  """
99
108
  # TODO(b/365983264) Handle intermittent issues for the user.
100
109
  while result := await self._receive():
@@ -113,28 +122,27 @@ class AsyncSession:
113
122
  input stream to the model and the other task will be used to receive the
114
123
  responses from the model.
115
124
 
116
- Example usage:
117
- ```
118
- client = genai.Client(api_key=API_KEY)
119
- config = {'response_modalities': ['AUDIO']}
120
-
121
- async def audio_stream():
122
- stream = read_audio()
123
- for data in stream:
124
- yield data
125
-
126
- async with client.aio.live.connect(model='...') as session:
127
- for audio in session.start_stream(stream = audio_stream(),
128
- mime_type = 'audio/pcm'):
129
- play_audio_chunk(audio.data)
130
- ```
131
-
132
125
  Args:
133
- stream: An iterator that yields the model response.
134
- mime_type: The MIME type of the data in the stream.
126
+ stream: An iterator that yields the model response.
127
+ mime_type: The MIME type of the data in the stream.
135
128
 
136
129
  Yields:
137
- The audio bytes received from the model and server response messages.
130
+ The audio bytes received from the model and server response messages.
131
+
132
+ Example usage:
133
+
134
+ .. code-block:: python
135
+
136
+ client = genai.Client(api_key=API_KEY)
137
+ config = {'response_modalities': ['AUDIO']}
138
+ async def audio_stream():
139
+ stream = read_audio()
140
+ for data in stream:
141
+ yield data
142
+ async with client.aio.live.connect(model='...') as session:
143
+ for audio in session.start_stream(stream = audio_stream(),
144
+ mime_type = 'audio/pcm'):
145
+ play_audio_chunk(audio.data)
138
146
  """
139
147
  stop_event = asyncio.Event()
140
148
  # Start the send loop. When stream is complete stop_event is set.
@@ -340,6 +348,8 @@ class AsyncSession:
340
348
  input = [input]
341
349
  elif isinstance(input, dict) and 'name' in input and 'response' in input:
342
350
  # ToolResponse.FunctionResponse
351
+ if not (self._api_client.vertexai) and 'id' not in input:
352
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
343
353
  input = [input]
344
354
 
345
355
  if isinstance(input, Sequence) and any(
@@ -395,7 +405,29 @@ class AsyncSession:
395
405
  client_message = {'client_content': input.model_dump(exclude_none=True)}
396
406
  elif isinstance(input, types.LiveClientToolResponse):
397
407
  # ToolResponse.FunctionResponse
408
+ if not (self._api_client.vertexai) and not (input.function_responses[0].id):
409
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
398
410
  client_message = {'tool_response': input.model_dump(exclude_none=True)}
411
+ elif isinstance(input, types.FunctionResponse):
412
+ if not (self._api_client.vertexai) and not (input.id):
413
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
414
+ client_message = {
415
+ 'tool_response': {
416
+ 'function_responses': [input.model_dump(exclude_none=True)]
417
+ }
418
+ }
419
+ elif isinstance(input, Sequence) and isinstance(
420
+ input[0], types.FunctionResponse
421
+ ):
422
+ if not (self._api_client.vertexai) and not (input[0].id):
423
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
424
+ client_message = {
425
+ 'tool_response': {
426
+ 'function_responses': [
427
+ c.model_dump(exclude_none=True) for c in input
428
+ ]
429
+ }
430
+ }
399
431
  else:
400
432
  raise ValueError(
401
433
  f'Unsupported input type "{type(input)}" or input content "{input}"'
@@ -571,16 +603,16 @@ class AsyncLive(_common.BaseModule):
571
603
  ) -> AsyncSession:
572
604
  """Connect to the live server.
573
605
 
574
- Example usage:
575
- ```
576
- client = genai.Client(api_key=API_KEY)
577
- config = {}
578
-
579
- async with client.aio.live.connect(model='gemini-1.0-pro-002', config=config) as session:
580
- await session.send(input='Hello world!', end_of_turn=True)
581
- async for message in session:
582
- print(message)
583
- ```
606
+ Usage:
607
+
608
+ .. code-block:: python
609
+
610
+ client = genai.Client(api_key=API_KEY)
611
+ config = {}
612
+ async with client.aio.live.connect(model='gemini-1.0-pro-002', config=config) as session:
613
+ await session.send(input='Hello world!', end_of_turn=True)
614
+ async for message in session:
615
+ print(message)
584
616
  """
585
617
  base_url = self.api_client._websocket_base_url()
586
618
  if self.api_client.api_key: