google-genai 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,478 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Replay API client."""
17
+
18
+ import copy
19
+ import inspect
20
+ import json
21
+ import os
22
+ import re
23
+ import datetime
24
+ from typing import Any, Literal, Optional, Union
25
+
26
+ import google.auth
27
+ from pydantic import BaseModel
28
+ from requests.exceptions import HTTPError
29
+
30
+ from . import errors
31
+ from ._api_client import ApiClient
32
+ from ._api_client import HttpOptions
33
+ from ._api_client import HttpRequest
34
+ from ._api_client import HttpResponse
35
+ from ._api_client import RequestJsonEncoder
36
+
37
+ def _redact_version_numbers(version_string: str) -> str:
38
+ """Redacts version numbers in the form x.y.z from a string."""
39
+ return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
40
+
41
+
42
+ def _redact_language_label(language_label: str) -> str:
43
+ """Removed because replay requests are used for all languages."""
44
+ return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label)
45
+
46
+
47
+ def _redact_request_headers(headers):
48
+ """Redacts headers that should not be recorded."""
49
+ redacted_headers = {}
50
+ for header_name, header_value in headers.items():
51
+ if header_name.lower() == 'x-goog-api-key':
52
+ redacted_headers[header_name] = '{REDACTED}'
53
+ elif header_name.lower() == 'user-agent':
54
+ redacted_headers[header_name] = _redact_language_label(
55
+ _redact_version_numbers(header_value)
56
+ )
57
+ elif header_name.lower() == 'x-goog-api-client':
58
+ redacted_headers[header_name] = _redact_language_label(
59
+ _redact_version_numbers(header_value)
60
+ )
61
+ else:
62
+ redacted_headers[header_name] = header_value
63
+ return redacted_headers
64
+
65
+
66
+ def _redact_request_url(url: str) -> str:
67
+ # Redact all the url parts before the resource name, so the test can work
68
+ # against any project, location, version, or whether it's EasyGCP.
69
+ result = re.sub(
70
+ r'.*/projects/[^/]+/locations/[^/]+/',
71
+ '{VERTEX_URL_PREFIX}/',
72
+ url,
73
+ )
74
+ result = re.sub(
75
+ r'https://generativelanguage.googleapis.com/[^/]+',
76
+ '{MLDEV_URL_PREFIX}',
77
+ result,
78
+ )
79
+ return result
80
+
81
+
82
+ def _redact_project_location_path(path: str) -> str:
83
+ # Redact a field in the request that is known to vary based on project and
84
+ # location.
85
+ if 'projects/' in path and 'locations/' in path:
86
+ result = re.sub(
87
+ r'projects/[^/]+/locations/[^/]+/',
88
+ '{PROJECT_AND_LOCATION_PATH}/',
89
+ path,
90
+ )
91
+ return result
92
+ else:
93
+ return path
94
+
95
+
96
+ def _redact_request_body(body: dict[str, object]) -> dict[str, object]:
97
+ for key, value in body.items():
98
+ if isinstance(value, str):
99
+ body[key] = _redact_project_location_path(value)
100
+
101
+
102
+ def redact_http_request(http_request: HttpRequest):
103
+ http_request.headers = _redact_request_headers(http_request.headers)
104
+ http_request.url = _redact_request_url(http_request.url)
105
+ _redact_request_body(http_request.data)
106
+
107
+
108
+ def process_bytes_fields(data: dict[str, object]):
109
+ """Converts bytes fields to strings.
110
+
111
+ This function doesn't modify the content of data dict.
112
+ """
113
+ if not isinstance(data, dict):
114
+ return data
115
+ for key, value in data.items():
116
+ if isinstance(value, bytes):
117
+ data[key] = value.decode()
118
+ elif isinstance(value, dict):
119
+ process_bytes_fields(value)
120
+ elif isinstance(value, list):
121
+ if all(isinstance(v, bytes) for v in value):
122
+ data[key] = [v.decode() for v in value]
123
+ else:
124
+ data[key] = [process_bytes_fields(v) for v in value]
125
+ else:
126
+ data[key] = value
127
+ return data
128
+
129
+
130
+ def _current_file_path_and_line():
131
+ """Prints the current file path and line number."""
132
+ frame = inspect.currentframe().f_back.f_back
133
+ filepath = inspect.getfile(frame)
134
+ lineno = frame.f_lineno
135
+ return f'File: {filepath}, Line: {lineno}'
136
+
137
+
138
+ def _debug_print(message: str):
139
+ print(
140
+ 'DEBUG (test',
141
+ os.environ.get('PYTEST_CURRENT_TEST'),
142
+ ')',
143
+ _current_file_path_and_line(),
144
+ ':\n ',
145
+ message,
146
+ )
147
+
148
+
149
+ class ReplayRequest(BaseModel):
150
+ """Represents a single request in a replay."""
151
+
152
+ method: str
153
+ url: str
154
+ headers: dict[str, str]
155
+ body_segments: list[dict[str, object]]
156
+
157
+
158
+ class ReplayResponse(BaseModel):
159
+ """Represents a single response in a replay."""
160
+
161
+ status_code: int = 200
162
+ headers: dict[str, str]
163
+ body_segments: list[dict[str, object]]
164
+ sdk_response_segments: list[dict[str, object]]
165
+
166
+ def model_post_init(self, __context: Any) -> None:
167
+ # Remove headers that are not deterministic so the replay files don't change
168
+ # every time they are recorded.
169
+ self.headers.pop('Date', None)
170
+ self.headers.pop('Server-Timing', None)
171
+
172
+
173
+ class ReplayInteraction(BaseModel):
174
+ """Represents a single interaction, request and response in a replay."""
175
+
176
+ request: ReplayRequest
177
+ response: ReplayResponse
178
+
179
+
180
+ class ReplayFile(BaseModel):
181
+ """Represents a recorded session."""
182
+
183
+ replay_id: str
184
+ interactions: list[ReplayInteraction]
185
+
186
+
187
+ class ReplayApiClient(ApiClient):
188
+ """For integration testing, send recorded responese or records a response."""
189
+
190
+ def __init__(
191
+ self,
192
+ mode: Literal['record', 'replay', 'auto', 'api'],
193
+ replay_id: str,
194
+ replays_directory: Optional[str] = None,
195
+ vertexai: bool = False,
196
+ api_key: Optional[str] = None,
197
+ credentials: Optional[google.auth.credentials.Credentials] = None,
198
+ project: Optional[str] = None,
199
+ location: Optional[str] = None,
200
+ http_options: Optional[HttpOptions] = None,
201
+ ):
202
+ super().__init__(
203
+ vertexai=vertexai,
204
+ api_key=api_key,
205
+ credentials=credentials,
206
+ project=project,
207
+ location=location,
208
+ http_options=http_options,
209
+ )
210
+ self.replays_directory = replays_directory
211
+ if not self.replays_directory:
212
+ self.replays_directory = os.environ.get(
213
+ 'GOOGLE_GENAI_REPLAYS_DIRECTORY', None
214
+ )
215
+ # Valid replay modes are replay-only or record-and-replay.
216
+ self.replay_session = None
217
+ self._mode = mode
218
+ self._replay_id = replay_id
219
+
220
+ def initialize_replay_session(self, replay_id: str):
221
+ self._replay_id = replay_id
222
+ self._initialize_replay_session()
223
+
224
+ def _get_replay_file_path(self):
225
+ return self._generate_file_path_from_replay_id(
226
+ self.replays_directory, self._replay_id
227
+ )
228
+
229
+ def _should_call_api(self):
230
+ return self._mode in ['record', 'api'] or (
231
+ self._mode == 'auto'
232
+ and not os.path.isfile(self._get_replay_file_path())
233
+ )
234
+
235
+ def _should_update_replay(self):
236
+ return self._should_call_api() and self._mode != 'api'
237
+
238
+ def _initialize_replay_session_if_not_loaded(self):
239
+ if not self.replay_session:
240
+ self._initialize_replay_session()
241
+
242
+ def _initialize_replay_session(self):
243
+ _debug_print('Test is using replay id: ' + self._replay_id)
244
+ self._replay_index = 0
245
+ self._sdk_response_index = 0
246
+ replay_file_path = self._get_replay_file_path()
247
+ # This should not be triggered from the constructor.
248
+ replay_file_exists = os.path.isfile(replay_file_path)
249
+ if self._mode == 'replay' and not replay_file_exists:
250
+ raise ValueError(
251
+ 'Replay files do not exist for replay id: ' + self._replay_id
252
+ )
253
+
254
+ if self._mode in ['replay', 'auto'] and replay_file_exists:
255
+ with open(replay_file_path, 'r') as f:
256
+ self.replay_session = ReplayFile.model_validate(json.loads(f.read()))
257
+
258
+ if self._should_update_replay():
259
+ self.replay_session = ReplayFile(
260
+ replay_id=self._replay_id, interactions=[]
261
+ )
262
+
263
+ def _generate_file_path_from_replay_id(self, replay_directory, replay_id):
264
+ session_parts = replay_id.split('/')
265
+ if len(session_parts) < 3:
266
+ raise ValueError(
267
+ f'{replay_id}: Session ID must be in the format of'
268
+ ' module/function/[vertex|mldev]'
269
+ )
270
+ if replay_directory is None:
271
+ path_parts = []
272
+ else:
273
+ path_parts = [replay_directory]
274
+ path_parts.extend(session_parts)
275
+ return os.path.join(*path_parts) + '.json'
276
+
277
+ def close(self):
278
+ if not self._should_update_replay() or not self.replay_session:
279
+ return
280
+ replay_file_path = self._get_replay_file_path()
281
+ os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
282
+ with open(replay_file_path, 'w') as f:
283
+ f.write(
284
+ json.dumps(
285
+ self.replay_session.model_dump(), indent=2, cls=RequestJsonEncoder
286
+ )
287
+ )
288
+ self.replay_session = None
289
+
290
+ def _record_interaction(
291
+ self,
292
+ http_request: HttpRequest,
293
+ http_response: Union[HttpResponse, errors.APIError],
294
+ ):
295
+ if not self._should_update_replay():
296
+ return
297
+ redact_http_request(http_request)
298
+ request = ReplayRequest(
299
+ method=http_request.method,
300
+ url=http_request.url,
301
+ headers=http_request.headers,
302
+ body_segments=[http_request.data],
303
+ )
304
+ if isinstance(http_response, HttpResponse):
305
+ response = ReplayResponse(
306
+ headers=dict(http_response.headers),
307
+ body_segments=list(http_response.segments()),
308
+ status_code=http_response.status_code,
309
+ sdk_response_segments=[],
310
+ )
311
+ else:
312
+ response = ReplayResponse(
313
+ headers=dict(http_response.response.headers),
314
+ body_segments=[http_response._to_replay_record()],
315
+ status_code=http_response.code,
316
+ sdk_response_segments=[],
317
+ )
318
+ self.replay_session.interactions.append(
319
+ ReplayInteraction(request=request, response=response)
320
+ )
321
+
322
+ def _match_request(
323
+ self,
324
+ http_request: HttpRequest,
325
+ interaction: ReplayInteraction,
326
+ ):
327
+ assert http_request.url == interaction.request.url
328
+ assert http_request.headers == interaction.request.headers, (
329
+ 'Request headers mismatch:\n'
330
+ f'Actual: {http_request.headers}\n'
331
+ f'Expected: {interaction.request.headers}'
332
+ )
333
+ assert http_request.method == interaction.request.method
334
+
335
+ # Sanitize the request body, rewrite any fields that vary.
336
+ request_data_copy = copy.deepcopy(http_request.data)
337
+ # Both the request and recorded request must be redacted before comparing
338
+ # so that the comparison is fair.
339
+ _redact_request_body(request_data_copy)
340
+
341
+ # Need to call dumps() and loads() to convert dict bytes values to strings.
342
+ # Because the expected_request_body dict never contains bytes values.
343
+ actual_request_body = [
344
+ json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
345
+ ]
346
+ expected_request_body = interaction.request.body_segments
347
+ assert actual_request_body == expected_request_body, (
348
+ 'Request body mismatch:\n'
349
+ f'Actual: {actual_request_body}\n'
350
+ f'Expected: {expected_request_body}'
351
+ )
352
+
353
+ def _build_response_from_replay(self, http_request: HttpRequest):
354
+ redact_http_request(http_request)
355
+
356
+ interaction = self.replay_session.interactions[self._replay_index]
357
+ # Replay is on the right side of the assert so the diff makes more sense.
358
+ self._match_request(http_request, interaction)
359
+ self._replay_index += 1
360
+ self._sdk_response_index = 0
361
+ errors.APIError.raise_for_response(interaction.response)
362
+ return HttpResponse(
363
+ headers=interaction.response.headers,
364
+ response_stream=[
365
+ json.dumps(segment)
366
+ for segment in interaction.response.body_segments
367
+ ],
368
+ )
369
+
370
+ def _verify_response(self, response_model: BaseModel):
371
+ if self._mode == 'api':
372
+ return
373
+ # replay_index is advanced in _build_response_from_replay, so we need to -1.
374
+ interaction = self.replay_session.interactions[self._replay_index - 1]
375
+ if self._should_update_replay():
376
+ if isinstance(response_model, list):
377
+ response_model = response_model[0]
378
+ interaction.response.sdk_response_segments.append(
379
+ response_model.model_dump(exclude_none=True)
380
+ )
381
+ return
382
+
383
+ if isinstance(response_model, list):
384
+ response_model = response_model[0]
385
+ print('response_model: ', response_model.model_dump(exclude_none=True))
386
+ actual = json.dumps(
387
+ response_model.model_dump(exclude_none=True),
388
+ cls=ResponseJsonEncoder,
389
+ sort_keys=True,
390
+ )
391
+ expected = json.dumps(
392
+ interaction.response.sdk_response_segments[self._sdk_response_index],
393
+ sort_keys=True,
394
+ )
395
+ assert (
396
+ actual == expected
397
+ ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
398
+ self._sdk_response_index += 1
399
+
400
+ def _request(
401
+ self,
402
+ http_request: HttpRequest,
403
+ stream: bool = False,
404
+ ) -> HttpResponse:
405
+ self._initialize_replay_session_if_not_loaded()
406
+ if self._should_call_api():
407
+ _debug_print('api mode request: %s' % http_request)
408
+ try:
409
+ result = super()._request(http_request, stream)
410
+ except errors.APIError as e:
411
+ self._record_interaction(http_request, e)
412
+ raise e
413
+ if stream:
414
+ result_segments = []
415
+ for segment in result.segments():
416
+ result_segments.append(json.dumps(segment))
417
+ result = HttpResponse(result.headers, result_segments)
418
+ self._record_interaction(http_request, result)
419
+ # Need to return a RecordedResponse that rebuilds the response
420
+ # segments since the stream has been consumed.
421
+ else:
422
+ self._record_interaction(http_request, result)
423
+ _debug_print('api mode result: %s' % result.text)
424
+ return result
425
+ else:
426
+ return self._build_response_from_replay(http_request)
427
+
428
+ def upload_file(self, file_path: str, upload_url: str, upload_size: int):
429
+ request = HttpRequest(
430
+ method='POST', url='', data={'file_path': file_path}, headers={}
431
+ )
432
+ if self._should_call_api():
433
+ try:
434
+ result = super().upload_file(file_path, upload_url, upload_size)
435
+ except HTTPError as e:
436
+ result = HttpResponse(
437
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
438
+ )
439
+ result.status_code = e.response.status_code
440
+ raise e
441
+ self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
442
+ return result
443
+ else:
444
+ return self._build_response_from_replay(request).text
445
+
446
+
447
+ class ResponseJsonEncoder(json.JSONEncoder):
448
+ """The replay test json encoder for response.
449
+
450
+ We need RequestJsonEncoder and ResponseJsonEncoder because:
451
+ 1. In production, we only need RequestJsonEncoder to help json module
452
+ to convert non-stringable and stringable types to json string. Especially
453
+ for bytes type, the value of bytes field is encoded to base64 string so it
454
+ is always stringable and the RequestJsonEncoder doesn't have to deal with
455
+ utf-8 JSON broken issue.
456
+ 2. In replay test, we also need ResponseJsonEncoder to help json module
457
+ convert non-stringable and stringable types to json string. But response
458
+ object returned from SDK method is different from the request api_client
459
+ sent to server. For the bytes type, there is no base64 string in response
460
+ anymore, because SDK handles it internally. So bytes type in Response is
461
+ non-stringable. The ResponseJsonEncoder uses different encoding
462
+ strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue.
463
+ """
464
+ def default(self, o):
465
+ if isinstance(o, bytes):
466
+ # use error replace because response need to be serialized with bytes
467
+ # string, not base64 string. Otherwise, we cannot tell the response is
468
+ # already decoded from base64 or not from the replay file.
469
+ return o.decode(encoding='utf-8', errors='replace')
470
+ elif isinstance(o, datetime.datetime):
471
+ # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
472
+ # but replay files want "2024-11-15T23:27:45.624657Z"
473
+ if o.isoformat().endswith('+00:00'):
474
+ return o.isoformat().replace('+00:00', 'Z')
475
+ else:
476
+ return o.isoformat()
477
+ else:
478
+ return super().default(o)
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ import asyncio
17
+ import time
18
+ from unittest.mock import MagicMock, patch
19
+ import pytest
20
+ from .api_client import ApiClient
21
+
22
+
23
+ @patch('genai.api_client.ApiClient._build_request')
24
+ @patch('genai.api_client.ApiClient._request')
25
+ def test_request_streamed_non_blocking(mock_request, mock_build_request):
26
+ api_client = ApiClient(api_key='test_api_key')
27
+ http_method = 'GET'
28
+ path = 'test/path'
29
+ request_dict = {'key': 'value'}
30
+
31
+ mock_http_request = MagicMock()
32
+ mock_build_request.return_value = mock_http_request
33
+
34
+ def delayed_segments():
35
+ chunks = ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
36
+ for chunk in chunks:
37
+ time.sleep(0.1) # 100ms delay
38
+ yield chunk
39
+
40
+ mock_response = MagicMock()
41
+ mock_response.segments.side_effect = delayed_segments
42
+ mock_request.return_value = mock_response
43
+
44
+ chunks = []
45
+ start_time = time.time()
46
+ for chunk in api_client.request_streamed(http_method, path, request_dict):
47
+ chunks.append(chunk)
48
+ assert len(chunks) <= 3
49
+ end_time = time.time()
50
+
51
+ mock_build_request.assert_called_once_with(
52
+ http_method, path, request_dict, None
53
+ )
54
+ mock_request.assert_called_once_with(mock_http_request, stream=True)
55
+ assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
56
+ assert end_time - start_time > 0.3
57
+
58
+
59
+ @patch('genai.api_client.ApiClient._build_request')
60
+ @patch('genai.api_client.ApiClient._async_request')
61
+ @pytest.mark.asyncio
62
+ async def test_async_request(mock_async_request, mock_build_request):
63
+ api_client = ApiClient(api_key='test_api_key')
64
+ http_method = 'GET'
65
+ path = 'test/path'
66
+ request_dict = {'key': 'value'}
67
+
68
+ mock_http_request = MagicMock()
69
+ mock_build_request.return_value = mock_http_request
70
+
71
+ class MockResponse:
72
+
73
+ def __init__(self, text):
74
+ self.text = text
75
+
76
+ async def delayed_response(http_request, stream):
77
+ await asyncio.sleep(0.1) # 100ms delay
78
+ return MockResponse('value')
79
+
80
+ mock_async_request.side_effect = delayed_response
81
+
82
+ async_coroutine1 = api_client.async_request(http_method, path, request_dict)
83
+ async_coroutine2 = api_client.async_request(http_method, path, request_dict)
84
+ async_coroutine3 = api_client.async_request(http_method, path, request_dict)
85
+
86
+ start_time = time.time()
87
+ results = await asyncio.gather(
88
+ async_coroutine1, async_coroutine2, async_coroutine3
89
+ )
90
+ end_time = time.time()
91
+
92
+ mock_build_request.assert_called_with(http_method, path, request_dict, None)
93
+ assert mock_build_request.call_count == 3
94
+ mock_async_request.assert_called_with(
95
+ http_request=mock_http_request, stream=False
96
+ )
97
+ assert mock_async_request.call_count == 3
98
+ assert results == ['value', 'value', 'value']
99
+ assert 0.1 <= end_time - start_time < 0.15
100
+
101
+
102
+ @patch('genai.api_client.ApiClient._build_request')
103
+ @patch('genai.api_client.ApiClient._async_request')
104
+ @pytest.mark.asyncio
105
+ async def test_async_request_streamed_non_blocking(
106
+ mock_async_request, mock_build_request
107
+ ):
108
+ api_client = ApiClient(api_key='test_api_key')
109
+ http_method = 'GET'
110
+ path = 'test/path'
111
+ request_dict = {'key': 'value'}
112
+
113
+ mock_http_request = MagicMock()
114
+ mock_build_request.return_value = mock_http_request
115
+
116
+ class MockResponse:
117
+
118
+ def __init__(self, segments):
119
+ self._segments = segments
120
+
121
+ # should mock async generator here but source code combines sync and async streaming in one segment method.
122
+ # TODO: fix the above
123
+ def segments(self):
124
+ for segment in self._segments:
125
+ time.sleep(0.1) # 100ms delay
126
+ yield segment
127
+
128
+ async def delayed_response(http_request, stream):
129
+ return MockResponse(['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'])
130
+
131
+ mock_async_request.side_effect = delayed_response
132
+
133
+ chunks = []
134
+ start_time = time.time()
135
+ async for chunk in api_client.async_request_streamed(
136
+ http_method, path, request_dict
137
+ ):
138
+ chunks.append(chunk)
139
+ assert len(chunks) <= 3
140
+ end_time = time.time()
141
+
142
+ mock_build_request.assert_called_once_with(
143
+ http_method, path, request_dict, None
144
+ )
145
+ mock_async_request.assert_called_once_with(
146
+ http_request=mock_http_request, stream=True
147
+ )
148
+ assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
149
+ assert end_time - start_time > 0.3