zenml-nightly 0.61.0.dev20240711__py3-none-any.whl → 0.61.0.dev20240713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/stack.py +220 -71
- zenml/constants.py +6 -0
- zenml/enums.py +16 -0
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +54 -6
- zenml/logging/step_logging.py +34 -35
- zenml/models/__init__.py +2 -0
- zenml/models/v2/core/server_settings.py +0 -20
- zenml/models/v2/misc/stack_deployment.py +20 -0
- zenml/orchestrators/step_launcher.py +1 -0
- zenml/stack_deployments/aws_stack_deployment.py +56 -91
- zenml/stack_deployments/gcp_stack_deployment.py +260 -0
- zenml/stack_deployments/stack_deployment.py +103 -25
- zenml/stack_deployments/utils.py +4 -0
- zenml/zen_server/routers/devices_endpoints.py +4 -1
- zenml/zen_server/routers/server_endpoints.py +29 -2
- zenml/zen_server/routers/stack_deployment_endpoints.py +34 -20
- zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py +167 -0
- zenml/zen_stores/rest_zen_store.py +45 -21
- zenml/zen_stores/schemas/server_settings_schemas.py +23 -11
- zenml/zen_stores/sql_zen_store.py +117 -19
- zenml/zen_stores/zen_store_interface.py +6 -5
- {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/METADATA +1 -1
- {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/RECORD +27 -25
- {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.61.0.
|
1
|
+
0.61.0.dev20240713
|
zenml/cli/stack.py
CHANGED
@@ -35,6 +35,7 @@ import click
|
|
35
35
|
from rich.console import Console
|
36
36
|
from rich.markdown import Markdown
|
37
37
|
from rich.prompt import Confirm
|
38
|
+
from rich.style import Style
|
38
39
|
from rich.syntax import Syntax
|
39
40
|
|
40
41
|
import zenml
|
@@ -287,6 +288,19 @@ def register_stack(
|
|
287
288
|
|
288
289
|
client = Client()
|
289
290
|
|
291
|
+
if provider is not None or connector is not None:
|
292
|
+
if client.zen_store.is_local_store():
|
293
|
+
cli_utils.error(
|
294
|
+
"You are registering a stack using a service connector, but "
|
295
|
+
"this feature cannot be used with a local ZenML deployment. "
|
296
|
+
"ZenML needs to be accessible from the cloud provider to allow the "
|
297
|
+
"stack and its components to be registered automatically. "
|
298
|
+
"Please deploy ZenML in a remote environment as described in the "
|
299
|
+
"documentation: https://docs.zenml.io/getting-started/deploying-zenml "
|
300
|
+
"or use a managed ZenML Pro server instance for quick access to "
|
301
|
+
"this feature and more: https://www.zenml.io/pro"
|
302
|
+
)
|
303
|
+
|
290
304
|
try:
|
291
305
|
client.get_stack(
|
292
306
|
name_id_or_prefix=stack_name,
|
@@ -377,6 +391,7 @@ def register_stack(
|
|
377
391
|
if provider:
|
378
392
|
labels["zenml:provider"] = provider
|
379
393
|
service_connector_resource_model = None
|
394
|
+
can_generate_long_tokens = False
|
380
395
|
# create components
|
381
396
|
needed_components = (
|
382
397
|
(StackComponentType.ARTIFACT_STORE, artifact_store),
|
@@ -419,26 +434,38 @@ def register_stack(
|
|
419
434
|
|
420
435
|
if component_selected is None:
|
421
436
|
if service_connector_resource_model is None:
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
437
|
+
with console.status(
|
438
|
+
"Exploring resources available to the service connector...\n"
|
439
|
+
):
|
440
|
+
if isinstance(service_connector, UUID):
|
441
|
+
service_connector_resource_model = (
|
442
|
+
client.verify_service_connector(
|
443
|
+
service_connector
|
444
|
+
)
|
426
445
|
)
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
name=stack_name,
|
432
|
-
connector_type=service_connector.type,
|
433
|
-
auth_method=service_connector.auth_method,
|
434
|
-
configuration=service_connector.configuration,
|
435
|
-
register=False,
|
446
|
+
existing_service_connector_info = (
|
447
|
+
client.get_service_connector(
|
448
|
+
service_connector
|
449
|
+
)
|
436
450
|
)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
451
|
+
can_generate_long_tokens = not existing_service_connector_info.configuration.get(
|
452
|
+
"generate_temporary_tokens", True
|
453
|
+
)
|
454
|
+
else:
|
455
|
+
_, service_connector_resource_model = (
|
456
|
+
client.create_service_connector(
|
457
|
+
name=stack_name,
|
458
|
+
connector_type=service_connector.type,
|
459
|
+
auth_method=service_connector.auth_method,
|
460
|
+
configuration=service_connector.configuration,
|
461
|
+
register=False,
|
462
|
+
)
|
441
463
|
)
|
464
|
+
can_generate_long_tokens = True
|
465
|
+
if service_connector_resource_model is None:
|
466
|
+
cli_utils.error(
|
467
|
+
f"Failed to validate service connector {service_connector}..."
|
468
|
+
)
|
442
469
|
if provider is None:
|
443
470
|
if isinstance(
|
444
471
|
service_connector_resource_model.connector_type,
|
@@ -455,6 +482,7 @@ def register_stack(
|
|
455
482
|
cloud_provider=provider,
|
456
483
|
service_connector_resource_models=service_connector_resource_model.resources,
|
457
484
|
service_connector_index=0,
|
485
|
+
can_generate_long_tokens=can_generate_long_tokens,
|
458
486
|
)
|
459
487
|
component_name = stack_name
|
460
488
|
created_objects.add(component_type.value)
|
@@ -470,6 +498,18 @@ def register_stack(
|
|
470
498
|
artifact_store = component_name
|
471
499
|
if component_type == StackComponentType.ORCHESTRATOR:
|
472
500
|
orchestrator = component_name
|
501
|
+
if not isinstance(
|
502
|
+
component_info, UUID
|
503
|
+
) and component_info.flavor.startswith("vm"):
|
504
|
+
if isinstance(
|
505
|
+
service_connector, ServiceConnectorInfo
|
506
|
+
) and service_connector.auth_method in {
|
507
|
+
"service-account",
|
508
|
+
"external-account",
|
509
|
+
}:
|
510
|
+
service_connector.configuration[
|
511
|
+
"generate_temporary_tokens"
|
512
|
+
] = False
|
473
513
|
if component_type == StackComponentType.CONTAINER_REGISTRY:
|
474
514
|
container_registry = component_name
|
475
515
|
|
@@ -1687,6 +1727,12 @@ def deploy(
|
|
1687
1727
|
provider=StackDeploymentProvider(provider),
|
1688
1728
|
)
|
1689
1729
|
|
1730
|
+
if location and location not in deployment.locations.values():
|
1731
|
+
cli_utils.error(
|
1732
|
+
f"Invalid location '{location}' for provider '{provider}'. "
|
1733
|
+
f"Valid locations are: {', '.join(deployment.locations.values())}"
|
1734
|
+
)
|
1735
|
+
|
1690
1736
|
console.print(
|
1691
1737
|
Markdown(
|
1692
1738
|
f"# {provider.upper()} ZenML Cloud Stack Deployment\n"
|
@@ -1695,55 +1741,71 @@ def deploy(
|
|
1695
1741
|
)
|
1696
1742
|
console.print(Markdown("## Instructions\n" + deployment.instructions))
|
1697
1743
|
|
1744
|
+
deployment_config = client.zen_store.get_stack_deployment_config(
|
1745
|
+
provider=StackDeploymentProvider(provider),
|
1746
|
+
stack_name=stack_name,
|
1747
|
+
location=location,
|
1748
|
+
)
|
1749
|
+
|
1750
|
+
if deployment_config.configuration:
|
1751
|
+
console.print(
|
1752
|
+
Markdown(
|
1753
|
+
"## Configuration\n"
|
1754
|
+
"You will be asked to provide the following configuration "
|
1755
|
+
"values during the deployment process:\n"
|
1756
|
+
)
|
1757
|
+
)
|
1758
|
+
|
1759
|
+
console.print(
|
1760
|
+
"\n",
|
1761
|
+
deployment_config.configuration,
|
1762
|
+
no_wrap=True,
|
1763
|
+
overflow="ignore",
|
1764
|
+
crop=False,
|
1765
|
+
style=Style(bgcolor="grey15"),
|
1766
|
+
)
|
1767
|
+
|
1698
1768
|
if not cli_utils.confirmation(
|
1699
1769
|
"\n\nProceed to continue with the deployment. You will be "
|
1700
1770
|
f"automatically redirected to {provider.upper()} in your browser.",
|
1701
1771
|
):
|
1702
1772
|
raise click.Abort()
|
1703
1773
|
|
1704
|
-
deployment_url, deployment_url_title = (
|
1705
|
-
client.zen_store.get_stack_deployment_url(
|
1706
|
-
provider=StackDeploymentProvider(provider),
|
1707
|
-
stack_name=stack_name,
|
1708
|
-
location=location,
|
1709
|
-
)
|
1710
|
-
)
|
1711
|
-
|
1712
1774
|
date_start = datetime.utcnow()
|
1713
1775
|
|
1714
|
-
webbrowser.open(deployment_url)
|
1776
|
+
webbrowser.open(deployment_config.deployment_url)
|
1715
1777
|
console.print(
|
1716
1778
|
Markdown(
|
1717
1779
|
f"If your browser did not open automatically, please open "
|
1718
1780
|
f"the following URL into your browser to deploy the stack to "
|
1719
1781
|
f"{provider.upper()}: "
|
1720
|
-
f"[{
|
1782
|
+
f"[{deployment_config.deployment_url_text}]"
|
1783
|
+
f"({deployment_config.deployment_url}).\n\n"
|
1721
1784
|
)
|
1722
1785
|
)
|
1723
1786
|
|
1724
1787
|
try:
|
1725
|
-
|
1726
|
-
"
|
1788
|
+
cli_utils.declare(
|
1789
|
+
"\n\nWaiting for the deployment to complete and the stack to be "
|
1727
1790
|
"registered. Press CTRL+C to abort...\n"
|
1728
|
-
)
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
)
|
1737
|
-
)
|
1738
|
-
if deployed_stack:
|
1739
|
-
break
|
1740
|
-
time.sleep(10)
|
1741
|
-
|
1742
|
-
analytics_handler.metadata.update(
|
1743
|
-
{
|
1744
|
-
"stack_id": deployed_stack.stack.id,
|
1745
|
-
}
|
1791
|
+
)
|
1792
|
+
|
1793
|
+
while True:
|
1794
|
+
deployed_stack = client.zen_store.get_stack_deployment_stack(
|
1795
|
+
provider=StackDeploymentProvider(provider),
|
1796
|
+
stack_name=stack_name,
|
1797
|
+
location=location,
|
1798
|
+
date_start=date_start,
|
1746
1799
|
)
|
1800
|
+
if deployed_stack:
|
1801
|
+
break
|
1802
|
+
time.sleep(10)
|
1803
|
+
|
1804
|
+
analytics_handler.metadata.update(
|
1805
|
+
{
|
1806
|
+
"stack_id": deployed_stack.stack.id,
|
1807
|
+
}
|
1808
|
+
)
|
1747
1809
|
|
1748
1810
|
except KeyboardInterrupt:
|
1749
1811
|
cli_utils.declare("Stack deployment aborted.")
|
@@ -1767,15 +1829,28 @@ Stack [{deployed_stack.stack.name}]({get_stack_url(deployed_stack.stack)}):\n"""
|
|
1767
1829
|
|
1768
1830
|
console.print(Markdown(stack_desc))
|
1769
1831
|
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1832
|
+
follow_up = f"""
|
1833
|
+
## Follow-up
|
1834
|
+
|
1835
|
+
{deployment.post_deploy_instructions}
|
1836
|
+
|
1837
|
+
To use the `{deployed_stack.stack.name}` stack to run pipelines:
|
1773
1838
|
|
1839
|
+
* install the required ZenML integrations by running: `zenml integration install {" ".join(deployment.integrations)}`
|
1840
|
+
"""
|
1774
1841
|
if set_stack:
|
1775
1842
|
client.activate_stack(deployed_stack.stack.id)
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1843
|
+
follow_up += f"""
|
1844
|
+
* the `{deployed_stack.stack.name}` stack has already been set as active
|
1845
|
+
"""
|
1846
|
+
else:
|
1847
|
+
follow_up += f"""
|
1848
|
+
* set the `{deployed_stack.stack.name}` stack as active by running: `zenml stack set {deployed_stack.stack.name}`
|
1849
|
+
"""
|
1850
|
+
|
1851
|
+
console.print(
|
1852
|
+
Markdown(follow_up),
|
1853
|
+
)
|
1779
1854
|
|
1780
1855
|
|
1781
1856
|
@stack.command(help="[DEPRECATED] Deploy a stack using mlstacks.")
|
@@ -2248,7 +2323,7 @@ def _get_service_connector_info(
|
|
2248
2323
|
"""
|
2249
2324
|
from rich.prompt import Prompt
|
2250
2325
|
|
2251
|
-
if cloud_provider not in {"aws"}:
|
2326
|
+
if cloud_provider not in {"aws", "gcp"}:
|
2252
2327
|
raise ValueError(f"Unknown cloud provider {cloud_provider}")
|
2253
2328
|
|
2254
2329
|
client = Client()
|
@@ -2277,7 +2352,7 @@ def _get_service_connector_info(
|
|
2277
2352
|
object_type=f"authentication methods for {cloud_provider}",
|
2278
2353
|
choices=choices,
|
2279
2354
|
headers=headers,
|
2280
|
-
prompt_text="Please choose one of the authentication option above
|
2355
|
+
prompt_text="Please choose one of the authentication option above",
|
2281
2356
|
)
|
2282
2357
|
if selected_auth_idx is None:
|
2283
2358
|
cli_utils.error("No authentication method selected.")
|
@@ -2307,7 +2382,6 @@ def _get_service_connector_info(
|
|
2307
2382
|
password="format" in properties[req_field]
|
2308
2383
|
and properties[req_field]["format"] == "password",
|
2309
2384
|
)
|
2310
|
-
Console().print("All mandatory configuration parameters received!")
|
2311
2385
|
|
2312
2386
|
return ServiceConnectorInfo(
|
2313
2387
|
type=cloud_provider,
|
@@ -2322,6 +2396,7 @@ def _get_stack_component_info(
|
|
2322
2396
|
service_connector_resource_models: List[
|
2323
2397
|
ServiceConnectorTypedResourcesModel
|
2324
2398
|
],
|
2399
|
+
can_generate_long_tokens: bool,
|
2325
2400
|
service_connector_index: Optional[int] = None,
|
2326
2401
|
) -> ComponentInfo:
|
2327
2402
|
"""Get a stack component info with given type and service connector.
|
@@ -2330,6 +2405,7 @@ def _get_stack_component_info(
|
|
2330
2405
|
component_type: The type of component to create.
|
2331
2406
|
cloud_provider: The cloud provider to use.
|
2332
2407
|
service_connector_resource_models: The list of the available service connector resource models.
|
2408
|
+
can_generate_long_tokens: Whether connector can generate long-living tokens.
|
2333
2409
|
service_connector_index: The index of the service connector to use.
|
2334
2410
|
|
2335
2411
|
Returns:
|
@@ -2347,6 +2423,9 @@ def _get_stack_component_info(
|
|
2347
2423
|
AWS_DOCS = (
|
2348
2424
|
"https://docs.zenml.io/how-to/auth-management/aws-service-connector"
|
2349
2425
|
)
|
2426
|
+
GCP_DOCS = (
|
2427
|
+
"https://docs.zenml.io/how-to/auth-management/gcp-service-connector"
|
2428
|
+
)
|
2350
2429
|
|
2351
2430
|
flavor = "undefined"
|
2352
2431
|
service_connector_resource_id = None
|
@@ -2370,7 +2449,19 @@ def _get_stack_component_info(
|
|
2370
2449
|
elif cloud_provider == "azure":
|
2371
2450
|
flavor = "azure"
|
2372
2451
|
elif cloud_provider == "gcp":
|
2373
|
-
flavor = "
|
2452
|
+
flavor = "gcp"
|
2453
|
+
for each in service_connector_resource_models:
|
2454
|
+
if each.resource_type == "gcs-bucket":
|
2455
|
+
available_storages = each.resource_ids or []
|
2456
|
+
if not available_storages:
|
2457
|
+
cli_utils.error(
|
2458
|
+
"We were unable to find any GCS buckets available "
|
2459
|
+
"to configured service connector. Please, verify "
|
2460
|
+
"that needed permission are granted for the "
|
2461
|
+
"service connector.\nDocumentation for the GCS "
|
2462
|
+
"Buckets configuration can be found at "
|
2463
|
+
f"{GCP_DOCS}#gcs-bucket"
|
2464
|
+
)
|
2374
2465
|
|
2375
2466
|
selected_storage_idx = cli_utils.multi_choice_prompt(
|
2376
2467
|
object_type=f"{cloud_provider.upper()} storages",
|
@@ -2386,12 +2477,29 @@ def _get_stack_component_info(
|
|
2386
2477
|
config = {"path": selected_storage}
|
2387
2478
|
service_connector_resource_id = selected_storage
|
2388
2479
|
elif component_type == "orchestrator":
|
2480
|
+
|
2481
|
+
def query_gcp_region(compute_type: str) -> str:
|
2482
|
+
region = Prompt.ask(
|
2483
|
+
f"Select the location for your {compute_type}:",
|
2484
|
+
choices=sorted(
|
2485
|
+
Client()
|
2486
|
+
.zen_store.get_stack_deployment_info(
|
2487
|
+
StackDeploymentProvider.GCP
|
2488
|
+
)
|
2489
|
+
.locations.values()
|
2490
|
+
),
|
2491
|
+
show_choices=True,
|
2492
|
+
)
|
2493
|
+
return region
|
2494
|
+
|
2389
2495
|
if cloud_provider == "aws":
|
2390
2496
|
available_orchestrators = []
|
2391
2497
|
for each in service_connector_resource_models:
|
2392
2498
|
types = []
|
2393
2499
|
if each.resource_type == "aws-generic":
|
2394
|
-
types = ["Sagemaker"
|
2500
|
+
types = ["Sagemaker"]
|
2501
|
+
if can_generate_long_tokens:
|
2502
|
+
types.append("Skypilot (EC2)")
|
2395
2503
|
if each.resource_type == "kubernetes-cluster":
|
2396
2504
|
types = ["Kubernetes"]
|
2397
2505
|
|
@@ -2412,7 +2520,32 @@ def _get_stack_component_info(
|
|
2412
2520
|
f"{AWS_DOCS}#eks-kubernetes-cluster"
|
2413
2521
|
)
|
2414
2522
|
elif cloud_provider == "gcp":
|
2415
|
-
|
2523
|
+
available_orchestrators = []
|
2524
|
+
for each in service_connector_resource_models:
|
2525
|
+
types = []
|
2526
|
+
if each.resource_type == "gcp-generic":
|
2527
|
+
types = ["Vertex AI"]
|
2528
|
+
if can_generate_long_tokens:
|
2529
|
+
types.append("Skypilot (Compute)")
|
2530
|
+
if each.resource_type == "kubernetes-cluster":
|
2531
|
+
types = ["Kubernetes"]
|
2532
|
+
|
2533
|
+
if each.resource_ids:
|
2534
|
+
for orchestrator in each.resource_ids:
|
2535
|
+
for t in types:
|
2536
|
+
available_orchestrators.append([t, orchestrator])
|
2537
|
+
if not available_orchestrators:
|
2538
|
+
cli_utils.error(
|
2539
|
+
"We were unable to find any orchestrator engines "
|
2540
|
+
"available to the service connector. Please, verify "
|
2541
|
+
"that needed permission are granted for the "
|
2542
|
+
"service connector.\nDocumentation for the Generic "
|
2543
|
+
"GCP resource configuration can be found at "
|
2544
|
+
f"{GCP_DOCS}#generic-gcp-resource\n"
|
2545
|
+
"Documentation for the GKE Kubernetes resource "
|
2546
|
+
"configuration can be found at "
|
2547
|
+
f"{GCP_DOCS}#gke-kubernetes-cluster"
|
2548
|
+
)
|
2416
2549
|
elif cloud_provider == "azure":
|
2417
2550
|
pass
|
2418
2551
|
|
@@ -2429,25 +2562,31 @@ def _get_stack_component_info(
|
|
2429
2562
|
selected_orchestrator_idx
|
2430
2563
|
]
|
2431
2564
|
|
2565
|
+
config = {}
|
2432
2566
|
if selected_orchestrator[0] == "Sagemaker":
|
2433
2567
|
flavor = "sagemaker"
|
2434
|
-
execution_role = Prompt.ask("
|
2435
|
-
config
|
2568
|
+
execution_role = Prompt.ask("Enter an execution role ARN:")
|
2569
|
+
config["execution_role"] = execution_role
|
2436
2570
|
elif selected_orchestrator[0] == "Skypilot (EC2)":
|
2437
2571
|
flavor = "vm_aws"
|
2438
|
-
config
|
2572
|
+
config["region"] = selected_orchestrator[1]
|
2573
|
+
elif selected_orchestrator[0] == "Skypilot (Compute)":
|
2574
|
+
flavor = "vm_gcp"
|
2575
|
+
config["region"] = query_gcp_region("Skypilot cluster")
|
2576
|
+
elif selected_orchestrator[0] == "Vertex AI":
|
2577
|
+
flavor = "vertex"
|
2578
|
+
config["location"] = query_gcp_region("Vertex AI job")
|
2439
2579
|
elif selected_orchestrator[0] == "Kubernetes":
|
2440
2580
|
flavor = "kubernetes"
|
2441
|
-
config = {}
|
2442
2581
|
else:
|
2443
2582
|
raise ValueError(
|
2444
2583
|
f"Unknown orchestrator type {selected_orchestrator[0]}"
|
2445
2584
|
)
|
2446
2585
|
service_connector_resource_id = selected_orchestrator[1]
|
2447
2586
|
elif component_type == "container_registry":
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2587
|
+
|
2588
|
+
def _get_registries(registry_name: str, docs_link: str) -> List[str]:
|
2589
|
+
available_registries: List[str] = []
|
2451
2590
|
for each in service_connector_resource_models:
|
2452
2591
|
if each.resource_type == "docker-registry":
|
2453
2592
|
available_registries = each.resource_ids or []
|
@@ -2456,14 +2595,24 @@ def _get_stack_component_info(
|
|
2456
2595
|
"We were unable to find any container registries "
|
2457
2596
|
"available to the service connector. Please, verify "
|
2458
2597
|
"that needed permission are granted for the "
|
2459
|
-
"service connector.\nDocumentation for the
|
2598
|
+
f"service connector.\nDocumentation for the {registry_name} "
|
2460
2599
|
"container registry resource configuration can "
|
2461
|
-
f"be found at {
|
2600
|
+
f"be found at {docs_link}"
|
2462
2601
|
)
|
2463
|
-
|
2464
|
-
|
2465
|
-
|
2602
|
+
return available_registries
|
2603
|
+
|
2604
|
+
if cloud_provider == "aws":
|
2605
|
+
flavor = "aws"
|
2606
|
+
available_registries = _get_registries(
|
2607
|
+
"ECR", f"{AWS_DOCS}#ecr-container-registry"
|
2608
|
+
)
|
2609
|
+
if cloud_provider == "gcp":
|
2466
2610
|
flavor = "gcp"
|
2611
|
+
available_registries = _get_registries(
|
2612
|
+
"GCR", f"{GCP_DOCS}#gcr-container-registry"
|
2613
|
+
)
|
2614
|
+
if cloud_provider == "azure":
|
2615
|
+
flavor = "azure"
|
2467
2616
|
|
2468
2617
|
selected_registry_idx = cli_utils.multi_choice_prompt(
|
2469
2618
|
object_type=f"{cloud_provider.upper()} registries",
|
zenml/constants.py
CHANGED
@@ -263,6 +263,7 @@ DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS = 3
|
|
263
263
|
DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT = 60 * 5 # 5 minutes
|
264
264
|
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING = 5 # seconds
|
265
265
|
DEFAULT_HTTP_TIMEOUT = 30
|
266
|
+
SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT = 120 # seconds
|
266
267
|
ZENML_API_KEY_PREFIX = "ZENKEY_"
|
267
268
|
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours
|
268
269
|
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5
|
@@ -337,6 +338,7 @@ ARTIFACT_VISUALIZATIONS = "/artifact_visualizations"
|
|
337
338
|
CODE_REFERENCES = "/code_references"
|
338
339
|
CODE_REPOSITORIES = "/code_repositories"
|
339
340
|
COMPONENT_TYPES = "/component-types"
|
341
|
+
CONFIG = "/config"
|
340
342
|
CURRENT_USER = "/current-user"
|
341
343
|
DEACTIVATE = "/deactivate"
|
342
344
|
DEVICES = "/devices"
|
@@ -390,6 +392,7 @@ STEPS = "/steps"
|
|
390
392
|
TAGS = "/tags"
|
391
393
|
TRIGGERS = "/triggers"
|
392
394
|
TRIGGER_EXECUTIONS = "/trigger_executions"
|
395
|
+
ONBOARDING_STATE = "/onboarding_state"
|
393
396
|
USERS = "/users"
|
394
397
|
URL = "/url"
|
395
398
|
VERSION_1 = "/v1"
|
@@ -484,3 +487,6 @@ FINISHED_ONBOARDING_SURVEY_KEY = "awareness_channels"
|
|
484
487
|
|
485
488
|
# Name validation
|
486
489
|
BANNED_NAME_CHARACTERS = "\t\n\r\v\f"
|
490
|
+
|
491
|
+
|
492
|
+
STACK_DEPLOYMENT_API_TOKEN_EXPIRATION = 60 * 6 # 6 hours
|
zenml/enums.py
CHANGED
@@ -386,7 +386,23 @@ class PluginSubType(StrEnum):
|
|
386
386
|
PIPELINE_RUN = "pipeline_run"
|
387
387
|
|
388
388
|
|
389
|
+
class OnboardingStep(StrEnum):
|
390
|
+
"""All onboarding steps."""
|
391
|
+
|
392
|
+
DEVICE_VERIFIED = "device_verified"
|
393
|
+
PIPELINE_RUN = "pipeline_run"
|
394
|
+
STARTER_SETUP_COMPLETED = "starter_setup_completed"
|
395
|
+
STACK_WITH_REMOTE_ORCHESTRATOR_CREATED = (
|
396
|
+
"stack_with_remote_orchestrator_created"
|
397
|
+
)
|
398
|
+
PIPELINE_RUN_WITH_REMOTE_ORCHESTRATOR = (
|
399
|
+
"pipeline_run_with_remote_orchestrator"
|
400
|
+
)
|
401
|
+
PRODUCTION_SETUP_COMPLETED = "production_setup_completed"
|
402
|
+
|
403
|
+
|
389
404
|
class StackDeploymentProvider(StrEnum):
|
390
405
|
"""All possible stack deployment providers."""
|
391
406
|
|
392
407
|
AWS = "aws"
|
408
|
+
GCP = "gcp"
|
@@ -20,6 +20,7 @@ services:
|
|
20
20
|
|
21
21
|
"""
|
22
22
|
|
23
|
+
import base64
|
23
24
|
import datetime
|
24
25
|
import json
|
25
26
|
import os
|
@@ -89,7 +90,7 @@ class GCPUserAccountCredentials(AuthenticationConfig):
|
|
89
90
|
"""GCP user account credentials."""
|
90
91
|
|
91
92
|
user_account_json: PlainSerializedSecretStr = Field(
|
92
|
-
title="GCP User Account Credentials JSON",
|
93
|
+
title="GCP User Account Credentials JSON optionally base64 encoded.",
|
93
94
|
)
|
94
95
|
|
95
96
|
generate_temporary_tokens: bool = Field(
|
@@ -113,9 +114,24 @@ class GCPUserAccountCredentials(AuthenticationConfig):
|
|
113
114
|
|
114
115
|
Returns:
|
115
116
|
The validated configuration values.
|
117
|
+
|
118
|
+
Raises:
|
119
|
+
ValueError: If the user account credentials JSON is invalid.
|
116
120
|
"""
|
117
|
-
|
121
|
+
user_account_json = data.get("user_account_json")
|
122
|
+
if isinstance(user_account_json, dict):
|
118
123
|
data["user_account_json"] = json.dumps(data["user_account_json"])
|
124
|
+
elif isinstance(user_account_json, str):
|
125
|
+
# Check if the user account JSON is base64 encoded and decode it
|
126
|
+
if re.match(r"^[A-Za-z0-9+/=]+$", user_account_json):
|
127
|
+
try:
|
128
|
+
data["user_account_json"] = base64.b64decode(
|
129
|
+
user_account_json
|
130
|
+
).decode("utf-8")
|
131
|
+
except Exception as e:
|
132
|
+
raise ValueError(
|
133
|
+
f"Failed to decode base64 encoded user account JSON: {e}"
|
134
|
+
)
|
119
135
|
return data
|
120
136
|
|
121
137
|
@field_validator("user_account_json")
|
@@ -170,7 +186,7 @@ class GCPServiceAccountCredentials(AuthenticationConfig):
|
|
170
186
|
"""GCP service account credentials."""
|
171
187
|
|
172
188
|
service_account_json: PlainSerializedSecretStr = Field(
|
173
|
-
title="GCP Service Account Key JSON",
|
189
|
+
title="GCP Service Account Key JSON optionally base64 encoded.",
|
174
190
|
)
|
175
191
|
|
176
192
|
generate_temporary_tokens: bool = Field(
|
@@ -194,11 +210,27 @@ class GCPServiceAccountCredentials(AuthenticationConfig):
|
|
194
210
|
|
195
211
|
Returns:
|
196
212
|
The validated configuration values.
|
213
|
+
|
214
|
+
Raises:
|
215
|
+
ValueError: If the service account credentials JSON is invalid.
|
197
216
|
"""
|
198
|
-
|
217
|
+
service_account_json = data.get("service_account_json")
|
218
|
+
if isinstance(service_account_json, dict):
|
199
219
|
data["service_account_json"] = json.dumps(
|
200
220
|
data["service_account_json"]
|
201
221
|
)
|
222
|
+
elif isinstance(service_account_json, str):
|
223
|
+
# Check if the service account JSON is base64 encoded and decode it
|
224
|
+
if re.match(r"^[A-Za-z0-9+/=]+$", service_account_json):
|
225
|
+
try:
|
226
|
+
data["service_account_json"] = base64.b64decode(
|
227
|
+
service_account_json
|
228
|
+
).decode("utf-8")
|
229
|
+
except Exception as e:
|
230
|
+
raise ValueError(
|
231
|
+
f"Failed to decode base64 encoded service account JSON: {e}"
|
232
|
+
)
|
233
|
+
|
202
234
|
return data
|
203
235
|
|
204
236
|
@field_validator("service_account_json")
|
@@ -261,7 +293,7 @@ class GCPExternalAccountCredentials(AuthenticationConfig):
|
|
261
293
|
"""GCP external account credentials."""
|
262
294
|
|
263
295
|
external_account_json: PlainSerializedSecretStr = Field(
|
264
|
-
title="GCP External Account JSON",
|
296
|
+
title="GCP External Account JSON optionally base64 encoded.",
|
265
297
|
)
|
266
298
|
|
267
299
|
generate_temporary_tokens: bool = Field(
|
@@ -285,11 +317,27 @@ class GCPExternalAccountCredentials(AuthenticationConfig):
|
|
285
317
|
|
286
318
|
Returns:
|
287
319
|
The validated configuration values.
|
320
|
+
|
321
|
+
Raises:
|
322
|
+
ValueError: If the external account credentials JSON is invalid.
|
288
323
|
"""
|
289
|
-
|
324
|
+
external_account_json = data.get("external_account_json")
|
325
|
+
if isinstance(external_account_json, dict):
|
290
326
|
data["external_account_json"] = json.dumps(
|
291
327
|
data["external_account_json"]
|
292
328
|
)
|
329
|
+
elif isinstance(external_account_json, str):
|
330
|
+
# Check if the external account JSON is base64 encoded and decode it
|
331
|
+
if re.match(r"^[A-Za-z0-9+/=]+$", external_account_json):
|
332
|
+
try:
|
333
|
+
data["external_account_json"] = base64.b64decode(
|
334
|
+
external_account_json
|
335
|
+
).decode("utf-8")
|
336
|
+
except Exception as e:
|
337
|
+
raise ValueError(
|
338
|
+
f"Failed to decode base64 encoded external account JSON: {e}"
|
339
|
+
)
|
340
|
+
|
293
341
|
return data
|
294
342
|
|
295
343
|
@field_validator("external_account_json")
|