xpk 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. integration/README.md +19 -0
  2. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  3. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  4. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  5. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  6. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  7. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  8. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  9. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  10. xpk/blueprints/a4/storage_crd.yaml +52 -0
  11. xpk/commands/cluster.py +33 -12
  12. xpk/commands/cluster_gcluster_test.py +5 -1
  13. xpk/commands/cluster_test.py +125 -0
  14. xpk/commands/config.py +3 -3
  15. xpk/commands/inspector.py +5 -3
  16. xpk/commands/kind.py +2 -0
  17. xpk/commands/managed_ml_diagnostics.py +249 -0
  18. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  19. xpk/commands/workload.py +124 -139
  20. xpk/commands/workload_test.py +160 -118
  21. xpk/core/blueprint/blueprint_generator.py +3 -0
  22. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  23. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  24. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  25. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  26. xpk/core/capacity.py +2 -0
  27. xpk/core/cluster.py +18 -47
  28. xpk/core/cluster_test.py +76 -1
  29. xpk/core/config.py +81 -7
  30. xpk/core/config_test.py +67 -11
  31. xpk/core/docker_container.py +3 -1
  32. xpk/core/docker_image.py +10 -6
  33. xpk/core/docker_resources.py +1 -10
  34. xpk/core/kjob.py +17 -16
  35. xpk/core/kueue_manager.py +13 -19
  36. xpk/core/kueue_manager_test.py +27 -1
  37. xpk/core/nap.py +13 -14
  38. xpk/core/nodepool.py +17 -15
  39. xpk/core/nodepool_test.py +25 -4
  40. xpk/core/pathways.py +23 -0
  41. xpk/core/pathways_test.py +57 -0
  42. xpk/core/resources.py +84 -27
  43. xpk/core/scheduling.py +128 -132
  44. xpk/core/scheduling_test.py +215 -2
  45. xpk/core/system_characteristics.py +179 -0
  46. xpk/core/system_characteristics_test.py +49 -1
  47. xpk/core/telemetry.py +4 -4
  48. xpk/core/telemetry_test.py +9 -9
  49. xpk/core/vertex.py +4 -3
  50. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  51. xpk/main.py +2 -0
  52. xpk/parser/cluster.py +22 -88
  53. xpk/parser/cluster_test.py +41 -0
  54. xpk/parser/common.py +84 -0
  55. xpk/parser/storage.py +10 -0
  56. xpk/parser/storage_test.py +47 -0
  57. xpk/parser/workload.py +14 -41
  58. xpk/parser/workload_test.py +2 -48
  59. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  60. xpk/utils/feature_flags.py +3 -0
  61. xpk/utils/validation.py +2 -2
  62. xpk-0.16.0.dist-info/METADATA +127 -0
  63. {xpk-0.15.0.dist-info → xpk-0.16.0.dist-info}/RECORD +67 -48
  64. xpk-0.15.0.dist-info/METADATA +0 -1666
  65. {xpk-0.15.0.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  66. {xpk-0.15.0.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  67. {xpk-0.15.0.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  68. {xpk-0.15.0.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ class _Mocks:
35
35
  common_print_mock: MagicMock
36
36
  commands_print_mock: MagicMock
37
37
  commands_get_reservation_deployment_type: MagicMock
38
+ commands_get_pathways_machine_types: MagicMock
38
39
  commands_tester: CommandsTester
39
40
 
40
41
 
@@ -75,15 +76,23 @@ def mocks(mocker) -> _Mocks:
75
76
  'xpk.commands.cluster.get_reservation_deployment_type',
76
77
  return_value='DENSE',
77
78
  )
79
+ commands_get_pathways_machine_types = mocker.patch(
80
+ 'xpk.commands.cluster.get_pathways_machine_types',
81
+ return_value=(0, []),
82
+ )
78
83
  return _Mocks(
79
84
  common_print_mock=common_print_mock,
80
85
  commands_get_reservation_deployment_type=commands_get_reservation_deployment_type,
81
86
  commands_print_mock=commands_print_mock,
87
+ commands_get_pathways_machine_types=commands_get_pathways_machine_types,
82
88
  commands_tester=CommandsTester(
83
89
  mocker,
84
90
  run_command_with_updates_path=(
85
91
  'xpk.commands.cluster.run_command_with_updates'
86
92
  ),
93
+ run_command_for_value_path=(
94
+ 'xpk.commands.cluster.run_command_for_value'
95
+ ),
87
96
  ),
88
97
  )
89
98
 
@@ -104,6 +113,7 @@ def construct_args(**kwargs: Any) -> Namespace:
104
113
  gke_version='',
105
114
  private=False,
106
115
  authorized_networks=None,
116
+ pathways_gce_machine_type='n2-standard-64',
107
117
  enable_pathways=False,
108
118
  enable_ray_cluster=False,
109
119
  enable_workload_identity=False,
@@ -121,6 +131,21 @@ def construct_args(**kwargs: Any) -> Namespace:
121
131
  cluster_cpu_machine_type='',
122
132
  create_vertex_tensorboard=False,
123
133
  enable_autoprovisioning=False,
134
+ sub_slicing_topology='2x2x2',
135
+ use_vertex_tensorboard=False,
136
+ env_file='',
137
+ env=None,
138
+ use_pathways=False,
139
+ debug_dump_gcs=False,
140
+ storage='',
141
+ restart_on_exit_codes=None,
142
+ ttl_seconds_after_finished=0,
143
+ max_restarts=1,
144
+ priority=0,
145
+ termination_grace_period_seconds=0,
146
+ docker_image_pull_secret='',
147
+ managed_mldiagnostics=False,
148
+ output_manifest_file='',
124
149
  )
125
150
  args_dict.update(kwargs)
126
151
  return Namespace(**args_dict)
@@ -288,6 +313,64 @@ def test_validate_cluster_create_args_for_invalid_reservation(
288
313
  )
289
314
 
290
315
 
316
+ def test_validate_cluster_create_args_for_enable_pathways_set_to_false(
317
+ mocks: _Mocks,
318
+ ):
319
+ args = construct_args(enable_pathways=False)
320
+ mocks.commands_get_pathways_machine_types.return_value = (1, [])
321
+
322
+ _validate_cluster_create_args(args, TPU_TEST_SYSTEM)
323
+
324
+ assert mocks.commands_print_mock.call_count == 0
325
+
326
+
327
+ def test_validate_cluster_create_args_for_errored_pathways_machine_types_retrieval(
328
+ mocks: _Mocks,
329
+ ):
330
+ args = construct_args(enable_pathways=True)
331
+ mocks.commands_get_pathways_machine_types.return_value = (1, [])
332
+
333
+ with pytest.raises(SystemExit):
334
+ _validate_cluster_create_args(args, TPU_TEST_SYSTEM)
335
+
336
+ assert mocks.commands_print_mock.call_count == 1
337
+ assert 'Unable to retrieve' in mocks.commands_print_mock.call_args[0][0]
338
+
339
+
340
+ def test_validate_cluster_create_args_for_invalid_pathways_machine_type(
341
+ mocks: _Mocks,
342
+ ):
343
+ args = construct_args(
344
+ enable_pathways=True, pathways_gce_machine_type='n2-standard-32'
345
+ )
346
+ mocks.commands_get_pathways_machine_types.return_value = (
347
+ 0,
348
+ ['n2-standard-64'],
349
+ )
350
+
351
+ with pytest.raises(SystemExit):
352
+ _validate_cluster_create_args(args, TPU_TEST_SYSTEM)
353
+
354
+ assert mocks.commands_print_mock.call_count == 2
355
+ assert 'Available machine types' in mocks.commands_print_mock.call_args[0][0]
356
+
357
+
358
+ def test_validate_cluster_create_args_for_valid_pathways_machine_type(
359
+ mocks: _Mocks,
360
+ ):
361
+ args = construct_args(
362
+ enable_pathways=True, pathways_gce_machine_type='n2-standard-32'
363
+ )
364
+ mocks.commands_get_pathways_machine_types.return_value = (
365
+ 0,
366
+ ['n2-standard-32'],
367
+ )
368
+
369
+ _validate_cluster_create_args(args, TPU_TEST_SYSTEM)
370
+
371
+ assert mocks.commands_print_mock.call_count == 0
372
+
373
+
291
374
  @patch('xpk.commands.cluster.KueueManager.install_or_upgrade')
292
375
  def test_install_kueue_returns_kueue_installation_code(
293
376
  mock_kueue_manager_install: MagicMock,
@@ -358,6 +441,48 @@ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag
358
441
  )
359
442
 
360
443
 
444
+ def test_run_gke_cluster_create_command_with_lustre_runs_correct_command(
445
+ mocks: _Mocks,
446
+ ):
447
+ result = run_gke_cluster_create_command(
448
+ args=construct_args(
449
+ enable_lustre_csi_driver=True, enable_legacy_lustre_port=False
450
+ ),
451
+ gke_control_plane_version='1.2.3',
452
+ system=TPU_TEST_SYSTEM,
453
+ release_channel=ReleaseChannel.REGULAR,
454
+ )
455
+
456
+ assert result == 0
457
+ commands = mocks.commands_tester.get_matching_commands('clusters create')
458
+ assert len(commands) == 1
459
+ command = commands[0]
460
+ assert (
461
+ '--addons=LustreCsiDriver' in command
462
+ and '--enable-legacy-lustre-port' not in command
463
+ )
464
+
465
+
466
+ def test_run_gke_cluster_create_command_with_lustre_legacy_port_adds_correct_flag(
467
+ mocks: _Mocks,
468
+ ):
469
+ result = run_gke_cluster_create_command(
470
+ args=construct_args(
471
+ enable_lustre_csi_driver=True, enable_legacy_lustre_port=True
472
+ ),
473
+ gke_control_plane_version='1.2.3',
474
+ system=TPU_TEST_SYSTEM,
475
+ release_channel=ReleaseChannel.REGULAR,
476
+ )
477
+
478
+ assert result == 0
479
+ mocks.commands_tester.assert_command_run(
480
+ 'clusters create',
481
+ '--enable-legacy-lustre-port',
482
+ '--addons=LustreCsiDriver',
483
+ )
484
+
485
+
361
486
  def test_log_cluster_create_telemetry_does_not_log_when_feature_flag_is_disabled():
362
487
  FeatureFlags.TELEMETRY_ENABLED = False
363
488
  _log_cluster_create_telemetry(construct_args())
xpk/commands/config.py CHANGED
@@ -14,14 +14,14 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from ..core.config import xpk_config
17
+ from ..core.config import get_config as get_xpk_config
18
18
  from ..utils.console import xpk_print
19
19
 
20
20
 
21
21
  def set_config(args):
22
- xpk_config.set(args.set_config_args[0], args.set_config_args[1])
22
+ get_xpk_config().set(args.set_config_args[0], args.set_config_args[1])
23
23
 
24
24
 
25
25
  def get_config(args):
26
- value = xpk_config.get(args.get_config_key[0])
26
+ value = get_xpk_config().get(args.get_config_key[0])
27
27
  xpk_print(value)
xpk/commands/inspector.py CHANGED
@@ -18,7 +18,7 @@ from ..core.cluster import get_cluster_credentials
18
18
  from ..core.commands import run_command_for_value
19
19
  from ..core.gcloud_context import add_zone_and_project, get_cluster_location
20
20
  from ..core.kueue_manager import CLUSTER_QUEUE_NAME, LOCAL_QUEUE_NAME
21
- from ..core.resources import CLUSTER_METADATA_CONFIGMAP, CLUSTER_RESOURCES_CONFIGMAP
21
+ from ..core.resources import ConfigMapType, get_config_map_name
22
22
  from ..utils.console import xpk_exit, xpk_print
23
23
  from ..utils.file import append_tmp_file, write_tmp_file
24
24
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
@@ -162,14 +162,16 @@ def inspector(args) -> None:
162
162
  (
163
163
  (
164
164
  'kubectl get configmap'
165
- f' {args.cluster}-{CLUSTER_METADATA_CONFIGMAP} -o yaml'
165
+ f' {get_config_map_name(args.cluster, ConfigMapType.METADATA)} -o'
166
+ ' yaml'
166
167
  ),
167
168
  'GKE: Cluster Metadata ConfigMap Details',
168
169
  ),
169
170
  (
170
171
  (
171
172
  'kubectl get configmap'
172
- f' {args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP} -o yaml'
173
+ f' {get_config_map_name(args.cluster, ConfigMapType.RESOURCES)} -o'
174
+ ' yaml'
173
175
  ),
174
176
  'GKE: Cluster Resources ConfigMap Details',
175
177
  ),
xpk/commands/kind.py CHANGED
@@ -30,6 +30,7 @@ from ..core.storage import install_storage_crd
30
30
  from ..core.system_characteristics import (
31
31
  SystemCharacteristics,
32
32
  AcceleratorType,
33
+ DockerPlatform,
33
34
  )
34
35
  from ..utils.console import (xpk_exit, xpk_print)
35
36
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
@@ -97,6 +98,7 @@ def cluster_create(args) -> None:
97
98
  AcceleratorType.CPU,
98
99
  'kind',
99
100
  supports_sub_slicing=False,
101
+ docker_platform=DockerPlatform.ARM,
100
102
  )
101
103
 
102
104
  kueue_manager = KueueManager(project='', zone='')
@@ -0,0 +1,249 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from packaging.version import Version
18
+ from ..core.commands import run_command_for_value, run_command_with_updates
19
+ from ..utils.console import xpk_print
20
+ import os
21
+ import tempfile
22
+
23
+ _KUEUE_DEPLOYMENT_NAME = 'kueue-controller-manager'
24
+ _KUEUE_NAMESPACE_NAME = 'kueue-system'
25
+ _CERT_WEBHOOK_DEPLOYMENT_NAME = 'cert-manager-webhook'
26
+ _CERT_WEBHOOK_NAMESPACE_NAME = 'cert-manager'
27
+ _WEBHOOK_PACKAGE = 'mldiagnostics-injection-webhook'
28
+ _WEBHOOK_VERSION = Version('v0.5.0')
29
+ _WEBHOOK_FILENAME = f'{_WEBHOOK_PACKAGE}-v{_WEBHOOK_VERSION}.yaml'
30
+ _OPERATOR_PACKAGE = 'mldiagnostics-connection-operator'
31
+ _OPERATOR_VERSION = Version('v0.5.0')
32
+ _OPERATOR_FILENAME = f'{_OPERATOR_PACKAGE}-v{_OPERATOR_VERSION}.yaml'
33
+ _CERT_MANAGER_VERSION = Version('v1.13.0')
34
+
35
+
36
+ def _install_cert_manager(version: Version = _CERT_MANAGER_VERSION) -> int:
37
+ """
38
+ Apply the cert-manager manifest.
39
+
40
+ Returns:
41
+ 0 if successful and 1 otherwise.
42
+ """
43
+
44
+ command = (
45
+ 'kubectl apply -f'
46
+ ' https://github.com/cert-manager/cert-manager/releases/download/'
47
+ f'v{version}/cert-manager.yaml'
48
+ )
49
+
50
+ return_code = run_command_with_updates(
51
+ command, f'Applying cert-manager {version} manifest...'
52
+ )
53
+
54
+ return return_code
55
+
56
+
57
+ def _download_mldiagnostics_yaml(package_name: str, version: Version) -> int:
58
+ """
59
+ Downloads the mldiagnostics injection webhook YAML from Artifact Registry.
60
+
61
+ Returns:
62
+ 0 if successful and 1 otherwise.
63
+ """
64
+
65
+ command = (
66
+ 'gcloud artifacts generic download'
67
+ ' --repository=mldiagnostics-webhook-and-operator-yaml --location=us'
68
+ f' --package={package_name} --version=v{version} --destination=/tmp/'
69
+ ' --project=ai-on-gke'
70
+ )
71
+
72
+ return_code, return_output = run_command_for_value(
73
+ command,
74
+ f'Download {package_name} {version}...',
75
+ )
76
+
77
+ if return_code != 0:
78
+ if 'already exists' in return_output:
79
+ xpk_print(
80
+ f'Artifact file for {package_name} {version} already exists locally.'
81
+ ' Skipping download.'
82
+ )
83
+ return 0
84
+
85
+ return return_code
86
+
87
+
88
+ def _create_mldiagnostics_namespace() -> int:
89
+ """
90
+ Creates the 'gke-mldiagnostics' namespace.
91
+
92
+ Returns:
93
+ 0 if successful and 1 otherwise.
94
+ """
95
+
96
+ command = 'kubectl create namespace gke-mldiagnostics'
97
+
98
+ return_code, return_output = run_command_for_value(
99
+ command, 'Create gke-mldiagnostics namespace...'
100
+ )
101
+
102
+ if return_code != 0:
103
+ if 'already exists' in return_output:
104
+ xpk_print('Namespace already exists. Skipping creation.')
105
+ return 0
106
+
107
+ return return_code
108
+
109
+
110
+ def _install_mldiagnostics_yaml(artifact_filename: str) -> int:
111
+ """
112
+ Applies the mldiagnostics injection webhook YAML manifest.
113
+
114
+ Returns:
115
+ 0 if successful and 1 otherwise.
116
+ """
117
+ full_artifact_path = os.path.join(tempfile.gettempdir(), artifact_filename)
118
+
119
+ command = f'kubectl apply -f {full_artifact_path} -n gke-mldiagnostics'
120
+
121
+ return run_command_with_updates(
122
+ command,
123
+ f'Install {full_artifact_path}...',
124
+ )
125
+
126
+
127
+ def _label_default_namespace_mldiagnostics() -> int:
128
+ """
129
+ Labels the 'default' namespace with 'managed-mldiagnostics-gke=true'.
130
+
131
+ Returns:
132
+ 0 if successful and 1 otherwise.
133
+ """
134
+
135
+ command = 'kubectl label namespace default managed-mldiagnostics-gke=true'
136
+
137
+ return run_command_with_updates(
138
+ command,
139
+ 'Label default namespace with managed-mldiagnostics-gke=true',
140
+ )
141
+
142
+
143
+ def install_mldiagnostics_prerequisites() -> int:
144
+ """
145
+ Mldiagnostics installation requirements.
146
+
147
+ Returns:
148
+ 0 if successful and 1 otherwise.
149
+ """
150
+
151
+ if not _wait_for_deployment_ready(
152
+ deployment_name=_KUEUE_DEPLOYMENT_NAME, namespace=_KUEUE_NAMESPACE_NAME
153
+ ):
154
+ xpk_print(
155
+ f'Application {_KUEUE_DEPLOYMENT_NAME} failed to become ready within'
156
+ ' the timeout.'
157
+ )
158
+ return 1
159
+
160
+ return_code = _install_cert_manager()
161
+ if return_code != 0:
162
+ return return_code
163
+
164
+ cert_webhook_ready = _wait_for_deployment_ready(
165
+ deployment_name=_CERT_WEBHOOK_DEPLOYMENT_NAME,
166
+ namespace=_CERT_WEBHOOK_NAMESPACE_NAME,
167
+ )
168
+ if not cert_webhook_ready:
169
+ xpk_print('The cert-manager-webhook installation failed.')
170
+ return 1
171
+
172
+ return_code = _download_mldiagnostics_yaml(
173
+ package_name=_WEBHOOK_PACKAGE, version=_WEBHOOK_VERSION
174
+ )
175
+ if return_code != 0:
176
+ return return_code
177
+
178
+ return_code = _create_mldiagnostics_namespace()
179
+ if return_code != 0:
180
+ return return_code
181
+
182
+ return_code = _install_mldiagnostics_yaml(artifact_filename=_WEBHOOK_FILENAME)
183
+ if return_code != 0:
184
+ return return_code
185
+
186
+ return_code = _label_default_namespace_mldiagnostics()
187
+ if return_code != 0:
188
+ return return_code
189
+
190
+ return_code = _download_mldiagnostics_yaml(
191
+ package_name=_OPERATOR_PACKAGE, version=_OPERATOR_VERSION
192
+ )
193
+ if return_code != 0:
194
+ return return_code
195
+
196
+ return_code = _install_mldiagnostics_yaml(
197
+ artifact_filename=_OPERATOR_FILENAME
198
+ )
199
+ if return_code != 0:
200
+ return return_code
201
+
202
+ xpk_print(
203
+ 'All mldiagnostics installation and setup steps have been'
204
+ ' successfully completed!'
205
+ )
206
+ return 0
207
+
208
+
209
+ def _wait_for_deployment_ready(
210
+ deployment_name: str, namespace: str, timeout_seconds: int = 300
211
+ ) -> bool:
212
+ """
213
+ Polls the Kubernetes Deployment status using kubectl rollout status
214
+ until it successfully rolls out (all replicas are ready) or times out.
215
+
216
+ Args:
217
+ deployment_name: The name of the Kubernetes Deployment (e.g., 'kueue-controller-manager').
218
+ namespace: The namespace where the Deployment is located (e.g., 'kueue-system').
219
+ timeout_seconds: Timeout duration in seconds (default is 300s / 5 minutes).
220
+
221
+ Returns:
222
+ bool: True if the Deployment successfully rolled out, False otherwise (timeout or error).
223
+ """
224
+
225
+ command = (
226
+ f'kubectl rollout status deployment/{deployment_name} -n {namespace}'
227
+ f' --timeout={timeout_seconds}s'
228
+ )
229
+
230
+ return_code = run_command_with_updates(
231
+ command, f'Checking status of deployment {deployment_name}...'
232
+ )
233
+
234
+ if return_code != 0:
235
+ return False
236
+
237
+ # When the status changes to 'running,' it might need about 10 seconds to fully stabilize.
238
+ stabilization_seconds = 30
239
+ stabilization_command = f'sleep {stabilization_seconds}'
240
+ stabilization_code = run_command_with_updates(
241
+ stabilization_command,
242
+ f'Deployment {deployment_name} is ready. Waiting {stabilization_seconds}'
243
+ ' seconds for full stabilization',
244
+ verbose=True,
245
+ )
246
+ if stabilization_code != 0:
247
+ return False
248
+
249
+ return True
@@ -0,0 +1,146 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from dataclasses import dataclass
18
+ from unittest.mock import MagicMock
19
+ import pytest
20
+ from xpk.commands.managed_ml_diagnostics import install_mldiagnostics_prerequisites
21
+ from xpk.core.testing.commands_tester import CommandsTester
22
+
23
+
24
+ @dataclass
25
+ class _Mocks:
26
+ common_print_mock: MagicMock
27
+ commands_print_mock: MagicMock
28
+ commands_get_reservation_deployment_type: MagicMock
29
+ commands_tester: CommandsTester
30
+
31
+
32
+ @pytest.fixture
33
+ def mocks(mocker) -> _Mocks:
34
+ common_print_mock = mocker.patch(
35
+ 'xpk.commands.common.xpk_print',
36
+ return_value=None,
37
+ )
38
+ commands_print_mock = mocker.patch(
39
+ 'xpk.commands.cluster.xpk_print', return_value=None
40
+ )
41
+ commands_get_reservation_deployment_type = mocker.patch(
42
+ 'xpk.commands.cluster.get_reservation_deployment_type',
43
+ return_value='DENSE',
44
+ )
45
+ return _Mocks(
46
+ common_print_mock=common_print_mock,
47
+ commands_get_reservation_deployment_type=commands_get_reservation_deployment_type,
48
+ commands_print_mock=commands_print_mock,
49
+ commands_tester=CommandsTester(
50
+ mocker,
51
+ run_command_with_updates_path=(
52
+ 'xpk.commands.managed_ml_diagnostics.run_command_with_updates'
53
+ ),
54
+ run_command_for_value_path=(
55
+ 'xpk.commands.managed_ml_diagnostics.run_command_for_value'
56
+ ),
57
+ ),
58
+ )
59
+
60
+
61
+ def test_install_mldiagnostics_prerequisites_commands_executed(
62
+ mocks: _Mocks,
63
+ ):
64
+
65
+ install_mldiagnostics_prerequisites()
66
+
67
+ mocks.commands_tester.assert_command_run(
68
+ 'kubectl',
69
+ 'rollout',
70
+ 'status',
71
+ 'deployment/kueue-controller-manager',
72
+ times=1,
73
+ )
74
+
75
+ mocks.commands_tester.assert_command_run(
76
+ 'kubectl',
77
+ 'apply',
78
+ '-f',
79
+ 'https://github.com/cert-manager/cert-manager/',
80
+ times=1,
81
+ )
82
+
83
+ mocks.commands_tester.assert_command_run(
84
+ 'kubectl', 'rollout', 'status', 'deployment/cert-manager-webhook', times=1
85
+ )
86
+
87
+ mocks.commands_tester.assert_command_run(
88
+ 'gcloud',
89
+ 'artifacts',
90
+ 'generic',
91
+ 'download',
92
+ '--package=mldiagnostics-injection-webhook',
93
+ '--version=v0.5.0',
94
+ times=1,
95
+ )
96
+
97
+ mocks.commands_tester.assert_command_run(
98
+ 'kubectl', 'create', 'namespace', 'gke-mldiagnostics', times=1
99
+ )
100
+
101
+ mocks.commands_tester.assert_command_run(
102
+ 'kubectl',
103
+ 'apply',
104
+ '-f',
105
+ '/tmp/mldiagnostics-injection-webhook-v0.5.0.yaml',
106
+ '-n',
107
+ 'gke-mldiagnostics',
108
+ times=1,
109
+ )
110
+
111
+ mocks.commands_tester.assert_command_run(
112
+ 'kubectl',
113
+ 'label',
114
+ 'namespace',
115
+ 'default',
116
+ 'managed-mldiagnostics-gke=true',
117
+ times=1,
118
+ )
119
+
120
+ mocks.commands_tester.assert_command_run(
121
+ 'gcloud',
122
+ 'artifacts',
123
+ 'generic',
124
+ 'download',
125
+ '--package=mldiagnostics-connection-operator',
126
+ '--version=v0.5.0',
127
+ times=1,
128
+ )
129
+
130
+ mocks.commands_tester.assert_command_run(
131
+ 'kubectl',
132
+ 'apply',
133
+ '-f',
134
+ '/tmp/mldiagnostics-connection-operator-v0.5.0.yaml',
135
+ '-n',
136
+ 'gke-mldiagnostics',
137
+ times=1,
138
+ )
139
+
140
+ mocks.commands_tester.assert_command_run(
141
+ 'gcloud', 'artifacts', 'generic', 'download', times=2
142
+ )
143
+
144
+ mocks.commands_tester.assert_command_run(
145
+ 'kubectl', 'apply', '-f', '-n', 'gke-mldiagnostics', times=2
146
+ )