google-genai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +20 -0
- google/genai/_api_client.py +467 -0
- google/genai/_automatic_function_calling_util.py +341 -0
- google/genai/_common.py +256 -0
- google/genai/_extra_utils.py +295 -0
- google/genai/_replay_api_client.py +478 -0
- google/genai/_test_api_client.py +149 -0
- google/genai/_transformers.py +438 -0
- google/genai/batches.py +1041 -0
- google/genai/caches.py +1830 -0
- google/genai/chats.py +184 -0
- google/genai/client.py +277 -0
- google/genai/errors.py +110 -0
- google/genai/files.py +1211 -0
- google/genai/live.py +629 -0
- google/genai/models.py +5307 -0
- google/genai/pagers.py +245 -0
- google/genai/tunings.py +1366 -0
- google/genai/types.py +7639 -0
- google_genai-0.0.1.dist-info/LICENSE +202 -0
- google_genai-0.0.1.dist-info/METADATA +763 -0
- google_genai-0.0.1.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/WHEEL +5 -0
- google_genai-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,478 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Replay API client."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import inspect
|
20
|
+
import json
|
21
|
+
import os
|
22
|
+
import re
|
23
|
+
import datetime
|
24
|
+
from typing import Any, Literal, Optional, Union
|
25
|
+
|
26
|
+
import google.auth
|
27
|
+
from pydantic import BaseModel
|
28
|
+
from requests.exceptions import HTTPError
|
29
|
+
|
30
|
+
from . import errors
|
31
|
+
from ._api_client import ApiClient
|
32
|
+
from ._api_client import HttpOptions
|
33
|
+
from ._api_client import HttpRequest
|
34
|
+
from ._api_client import HttpResponse
|
35
|
+
from ._api_client import RequestJsonEncoder
|
36
|
+
|
37
|
+
def _redact_version_numbers(version_string: str) -> str:
|
38
|
+
"""Redacts version numbers in the form x.y.z from a string."""
|
39
|
+
return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
|
40
|
+
|
41
|
+
|
42
|
+
def _redact_language_label(language_label: str) -> str:
|
43
|
+
"""Removed because replay requests are used for all languages."""
|
44
|
+
return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label)
|
45
|
+
|
46
|
+
|
47
|
+
def _redact_request_headers(headers):
|
48
|
+
"""Redacts headers that should not be recorded."""
|
49
|
+
redacted_headers = {}
|
50
|
+
for header_name, header_value in headers.items():
|
51
|
+
if header_name.lower() == 'x-goog-api-key':
|
52
|
+
redacted_headers[header_name] = '{REDACTED}'
|
53
|
+
elif header_name.lower() == 'user-agent':
|
54
|
+
redacted_headers[header_name] = _redact_language_label(
|
55
|
+
_redact_version_numbers(header_value)
|
56
|
+
)
|
57
|
+
elif header_name.lower() == 'x-goog-api-client':
|
58
|
+
redacted_headers[header_name] = _redact_language_label(
|
59
|
+
_redact_version_numbers(header_value)
|
60
|
+
)
|
61
|
+
else:
|
62
|
+
redacted_headers[header_name] = header_value
|
63
|
+
return redacted_headers
|
64
|
+
|
65
|
+
|
66
|
+
def _redact_request_url(url: str) -> str:
|
67
|
+
# Redact all the url parts before the resource name, so the test can work
|
68
|
+
# against any project, location, version, or whether it's EasyGCP.
|
69
|
+
result = re.sub(
|
70
|
+
r'.*/projects/[^/]+/locations/[^/]+/',
|
71
|
+
'{VERTEX_URL_PREFIX}/',
|
72
|
+
url,
|
73
|
+
)
|
74
|
+
result = re.sub(
|
75
|
+
r'https://generativelanguage.googleapis.com/[^/]+',
|
76
|
+
'{MLDEV_URL_PREFIX}',
|
77
|
+
result,
|
78
|
+
)
|
79
|
+
return result
|
80
|
+
|
81
|
+
|
82
|
+
def _redact_project_location_path(path: str) -> str:
|
83
|
+
# Redact a field in the request that is known to vary based on project and
|
84
|
+
# location.
|
85
|
+
if 'projects/' in path and 'locations/' in path:
|
86
|
+
result = re.sub(
|
87
|
+
r'projects/[^/]+/locations/[^/]+/',
|
88
|
+
'{PROJECT_AND_LOCATION_PATH}/',
|
89
|
+
path,
|
90
|
+
)
|
91
|
+
return result
|
92
|
+
else:
|
93
|
+
return path
|
94
|
+
|
95
|
+
|
96
|
+
def _redact_request_body(body: dict[str, object]) -> dict[str, object]:
|
97
|
+
for key, value in body.items():
|
98
|
+
if isinstance(value, str):
|
99
|
+
body[key] = _redact_project_location_path(value)
|
100
|
+
|
101
|
+
|
102
|
+
def redact_http_request(http_request: HttpRequest):
|
103
|
+
http_request.headers = _redact_request_headers(http_request.headers)
|
104
|
+
http_request.url = _redact_request_url(http_request.url)
|
105
|
+
_redact_request_body(http_request.data)
|
106
|
+
|
107
|
+
|
108
|
+
def process_bytes_fields(data: dict[str, object]):
|
109
|
+
"""Converts bytes fields to strings.
|
110
|
+
|
111
|
+
This function doesn't modify the content of data dict.
|
112
|
+
"""
|
113
|
+
if not isinstance(data, dict):
|
114
|
+
return data
|
115
|
+
for key, value in data.items():
|
116
|
+
if isinstance(value, bytes):
|
117
|
+
data[key] = value.decode()
|
118
|
+
elif isinstance(value, dict):
|
119
|
+
process_bytes_fields(value)
|
120
|
+
elif isinstance(value, list):
|
121
|
+
if all(isinstance(v, bytes) for v in value):
|
122
|
+
data[key] = [v.decode() for v in value]
|
123
|
+
else:
|
124
|
+
data[key] = [process_bytes_fields(v) for v in value]
|
125
|
+
else:
|
126
|
+
data[key] = value
|
127
|
+
return data
|
128
|
+
|
129
|
+
|
130
|
+
def _current_file_path_and_line():
|
131
|
+
"""Prints the current file path and line number."""
|
132
|
+
frame = inspect.currentframe().f_back.f_back
|
133
|
+
filepath = inspect.getfile(frame)
|
134
|
+
lineno = frame.f_lineno
|
135
|
+
return f'File: {filepath}, Line: {lineno}'
|
136
|
+
|
137
|
+
|
138
|
+
def _debug_print(message: str):
|
139
|
+
print(
|
140
|
+
'DEBUG (test',
|
141
|
+
os.environ.get('PYTEST_CURRENT_TEST'),
|
142
|
+
')',
|
143
|
+
_current_file_path_and_line(),
|
144
|
+
':\n ',
|
145
|
+
message,
|
146
|
+
)
|
147
|
+
|
148
|
+
|
149
|
+
class ReplayRequest(BaseModel):
|
150
|
+
"""Represents a single request in a replay."""
|
151
|
+
|
152
|
+
method: str
|
153
|
+
url: str
|
154
|
+
headers: dict[str, str]
|
155
|
+
body_segments: list[dict[str, object]]
|
156
|
+
|
157
|
+
|
158
|
+
class ReplayResponse(BaseModel):
|
159
|
+
"""Represents a single response in a replay."""
|
160
|
+
|
161
|
+
status_code: int = 200
|
162
|
+
headers: dict[str, str]
|
163
|
+
body_segments: list[dict[str, object]]
|
164
|
+
sdk_response_segments: list[dict[str, object]]
|
165
|
+
|
166
|
+
def model_post_init(self, __context: Any) -> None:
|
167
|
+
# Remove headers that are not deterministic so the replay files don't change
|
168
|
+
# every time they are recorded.
|
169
|
+
self.headers.pop('Date', None)
|
170
|
+
self.headers.pop('Server-Timing', None)
|
171
|
+
|
172
|
+
|
173
|
+
class ReplayInteraction(BaseModel):
|
174
|
+
"""Represents a single interaction, request and response in a replay."""
|
175
|
+
|
176
|
+
request: ReplayRequest
|
177
|
+
response: ReplayResponse
|
178
|
+
|
179
|
+
|
180
|
+
class ReplayFile(BaseModel):
|
181
|
+
"""Represents a recorded session."""
|
182
|
+
|
183
|
+
replay_id: str
|
184
|
+
interactions: list[ReplayInteraction]
|
185
|
+
|
186
|
+
|
187
|
+
class ReplayApiClient(ApiClient):
|
188
|
+
"""For integration testing, send recorded responese or records a response."""
|
189
|
+
|
190
|
+
def __init__(
|
191
|
+
self,
|
192
|
+
mode: Literal['record', 'replay', 'auto', 'api'],
|
193
|
+
replay_id: str,
|
194
|
+
replays_directory: Optional[str] = None,
|
195
|
+
vertexai: bool = False,
|
196
|
+
api_key: Optional[str] = None,
|
197
|
+
credentials: Optional[google.auth.credentials.Credentials] = None,
|
198
|
+
project: Optional[str] = None,
|
199
|
+
location: Optional[str] = None,
|
200
|
+
http_options: Optional[HttpOptions] = None,
|
201
|
+
):
|
202
|
+
super().__init__(
|
203
|
+
vertexai=vertexai,
|
204
|
+
api_key=api_key,
|
205
|
+
credentials=credentials,
|
206
|
+
project=project,
|
207
|
+
location=location,
|
208
|
+
http_options=http_options,
|
209
|
+
)
|
210
|
+
self.replays_directory = replays_directory
|
211
|
+
if not self.replays_directory:
|
212
|
+
self.replays_directory = os.environ.get(
|
213
|
+
'GOOGLE_GENAI_REPLAYS_DIRECTORY', None
|
214
|
+
)
|
215
|
+
# Valid replay modes are replay-only or record-and-replay.
|
216
|
+
self.replay_session = None
|
217
|
+
self._mode = mode
|
218
|
+
self._replay_id = replay_id
|
219
|
+
|
220
|
+
def initialize_replay_session(self, replay_id: str):
|
221
|
+
self._replay_id = replay_id
|
222
|
+
self._initialize_replay_session()
|
223
|
+
|
224
|
+
def _get_replay_file_path(self):
|
225
|
+
return self._generate_file_path_from_replay_id(
|
226
|
+
self.replays_directory, self._replay_id
|
227
|
+
)
|
228
|
+
|
229
|
+
def _should_call_api(self):
|
230
|
+
return self._mode in ['record', 'api'] or (
|
231
|
+
self._mode == 'auto'
|
232
|
+
and not os.path.isfile(self._get_replay_file_path())
|
233
|
+
)
|
234
|
+
|
235
|
+
def _should_update_replay(self):
|
236
|
+
return self._should_call_api() and self._mode != 'api'
|
237
|
+
|
238
|
+
def _initialize_replay_session_if_not_loaded(self):
|
239
|
+
if not self.replay_session:
|
240
|
+
self._initialize_replay_session()
|
241
|
+
|
242
|
+
def _initialize_replay_session(self):
|
243
|
+
_debug_print('Test is using replay id: ' + self._replay_id)
|
244
|
+
self._replay_index = 0
|
245
|
+
self._sdk_response_index = 0
|
246
|
+
replay_file_path = self._get_replay_file_path()
|
247
|
+
# This should not be triggered from the constructor.
|
248
|
+
replay_file_exists = os.path.isfile(replay_file_path)
|
249
|
+
if self._mode == 'replay' and not replay_file_exists:
|
250
|
+
raise ValueError(
|
251
|
+
'Replay files do not exist for replay id: ' + self._replay_id
|
252
|
+
)
|
253
|
+
|
254
|
+
if self._mode in ['replay', 'auto'] and replay_file_exists:
|
255
|
+
with open(replay_file_path, 'r') as f:
|
256
|
+
self.replay_session = ReplayFile.model_validate(json.loads(f.read()))
|
257
|
+
|
258
|
+
if self._should_update_replay():
|
259
|
+
self.replay_session = ReplayFile(
|
260
|
+
replay_id=self._replay_id, interactions=[]
|
261
|
+
)
|
262
|
+
|
263
|
+
def _generate_file_path_from_replay_id(self, replay_directory, replay_id):
|
264
|
+
session_parts = replay_id.split('/')
|
265
|
+
if len(session_parts) < 3:
|
266
|
+
raise ValueError(
|
267
|
+
f'{replay_id}: Session ID must be in the format of'
|
268
|
+
' module/function/[vertex|mldev]'
|
269
|
+
)
|
270
|
+
if replay_directory is None:
|
271
|
+
path_parts = []
|
272
|
+
else:
|
273
|
+
path_parts = [replay_directory]
|
274
|
+
path_parts.extend(session_parts)
|
275
|
+
return os.path.join(*path_parts) + '.json'
|
276
|
+
|
277
|
+
def close(self):
|
278
|
+
if not self._should_update_replay() or not self.replay_session:
|
279
|
+
return
|
280
|
+
replay_file_path = self._get_replay_file_path()
|
281
|
+
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
282
|
+
with open(replay_file_path, 'w') as f:
|
283
|
+
f.write(
|
284
|
+
json.dumps(
|
285
|
+
self.replay_session.model_dump(), indent=2, cls=RequestJsonEncoder
|
286
|
+
)
|
287
|
+
)
|
288
|
+
self.replay_session = None
|
289
|
+
|
290
|
+
def _record_interaction(
|
291
|
+
self,
|
292
|
+
http_request: HttpRequest,
|
293
|
+
http_response: Union[HttpResponse, errors.APIError],
|
294
|
+
):
|
295
|
+
if not self._should_update_replay():
|
296
|
+
return
|
297
|
+
redact_http_request(http_request)
|
298
|
+
request = ReplayRequest(
|
299
|
+
method=http_request.method,
|
300
|
+
url=http_request.url,
|
301
|
+
headers=http_request.headers,
|
302
|
+
body_segments=[http_request.data],
|
303
|
+
)
|
304
|
+
if isinstance(http_response, HttpResponse):
|
305
|
+
response = ReplayResponse(
|
306
|
+
headers=dict(http_response.headers),
|
307
|
+
body_segments=list(http_response.segments()),
|
308
|
+
status_code=http_response.status_code,
|
309
|
+
sdk_response_segments=[],
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
response = ReplayResponse(
|
313
|
+
headers=dict(http_response.response.headers),
|
314
|
+
body_segments=[http_response._to_replay_record()],
|
315
|
+
status_code=http_response.code,
|
316
|
+
sdk_response_segments=[],
|
317
|
+
)
|
318
|
+
self.replay_session.interactions.append(
|
319
|
+
ReplayInteraction(request=request, response=response)
|
320
|
+
)
|
321
|
+
|
322
|
+
def _match_request(
|
323
|
+
self,
|
324
|
+
http_request: HttpRequest,
|
325
|
+
interaction: ReplayInteraction,
|
326
|
+
):
|
327
|
+
assert http_request.url == interaction.request.url
|
328
|
+
assert http_request.headers == interaction.request.headers, (
|
329
|
+
'Request headers mismatch:\n'
|
330
|
+
f'Actual: {http_request.headers}\n'
|
331
|
+
f'Expected: {interaction.request.headers}'
|
332
|
+
)
|
333
|
+
assert http_request.method == interaction.request.method
|
334
|
+
|
335
|
+
# Sanitize the request body, rewrite any fields that vary.
|
336
|
+
request_data_copy = copy.deepcopy(http_request.data)
|
337
|
+
# Both the request and recorded request must be redacted before comparing
|
338
|
+
# so that the comparison is fair.
|
339
|
+
_redact_request_body(request_data_copy)
|
340
|
+
|
341
|
+
# Need to call dumps() and loads() to convert dict bytes values to strings.
|
342
|
+
# Because the expected_request_body dict never contains bytes values.
|
343
|
+
actual_request_body = [
|
344
|
+
json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
|
345
|
+
]
|
346
|
+
expected_request_body = interaction.request.body_segments
|
347
|
+
assert actual_request_body == expected_request_body, (
|
348
|
+
'Request body mismatch:\n'
|
349
|
+
f'Actual: {actual_request_body}\n'
|
350
|
+
f'Expected: {expected_request_body}'
|
351
|
+
)
|
352
|
+
|
353
|
+
def _build_response_from_replay(self, http_request: HttpRequest):
|
354
|
+
redact_http_request(http_request)
|
355
|
+
|
356
|
+
interaction = self.replay_session.interactions[self._replay_index]
|
357
|
+
# Replay is on the right side of the assert so the diff makes more sense.
|
358
|
+
self._match_request(http_request, interaction)
|
359
|
+
self._replay_index += 1
|
360
|
+
self._sdk_response_index = 0
|
361
|
+
errors.APIError.raise_for_response(interaction.response)
|
362
|
+
return HttpResponse(
|
363
|
+
headers=interaction.response.headers,
|
364
|
+
response_stream=[
|
365
|
+
json.dumps(segment)
|
366
|
+
for segment in interaction.response.body_segments
|
367
|
+
],
|
368
|
+
)
|
369
|
+
|
370
|
+
def _verify_response(self, response_model: BaseModel):
|
371
|
+
if self._mode == 'api':
|
372
|
+
return
|
373
|
+
# replay_index is advanced in _build_response_from_replay, so we need to -1.
|
374
|
+
interaction = self.replay_session.interactions[self._replay_index - 1]
|
375
|
+
if self._should_update_replay():
|
376
|
+
if isinstance(response_model, list):
|
377
|
+
response_model = response_model[0]
|
378
|
+
interaction.response.sdk_response_segments.append(
|
379
|
+
response_model.model_dump(exclude_none=True)
|
380
|
+
)
|
381
|
+
return
|
382
|
+
|
383
|
+
if isinstance(response_model, list):
|
384
|
+
response_model = response_model[0]
|
385
|
+
print('response_model: ', response_model.model_dump(exclude_none=True))
|
386
|
+
actual = json.dumps(
|
387
|
+
response_model.model_dump(exclude_none=True),
|
388
|
+
cls=ResponseJsonEncoder,
|
389
|
+
sort_keys=True,
|
390
|
+
)
|
391
|
+
expected = json.dumps(
|
392
|
+
interaction.response.sdk_response_segments[self._sdk_response_index],
|
393
|
+
sort_keys=True,
|
394
|
+
)
|
395
|
+
assert (
|
396
|
+
actual == expected
|
397
|
+
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
|
398
|
+
self._sdk_response_index += 1
|
399
|
+
|
400
|
+
def _request(
|
401
|
+
self,
|
402
|
+
http_request: HttpRequest,
|
403
|
+
stream: bool = False,
|
404
|
+
) -> HttpResponse:
|
405
|
+
self._initialize_replay_session_if_not_loaded()
|
406
|
+
if self._should_call_api():
|
407
|
+
_debug_print('api mode request: %s' % http_request)
|
408
|
+
try:
|
409
|
+
result = super()._request(http_request, stream)
|
410
|
+
except errors.APIError as e:
|
411
|
+
self._record_interaction(http_request, e)
|
412
|
+
raise e
|
413
|
+
if stream:
|
414
|
+
result_segments = []
|
415
|
+
for segment in result.segments():
|
416
|
+
result_segments.append(json.dumps(segment))
|
417
|
+
result = HttpResponse(result.headers, result_segments)
|
418
|
+
self._record_interaction(http_request, result)
|
419
|
+
# Need to return a RecordedResponse that rebuilds the response
|
420
|
+
# segments since the stream has been consumed.
|
421
|
+
else:
|
422
|
+
self._record_interaction(http_request, result)
|
423
|
+
_debug_print('api mode result: %s' % result.text)
|
424
|
+
return result
|
425
|
+
else:
|
426
|
+
return self._build_response_from_replay(http_request)
|
427
|
+
|
428
|
+
def upload_file(self, file_path: str, upload_url: str, upload_size: int):
|
429
|
+
request = HttpRequest(
|
430
|
+
method='POST', url='', data={'file_path': file_path}, headers={}
|
431
|
+
)
|
432
|
+
if self._should_call_api():
|
433
|
+
try:
|
434
|
+
result = super().upload_file(file_path, upload_url, upload_size)
|
435
|
+
except HTTPError as e:
|
436
|
+
result = HttpResponse(
|
437
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
438
|
+
)
|
439
|
+
result.status_code = e.response.status_code
|
440
|
+
raise e
|
441
|
+
self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
|
442
|
+
return result
|
443
|
+
else:
|
444
|
+
return self._build_response_from_replay(request).text
|
445
|
+
|
446
|
+
|
447
|
+
class ResponseJsonEncoder(json.JSONEncoder):
|
448
|
+
"""The replay test json encoder for response.
|
449
|
+
|
450
|
+
We need RequestJsonEncoder and ResponseJsonEncoder because:
|
451
|
+
1. In production, we only need RequestJsonEncoder to help json module
|
452
|
+
to convert non-stringable and stringable types to json string. Especially
|
453
|
+
for bytes type, the value of bytes field is encoded to base64 string so it
|
454
|
+
is always stringable and the RequestJsonEncoder doesn't have to deal with
|
455
|
+
utf-8 JSON broken issue.
|
456
|
+
2. In replay test, we also need ResponseJsonEncoder to help json module
|
457
|
+
convert non-stringable and stringable types to json string. But response
|
458
|
+
object returned from SDK method is different from the request api_client
|
459
|
+
sent to server. For the bytes type, there is no base64 string in response
|
460
|
+
anymore, because SDK handles it internally. So bytes type in Response is
|
461
|
+
non-stringable. The ResponseJsonEncoder uses different encoding
|
462
|
+
strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue.
|
463
|
+
"""
|
464
|
+
def default(self, o):
|
465
|
+
if isinstance(o, bytes):
|
466
|
+
# use error replace because response need to be serialized with bytes
|
467
|
+
# string, not base64 string. Otherwise, we cannot tell the response is
|
468
|
+
# already decoded from base64 or not from the replay file.
|
469
|
+
return o.decode(encoding='utf-8', errors='replace')
|
470
|
+
elif isinstance(o, datetime.datetime):
|
471
|
+
# dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
|
472
|
+
# but replay files want "2024-11-15T23:27:45.624657Z"
|
473
|
+
if o.isoformat().endswith('+00:00'):
|
474
|
+
return o.isoformat().replace('+00:00', 'Z')
|
475
|
+
else:
|
476
|
+
return o.isoformat()
|
477
|
+
else:
|
478
|
+
return super().default(o)
|
@@ -0,0 +1,149 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
import time
|
18
|
+
from unittest.mock import MagicMock, patch
|
19
|
+
import pytest
|
20
|
+
from .api_client import ApiClient
|
21
|
+
|
22
|
+
|
23
|
+
@patch('genai.api_client.ApiClient._build_request')
|
24
|
+
@patch('genai.api_client.ApiClient._request')
|
25
|
+
def test_request_streamed_non_blocking(mock_request, mock_build_request):
|
26
|
+
api_client = ApiClient(api_key='test_api_key')
|
27
|
+
http_method = 'GET'
|
28
|
+
path = 'test/path'
|
29
|
+
request_dict = {'key': 'value'}
|
30
|
+
|
31
|
+
mock_http_request = MagicMock()
|
32
|
+
mock_build_request.return_value = mock_http_request
|
33
|
+
|
34
|
+
def delayed_segments():
|
35
|
+
chunks = ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
|
36
|
+
for chunk in chunks:
|
37
|
+
time.sleep(0.1) # 100ms delay
|
38
|
+
yield chunk
|
39
|
+
|
40
|
+
mock_response = MagicMock()
|
41
|
+
mock_response.segments.side_effect = delayed_segments
|
42
|
+
mock_request.return_value = mock_response
|
43
|
+
|
44
|
+
chunks = []
|
45
|
+
start_time = time.time()
|
46
|
+
for chunk in api_client.request_streamed(http_method, path, request_dict):
|
47
|
+
chunks.append(chunk)
|
48
|
+
assert len(chunks) <= 3
|
49
|
+
end_time = time.time()
|
50
|
+
|
51
|
+
mock_build_request.assert_called_once_with(
|
52
|
+
http_method, path, request_dict, None
|
53
|
+
)
|
54
|
+
mock_request.assert_called_once_with(mock_http_request, stream=True)
|
55
|
+
assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
|
56
|
+
assert end_time - start_time > 0.3
|
57
|
+
|
58
|
+
|
59
|
+
@patch('genai.api_client.ApiClient._build_request')
|
60
|
+
@patch('genai.api_client.ApiClient._async_request')
|
61
|
+
@pytest.mark.asyncio
|
62
|
+
async def test_async_request(mock_async_request, mock_build_request):
|
63
|
+
api_client = ApiClient(api_key='test_api_key')
|
64
|
+
http_method = 'GET'
|
65
|
+
path = 'test/path'
|
66
|
+
request_dict = {'key': 'value'}
|
67
|
+
|
68
|
+
mock_http_request = MagicMock()
|
69
|
+
mock_build_request.return_value = mock_http_request
|
70
|
+
|
71
|
+
class MockResponse:
|
72
|
+
|
73
|
+
def __init__(self, text):
|
74
|
+
self.text = text
|
75
|
+
|
76
|
+
async def delayed_response(http_request, stream):
|
77
|
+
await asyncio.sleep(0.1) # 100ms delay
|
78
|
+
return MockResponse('value')
|
79
|
+
|
80
|
+
mock_async_request.side_effect = delayed_response
|
81
|
+
|
82
|
+
async_coroutine1 = api_client.async_request(http_method, path, request_dict)
|
83
|
+
async_coroutine2 = api_client.async_request(http_method, path, request_dict)
|
84
|
+
async_coroutine3 = api_client.async_request(http_method, path, request_dict)
|
85
|
+
|
86
|
+
start_time = time.time()
|
87
|
+
results = await asyncio.gather(
|
88
|
+
async_coroutine1, async_coroutine2, async_coroutine3
|
89
|
+
)
|
90
|
+
end_time = time.time()
|
91
|
+
|
92
|
+
mock_build_request.assert_called_with(http_method, path, request_dict, None)
|
93
|
+
assert mock_build_request.call_count == 3
|
94
|
+
mock_async_request.assert_called_with(
|
95
|
+
http_request=mock_http_request, stream=False
|
96
|
+
)
|
97
|
+
assert mock_async_request.call_count == 3
|
98
|
+
assert results == ['value', 'value', 'value']
|
99
|
+
assert 0.1 <= end_time - start_time < 0.15
|
100
|
+
|
101
|
+
|
102
|
+
@patch('genai.api_client.ApiClient._build_request')
|
103
|
+
@patch('genai.api_client.ApiClient._async_request')
|
104
|
+
@pytest.mark.asyncio
|
105
|
+
async def test_async_request_streamed_non_blocking(
|
106
|
+
mock_async_request, mock_build_request
|
107
|
+
):
|
108
|
+
api_client = ApiClient(api_key='test_api_key')
|
109
|
+
http_method = 'GET'
|
110
|
+
path = 'test/path'
|
111
|
+
request_dict = {'key': 'value'}
|
112
|
+
|
113
|
+
mock_http_request = MagicMock()
|
114
|
+
mock_build_request.return_value = mock_http_request
|
115
|
+
|
116
|
+
class MockResponse:
|
117
|
+
|
118
|
+
def __init__(self, segments):
|
119
|
+
self._segments = segments
|
120
|
+
|
121
|
+
# should mock async generator here but source code combines sync and async streaming in one segment method.
|
122
|
+
# TODO: fix the above
|
123
|
+
def segments(self):
|
124
|
+
for segment in self._segments:
|
125
|
+
time.sleep(0.1) # 100ms delay
|
126
|
+
yield segment
|
127
|
+
|
128
|
+
async def delayed_response(http_request, stream):
|
129
|
+
return MockResponse(['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'])
|
130
|
+
|
131
|
+
mock_async_request.side_effect = delayed_response
|
132
|
+
|
133
|
+
chunks = []
|
134
|
+
start_time = time.time()
|
135
|
+
async for chunk in api_client.async_request_streamed(
|
136
|
+
http_method, path, request_dict
|
137
|
+
):
|
138
|
+
chunks.append(chunk)
|
139
|
+
assert len(chunks) <= 3
|
140
|
+
end_time = time.time()
|
141
|
+
|
142
|
+
mock_build_request.assert_called_once_with(
|
143
|
+
http_method, path, request_dict, None
|
144
|
+
)
|
145
|
+
mock_async_request.assert_called_once_with(
|
146
|
+
http_request=mock_http_request, stream=True
|
147
|
+
)
|
148
|
+
assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
|
149
|
+
assert end_time - start_time > 0.3
|