google-genai 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/tunings.py CHANGED
@@ -24,6 +24,30 @@ from ._common import set_value_by_path as setv
24
24
  from .pagers import AsyncPager, Pager
25
25
 
26
26
 
27
+ def _GetTuningJobConfig_to_mldev(
28
+ api_client: ApiClient,
29
+ from_object: Union[dict, object],
30
+ parent_object: dict = None,
31
+ ) -> dict:
32
+ to_object = {}
33
+ if getv(from_object, ['http_options']) is not None:
34
+ setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
35
+
36
+ return to_object
37
+
38
+
39
+ def _GetTuningJobConfig_to_vertex(
40
+ api_client: ApiClient,
41
+ from_object: Union[dict, object],
42
+ parent_object: dict = None,
43
+ ) -> dict:
44
+ to_object = {}
45
+ if getv(from_object, ['http_options']) is not None:
46
+ setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
47
+
48
+ return to_object
49
+
50
+
27
51
  def _GetTuningJobParameters_to_mldev(
28
52
  api_client: ApiClient,
29
53
  from_object: Union[dict, object],
@@ -33,6 +57,15 @@ def _GetTuningJobParameters_to_mldev(
33
57
  if getv(from_object, ['name']) is not None:
34
58
  setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
35
59
 
60
+ if getv(from_object, ['config']) is not None:
61
+ setv(
62
+ to_object,
63
+ ['config'],
64
+ _GetTuningJobConfig_to_mldev(
65
+ api_client, getv(from_object, ['config']), to_object
66
+ ),
67
+ )
68
+
36
69
  return to_object
37
70
 
38
71
 
@@ -45,6 +78,15 @@ def _GetTuningJobParameters_to_vertex(
45
78
  if getv(from_object, ['name']) is not None:
46
79
  setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
47
80
 
81
+ if getv(from_object, ['config']) is not None:
82
+ setv(
83
+ to_object,
84
+ ['config'],
85
+ _GetTuningJobConfig_to_vertex(
86
+ api_client, getv(from_object, ['config']), to_object
87
+ ),
88
+ )
89
+
48
90
  return to_object
49
91
 
50
92
 
@@ -233,6 +275,9 @@ def _CreateTuningJobConfig_to_mldev(
233
275
  parent_object: dict = None,
234
276
  ) -> dict:
235
277
  to_object = {}
278
+ if getv(from_object, ['http_options']) is not None:
279
+ setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
280
+
236
281
  if getv(from_object, ['validation_dataset']):
237
282
  raise ValueError(
238
283
  'validation_dataset parameter is not supported in Google AI.'
@@ -288,6 +333,9 @@ def _CreateTuningJobConfig_to_vertex(
288
333
  parent_object: dict = None,
289
334
  ) -> dict:
290
335
  to_object = {}
336
+ if getv(from_object, ['http_options']) is not None:
337
+ setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
338
+
291
339
  if getv(from_object, ['validation_dataset']) is not None:
292
340
  setv(
293
341
  parent_object,
@@ -455,6 +503,9 @@ def _CreateDistillationJobConfig_to_mldev(
455
503
  parent_object: dict = None,
456
504
  ) -> dict:
457
505
  to_object = {}
506
+ if getv(from_object, ['http_options']) is not None:
507
+ setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
508
+
458
509
  if getv(from_object, ['validation_dataset']):
459
510
  raise ValueError(
460
511
  'validation_dataset parameter is not supported in Google AI.'
@@ -498,6 +549,9 @@ def _CreateDistillationJobConfig_to_vertex(
498
549
  parent_object: dict = None,
499
550
  ) -> dict:
500
551
  to_object = {}
552
+ if getv(from_object, ['http_options']) is not None:
553
+ setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
554
+
501
555
  if getv(from_object, ['validation_dataset']) is not None:
502
556
  setv(
503
557
  parent_object,
@@ -896,7 +950,12 @@ def _TuningJobOrOperation_from_vertex(
896
950
 
897
951
  class Tunings(_common.BaseModule):
898
952
 
899
- def get(self, *, name: str) -> types.TuningJob:
953
+ def _get(
954
+ self,
955
+ *,
956
+ name: str,
957
+ config: Optional[types.GetTuningJobConfigOrDict] = None,
958
+ ) -> types.TuningJob:
900
959
  """Gets a TuningJob.
901
960
 
902
961
  Args:
@@ -908,6 +967,7 @@ class Tunings(_common.BaseModule):
908
967
 
909
968
  parameter_model = types._GetTuningJobParameters(
910
969
  name=name,
970
+ config=config,
911
971
  )
912
972
 
913
973
  if self.api_client.vertexai:
@@ -998,7 +1058,7 @@ class Tunings(_common.BaseModule):
998
1058
  self.api_client._verify_response(return_value)
999
1059
  return return_value
1000
1060
 
1001
- def tune(
1061
+ def _tune(
1002
1062
  self,
1003
1063
  *,
1004
1064
  base_model: str,
@@ -1129,10 +1189,45 @@ class Tunings(_common.BaseModule):
1129
1189
  config,
1130
1190
  )
1131
1191
 
1192
+ def get(
1193
+ self,
1194
+ *,
1195
+ name: str,
1196
+ config: Optional[types.GetTuningJobConfigOrDict] = None,
1197
+ ) -> types.TuningJob:
1198
+ job = self._get(name=name, config=config)
1199
+ if job.experiment and self.api_client.vertexai:
1200
+ _IpythonUtils.display_experiment_button(
1201
+ experiment=job.experiment,
1202
+ project=self.api_client.project,
1203
+ )
1204
+ return job
1205
+
1206
+ def tune(
1207
+ self,
1208
+ *,
1209
+ base_model: str,
1210
+ training_dataset: types.TuningDatasetOrDict,
1211
+ config: Optional[types.CreateTuningJobConfigOrDict] = None,
1212
+ ) -> types.TuningJobOrOperation:
1213
+ result = self._tune(
1214
+ base_model=base_model,
1215
+ training_dataset=training_dataset,
1216
+ config=config,
1217
+ )
1218
+ if result.name and self.api_client.vertexai:
1219
+ _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
1220
+ return result
1221
+
1132
1222
 
1133
1223
  class AsyncTunings(_common.BaseModule):
1134
1224
 
1135
- async def get(self, *, name: str) -> types.TuningJob:
1225
+ async def _get(
1226
+ self,
1227
+ *,
1228
+ name: str,
1229
+ config: Optional[types.GetTuningJobConfigOrDict] = None,
1230
+ ) -> types.TuningJob:
1136
1231
  """Gets a TuningJob.
1137
1232
 
1138
1233
  Args:
@@ -1144,6 +1239,7 @@ class AsyncTunings(_common.BaseModule):
1144
1239
 
1145
1240
  parameter_model = types._GetTuningJobParameters(
1146
1241
  name=name,
1242
+ config=config,
1147
1243
  )
1148
1244
 
1149
1245
  if self.api_client.vertexai:
@@ -1234,7 +1330,7 @@ class AsyncTunings(_common.BaseModule):
1234
1330
  self.api_client._verify_response(return_value)
1235
1331
  return return_value
1236
1332
 
1237
- async def tune(
1333
+ async def _tune(
1238
1334
  self,
1239
1335
  *,
1240
1336
  base_model: str,
@@ -1364,3 +1460,220 @@ class AsyncTunings(_common.BaseModule):
1364
1460
  await self._list(config=config),
1365
1461
  config,
1366
1462
  )
1463
+
1464
+ async def get(
1465
+ self,
1466
+ *,
1467
+ name: str,
1468
+ config: Optional[types.GetTuningJobConfigOrDict] = None,
1469
+ ) -> types.TuningJob:
1470
+ job = await self._get(name=name, config=config)
1471
+ if job.experiment and self.api_client.vertexai:
1472
+ _IpythonUtils.display_experiment_button(
1473
+ experiment=job.experiment,
1474
+ project=self.api_client.project,
1475
+ )
1476
+ return job
1477
+
1478
+ async def tune(
1479
+ self,
1480
+ *,
1481
+ base_model: str,
1482
+ training_dataset: types.TuningDatasetOrDict,
1483
+ config: Optional[types.CreateTuningJobConfigOrDict] = None,
1484
+ ) -> types.TuningJobOrOperation:
1485
+ result = await self._tune(
1486
+ base_model=base_model,
1487
+ training_dataset=training_dataset,
1488
+ config=config,
1489
+ )
1490
+ if result.name and self.api_client.vertexai:
1491
+ _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
1492
+ return result
1493
+
1494
+
1495
+ class _IpythonUtils:
1496
+ """Temporary class to hold the IPython related functions."""
1497
+
1498
+ displayed_experiments = set()
1499
+
1500
+ @staticmethod
1501
+ def _get_ipython_shell_name() -> str:
1502
+ import sys
1503
+
1504
+ if 'IPython' in sys.modules:
1505
+ from IPython import get_ipython
1506
+
1507
+ return get_ipython().__class__.__name__
1508
+ return ''
1509
+
1510
+ @staticmethod
1511
+ def is_ipython_available() -> bool:
1512
+ return bool(_IpythonUtils._get_ipython_shell_name())
1513
+
1514
+ @staticmethod
1515
+ def _get_styles() -> None:
1516
+ """Returns the HTML style markup to support custom buttons."""
1517
+ return """
1518
+ <link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
1519
+ <style>
1520
+ .view-vertex-resource,
1521
+ .view-vertex-resource:hover,
1522
+ .view-vertex-resource:visited {
1523
+ position: relative;
1524
+ display: inline-flex;
1525
+ flex-direction: row;
1526
+ height: 32px;
1527
+ padding: 0 12px;
1528
+ margin: 4px 18px;
1529
+ gap: 4px;
1530
+ border-radius: 4px;
1531
+
1532
+ align-items: center;
1533
+ justify-content: center;
1534
+ background-color: rgb(255, 255, 255);
1535
+ color: rgb(51, 103, 214);
1536
+
1537
+ font-family: Roboto,"Helvetica Neue",sans-serif;
1538
+ font-size: 13px;
1539
+ font-weight: 500;
1540
+ text-transform: uppercase;
1541
+ text-decoration: none !important;
1542
+
1543
+ transition: box-shadow 280ms cubic-bezier(0.4, 0, 0.2, 1) 0s;
1544
+ box-shadow: 0px 3px 1px -2px rgba(0,0,0,0.2), 0px 2px 2px 0px rgba(0,0,0,0.14), 0px 1px 5px 0px rgba(0,0,0,0.12);
1545
+ }
1546
+ .view-vertex-resource:active {
1547
+ box-shadow: 0px 5px 5px -3px rgba(0,0,0,0.2),0px 8px 10px 1px rgba(0,0,0,0.14),0px 3px 14px 2px rgba(0,0,0,0.12);
1548
+ }
1549
+ .view-vertex-resource:active .view-vertex-ripple::before {
1550
+ position: absolute;
1551
+ top: 0;
1552
+ bottom: 0;
1553
+ left: 0;
1554
+ right: 0;
1555
+ border-radius: 4px;
1556
+ pointer-events: none;
1557
+
1558
+ content: '';
1559
+ background-color: rgb(51, 103, 214);
1560
+ opacity: 0.12;
1561
+ }
1562
+ .view-vertex-icon {
1563
+ font-size: 18px;
1564
+ }
1565
+ </style>
1566
+ """
1567
+
1568
+ @staticmethod
1569
+ def _parse_resource_name(marker: str, resource_parts: list[str]) -> str:
1570
+ """Returns the part after the marker text part."""
1571
+ for i in range(len(resource_parts)):
1572
+ if resource_parts[i] == marker and i + 1 < len(resource_parts):
1573
+ return resource_parts[i + 1]
1574
+ return ''
1575
+
1576
+ @staticmethod
1577
+ def _display_link(
1578
+ text: str, url: str, icon: Optional[str] = 'open_in_new'
1579
+ ) -> None:
1580
+ """Creates and displays the link to open the Vertex resource.
1581
+
1582
+ Args:
1583
+ text: The text displayed on the clickable button.
1584
+ url: The url that the button will lead to. Only cloud console URIs are
1585
+ allowed.
1586
+ icon: The icon name on the button (from material-icons library)
1587
+ """
1588
+ CLOUD_UI_URL = 'https://console.cloud.google.com' # pylint: disable=invalid-name
1589
+ if not url.startswith(CLOUD_UI_URL):
1590
+ raise ValueError(f'Only urls starting with {CLOUD_UI_URL} are allowed.')
1591
+
1592
+ import uuid
1593
+
1594
+ button_id = f'view-vertex-resource-{str(uuid.uuid4())}'
1595
+
1596
+ # Add the markup for the CSS and link component
1597
+ html = f"""
1598
+ {_IpythonUtils._get_styles()}
1599
+ <a class="view-vertex-resource" id="{button_id}" href="#view-{button_id}">
1600
+ <span class="material-icons view-vertex-icon">{icon}</span>
1601
+ <span>{text}</span>
1602
+ </a>
1603
+ """
1604
+
1605
+ # Add the click handler for the link
1606
+ html += f"""
1607
+ <script>
1608
+ (function () {{
1609
+ const link = document.getElementById('{button_id}');
1610
+ link.addEventListener('click', (e) => {{
1611
+ if (window.google?.colab?.openUrl) {{
1612
+ window.google.colab.openUrl('{url}');
1613
+ }} else {{
1614
+ window.open('{url}', '_blank');
1615
+ }}
1616
+ e.stopPropagation();
1617
+ e.preventDefault();
1618
+ }});
1619
+ }})();
1620
+ </script>
1621
+ """
1622
+
1623
+ from IPython.core.display import display
1624
+ from IPython.display import HTML
1625
+
1626
+ display(HTML(html))
1627
+
1628
+ @staticmethod
1629
+ def display_experiment_button(experiment: str, project: str) -> None:
1630
+ """Function to generate a link bound to the Vertex experiment.
1631
+
1632
+ Args:
1633
+ experiment: The Vertex experiment name. Example format:
1634
+ projects/{project_id}/locations/{location}/metadataStores/default/contexts/{experiment_name}
1635
+ project: The project (alphanumeric) name.
1636
+ """
1637
+ if (
1638
+ not _IpythonUtils.is_ipython_available()
1639
+ or experiment in _IpythonUtils.displayed_experiments
1640
+ ):
1641
+ return
1642
+ # Experiment gives the numeric id, but we need the alphanumeric project
1643
+ # name. So we get the project from the api client object as an argument.
1644
+ resource_parts = experiment.split('/')
1645
+ location = resource_parts[3]
1646
+ experiment_name = resource_parts[-1]
1647
+
1648
+ uri = (
1649
+ 'https://console.cloud.google.com/vertex-ai/experiments/locations/'
1650
+ + f'{location}/experiments/{experiment_name}/'
1651
+ + f'runs?project={project}'
1652
+ )
1653
+ _IpythonUtils._display_link('View Experiment', uri, 'science')
1654
+
1655
+ # Avoid repeatedly showing the button
1656
+ _IpythonUtils.displayed_experiments.add(experiment)
1657
+
1658
+ @staticmethod
1659
+ def display_model_tuning_button(tuning_job_resource: str) -> None:
1660
+ """Function to generate a link bound to the Vertex model tuning job.
1661
+
1662
+ Args:
1663
+ tuning_job_resource: The Vertex tuning job name. Example format:
1664
+ projects/{project_id}/locations/{location}/tuningJobs/{tuning_job_id}
1665
+ """
1666
+ if not _IpythonUtils.is_ipython_available():
1667
+ return
1668
+
1669
+ resource_parts = tuning_job_resource.split('/')
1670
+ project = resource_parts[1]
1671
+ location = resource_parts[3]
1672
+ tuning_job_id = resource_parts[-1]
1673
+
1674
+ uri = (
1675
+ 'https://console.cloud.google.com/vertex-ai/generative/language/'
1676
+ + f'locations/{location}/tuning/tuningJob/{tuning_job_id}'
1677
+ + f'?project={project}'
1678
+ )
1679
+ _IpythonUtils._display_link('View Tuning Job', uri, 'tune')