google-genai 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/chats.py ADDED
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ from typing import Optional
17
+ from typing import Union
18
+
19
+ from . import _transformers as t
20
+ from .models import AsyncModels, Models
21
+ from .types import Content, ContentDict, GenerateContentConfigOrDict, GenerateContentResponse, PartUnionDict
22
+
23
+ class _BaseChat:
24
+ """Base chat session."""
25
+
26
+ def __init__(
27
+ self,
28
+ *,
29
+ modules: Union[Models, AsyncModels],
30
+ model: str,
31
+ config: GenerateContentConfigOrDict = None,
32
+ history: list[Content],
33
+ ):
34
+ self._modules = modules
35
+ self._model = model
36
+ self._config = config
37
+ self._curated_history = history
38
+
39
+
40
+ class Chat(_BaseChat):
41
+ """Chat session."""
42
+
43
+ def send_message(
44
+ self, message: Union[list[PartUnionDict], PartUnionDict]
45
+ ) -> GenerateContentResponse:
46
+ """Sends the conversation history with the additional message and returns the model's response.
47
+
48
+ Args:
49
+ message: The message to send to the model.
50
+
51
+ Returns:
52
+ The model's response.
53
+
54
+ Usage:
55
+
56
+ .. code-block:: python
57
+ chat = client.chats.create(model='gemini-1.5-flash')
58
+ response = chat.send_message('tell me a story')
59
+ """
60
+
61
+ input_content = t.t_content(self._modules.api_client, message)
62
+ response = self._modules.generate_content(
63
+ model=self._model,
64
+ contents=self._curated_history + [input_content],
65
+ config=self._config,
66
+ )
67
+ if response.candidates and response.candidates[0].content:
68
+ self._curated_history.append(input_content)
69
+ self._curated_history.append(response.candidates[0].content)
70
+ return response
71
+
72
+ def _send_message_stream(self, message: Union[list[ContentDict], str]):
73
+ for content in t.t_contents(self._modules.api_client, message):
74
+ self._curated_history.append(content)
75
+ for chunk in self._modules.generate_content_stream(
76
+ model=self._model, contents=self._curated_history, config=self._config
77
+ ):
78
+ # TODO(b/381089069): add successful response to history
79
+ yield chunk
80
+
81
+
82
+ class Chats:
83
+ """A util class to create chat sessions."""
84
+
85
+ def __init__(self, modules: Models):
86
+ self._modules = modules
87
+
88
+ def create(
89
+ self,
90
+ *,
91
+ model: str,
92
+ config: GenerateContentConfigOrDict = None,
93
+ history: Optional[list[Content]] = None,
94
+ ) -> Chat:
95
+ """Creates a new chat session.
96
+
97
+ Args:
98
+ model: The model to use for the chat.
99
+ config: The configuration to use for the generate content request.
100
+ history: The history to use for the chat.
101
+
102
+ Returns:
103
+ A new chat session.
104
+ """
105
+ return Chat(
106
+ modules=self._modules,
107
+ model=model,
108
+ config=config,
109
+ history=history if history else [],
110
+ )
111
+
112
+
113
+ class AsyncChat(_BaseChat):
114
+ """Async chat session."""
115
+
116
+ async def send_message(
117
+ self, message: Union[list[PartUnionDict], PartUnionDict]
118
+ ) -> GenerateContentResponse:
119
+ """Sends the conversation history with the additional message and returns model's response.
120
+
121
+ Args:
122
+ message: The message to send to the model.
123
+
124
+ Returns:
125
+ The model's response.
126
+
127
+ Usage:
128
+
129
+ .. code-block:: python
130
+ chat = client.chats.create(model='gemini-1.5-flash')
131
+ response = chat.send_message('tell me a story')
132
+ """
133
+
134
+ input_content = t.t_content(self._modules.api_client, message)
135
+ response = await self._modules.generate_content(
136
+ model=self._model,
137
+ contents=self._curated_history + [input_content],
138
+ config=self._config,
139
+ )
140
+ if response.candidates and response.candidates[0].content:
141
+ self._curated_history.append(input_content)
142
+ self._curated_history.append(response.candidates[0].content)
143
+ return response
144
+
145
+ async def _send_message_stream(self, message: Union[list[ContentDict], str]):
146
+ for content in t.t_contents(self._modules.api_client, message):
147
+ self._curated_history.append(content)
148
+ async for chunk in self._modules.generate_content_stream(
149
+ model=self._model, contents=self._curated_history, config=self._config
150
+ ):
151
+ # TODO(b/381089069): add successful response to history
152
+ yield chunk
153
+
154
+
155
+ class AsyncChats:
156
+ """A util class to create async chat sessions."""
157
+
158
+
159
+ def __init__(self, modules: AsyncModels):
160
+ self._modules = modules
161
+
162
+ def create(
163
+ self,
164
+ *,
165
+ model: str,
166
+ config: GenerateContentConfigOrDict = None,
167
+ history: Optional[list[Content]] = None,
168
+ ) -> AsyncChat:
169
+ """Creates a new chat session.
170
+
171
+ Args:
172
+ model: The model to use for the chat.
173
+ config: The configuration to use for the generate content request.
174
+ history: The history to use for the chat.
175
+
176
+ Returns:
177
+ A new chat session.
178
+ """
179
+ return AsyncChat(
180
+ modules=self._modules,
181
+ model=model,
182
+ config=config,
183
+ history=history if history else [],
184
+ )
google/genai/client.py ADDED
@@ -0,0 +1,277 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ import os
17
+ from typing import Optional
18
+
19
+ import google.auth
20
+ import pydantic
21
+
22
+ from ._api_client import ApiClient, HttpOptions
23
+ from ._replay_api_client import ReplayApiClient
24
+ from .batches import AsyncBatches, Batches
25
+ from .caches import AsyncCaches, Caches
26
+ from .chats import AsyncChats, Chats
27
+ from .files import AsyncFiles, Files
28
+ from .live import AsyncLive
29
+ from .models import AsyncModels, Models
30
+ from .tunings import AsyncTunings, Tunings
31
+
32
+
33
+ class AsyncClient:
34
+ """Client for making asynchronous (non-blocking) requests."""
35
+
36
+ def __init__(self, api_client: ApiClient):
37
+
38
+ self._api_client = api_client
39
+ self._models = AsyncModels(self._api_client)
40
+ self._tunings = AsyncTunings(self._api_client)
41
+ self._caches = AsyncCaches(self._api_client)
42
+ self._batches = AsyncBatches(self._api_client)
43
+ self._files = AsyncFiles(self._api_client)
44
+ self._live = AsyncLive(self._api_client)
45
+
46
+ @property
47
+ def models(self) -> AsyncModels:
48
+ return self._models
49
+
50
+ @property
51
+ def tunings(self) -> AsyncTunings:
52
+ return self._tunings
53
+
54
+ @property
55
+ def caches(self) -> AsyncCaches:
56
+ return self._caches
57
+
58
+ @property
59
+ def batches(self) -> AsyncBatches:
60
+ return self._batches
61
+
62
+ @property
63
+ def chats(self) -> AsyncChats:
64
+ return AsyncChats(modules=self.models)
65
+
66
+ @property
67
+ def files(self) -> AsyncFiles:
68
+ return self._files
69
+
70
+ @property
71
+ def live(self) -> AsyncLive:
72
+ return self._live
73
+
74
+
75
+ class DebugConfig(pydantic.BaseModel):
76
+ """Configuration options that change client network behavior when testing."""
77
+
78
+ client_mode: Optional[str] = pydantic.Field(
79
+ default_factory=lambda: os.getenv('GOOGLE_GENAI_CLIENT_MODE', None)
80
+ )
81
+
82
+ replays_directory: Optional[str] = pydantic.Field(
83
+ default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAYS_DIRECTORY', None)
84
+ )
85
+
86
+ replay_id: Optional[str] = pydantic.Field(
87
+ default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAY_ID', None)
88
+ )
89
+
90
+
91
+ class Client:
92
+ """Client for making synchronous requests.
93
+
94
+ Use this client to make a request to the Gemini Developer API or Vertex AI
95
+ API and then wait for the response.
96
+
97
+ Attributes:
98
+ api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
99
+ use for authentication. Applies to the Gemini Developer API only.
100
+ vertexai: Indicates whether the client should use the Vertex AI
101
+ API endpoints. Defaults to False (uses Gemini Developer API endpoints).
102
+ Applies to the Vertex AI API only.
103
+ credentials: The credentials to use for authentication when calling the
104
+ Vertex AI APIs. Credentials can be obtained from environment variables and
105
+ default credentials. For more information, see
106
+ `Set up Application Default Credentials
107
+ <https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
108
+ Applies to the Vertex AI API only.
109
+ project: The `Google Cloud project ID <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
110
+ use for quota. Can be obtained from environment variables (for example,
111
+ ``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
112
+ location: The `location <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
113
+ to send API requests to (for example, ``us-central1``). Can be obtained
114
+ from environment variables. Applies to the Vertex AI API only.
115
+ debug_config: Config settings that control network behavior of the client.
116
+ This is typically used when running test code.
117
+ http_options: Http options to use for the client. Response_payload can't be
118
+ set when passing to the client constructor.
119
+
120
+ Usage for the Gemini Developer API:
121
+
122
+ .. code-block:: python
123
+
124
+ from google import genai
125
+
126
+ client = genai.Client(api_key='my-api-key')
127
+
128
+ Usage for the Vertex AI API:
129
+
130
+ .. code-block:: python
131
+
132
+ from google import genai
133
+
134
+ client = genai.Client(
135
+ vertexai=True, project='my-project-id', location='us-central1'
136
+ )
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ *,
142
+ vertexai: Optional[bool] = None,
143
+ api_key: Optional[str] = None,
144
+ credentials: Optional[google.auth.credentials.Credentials] = None,
145
+ project: Optional[str] = None,
146
+ location: Optional[str] = None,
147
+ debug_config: Optional[DebugConfig] = None,
148
+ http_options: Optional[HttpOptions] = None,
149
+ ):
150
+ """Initializes the client.
151
+
152
+ Args:
153
+ vertexai (bool):
154
+ Indicates whether the client should use the Vertex AI
155
+ API endpoints. Defaults to False (uses Gemini Developer API
156
+ endpoints). Applies to the Vertex AI API only.
157
+ api_key (str):
158
+ The `API key
159
+ <https://ai.google.dev/gemini-api/docs/api-key>`_ to use for
160
+ authentication. Applies to the Gemini Developer API only.
161
+ credentials (google.auth.credentials.Credentials):
162
+ The credentials to
163
+ use for authentication when calling the Vertex AI APIs. Credentials
164
+ can be obtained from environment variables and default credentials.
165
+ For more information, see `Set up Application Default Credentials
166
+ <https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
167
+ Applies to the Vertex AI API only.
168
+ project (str):
169
+ The `Google Cloud project ID
170
+ <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
171
+ use for quota. Can be obtained from environment variables (for
172
+ example, ``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
173
+ location (str):
174
+ The `location
175
+ <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
176
+ to send API requests to (for example, ``us-central1``). Can be
177
+ obtained from environment variables. Applies to the Vertex AI API
178
+ only.
179
+ debug_config (DebugConfig):
180
+ Config settings that control network
181
+ behavior of the client. This is typically used when running test code.
182
+ """
183
+
184
+ self._debug_config = debug_config or DebugConfig()
185
+
186
+ # Throw ValueError if response_payload is set in http_options due to
187
+ # unpredical behavior when running multiple coroutines through client.aio.
188
+ if http_options and 'response_payload' in http_options:
189
+ raise ValueError(
190
+ 'Setting response_payload in http_options is not supported.'
191
+ )
192
+
193
+ self._api_client = self._get_api_client(
194
+ vertexai=vertexai,
195
+ api_key=api_key,
196
+ credentials=credentials,
197
+ project=project,
198
+ location=location,
199
+ debug_config=self._debug_config,
200
+ http_options=http_options,
201
+ )
202
+
203
+ self._aio = AsyncClient(self._api_client)
204
+ self._models = Models(self._api_client)
205
+ self._tunings = Tunings(self._api_client)
206
+ self._caches = Caches(self._api_client)
207
+ self._batches = Batches(self._api_client)
208
+ self._files = Files(self._api_client)
209
+
210
+ @staticmethod
211
+ def _get_api_client(
212
+ vertexai: Optional[bool] = None,
213
+ api_key: Optional[str] = None,
214
+ credentials: Optional[google.auth.credentials.Credentials] = None,
215
+ project: Optional[str] = None,
216
+ location: Optional[str] = None,
217
+ debug_config: Optional[DebugConfig] = None,
218
+ http_options: Optional[HttpOptions] = None,
219
+ ):
220
+ if debug_config and debug_config.client_mode in [
221
+ 'record',
222
+ 'replay',
223
+ 'auto',
224
+ ]:
225
+ return ReplayApiClient(
226
+ mode=debug_config.client_mode,
227
+ replay_id=debug_config.replay_id,
228
+ replays_directory=debug_config.replays_directory,
229
+ vertexai=vertexai,
230
+ api_key=api_key,
231
+ credentials=credentials,
232
+ project=project,
233
+ location=location,
234
+ http_options=http_options,
235
+ )
236
+
237
+ return ApiClient(
238
+ vertexai=vertexai,
239
+ api_key=api_key,
240
+ credentials=credentials,
241
+ project=project,
242
+ location=location,
243
+ http_options=http_options,
244
+ )
245
+
246
+ @property
247
+ def chats(self) -> Chats:
248
+ return Chats(modules=self.models)
249
+
250
+ @property
251
+ def aio(self) -> AsyncClient:
252
+ return self._aio
253
+
254
+ @property
255
+ def models(self) -> Models:
256
+ return self._models
257
+
258
+ @property
259
+ def tunings(self) -> Tunings:
260
+ return self._tunings
261
+
262
+ @property
263
+ def caches(self) -> Caches:
264
+ return self._caches
265
+
266
+ @property
267
+ def batches(self) -> Batches:
268
+ return self._batches
269
+
270
+ @property
271
+ def files(self) -> Files:
272
+ return self._files
273
+
274
+ @property
275
+ def vertexai(self) -> bool:
276
+ """Returns whether the client is using the Vertex AI API."""
277
+ return self._api_client.vertexai or False
google/genai/errors.py ADDED
@@ -0,0 +1,110 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Error classes for the GenAI SDK."""
17
+
18
+ from typing import Any, Optional, TYPE_CHECKING, Union
19
+
20
+ import requests
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from .replay_api_client import ReplayResponse
25
+
26
+
27
+ class APIError(Exception):
28
+ """General errors raised by the GenAI API."""
29
+ code: int
30
+ response: requests.Response
31
+
32
+ message: str = ''
33
+ status: str = 'UNKNOWN'
34
+ details: Optional[Any] = None
35
+
36
+ def __init__(
37
+ self, code: int, response: Union[requests.Response, 'ReplayResponse']
38
+ ):
39
+ self.code = code
40
+ self.response = response
41
+
42
+ if isinstance(response, requests.Response):
43
+ try:
44
+ raw_error = response.json().get('error', {})
45
+ except requests.exceptions.JSONDecodeError:
46
+ raw_error = {'message': response.text, 'status': response.reason}
47
+ else:
48
+ raw_error = response.body_segments[0].get('error', {})
49
+
50
+ self.message = raw_error.get('message', '')
51
+ self.status = raw_error.get('status', 'UNKNOWN')
52
+ self.details = raw_error.get('details', None)
53
+
54
+ super().__init__(f'{self.code} {self.status}. {self.message}')
55
+
56
+ def _to_replay_record(self):
57
+ """Returns a dictionary representation of the error for replay recording.
58
+
59
+ details is not included since it may expose internal information in the
60
+ replay file.
61
+ """
62
+ return {
63
+ 'error': {
64
+ 'code': self.code,
65
+ 'message': self.message,
66
+ 'status': self.status,
67
+ }
68
+ }
69
+
70
+ @classmethod
71
+ def raise_for_response(
72
+ cls, response: Union[requests.Response, 'ReplayResponse']
73
+ ):
74
+ """Raises an error with detailed error message if the response has an error status."""
75
+ if response.status_code == 200:
76
+ return
77
+
78
+ status_code = response.status_code
79
+ if 400 <= status_code < 500:
80
+ raise ClientError(status_code, response)
81
+ elif 500 <= status_code < 600:
82
+ raise ServerError(status_code, response)
83
+ else:
84
+ raise cls(status_code, response)
85
+
86
+
87
+ class ClientError(APIError):
88
+ """Client error raised by the GenAI API."""
89
+ pass
90
+
91
+
92
+ class ServerError(APIError):
93
+ """Server error raised by the GenAI API."""
94
+ pass
95
+
96
+
97
+ class UnkownFunctionCallArgumentError(ValueError):
98
+ """Raised when the function call argument cannot be converted to the parameter annotation."""
99
+
100
+ pass
101
+
102
+
103
+ class UnsupportedFunctionError(ValueError):
104
+ """Raised when the function is not supported."""
105
+
106
+
107
+ class FunctionInvocationError(ValueError):
108
+ """Raised when the function cannot be invoked with the given arguments."""
109
+
110
+ pass