google-genai 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +2 -0
- google/genai/_api_client.py +14 -6
- google/genai/_automatic_function_calling_util.py +0 -44
- google/genai/_extra_utils.py +15 -0
- google/genai/_transformers.py +3 -2
- google/genai/batches.py +254 -4
- google/genai/caches.py +10 -0
- google/genai/chats.py +14 -2
- google/genai/files.py +6 -0
- google/genai/live.py +74 -42
- google/genai/models.py +110 -11
- google/genai/tunings.py +317 -4
- google/genai/types.py +482 -85
- {google_genai-0.0.1.dist-info → google_genai-0.2.0.dist-info}/METADATA +75 -58
- google_genai-0.2.0.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/RECORD +0 -24
- {google_genai-0.0.1.dist-info → google_genai-0.2.0.dist-info}/LICENSE +0 -0
- {google_genai-0.0.1.dist-info → google_genai-0.2.0.dist-info}/WHEEL +0 -0
- {google_genai-0.0.1.dist-info → google_genai-0.2.0.dist-info}/top_level.txt +0 -0
google/genai/tunings.py
CHANGED
@@ -24,6 +24,30 @@ from ._common import set_value_by_path as setv
|
|
24
24
|
from .pagers import AsyncPager, Pager
|
25
25
|
|
26
26
|
|
27
|
+
def _GetTuningJobConfig_to_mldev(
|
28
|
+
api_client: ApiClient,
|
29
|
+
from_object: Union[dict, object],
|
30
|
+
parent_object: dict = None,
|
31
|
+
) -> dict:
|
32
|
+
to_object = {}
|
33
|
+
if getv(from_object, ['http_options']) is not None:
|
34
|
+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
35
|
+
|
36
|
+
return to_object
|
37
|
+
|
38
|
+
|
39
|
+
def _GetTuningJobConfig_to_vertex(
|
40
|
+
api_client: ApiClient,
|
41
|
+
from_object: Union[dict, object],
|
42
|
+
parent_object: dict = None,
|
43
|
+
) -> dict:
|
44
|
+
to_object = {}
|
45
|
+
if getv(from_object, ['http_options']) is not None:
|
46
|
+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
47
|
+
|
48
|
+
return to_object
|
49
|
+
|
50
|
+
|
27
51
|
def _GetTuningJobParameters_to_mldev(
|
28
52
|
api_client: ApiClient,
|
29
53
|
from_object: Union[dict, object],
|
@@ -33,6 +57,15 @@ def _GetTuningJobParameters_to_mldev(
|
|
33
57
|
if getv(from_object, ['name']) is not None:
|
34
58
|
setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
|
35
59
|
|
60
|
+
if getv(from_object, ['config']) is not None:
|
61
|
+
setv(
|
62
|
+
to_object,
|
63
|
+
['config'],
|
64
|
+
_GetTuningJobConfig_to_mldev(
|
65
|
+
api_client, getv(from_object, ['config']), to_object
|
66
|
+
),
|
67
|
+
)
|
68
|
+
|
36
69
|
return to_object
|
37
70
|
|
38
71
|
|
@@ -45,6 +78,15 @@ def _GetTuningJobParameters_to_vertex(
|
|
45
78
|
if getv(from_object, ['name']) is not None:
|
46
79
|
setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
|
47
80
|
|
81
|
+
if getv(from_object, ['config']) is not None:
|
82
|
+
setv(
|
83
|
+
to_object,
|
84
|
+
['config'],
|
85
|
+
_GetTuningJobConfig_to_vertex(
|
86
|
+
api_client, getv(from_object, ['config']), to_object
|
87
|
+
),
|
88
|
+
)
|
89
|
+
|
48
90
|
return to_object
|
49
91
|
|
50
92
|
|
@@ -233,6 +275,9 @@ def _CreateTuningJobConfig_to_mldev(
|
|
233
275
|
parent_object: dict = None,
|
234
276
|
) -> dict:
|
235
277
|
to_object = {}
|
278
|
+
if getv(from_object, ['http_options']) is not None:
|
279
|
+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
280
|
+
|
236
281
|
if getv(from_object, ['validation_dataset']):
|
237
282
|
raise ValueError(
|
238
283
|
'validation_dataset parameter is not supported in Google AI.'
|
@@ -288,6 +333,9 @@ def _CreateTuningJobConfig_to_vertex(
|
|
288
333
|
parent_object: dict = None,
|
289
334
|
) -> dict:
|
290
335
|
to_object = {}
|
336
|
+
if getv(from_object, ['http_options']) is not None:
|
337
|
+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
338
|
+
|
291
339
|
if getv(from_object, ['validation_dataset']) is not None:
|
292
340
|
setv(
|
293
341
|
parent_object,
|
@@ -455,6 +503,9 @@ def _CreateDistillationJobConfig_to_mldev(
|
|
455
503
|
parent_object: dict = None,
|
456
504
|
) -> dict:
|
457
505
|
to_object = {}
|
506
|
+
if getv(from_object, ['http_options']) is not None:
|
507
|
+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
508
|
+
|
458
509
|
if getv(from_object, ['validation_dataset']):
|
459
510
|
raise ValueError(
|
460
511
|
'validation_dataset parameter is not supported in Google AI.'
|
@@ -498,6 +549,9 @@ def _CreateDistillationJobConfig_to_vertex(
|
|
498
549
|
parent_object: dict = None,
|
499
550
|
) -> dict:
|
500
551
|
to_object = {}
|
552
|
+
if getv(from_object, ['http_options']) is not None:
|
553
|
+
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
554
|
+
|
501
555
|
if getv(from_object, ['validation_dataset']) is not None:
|
502
556
|
setv(
|
503
557
|
parent_object,
|
@@ -896,7 +950,12 @@ def _TuningJobOrOperation_from_vertex(
|
|
896
950
|
|
897
951
|
class Tunings(_common.BaseModule):
|
898
952
|
|
899
|
-
def
|
953
|
+
def _get(
|
954
|
+
self,
|
955
|
+
*,
|
956
|
+
name: str,
|
957
|
+
config: Optional[types.GetTuningJobConfigOrDict] = None,
|
958
|
+
) -> types.TuningJob:
|
900
959
|
"""Gets a TuningJob.
|
901
960
|
|
902
961
|
Args:
|
@@ -908,6 +967,7 @@ class Tunings(_common.BaseModule):
|
|
908
967
|
|
909
968
|
parameter_model = types._GetTuningJobParameters(
|
910
969
|
name=name,
|
970
|
+
config=config,
|
911
971
|
)
|
912
972
|
|
913
973
|
if self.api_client.vertexai:
|
@@ -998,7 +1058,7 @@ class Tunings(_common.BaseModule):
|
|
998
1058
|
self.api_client._verify_response(return_value)
|
999
1059
|
return return_value
|
1000
1060
|
|
1001
|
-
def
|
1061
|
+
def _tune(
|
1002
1062
|
self,
|
1003
1063
|
*,
|
1004
1064
|
base_model: str,
|
@@ -1129,10 +1189,45 @@ class Tunings(_common.BaseModule):
|
|
1129
1189
|
config,
|
1130
1190
|
)
|
1131
1191
|
|
1192
|
+
def get(
|
1193
|
+
self,
|
1194
|
+
*,
|
1195
|
+
name: str,
|
1196
|
+
config: Optional[types.GetTuningJobConfigOrDict] = None,
|
1197
|
+
) -> types.TuningJob:
|
1198
|
+
job = self._get(name=name, config=config)
|
1199
|
+
if job.experiment and self.api_client.vertexai:
|
1200
|
+
_IpythonUtils.display_experiment_button(
|
1201
|
+
experiment=job.experiment,
|
1202
|
+
project=self.api_client.project,
|
1203
|
+
)
|
1204
|
+
return job
|
1205
|
+
|
1206
|
+
def tune(
|
1207
|
+
self,
|
1208
|
+
*,
|
1209
|
+
base_model: str,
|
1210
|
+
training_dataset: types.TuningDatasetOrDict,
|
1211
|
+
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1212
|
+
) -> types.TuningJobOrOperation:
|
1213
|
+
result = self._tune(
|
1214
|
+
base_model=base_model,
|
1215
|
+
training_dataset=training_dataset,
|
1216
|
+
config=config,
|
1217
|
+
)
|
1218
|
+
if result.name and self.api_client.vertexai:
|
1219
|
+
_IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
|
1220
|
+
return result
|
1221
|
+
|
1132
1222
|
|
1133
1223
|
class AsyncTunings(_common.BaseModule):
|
1134
1224
|
|
1135
|
-
async def
|
1225
|
+
async def _get(
|
1226
|
+
self,
|
1227
|
+
*,
|
1228
|
+
name: str,
|
1229
|
+
config: Optional[types.GetTuningJobConfigOrDict] = None,
|
1230
|
+
) -> types.TuningJob:
|
1136
1231
|
"""Gets a TuningJob.
|
1137
1232
|
|
1138
1233
|
Args:
|
@@ -1144,6 +1239,7 @@ class AsyncTunings(_common.BaseModule):
|
|
1144
1239
|
|
1145
1240
|
parameter_model = types._GetTuningJobParameters(
|
1146
1241
|
name=name,
|
1242
|
+
config=config,
|
1147
1243
|
)
|
1148
1244
|
|
1149
1245
|
if self.api_client.vertexai:
|
@@ -1234,7 +1330,7 @@ class AsyncTunings(_common.BaseModule):
|
|
1234
1330
|
self.api_client._verify_response(return_value)
|
1235
1331
|
return return_value
|
1236
1332
|
|
1237
|
-
async def
|
1333
|
+
async def _tune(
|
1238
1334
|
self,
|
1239
1335
|
*,
|
1240
1336
|
base_model: str,
|
@@ -1364,3 +1460,220 @@ class AsyncTunings(_common.BaseModule):
|
|
1364
1460
|
await self._list(config=config),
|
1365
1461
|
config,
|
1366
1462
|
)
|
1463
|
+
|
1464
|
+
async def get(
|
1465
|
+
self,
|
1466
|
+
*,
|
1467
|
+
name: str,
|
1468
|
+
config: Optional[types.GetTuningJobConfigOrDict] = None,
|
1469
|
+
) -> types.TuningJob:
|
1470
|
+
job = await self._get(name=name, config=config)
|
1471
|
+
if job.experiment and self.api_client.vertexai:
|
1472
|
+
_IpythonUtils.display_experiment_button(
|
1473
|
+
experiment=job.experiment,
|
1474
|
+
project=self.api_client.project,
|
1475
|
+
)
|
1476
|
+
return job
|
1477
|
+
|
1478
|
+
async def tune(
|
1479
|
+
self,
|
1480
|
+
*,
|
1481
|
+
base_model: str,
|
1482
|
+
training_dataset: types.TuningDatasetOrDict,
|
1483
|
+
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1484
|
+
) -> types.TuningJobOrOperation:
|
1485
|
+
result = await self._tune(
|
1486
|
+
base_model=base_model,
|
1487
|
+
training_dataset=training_dataset,
|
1488
|
+
config=config,
|
1489
|
+
)
|
1490
|
+
if result.name and self.api_client.vertexai:
|
1491
|
+
_IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
|
1492
|
+
return result
|
1493
|
+
|
1494
|
+
|
1495
|
+
class _IpythonUtils:
|
1496
|
+
"""Temporary class to hold the IPython related functions."""
|
1497
|
+
|
1498
|
+
displayed_experiments = set()
|
1499
|
+
|
1500
|
+
@staticmethod
|
1501
|
+
def _get_ipython_shell_name() -> str:
|
1502
|
+
import sys
|
1503
|
+
|
1504
|
+
if 'IPython' in sys.modules:
|
1505
|
+
from IPython import get_ipython
|
1506
|
+
|
1507
|
+
return get_ipython().__class__.__name__
|
1508
|
+
return ''
|
1509
|
+
|
1510
|
+
@staticmethod
|
1511
|
+
def is_ipython_available() -> bool:
|
1512
|
+
return bool(_IpythonUtils._get_ipython_shell_name())
|
1513
|
+
|
1514
|
+
@staticmethod
|
1515
|
+
def _get_styles() -> None:
|
1516
|
+
"""Returns the HTML style markup to support custom buttons."""
|
1517
|
+
return """
|
1518
|
+
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
1519
|
+
<style>
|
1520
|
+
.view-vertex-resource,
|
1521
|
+
.view-vertex-resource:hover,
|
1522
|
+
.view-vertex-resource:visited {
|
1523
|
+
position: relative;
|
1524
|
+
display: inline-flex;
|
1525
|
+
flex-direction: row;
|
1526
|
+
height: 32px;
|
1527
|
+
padding: 0 12px;
|
1528
|
+
margin: 4px 18px;
|
1529
|
+
gap: 4px;
|
1530
|
+
border-radius: 4px;
|
1531
|
+
|
1532
|
+
align-items: center;
|
1533
|
+
justify-content: center;
|
1534
|
+
background-color: rgb(255, 255, 255);
|
1535
|
+
color: rgb(51, 103, 214);
|
1536
|
+
|
1537
|
+
font-family: Roboto,"Helvetica Neue",sans-serif;
|
1538
|
+
font-size: 13px;
|
1539
|
+
font-weight: 500;
|
1540
|
+
text-transform: uppercase;
|
1541
|
+
text-decoration: none !important;
|
1542
|
+
|
1543
|
+
transition: box-shadow 280ms cubic-bezier(0.4, 0, 0.2, 1) 0s;
|
1544
|
+
box-shadow: 0px 3px 1px -2px rgba(0,0,0,0.2), 0px 2px 2px 0px rgba(0,0,0,0.14), 0px 1px 5px 0px rgba(0,0,0,0.12);
|
1545
|
+
}
|
1546
|
+
.view-vertex-resource:active {
|
1547
|
+
box-shadow: 0px 5px 5px -3px rgba(0,0,0,0.2),0px 8px 10px 1px rgba(0,0,0,0.14),0px 3px 14px 2px rgba(0,0,0,0.12);
|
1548
|
+
}
|
1549
|
+
.view-vertex-resource:active .view-vertex-ripple::before {
|
1550
|
+
position: absolute;
|
1551
|
+
top: 0;
|
1552
|
+
bottom: 0;
|
1553
|
+
left: 0;
|
1554
|
+
right: 0;
|
1555
|
+
border-radius: 4px;
|
1556
|
+
pointer-events: none;
|
1557
|
+
|
1558
|
+
content: '';
|
1559
|
+
background-color: rgb(51, 103, 214);
|
1560
|
+
opacity: 0.12;
|
1561
|
+
}
|
1562
|
+
.view-vertex-icon {
|
1563
|
+
font-size: 18px;
|
1564
|
+
}
|
1565
|
+
</style>
|
1566
|
+
"""
|
1567
|
+
|
1568
|
+
@staticmethod
|
1569
|
+
def _parse_resource_name(marker: str, resource_parts: list[str]) -> str:
|
1570
|
+
"""Returns the part after the marker text part."""
|
1571
|
+
for i in range(len(resource_parts)):
|
1572
|
+
if resource_parts[i] == marker and i + 1 < len(resource_parts):
|
1573
|
+
return resource_parts[i + 1]
|
1574
|
+
return ''
|
1575
|
+
|
1576
|
+
@staticmethod
|
1577
|
+
def _display_link(
|
1578
|
+
text: str, url: str, icon: Optional[str] = 'open_in_new'
|
1579
|
+
) -> None:
|
1580
|
+
"""Creates and displays the link to open the Vertex resource.
|
1581
|
+
|
1582
|
+
Args:
|
1583
|
+
text: The text displayed on the clickable button.
|
1584
|
+
url: The url that the button will lead to. Only cloud console URIs are
|
1585
|
+
allowed.
|
1586
|
+
icon: The icon name on the button (from material-icons library)
|
1587
|
+
"""
|
1588
|
+
CLOUD_UI_URL = 'https://console.cloud.google.com' # pylint: disable=invalid-name
|
1589
|
+
if not url.startswith(CLOUD_UI_URL):
|
1590
|
+
raise ValueError(f'Only urls starting with {CLOUD_UI_URL} are allowed.')
|
1591
|
+
|
1592
|
+
import uuid
|
1593
|
+
|
1594
|
+
button_id = f'view-vertex-resource-{str(uuid.uuid4())}'
|
1595
|
+
|
1596
|
+
# Add the markup for the CSS and link component
|
1597
|
+
html = f"""
|
1598
|
+
{_IpythonUtils._get_styles()}
|
1599
|
+
<a class="view-vertex-resource" id="{button_id}" href="#view-{button_id}">
|
1600
|
+
<span class="material-icons view-vertex-icon">{icon}</span>
|
1601
|
+
<span>{text}</span>
|
1602
|
+
</a>
|
1603
|
+
"""
|
1604
|
+
|
1605
|
+
# Add the click handler for the link
|
1606
|
+
html += f"""
|
1607
|
+
<script>
|
1608
|
+
(function () {{
|
1609
|
+
const link = document.getElementById('{button_id}');
|
1610
|
+
link.addEventListener('click', (e) => {{
|
1611
|
+
if (window.google?.colab?.openUrl) {{
|
1612
|
+
window.google.colab.openUrl('{url}');
|
1613
|
+
}} else {{
|
1614
|
+
window.open('{url}', '_blank');
|
1615
|
+
}}
|
1616
|
+
e.stopPropagation();
|
1617
|
+
e.preventDefault();
|
1618
|
+
}});
|
1619
|
+
}})();
|
1620
|
+
</script>
|
1621
|
+
"""
|
1622
|
+
|
1623
|
+
from IPython.core.display import display
|
1624
|
+
from IPython.display import HTML
|
1625
|
+
|
1626
|
+
display(HTML(html))
|
1627
|
+
|
1628
|
+
@staticmethod
|
1629
|
+
def display_experiment_button(experiment: str, project: str) -> None:
|
1630
|
+
"""Function to generate a link bound to the Vertex experiment.
|
1631
|
+
|
1632
|
+
Args:
|
1633
|
+
experiment: The Vertex experiment name. Example format:
|
1634
|
+
projects/{project_id}/locations/{location}/metadataStores/default/contexts/{experiment_name}
|
1635
|
+
project: The project (alphanumeric) name.
|
1636
|
+
"""
|
1637
|
+
if (
|
1638
|
+
not _IpythonUtils.is_ipython_available()
|
1639
|
+
or experiment in _IpythonUtils.displayed_experiments
|
1640
|
+
):
|
1641
|
+
return
|
1642
|
+
# Experiment gives the numeric id, but we need the alphanumeric project
|
1643
|
+
# name. So we get the project from the api client object as an argument.
|
1644
|
+
resource_parts = experiment.split('/')
|
1645
|
+
location = resource_parts[3]
|
1646
|
+
experiment_name = resource_parts[-1]
|
1647
|
+
|
1648
|
+
uri = (
|
1649
|
+
'https://console.cloud.google.com/vertex-ai/experiments/locations/'
|
1650
|
+
+ f'{location}/experiments/{experiment_name}/'
|
1651
|
+
+ f'runs?project={project}'
|
1652
|
+
)
|
1653
|
+
_IpythonUtils._display_link('View Experiment', uri, 'science')
|
1654
|
+
|
1655
|
+
# Avoid repeatedly showing the button
|
1656
|
+
_IpythonUtils.displayed_experiments.add(experiment)
|
1657
|
+
|
1658
|
+
@staticmethod
|
1659
|
+
def display_model_tuning_button(tuning_job_resource: str) -> None:
|
1660
|
+
"""Function to generate a link bound to the Vertex model tuning job.
|
1661
|
+
|
1662
|
+
Args:
|
1663
|
+
tuning_job_resource: The Vertex tuning job name. Example format:
|
1664
|
+
projects/{project_id}/locations/{location}/tuningJobs/{tuning_job_id}
|
1665
|
+
"""
|
1666
|
+
if not _IpythonUtils.is_ipython_available():
|
1667
|
+
return
|
1668
|
+
|
1669
|
+
resource_parts = tuning_job_resource.split('/')
|
1670
|
+
project = resource_parts[1]
|
1671
|
+
location = resource_parts[3]
|
1672
|
+
tuning_job_id = resource_parts[-1]
|
1673
|
+
|
1674
|
+
uri = (
|
1675
|
+
'https://console.cloud.google.com/vertex-ai/generative/language/'
|
1676
|
+
+ f'locations/{location}/tuning/tuningJob/{tuning_job_id}'
|
1677
|
+
+ f'?project={project}'
|
1678
|
+
)
|
1679
|
+
_IpythonUtils._display_link('View Tuning Job', uri, 'tune')
|