google-genai 1.32.0__py3-none-any.whl → 1.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +32 -9
- google/genai/_common.py +5 -1
- google/genai/_extra_utils.py +5 -8
- google/genai/_live_converters.py +100 -44
- google/genai/_replay_api_client.py +15 -0
- google/genai/_tokens_converters.py +24 -3
- google/genai/_transformers.py +55 -15
- google/genai/batches.py +680 -142
- google/genai/caches.py +50 -6
- google/genai/files.py +2 -0
- google/genai/local_tokenizer.py +29 -1
- google/genai/models.py +127 -191
- google/genai/types.py +431 -278
- google/genai/version.py +1 -1
- {google_genai-1.32.0.dist-info → google_genai-1.34.0.dist-info}/METADATA +18 -1
- {google_genai-1.32.0.dist-info → google_genai-1.34.0.dist-info}/RECORD +19 -19
- {google_genai-1.32.0.dist-info → google_genai-1.34.0.dist-info}/WHEEL +0 -0
- {google_genai-1.32.0.dist-info → google_genai-1.34.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.32.0.dist-info → google_genai-1.34.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -584,13 +584,9 @@ class BaseApiClient:
|
|
584
584
|
# Initialize the lock. This lock will be used to protect access to the
|
585
585
|
# credentials. This is crucial for thread safety when multiple coroutines
|
586
586
|
# might be accessing the credentials at the same time.
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
except RuntimeError:
|
591
|
-
asyncio.set_event_loop(asyncio.new_event_loop())
|
592
|
-
self._sync_auth_lock = threading.Lock()
|
593
|
-
self._async_auth_lock = asyncio.Lock()
|
587
|
+
self._sync_auth_lock = threading.Lock()
|
588
|
+
self._async_auth_lock: Optional[asyncio.Lock] = None
|
589
|
+
self._async_auth_lock_creation_lock: Optional[asyncio.Lock] = None
|
594
590
|
|
595
591
|
# Handle when to use Vertex AI in express mode (api key).
|
596
592
|
# Explicit initializer arguments are already validated above.
|
@@ -903,10 +899,36 @@ class BaseApiClient:
|
|
903
899
|
else:
|
904
900
|
raise RuntimeError('Could not resolve API token from the environment')
|
905
901
|
|
902
|
+
async def _get_async_auth_lock(self) -> asyncio.Lock:
|
903
|
+
"""Lazily initializes and returns an asyncio.Lock for async authentication.
|
904
|
+
|
905
|
+
This method ensures that a single `asyncio.Lock` instance is created and
|
906
|
+
shared among all asynchronous operations that require authentication,
|
907
|
+
preventing race conditions when accessing or refreshing credentials.
|
908
|
+
|
909
|
+
The lock is created on the first call to this method. An internal async lock
|
910
|
+
is used to protect the creation of the main authentication lock to ensure
|
911
|
+
it's a singleton within the client instance.
|
912
|
+
|
913
|
+
Returns:
|
914
|
+
The asyncio.Lock instance for asynchronous authentication operations.
|
915
|
+
"""
|
916
|
+
if self._async_auth_lock is None:
|
917
|
+
# Create async creation lock if needed
|
918
|
+
if self._async_auth_lock_creation_lock is None:
|
919
|
+
self._async_auth_lock_creation_lock = asyncio.Lock()
|
920
|
+
|
921
|
+
async with self._async_auth_lock_creation_lock:
|
922
|
+
if self._async_auth_lock is None:
|
923
|
+
self._async_auth_lock = asyncio.Lock()
|
924
|
+
|
925
|
+
return self._async_auth_lock
|
926
|
+
|
906
927
|
async def _async_access_token(self) -> Union[str, Any]:
|
907
928
|
"""Retrieves the access token for the credentials asynchronously."""
|
908
929
|
if not self._credentials:
|
909
|
-
|
930
|
+
async_auth_lock = await self._get_async_auth_lock()
|
931
|
+
async with async_auth_lock:
|
910
932
|
# This ensures that only one coroutine can execute the auth logic at a
|
911
933
|
# time for thread safety.
|
912
934
|
if not self._credentials:
|
@@ -920,7 +942,8 @@ class BaseApiClient:
|
|
920
942
|
if self._credentials:
|
921
943
|
if self._credentials.expired or not self._credentials.token:
|
922
944
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
923
|
-
|
945
|
+
async_auth_lock = await self._get_async_auth_lock()
|
946
|
+
async with async_auth_lock:
|
924
947
|
if self._credentials.expired or not self._credentials.token:
|
925
948
|
# Double check that the credentials expired before refreshing.
|
926
949
|
await asyncio.to_thread(refresh_auth, self._credentials)
|
google/genai/_common.py
CHANGED
@@ -100,7 +100,11 @@ def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: An
|
|
100
100
|
f' Existing value: {existing_data}; New value: {value}.'
|
101
101
|
)
|
102
102
|
else:
|
103
|
-
|
103
|
+
if (keys[-1] == '_self' and isinstance(data, dict)
|
104
|
+
and isinstance(value, dict)):
|
105
|
+
data.update(value)
|
106
|
+
else:
|
107
|
+
data[keys[-1]] = value
|
104
108
|
|
105
109
|
|
106
110
|
def get_value_by_path(data: Any, keys: list[str]) -> Any:
|
google/genai/_extra_utils.py
CHANGED
@@ -90,14 +90,12 @@ def _get_bigquery_uri(
|
|
90
90
|
|
91
91
|
|
92
92
|
def format_destination(
|
93
|
-
src: Union[str, types.
|
94
|
-
config: Optional[types.
|
93
|
+
src: Union[str, types.BatchJobSource],
|
94
|
+
config: Optional[types.CreateBatchJobConfig] = None,
|
95
95
|
) -> types.CreateBatchJobConfig:
|
96
96
|
"""Formats the destination uri based on the source uri for Vertex AI."""
|
97
|
-
config
|
98
|
-
|
99
|
-
or types.CreateBatchJobConfig()
|
100
|
-
)
|
97
|
+
if config is None:
|
98
|
+
config = types.CreateBatchJobConfig()
|
101
99
|
|
102
100
|
unique_name = None
|
103
101
|
if not config.display_name:
|
@@ -113,8 +111,7 @@ def format_destination(
|
|
113
111
|
elif bigquery_source_uri:
|
114
112
|
unique_name = unique_name or _common.timestamped_unique_name()
|
115
113
|
config.dest = f'{bigquery_source_uri}_dest_{unique_name}'
|
116
|
-
|
117
|
-
raise ValueError(f'The source {src} is not supported.')
|
114
|
+
|
118
115
|
return config
|
119
116
|
|
120
117
|
|
google/genai/_live_converters.py
CHANGED
@@ -165,6 +165,23 @@ def _FileData_to_mldev(
|
|
165
165
|
return to_object
|
166
166
|
|
167
167
|
|
168
|
+
def _FunctionCall_to_mldev(
|
169
|
+
from_object: Union[dict[str, Any], object],
|
170
|
+
parent_object: Optional[dict[str, Any]] = None,
|
171
|
+
) -> dict[str, Any]:
|
172
|
+
to_object: dict[str, Any] = {}
|
173
|
+
if getv(from_object, ['id']) is not None:
|
174
|
+
setv(to_object, ['id'], getv(from_object, ['id']))
|
175
|
+
|
176
|
+
if getv(from_object, ['args']) is not None:
|
177
|
+
setv(to_object, ['args'], getv(from_object, ['args']))
|
178
|
+
|
179
|
+
if getv(from_object, ['name']) is not None:
|
180
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
181
|
+
|
182
|
+
return to_object
|
183
|
+
|
184
|
+
|
168
185
|
def _Part_to_mldev(
|
169
186
|
from_object: Union[dict[str, Any], object],
|
170
187
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -203,6 +220,13 @@ def _Part_to_mldev(
|
|
203
220
|
getv(from_object, ['thought_signature']),
|
204
221
|
)
|
205
222
|
|
223
|
+
if getv(from_object, ['function_call']) is not None:
|
224
|
+
setv(
|
225
|
+
to_object,
|
226
|
+
['functionCall'],
|
227
|
+
_FunctionCall_to_mldev(getv(from_object, ['function_call']), to_object),
|
228
|
+
)
|
229
|
+
|
206
230
|
if getv(from_object, ['code_execution_result']) is not None:
|
207
231
|
setv(
|
208
232
|
to_object,
|
@@ -213,9 +237,6 @@ def _Part_to_mldev(
|
|
213
237
|
if getv(from_object, ['executable_code']) is not None:
|
214
238
|
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
|
215
239
|
|
216
|
-
if getv(from_object, ['function_call']) is not None:
|
217
|
-
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
|
218
|
-
|
219
240
|
if getv(from_object, ['function_response']) is not None:
|
220
241
|
setv(
|
221
242
|
to_object,
|
@@ -1317,6 +1338,23 @@ def _FileData_to_vertex(
|
|
1317
1338
|
return to_object
|
1318
1339
|
|
1319
1340
|
|
1341
|
+
def _FunctionCall_to_vertex(
|
1342
|
+
from_object: Union[dict[str, Any], object],
|
1343
|
+
parent_object: Optional[dict[str, Any]] = None,
|
1344
|
+
) -> dict[str, Any]:
|
1345
|
+
to_object: dict[str, Any] = {}
|
1346
|
+
if getv(from_object, ['id']) is not None:
|
1347
|
+
raise ValueError('id parameter is not supported in Vertex AI.')
|
1348
|
+
|
1349
|
+
if getv(from_object, ['args']) is not None:
|
1350
|
+
setv(to_object, ['args'], getv(from_object, ['args']))
|
1351
|
+
|
1352
|
+
if getv(from_object, ['name']) is not None:
|
1353
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
1354
|
+
|
1355
|
+
return to_object
|
1356
|
+
|
1357
|
+
|
1320
1358
|
def _Part_to_vertex(
|
1321
1359
|
from_object: Union[dict[str, Any], object],
|
1322
1360
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -1355,6 +1393,15 @@ def _Part_to_vertex(
|
|
1355
1393
|
getv(from_object, ['thought_signature']),
|
1356
1394
|
)
|
1357
1395
|
|
1396
|
+
if getv(from_object, ['function_call']) is not None:
|
1397
|
+
setv(
|
1398
|
+
to_object,
|
1399
|
+
['functionCall'],
|
1400
|
+
_FunctionCall_to_vertex(
|
1401
|
+
getv(from_object, ['function_call']), to_object
|
1402
|
+
),
|
1403
|
+
)
|
1404
|
+
|
1358
1405
|
if getv(from_object, ['code_execution_result']) is not None:
|
1359
1406
|
setv(
|
1360
1407
|
to_object,
|
@@ -1365,9 +1412,6 @@ def _Part_to_vertex(
|
|
1365
1412
|
if getv(from_object, ['executable_code']) is not None:
|
1366
1413
|
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
|
1367
1414
|
|
1368
|
-
if getv(from_object, ['function_call']) is not None:
|
1369
|
-
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
|
1370
|
-
|
1371
1415
|
if getv(from_object, ['function_response']) is not None:
|
1372
1416
|
setv(
|
1373
1417
|
to_object,
|
@@ -2394,6 +2438,23 @@ def _FileData_from_mldev(
|
|
2394
2438
|
return to_object
|
2395
2439
|
|
2396
2440
|
|
2441
|
+
def _FunctionCall_from_mldev(
|
2442
|
+
from_object: Union[dict[str, Any], object],
|
2443
|
+
parent_object: Optional[dict[str, Any]] = None,
|
2444
|
+
) -> dict[str, Any]:
|
2445
|
+
to_object: dict[str, Any] = {}
|
2446
|
+
if getv(from_object, ['id']) is not None:
|
2447
|
+
setv(to_object, ['id'], getv(from_object, ['id']))
|
2448
|
+
|
2449
|
+
if getv(from_object, ['args']) is not None:
|
2450
|
+
setv(to_object, ['args'], getv(from_object, ['args']))
|
2451
|
+
|
2452
|
+
if getv(from_object, ['name']) is not None:
|
2453
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
2454
|
+
|
2455
|
+
return to_object
|
2456
|
+
|
2457
|
+
|
2397
2458
|
def _Part_from_mldev(
|
2398
2459
|
from_object: Union[dict[str, Any], object],
|
2399
2460
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -2432,6 +2493,15 @@ def _Part_from_mldev(
|
|
2432
2493
|
getv(from_object, ['thoughtSignature']),
|
2433
2494
|
)
|
2434
2495
|
|
2496
|
+
if getv(from_object, ['functionCall']) is not None:
|
2497
|
+
setv(
|
2498
|
+
to_object,
|
2499
|
+
['function_call'],
|
2500
|
+
_FunctionCall_from_mldev(
|
2501
|
+
getv(from_object, ['functionCall']), to_object
|
2502
|
+
),
|
2503
|
+
)
|
2504
|
+
|
2435
2505
|
if getv(from_object, ['codeExecutionResult']) is not None:
|
2436
2506
|
setv(
|
2437
2507
|
to_object,
|
@@ -2442,9 +2512,6 @@ def _Part_from_mldev(
|
|
2442
2512
|
if getv(from_object, ['executableCode']) is not None:
|
2443
2513
|
setv(to_object, ['executable_code'], getv(from_object, ['executableCode']))
|
2444
2514
|
|
2445
|
-
if getv(from_object, ['functionCall']) is not None:
|
2446
|
-
setv(to_object, ['function_call'], getv(from_object, ['functionCall']))
|
2447
|
-
|
2448
2515
|
if getv(from_object, ['functionResponse']) is not None:
|
2449
2516
|
setv(
|
2450
2517
|
to_object,
|
@@ -2591,23 +2658,6 @@ def _LiveServerContent_from_mldev(
|
|
2591
2658
|
return to_object
|
2592
2659
|
|
2593
2660
|
|
2594
|
-
def _FunctionCall_from_mldev(
|
2595
|
-
from_object: Union[dict[str, Any], object],
|
2596
|
-
parent_object: Optional[dict[str, Any]] = None,
|
2597
|
-
) -> dict[str, Any]:
|
2598
|
-
to_object: dict[str, Any] = {}
|
2599
|
-
if getv(from_object, ['id']) is not None:
|
2600
|
-
setv(to_object, ['id'], getv(from_object, ['id']))
|
2601
|
-
|
2602
|
-
if getv(from_object, ['args']) is not None:
|
2603
|
-
setv(to_object, ['args'], getv(from_object, ['args']))
|
2604
|
-
|
2605
|
-
if getv(from_object, ['name']) is not None:
|
2606
|
-
setv(to_object, ['name'], getv(from_object, ['name']))
|
2607
|
-
|
2608
|
-
return to_object
|
2609
|
-
|
2610
|
-
|
2611
2661
|
def _LiveServerToolCall_from_mldev(
|
2612
2662
|
from_object: Union[dict[str, Any], object],
|
2613
2663
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -3111,6 +3161,21 @@ def _FileData_from_vertex(
|
|
3111
3161
|
return to_object
|
3112
3162
|
|
3113
3163
|
|
3164
|
+
def _FunctionCall_from_vertex(
|
3165
|
+
from_object: Union[dict[str, Any], object],
|
3166
|
+
parent_object: Optional[dict[str, Any]] = None,
|
3167
|
+
) -> dict[str, Any]:
|
3168
|
+
to_object: dict[str, Any] = {}
|
3169
|
+
|
3170
|
+
if getv(from_object, ['args']) is not None:
|
3171
|
+
setv(to_object, ['args'], getv(from_object, ['args']))
|
3172
|
+
|
3173
|
+
if getv(from_object, ['name']) is not None:
|
3174
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
3175
|
+
|
3176
|
+
return to_object
|
3177
|
+
|
3178
|
+
|
3114
3179
|
def _Part_from_vertex(
|
3115
3180
|
from_object: Union[dict[str, Any], object],
|
3116
3181
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -3149,6 +3214,15 @@ def _Part_from_vertex(
|
|
3149
3214
|
getv(from_object, ['thoughtSignature']),
|
3150
3215
|
)
|
3151
3216
|
|
3217
|
+
if getv(from_object, ['functionCall']) is not None:
|
3218
|
+
setv(
|
3219
|
+
to_object,
|
3220
|
+
['function_call'],
|
3221
|
+
_FunctionCall_from_vertex(
|
3222
|
+
getv(from_object, ['functionCall']), to_object
|
3223
|
+
),
|
3224
|
+
)
|
3225
|
+
|
3152
3226
|
if getv(from_object, ['codeExecutionResult']) is not None:
|
3153
3227
|
setv(
|
3154
3228
|
to_object,
|
@@ -3159,9 +3233,6 @@ def _Part_from_vertex(
|
|
3159
3233
|
if getv(from_object, ['executableCode']) is not None:
|
3160
3234
|
setv(to_object, ['executable_code'], getv(from_object, ['executableCode']))
|
3161
3235
|
|
3162
|
-
if getv(from_object, ['functionCall']) is not None:
|
3163
|
-
setv(to_object, ['function_call'], getv(from_object, ['functionCall']))
|
3164
|
-
|
3165
3236
|
if getv(from_object, ['functionResponse']) is not None:
|
3166
3237
|
setv(
|
3167
3238
|
to_object,
|
@@ -3263,21 +3334,6 @@ def _LiveServerContent_from_vertex(
|
|
3263
3334
|
return to_object
|
3264
3335
|
|
3265
3336
|
|
3266
|
-
def _FunctionCall_from_vertex(
|
3267
|
-
from_object: Union[dict[str, Any], object],
|
3268
|
-
parent_object: Optional[dict[str, Any]] = None,
|
3269
|
-
) -> dict[str, Any]:
|
3270
|
-
to_object: dict[str, Any] = {}
|
3271
|
-
|
3272
|
-
if getv(from_object, ['args']) is not None:
|
3273
|
-
setv(to_object, ['args'], getv(from_object, ['args']))
|
3274
|
-
|
3275
|
-
if getv(from_object, ['name']) is not None:
|
3276
|
-
setv(to_object, ['name'], getv(from_object, ['name']))
|
3277
|
-
|
3278
|
-
return to_object
|
3279
|
-
|
3280
|
-
|
3281
3337
|
def _LiveServerToolCall_from_vertex(
|
3282
3338
|
from_object: Union[dict[str, Any], object],
|
3283
3339
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -471,6 +471,21 @@ class ReplayApiClient(BaseApiClient):
|
|
471
471
|
expected = interaction.response.sdk_response_segments[
|
472
472
|
self._sdk_response_index
|
473
473
|
]
|
474
|
+
# The sdk_http_response.body has format in the string, need to get rid of
|
475
|
+
# the format information before comparing.
|
476
|
+
if isinstance(expected, dict):
|
477
|
+
if 'sdk_http_response' in expected and isinstance(
|
478
|
+
expected['sdk_http_response'], dict
|
479
|
+
):
|
480
|
+
if 'body' in expected['sdk_http_response']:
|
481
|
+
raw_body = expected['sdk_http_response']['body']
|
482
|
+
print('raw_body length: ', len(raw_body))
|
483
|
+
print('raw_body: ', raw_body)
|
484
|
+
if isinstance(raw_body, str) and raw_body != '':
|
485
|
+
raw_body = json.loads(raw_body)
|
486
|
+
raw_body = json.dumps(raw_body)
|
487
|
+
expected['sdk_http_response']['body'] = raw_body
|
488
|
+
|
474
489
|
assert (
|
475
490
|
actual == expected
|
476
491
|
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
|
@@ -165,6 +165,23 @@ def _FileData_to_mldev(
|
|
165
165
|
return to_object
|
166
166
|
|
167
167
|
|
168
|
+
def _FunctionCall_to_mldev(
|
169
|
+
from_object: Union[dict[str, Any], object],
|
170
|
+
parent_object: Optional[dict[str, Any]] = None,
|
171
|
+
) -> dict[str, Any]:
|
172
|
+
to_object: dict[str, Any] = {}
|
173
|
+
if getv(from_object, ['id']) is not None:
|
174
|
+
setv(to_object, ['id'], getv(from_object, ['id']))
|
175
|
+
|
176
|
+
if getv(from_object, ['args']) is not None:
|
177
|
+
setv(to_object, ['args'], getv(from_object, ['args']))
|
178
|
+
|
179
|
+
if getv(from_object, ['name']) is not None:
|
180
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
181
|
+
|
182
|
+
return to_object
|
183
|
+
|
184
|
+
|
168
185
|
def _Part_to_mldev(
|
169
186
|
from_object: Union[dict[str, Any], object],
|
170
187
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -203,6 +220,13 @@ def _Part_to_mldev(
|
|
203
220
|
getv(from_object, ['thought_signature']),
|
204
221
|
)
|
205
222
|
|
223
|
+
if getv(from_object, ['function_call']) is not None:
|
224
|
+
setv(
|
225
|
+
to_object,
|
226
|
+
['functionCall'],
|
227
|
+
_FunctionCall_to_mldev(getv(from_object, ['function_call']), to_object),
|
228
|
+
)
|
229
|
+
|
206
230
|
if getv(from_object, ['code_execution_result']) is not None:
|
207
231
|
setv(
|
208
232
|
to_object,
|
@@ -213,9 +237,6 @@ def _Part_to_mldev(
|
|
213
237
|
if getv(from_object, ['executable_code']) is not None:
|
214
238
|
setv(to_object, ['executableCode'], getv(from_object, ['executable_code']))
|
215
239
|
|
216
|
-
if getv(from_object, ['function_call']) is not None:
|
217
|
-
setv(to_object, ['functionCall'], getv(from_object, ['function_call']))
|
218
|
-
|
219
240
|
if getv(from_object, ['function_response']) is not None:
|
220
241
|
setv(
|
221
242
|
to_object,
|
google/genai/_transformers.py
CHANGED
@@ -964,30 +964,30 @@ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
964
964
|
|
965
965
|
def t_batch_job_source(
|
966
966
|
client: _api_client.BaseApiClient,
|
967
|
-
src:
|
968
|
-
str, List[types.InlinedRequestOrDict], types.BatchJobSourceOrDict
|
969
|
-
],
|
967
|
+
src: types.BatchJobSourceUnionDict,
|
970
968
|
) -> types.BatchJobSource:
|
971
969
|
if isinstance(src, dict):
|
972
970
|
src = types.BatchJobSource(**src)
|
973
971
|
if isinstance(src, types.BatchJobSource):
|
972
|
+
vertex_sources = sum(
|
973
|
+
[src.gcs_uri is not None, src.bigquery_uri is not None]
|
974
|
+
)
|
975
|
+
mldev_sources = sum([
|
976
|
+
src.inlined_requests is not None,
|
977
|
+
src.file_name is not None,
|
978
|
+
])
|
974
979
|
if client.vertexai:
|
975
|
-
if
|
980
|
+
if mldev_sources or vertex_sources != 1:
|
976
981
|
raise ValueError(
|
977
|
-
'
|
978
|
-
|
979
|
-
elif not src.gcs_uri and not src.bigquery_uri:
|
980
|
-
raise ValueError(
|
981
|
-
'One of `gcs_uri` or `bigquery_uri` must be set.'
|
982
|
+
'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
|
983
|
+
'sources are not supported in Vertex AI.'
|
982
984
|
)
|
983
985
|
else:
|
984
|
-
if
|
986
|
+
if vertex_sources or mldev_sources != 1:
|
985
987
|
raise ValueError(
|
986
|
-
'
|
987
|
-
|
988
|
-
|
989
|
-
raise ValueError(
|
990
|
-
'One of `inlined_requests` or `file_name` must be set.'
|
988
|
+
'Exactly one of `inlined_requests`, `file_name`, '
|
989
|
+
'`inlined_embed_content_requests`, or `embed_content_file_name` '
|
990
|
+
'must be set, other sources are not supported in Gemini API.'
|
991
991
|
)
|
992
992
|
return src
|
993
993
|
|
@@ -1012,6 +1012,29 @@ def t_batch_job_source(
|
|
1012
1012
|
raise ValueError(f'Unsupported source: {src}')
|
1013
1013
|
|
1014
1014
|
|
1015
|
+
def t_embedding_batch_job_source(
|
1016
|
+
client: _api_client.BaseApiClient,
|
1017
|
+
src: types.EmbeddingsBatchJobSourceOrDict,
|
1018
|
+
) -> types.EmbeddingsBatchJobSource:
|
1019
|
+
if isinstance(src, dict):
|
1020
|
+
src = types.EmbeddingsBatchJobSource(**src)
|
1021
|
+
|
1022
|
+
if isinstance(src, types.EmbeddingsBatchJobSource):
|
1023
|
+
mldev_sources = sum([
|
1024
|
+
src.inlined_requests is not None,
|
1025
|
+
src.file_name is not None,
|
1026
|
+
])
|
1027
|
+
if mldev_sources != 1:
|
1028
|
+
raise ValueError(
|
1029
|
+
'Exactly one of `inlined_requests`, `file_name`, '
|
1030
|
+
'`inlined_embed_content_requests`, or `embed_content_file_name` '
|
1031
|
+
'must be set, other sources are not supported in Gemini API.'
|
1032
|
+
)
|
1033
|
+
return src
|
1034
|
+
else:
|
1035
|
+
raise ValueError(f'Unsupported source type: {type(src)}')
|
1036
|
+
|
1037
|
+
|
1015
1038
|
def t_batch_job_destination(
|
1016
1039
|
dest: Union[str, types.BatchJobDestinationOrDict],
|
1017
1040
|
) -> types.BatchJobDestination:
|
@@ -1037,6 +1060,23 @@ def t_batch_job_destination(
|
|
1037
1060
|
raise ValueError(f'Unsupported destination: {dest}')
|
1038
1061
|
|
1039
1062
|
|
1063
|
+
def t_recv_batch_job_destination(dest: dict[str, Any]) -> dict[str, Any]:
|
1064
|
+
# Rename inlinedResponses if it looks like an embedding response.
|
1065
|
+
inline_responses = dest.get('inlinedResponses', {}).get(
|
1066
|
+
'inlinedResponses', []
|
1067
|
+
)
|
1068
|
+
if not inline_responses:
|
1069
|
+
return dest
|
1070
|
+
for response in inline_responses:
|
1071
|
+
inner_response = response.get('response', {})
|
1072
|
+
if not inner_response:
|
1073
|
+
continue
|
1074
|
+
if 'embedding' in inner_response:
|
1075
|
+
dest['inlinedEmbedContentResponses'] = dest.pop('inlinedResponses')
|
1076
|
+
break
|
1077
|
+
return dest
|
1078
|
+
|
1079
|
+
|
1040
1080
|
def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
|
1041
1081
|
if not client.vertexai:
|
1042
1082
|
mldev_pattern = r'batches/[^/]+$'
|