zenml-nightly 0.61.0.dev20240711__py3-none-any.whl → 0.61.0.dev20240713__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/stack.py +220 -71
  3. zenml/constants.py +6 -0
  4. zenml/enums.py +16 -0
  5. zenml/integrations/gcp/service_connectors/gcp_service_connector.py +54 -6
  6. zenml/logging/step_logging.py +34 -35
  7. zenml/models/__init__.py +2 -0
  8. zenml/models/v2/core/server_settings.py +0 -20
  9. zenml/models/v2/misc/stack_deployment.py +20 -0
  10. zenml/orchestrators/step_launcher.py +1 -0
  11. zenml/stack_deployments/aws_stack_deployment.py +56 -91
  12. zenml/stack_deployments/gcp_stack_deployment.py +260 -0
  13. zenml/stack_deployments/stack_deployment.py +103 -25
  14. zenml/stack_deployments/utils.py +4 -0
  15. zenml/zen_server/routers/devices_endpoints.py +4 -1
  16. zenml/zen_server/routers/server_endpoints.py +29 -2
  17. zenml/zen_server/routers/stack_deployment_endpoints.py +34 -20
  18. zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py +167 -0
  19. zenml/zen_stores/rest_zen_store.py +45 -21
  20. zenml/zen_stores/schemas/server_settings_schemas.py +23 -11
  21. zenml/zen_stores/sql_zen_store.py +117 -19
  22. zenml/zen_stores/zen_store_interface.py +6 -5
  23. {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/METADATA +1 -1
  24. {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/RECORD +27 -25
  25. {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/LICENSE +0 -0
  26. {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/WHEEL +0 -0
  27. {zenml_nightly-0.61.0.dev20240711.dist-info → zenml_nightly-0.61.0.dev20240713.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.61.0.dev20240711
1
+ 0.61.0.dev20240713
zenml/cli/stack.py CHANGED
@@ -35,6 +35,7 @@ import click
35
35
  from rich.console import Console
36
36
  from rich.markdown import Markdown
37
37
  from rich.prompt import Confirm
38
+ from rich.style import Style
38
39
  from rich.syntax import Syntax
39
40
 
40
41
  import zenml
@@ -287,6 +288,19 @@ def register_stack(
287
288
 
288
289
  client = Client()
289
290
 
291
+ if provider is not None or connector is not None:
292
+ if client.zen_store.is_local_store():
293
+ cli_utils.error(
294
+ "You are registering a stack using a service connector, but "
295
+ "this feature cannot be used with a local ZenML deployment. "
296
+ "ZenML needs to be accessible from the cloud provider to allow the "
297
+ "stack and its components to be registered automatically. "
298
+ "Please deploy ZenML in a remote environment as described in the "
299
+ "documentation: https://docs.zenml.io/getting-started/deploying-zenml "
300
+ "or use a managed ZenML Pro server instance for quick access to "
301
+ "this feature and more: https://www.zenml.io/pro"
302
+ )
303
+
290
304
  try:
291
305
  client.get_stack(
292
306
  name_id_or_prefix=stack_name,
@@ -377,6 +391,7 @@ def register_stack(
377
391
  if provider:
378
392
  labels["zenml:provider"] = provider
379
393
  service_connector_resource_model = None
394
+ can_generate_long_tokens = False
380
395
  # create components
381
396
  needed_components = (
382
397
  (StackComponentType.ARTIFACT_STORE, artifact_store),
@@ -419,26 +434,38 @@ def register_stack(
419
434
 
420
435
  if component_selected is None:
421
436
  if service_connector_resource_model is None:
422
- if isinstance(service_connector, UUID):
423
- service_connector_resource_model = (
424
- client.verify_service_connector(
425
- service_connector
437
+ with console.status(
438
+ "Exploring resources available to the service connector...\n"
439
+ ):
440
+ if isinstance(service_connector, UUID):
441
+ service_connector_resource_model = (
442
+ client.verify_service_connector(
443
+ service_connector
444
+ )
426
445
  )
427
- )
428
- else:
429
- _, service_connector_resource_model = (
430
- client.create_service_connector(
431
- name=stack_name,
432
- connector_type=service_connector.type,
433
- auth_method=service_connector.auth_method,
434
- configuration=service_connector.configuration,
435
- register=False,
446
+ existing_service_connector_info = (
447
+ client.get_service_connector(
448
+ service_connector
449
+ )
436
450
  )
437
- )
438
- if service_connector_resource_model is None:
439
- cli_utils.error(
440
- f"Failed to validate service connector {service_connector}..."
451
+ can_generate_long_tokens = not existing_service_connector_info.configuration.get(
452
+ "generate_temporary_tokens", True
453
+ )
454
+ else:
455
+ _, service_connector_resource_model = (
456
+ client.create_service_connector(
457
+ name=stack_name,
458
+ connector_type=service_connector.type,
459
+ auth_method=service_connector.auth_method,
460
+ configuration=service_connector.configuration,
461
+ register=False,
462
+ )
441
463
  )
464
+ can_generate_long_tokens = True
465
+ if service_connector_resource_model is None:
466
+ cli_utils.error(
467
+ f"Failed to validate service connector {service_connector}..."
468
+ )
442
469
  if provider is None:
443
470
  if isinstance(
444
471
  service_connector_resource_model.connector_type,
@@ -455,6 +482,7 @@ def register_stack(
455
482
  cloud_provider=provider,
456
483
  service_connector_resource_models=service_connector_resource_model.resources,
457
484
  service_connector_index=0,
485
+ can_generate_long_tokens=can_generate_long_tokens,
458
486
  )
459
487
  component_name = stack_name
460
488
  created_objects.add(component_type.value)
@@ -470,6 +498,18 @@ def register_stack(
470
498
  artifact_store = component_name
471
499
  if component_type == StackComponentType.ORCHESTRATOR:
472
500
  orchestrator = component_name
501
+ if not isinstance(
502
+ component_info, UUID
503
+ ) and component_info.flavor.startswith("vm"):
504
+ if isinstance(
505
+ service_connector, ServiceConnectorInfo
506
+ ) and service_connector.auth_method in {
507
+ "service-account",
508
+ "external-account",
509
+ }:
510
+ service_connector.configuration[
511
+ "generate_temporary_tokens"
512
+ ] = False
473
513
  if component_type == StackComponentType.CONTAINER_REGISTRY:
474
514
  container_registry = component_name
475
515
 
@@ -1687,6 +1727,12 @@ def deploy(
1687
1727
  provider=StackDeploymentProvider(provider),
1688
1728
  )
1689
1729
 
1730
+ if location and location not in deployment.locations.values():
1731
+ cli_utils.error(
1732
+ f"Invalid location '{location}' for provider '{provider}'. "
1733
+ f"Valid locations are: {', '.join(deployment.locations.values())}"
1734
+ )
1735
+
1690
1736
  console.print(
1691
1737
  Markdown(
1692
1738
  f"# {provider.upper()} ZenML Cloud Stack Deployment\n"
@@ -1695,55 +1741,71 @@ def deploy(
1695
1741
  )
1696
1742
  console.print(Markdown("## Instructions\n" + deployment.instructions))
1697
1743
 
1744
+ deployment_config = client.zen_store.get_stack_deployment_config(
1745
+ provider=StackDeploymentProvider(provider),
1746
+ stack_name=stack_name,
1747
+ location=location,
1748
+ )
1749
+
1750
+ if deployment_config.configuration:
1751
+ console.print(
1752
+ Markdown(
1753
+ "## Configuration\n"
1754
+ "You will be asked to provide the following configuration "
1755
+ "values during the deployment process:\n"
1756
+ )
1757
+ )
1758
+
1759
+ console.print(
1760
+ "\n",
1761
+ deployment_config.configuration,
1762
+ no_wrap=True,
1763
+ overflow="ignore",
1764
+ crop=False,
1765
+ style=Style(bgcolor="grey15"),
1766
+ )
1767
+
1698
1768
  if not cli_utils.confirmation(
1699
1769
  "\n\nProceed to continue with the deployment. You will be "
1700
1770
  f"automatically redirected to {provider.upper()} in your browser.",
1701
1771
  ):
1702
1772
  raise click.Abort()
1703
1773
 
1704
- deployment_url, deployment_url_title = (
1705
- client.zen_store.get_stack_deployment_url(
1706
- provider=StackDeploymentProvider(provider),
1707
- stack_name=stack_name,
1708
- location=location,
1709
- )
1710
- )
1711
-
1712
1774
  date_start = datetime.utcnow()
1713
1775
 
1714
- webbrowser.open(deployment_url)
1776
+ webbrowser.open(deployment_config.deployment_url)
1715
1777
  console.print(
1716
1778
  Markdown(
1717
1779
  f"If your browser did not open automatically, please open "
1718
1780
  f"the following URL into your browser to deploy the stack to "
1719
1781
  f"{provider.upper()}: "
1720
- f"[{deployment_url_title}]({deployment_url}).\n\n"
1782
+ f"[{deployment_config.deployment_url_text}]"
1783
+ f"({deployment_config.deployment_url}).\n\n"
1721
1784
  )
1722
1785
  )
1723
1786
 
1724
1787
  try:
1725
- with console.status(
1726
- "Waiting for the deployment to complete and the stack to be "
1788
+ cli_utils.declare(
1789
+ "\n\nWaiting for the deployment to complete and the stack to be "
1727
1790
  "registered. Press CTRL+C to abort...\n"
1728
- ):
1729
- while True:
1730
- deployed_stack = (
1731
- client.zen_store.get_stack_deployment_stack(
1732
- provider=StackDeploymentProvider(provider),
1733
- stack_name=stack_name,
1734
- location=location,
1735
- date_start=date_start,
1736
- )
1737
- )
1738
- if deployed_stack:
1739
- break
1740
- time.sleep(10)
1741
-
1742
- analytics_handler.metadata.update(
1743
- {
1744
- "stack_id": deployed_stack.stack.id,
1745
- }
1791
+ )
1792
+
1793
+ while True:
1794
+ deployed_stack = client.zen_store.get_stack_deployment_stack(
1795
+ provider=StackDeploymentProvider(provider),
1796
+ stack_name=stack_name,
1797
+ location=location,
1798
+ date_start=date_start,
1746
1799
  )
1800
+ if deployed_stack:
1801
+ break
1802
+ time.sleep(10)
1803
+
1804
+ analytics_handler.metadata.update(
1805
+ {
1806
+ "stack_id": deployed_stack.stack.id,
1807
+ }
1808
+ )
1747
1809
 
1748
1810
  except KeyboardInterrupt:
1749
1811
  cli_utils.declare("Stack deployment aborted.")
@@ -1767,15 +1829,28 @@ Stack [{deployed_stack.stack.name}]({get_stack_url(deployed_stack.stack)}):\n"""
1767
1829
 
1768
1830
  console.print(Markdown(stack_desc))
1769
1831
 
1770
- console.print(
1771
- Markdown("## Follow-up\n" + deployment.post_deploy_instructions)
1772
- )
1832
+ follow_up = f"""
1833
+ ## Follow-up
1834
+
1835
+ {deployment.post_deploy_instructions}
1836
+
1837
+ To use the `{deployed_stack.stack.name}` stack to run pipelines:
1773
1838
 
1839
+ * install the required ZenML integrations by running: `zenml integration install {" ".join(deployment.integrations)}`
1840
+ """
1774
1841
  if set_stack:
1775
1842
  client.activate_stack(deployed_stack.stack.id)
1776
- cli_utils.declare(
1777
- f"\nStack `{deployed_stack.stack.name}` set as active"
1778
- )
1843
+ follow_up += f"""
1844
+ * the `{deployed_stack.stack.name}` stack has already been set as active
1845
+ """
1846
+ else:
1847
+ follow_up += f"""
1848
+ * set the `{deployed_stack.stack.name}` stack as active by running: `zenml stack set {deployed_stack.stack.name}`
1849
+ """
1850
+
1851
+ console.print(
1852
+ Markdown(follow_up),
1853
+ )
1779
1854
 
1780
1855
 
1781
1856
  @stack.command(help="[DEPRECATED] Deploy a stack using mlstacks.")
@@ -2248,7 +2323,7 @@ def _get_service_connector_info(
2248
2323
  """
2249
2324
  from rich.prompt import Prompt
2250
2325
 
2251
- if cloud_provider not in {"aws"}:
2326
+ if cloud_provider not in {"aws", "gcp"}:
2252
2327
  raise ValueError(f"Unknown cloud provider {cloud_provider}")
2253
2328
 
2254
2329
  client = Client()
@@ -2277,7 +2352,7 @@ def _get_service_connector_info(
2277
2352
  object_type=f"authentication methods for {cloud_provider}",
2278
2353
  choices=choices,
2279
2354
  headers=headers,
2280
- prompt_text="Please choose one of the authentication option above.",
2355
+ prompt_text="Please choose one of the authentication option above",
2281
2356
  )
2282
2357
  if selected_auth_idx is None:
2283
2358
  cli_utils.error("No authentication method selected.")
@@ -2307,7 +2382,6 @@ def _get_service_connector_info(
2307
2382
  password="format" in properties[req_field]
2308
2383
  and properties[req_field]["format"] == "password",
2309
2384
  )
2310
- Console().print("All mandatory configuration parameters received!")
2311
2385
 
2312
2386
  return ServiceConnectorInfo(
2313
2387
  type=cloud_provider,
@@ -2322,6 +2396,7 @@ def _get_stack_component_info(
2322
2396
  service_connector_resource_models: List[
2323
2397
  ServiceConnectorTypedResourcesModel
2324
2398
  ],
2399
+ can_generate_long_tokens: bool,
2325
2400
  service_connector_index: Optional[int] = None,
2326
2401
  ) -> ComponentInfo:
2327
2402
  """Get a stack component info with given type and service connector.
@@ -2330,6 +2405,7 @@ def _get_stack_component_info(
2330
2405
  component_type: The type of component to create.
2331
2406
  cloud_provider: The cloud provider to use.
2332
2407
  service_connector_resource_models: The list of the available service connector resource models.
2408
+ can_generate_long_tokens: Whether connector can generate long-living tokens.
2333
2409
  service_connector_index: The index of the service connector to use.
2334
2410
 
2335
2411
  Returns:
@@ -2347,6 +2423,9 @@ def _get_stack_component_info(
2347
2423
  AWS_DOCS = (
2348
2424
  "https://docs.zenml.io/how-to/auth-management/aws-service-connector"
2349
2425
  )
2426
+ GCP_DOCS = (
2427
+ "https://docs.zenml.io/how-to/auth-management/gcp-service-connector"
2428
+ )
2350
2429
 
2351
2430
  flavor = "undefined"
2352
2431
  service_connector_resource_id = None
@@ -2370,7 +2449,19 @@ def _get_stack_component_info(
2370
2449
  elif cloud_provider == "azure":
2371
2450
  flavor = "azure"
2372
2451
  elif cloud_provider == "gcp":
2373
- flavor = "gcs"
2452
+ flavor = "gcp"
2453
+ for each in service_connector_resource_models:
2454
+ if each.resource_type == "gcs-bucket":
2455
+ available_storages = each.resource_ids or []
2456
+ if not available_storages:
2457
+ cli_utils.error(
2458
+ "We were unable to find any GCS buckets available "
2459
+ "to configured service connector. Please, verify "
2460
+ "that needed permission are granted for the "
2461
+ "service connector.\nDocumentation for the GCS "
2462
+ "Buckets configuration can be found at "
2463
+ f"{GCP_DOCS}#gcs-bucket"
2464
+ )
2374
2465
 
2375
2466
  selected_storage_idx = cli_utils.multi_choice_prompt(
2376
2467
  object_type=f"{cloud_provider.upper()} storages",
@@ -2386,12 +2477,29 @@ def _get_stack_component_info(
2386
2477
  config = {"path": selected_storage}
2387
2478
  service_connector_resource_id = selected_storage
2388
2479
  elif component_type == "orchestrator":
2480
+
2481
+ def query_gcp_region(compute_type: str) -> str:
2482
+ region = Prompt.ask(
2483
+ f"Select the location for your {compute_type}:",
2484
+ choices=sorted(
2485
+ Client()
2486
+ .zen_store.get_stack_deployment_info(
2487
+ StackDeploymentProvider.GCP
2488
+ )
2489
+ .locations.values()
2490
+ ),
2491
+ show_choices=True,
2492
+ )
2493
+ return region
2494
+
2389
2495
  if cloud_provider == "aws":
2390
2496
  available_orchestrators = []
2391
2497
  for each in service_connector_resource_models:
2392
2498
  types = []
2393
2499
  if each.resource_type == "aws-generic":
2394
- types = ["Sagemaker", "Skypilot (EC2)"]
2500
+ types = ["Sagemaker"]
2501
+ if can_generate_long_tokens:
2502
+ types.append("Skypilot (EC2)")
2395
2503
  if each.resource_type == "kubernetes-cluster":
2396
2504
  types = ["Kubernetes"]
2397
2505
 
@@ -2412,7 +2520,32 @@ def _get_stack_component_info(
2412
2520
  f"{AWS_DOCS}#eks-kubernetes-cluster"
2413
2521
  )
2414
2522
  elif cloud_provider == "gcp":
2415
- pass
2523
+ available_orchestrators = []
2524
+ for each in service_connector_resource_models:
2525
+ types = []
2526
+ if each.resource_type == "gcp-generic":
2527
+ types = ["Vertex AI"]
2528
+ if can_generate_long_tokens:
2529
+ types.append("Skypilot (Compute)")
2530
+ if each.resource_type == "kubernetes-cluster":
2531
+ types = ["Kubernetes"]
2532
+
2533
+ if each.resource_ids:
2534
+ for orchestrator in each.resource_ids:
2535
+ for t in types:
2536
+ available_orchestrators.append([t, orchestrator])
2537
+ if not available_orchestrators:
2538
+ cli_utils.error(
2539
+ "We were unable to find any orchestrator engines "
2540
+ "available to the service connector. Please, verify "
2541
+ "that needed permission are granted for the "
2542
+ "service connector.\nDocumentation for the Generic "
2543
+ "GCP resource configuration can be found at "
2544
+ f"{GCP_DOCS}#generic-gcp-resource\n"
2545
+ "Documentation for the GKE Kubernetes resource "
2546
+ "configuration can be found at "
2547
+ f"{GCP_DOCS}#gke-kubernetes-cluster"
2548
+ )
2416
2549
  elif cloud_provider == "azure":
2417
2550
  pass
2418
2551
 
@@ -2429,25 +2562,31 @@ def _get_stack_component_info(
2429
2562
  selected_orchestrator_idx
2430
2563
  ]
2431
2564
 
2565
+ config = {}
2432
2566
  if selected_orchestrator[0] == "Sagemaker":
2433
2567
  flavor = "sagemaker"
2434
- execution_role = Prompt.ask("Please enter an execution role ARN:")
2435
- config = {"execution_role": execution_role}
2568
+ execution_role = Prompt.ask("Enter an execution role ARN:")
2569
+ config["execution_role"] = execution_role
2436
2570
  elif selected_orchestrator[0] == "Skypilot (EC2)":
2437
2571
  flavor = "vm_aws"
2438
- config = {"region": selected_orchestrator[1]}
2572
+ config["region"] = selected_orchestrator[1]
2573
+ elif selected_orchestrator[0] == "Skypilot (Compute)":
2574
+ flavor = "vm_gcp"
2575
+ config["region"] = query_gcp_region("Skypilot cluster")
2576
+ elif selected_orchestrator[0] == "Vertex AI":
2577
+ flavor = "vertex"
2578
+ config["location"] = query_gcp_region("Vertex AI job")
2439
2579
  elif selected_orchestrator[0] == "Kubernetes":
2440
2580
  flavor = "kubernetes"
2441
- config = {}
2442
2581
  else:
2443
2582
  raise ValueError(
2444
2583
  f"Unknown orchestrator type {selected_orchestrator[0]}"
2445
2584
  )
2446
2585
  service_connector_resource_id = selected_orchestrator[1]
2447
2586
  elif component_type == "container_registry":
2448
- available_registries: List[str] = []
2449
- if cloud_provider == "aws":
2450
- flavor = "aws"
2587
+
2588
+ def _get_registries(registry_name: str, docs_link: str) -> List[str]:
2589
+ available_registries: List[str] = []
2451
2590
  for each in service_connector_resource_models:
2452
2591
  if each.resource_type == "docker-registry":
2453
2592
  available_registries = each.resource_ids or []
@@ -2456,14 +2595,24 @@ def _get_stack_component_info(
2456
2595
  "We were unable to find any container registries "
2457
2596
  "available to the service connector. Please, verify "
2458
2597
  "that needed permission are granted for the "
2459
- "service connector.\nDocumentation for the ECR "
2598
+ f"service connector.\nDocumentation for the {registry_name} "
2460
2599
  "container registry resource configuration can "
2461
- f"be found at {AWS_DOCS}#ecr-container-registry"
2600
+ f"be found at {docs_link}"
2462
2601
  )
2463
- elif cloud_provider == "azure":
2464
- flavor = "azure"
2465
- elif cloud_provider == "gcp":
2602
+ return available_registries
2603
+
2604
+ if cloud_provider == "aws":
2605
+ flavor = "aws"
2606
+ available_registries = _get_registries(
2607
+ "ECR", f"{AWS_DOCS}#ecr-container-registry"
2608
+ )
2609
+ if cloud_provider == "gcp":
2466
2610
  flavor = "gcp"
2611
+ available_registries = _get_registries(
2612
+ "GCR", f"{GCP_DOCS}#gcr-container-registry"
2613
+ )
2614
+ if cloud_provider == "azure":
2615
+ flavor = "azure"
2467
2616
 
2468
2617
  selected_registry_idx = cli_utils.multi_choice_prompt(
2469
2618
  object_type=f"{cloud_provider.upper()} registries",
zenml/constants.py CHANGED
@@ -263,6 +263,7 @@ DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS = 3
263
263
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT = 60 * 5 # 5 minutes
264
264
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING = 5 # seconds
265
265
  DEFAULT_HTTP_TIMEOUT = 30
266
+ SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT = 120 # seconds
266
267
  ZENML_API_KEY_PREFIX = "ZENKEY_"
267
268
  DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours
268
269
  DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5
@@ -337,6 +338,7 @@ ARTIFACT_VISUALIZATIONS = "/artifact_visualizations"
337
338
  CODE_REFERENCES = "/code_references"
338
339
  CODE_REPOSITORIES = "/code_repositories"
339
340
  COMPONENT_TYPES = "/component-types"
341
+ CONFIG = "/config"
340
342
  CURRENT_USER = "/current-user"
341
343
  DEACTIVATE = "/deactivate"
342
344
  DEVICES = "/devices"
@@ -390,6 +392,7 @@ STEPS = "/steps"
390
392
  TAGS = "/tags"
391
393
  TRIGGERS = "/triggers"
392
394
  TRIGGER_EXECUTIONS = "/trigger_executions"
395
+ ONBOARDING_STATE = "/onboarding_state"
393
396
  USERS = "/users"
394
397
  URL = "/url"
395
398
  VERSION_1 = "/v1"
@@ -484,3 +487,6 @@ FINISHED_ONBOARDING_SURVEY_KEY = "awareness_channels"
484
487
 
485
488
  # Name validation
486
489
  BANNED_NAME_CHARACTERS = "\t\n\r\v\f"
490
+
491
+
492
+ STACK_DEPLOYMENT_API_TOKEN_EXPIRATION = 60 * 6 # 6 hours
zenml/enums.py CHANGED
@@ -386,7 +386,23 @@ class PluginSubType(StrEnum):
386
386
  PIPELINE_RUN = "pipeline_run"
387
387
 
388
388
 
389
+ class OnboardingStep(StrEnum):
390
+ """All onboarding steps."""
391
+
392
+ DEVICE_VERIFIED = "device_verified"
393
+ PIPELINE_RUN = "pipeline_run"
394
+ STARTER_SETUP_COMPLETED = "starter_setup_completed"
395
+ STACK_WITH_REMOTE_ORCHESTRATOR_CREATED = (
396
+ "stack_with_remote_orchestrator_created"
397
+ )
398
+ PIPELINE_RUN_WITH_REMOTE_ORCHESTRATOR = (
399
+ "pipeline_run_with_remote_orchestrator"
400
+ )
401
+ PRODUCTION_SETUP_COMPLETED = "production_setup_completed"
402
+
403
+
389
404
  class StackDeploymentProvider(StrEnum):
390
405
  """All possible stack deployment providers."""
391
406
 
392
407
  AWS = "aws"
408
+ GCP = "gcp"
@@ -20,6 +20,7 @@ services:
20
20
 
21
21
  """
22
22
 
23
+ import base64
23
24
  import datetime
24
25
  import json
25
26
  import os
@@ -89,7 +90,7 @@ class GCPUserAccountCredentials(AuthenticationConfig):
89
90
  """GCP user account credentials."""
90
91
 
91
92
  user_account_json: PlainSerializedSecretStr = Field(
92
- title="GCP User Account Credentials JSON",
93
+ title="GCP User Account Credentials JSON optionally base64 encoded.",
93
94
  )
94
95
 
95
96
  generate_temporary_tokens: bool = Field(
@@ -113,9 +114,24 @@ class GCPUserAccountCredentials(AuthenticationConfig):
113
114
 
114
115
  Returns:
115
116
  The validated configuration values.
117
+
118
+ Raises:
119
+ ValueError: If the user account credentials JSON is invalid.
116
120
  """
117
- if isinstance(data.get("user_account_json"), dict):
121
+ user_account_json = data.get("user_account_json")
122
+ if isinstance(user_account_json, dict):
118
123
  data["user_account_json"] = json.dumps(data["user_account_json"])
124
+ elif isinstance(user_account_json, str):
125
+ # Check if the user account JSON is base64 encoded and decode it
126
+ if re.match(r"^[A-Za-z0-9+/=]+$", user_account_json):
127
+ try:
128
+ data["user_account_json"] = base64.b64decode(
129
+ user_account_json
130
+ ).decode("utf-8")
131
+ except Exception as e:
132
+ raise ValueError(
133
+ f"Failed to decode base64 encoded user account JSON: {e}"
134
+ )
119
135
  return data
120
136
 
121
137
  @field_validator("user_account_json")
@@ -170,7 +186,7 @@ class GCPServiceAccountCredentials(AuthenticationConfig):
170
186
  """GCP service account credentials."""
171
187
 
172
188
  service_account_json: PlainSerializedSecretStr = Field(
173
- title="GCP Service Account Key JSON",
189
+ title="GCP Service Account Key JSON optionally base64 encoded.",
174
190
  )
175
191
 
176
192
  generate_temporary_tokens: bool = Field(
@@ -194,11 +210,27 @@ class GCPServiceAccountCredentials(AuthenticationConfig):
194
210
 
195
211
  Returns:
196
212
  The validated configuration values.
213
+
214
+ Raises:
215
+ ValueError: If the service account credentials JSON is invalid.
197
216
  """
198
- if isinstance(data.get("service_account_json"), dict):
217
+ service_account_json = data.get("service_account_json")
218
+ if isinstance(service_account_json, dict):
199
219
  data["service_account_json"] = json.dumps(
200
220
  data["service_account_json"]
201
221
  )
222
+ elif isinstance(service_account_json, str):
223
+ # Check if the service account JSON is base64 encoded and decode it
224
+ if re.match(r"^[A-Za-z0-9+/=]+$", service_account_json):
225
+ try:
226
+ data["service_account_json"] = base64.b64decode(
227
+ service_account_json
228
+ ).decode("utf-8")
229
+ except Exception as e:
230
+ raise ValueError(
231
+ f"Failed to decode base64 encoded service account JSON: {e}"
232
+ )
233
+
202
234
  return data
203
235
 
204
236
  @field_validator("service_account_json")
@@ -261,7 +293,7 @@ class GCPExternalAccountCredentials(AuthenticationConfig):
261
293
  """GCP external account credentials."""
262
294
 
263
295
  external_account_json: PlainSerializedSecretStr = Field(
264
- title="GCP External Account JSON",
296
+ title="GCP External Account JSON optionally base64 encoded.",
265
297
  )
266
298
 
267
299
  generate_temporary_tokens: bool = Field(
@@ -285,11 +317,27 @@ class GCPExternalAccountCredentials(AuthenticationConfig):
285
317
 
286
318
  Returns:
287
319
  The validated configuration values.
320
+
321
+ Raises:
322
+ ValueError: If the external account credentials JSON is invalid.
288
323
  """
289
- if isinstance(data.get("external_account_json"), dict):
324
+ external_account_json = data.get("external_account_json")
325
+ if isinstance(external_account_json, dict):
290
326
  data["external_account_json"] = json.dumps(
291
327
  data["external_account_json"]
292
328
  )
329
+ elif isinstance(external_account_json, str):
330
+ # Check if the external account JSON is base64 encoded and decode it
331
+ if re.match(r"^[A-Za-z0-9+/=]+$", external_account_json):
332
+ try:
333
+ data["external_account_json"] = base64.b64decode(
334
+ external_account_json
335
+ ).decode("utf-8")
336
+ except Exception as e:
337
+ raise ValueError(
338
+ f"Failed to decode base64 encoded external account JSON: {e}"
339
+ )
340
+
293
341
  return data
294
342
 
295
343
  @field_validator("external_account_json")