xpk 0.12.0__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. xpk/commands/batch.py +17 -10
  2. xpk/commands/cluster.py +137 -123
  3. xpk/commands/cluster_gcluster.py +77 -14
  4. xpk/commands/cluster_gcluster_test.py +177 -0
  5. xpk/commands/common.py +13 -27
  6. xpk/commands/info.py +11 -9
  7. xpk/commands/inspector.py +22 -11
  8. xpk/commands/job.py +53 -9
  9. xpk/commands/kind.py +38 -40
  10. xpk/commands/kjob_common.py +4 -4
  11. xpk/commands/run.py +9 -2
  12. xpk/commands/shell.py +13 -10
  13. xpk/commands/storage.py +26 -2
  14. xpk/commands/version.py +0 -4
  15. xpk/commands/workload.py +58 -30
  16. xpk/core/blueprint/blueprint_generator.py +4 -40
  17. xpk/core/blueprint/blueprint_test.py +0 -6
  18. xpk/core/capacity.py +6 -5
  19. xpk/core/cluster.py +96 -195
  20. xpk/core/cluster_private.py +9 -12
  21. xpk/core/commands.py +21 -25
  22. xpk/core/config.py +1 -1
  23. xpk/core/docker_image.py +17 -9
  24. xpk/core/docker_resources.py +9 -4
  25. xpk/core/gcloud_context.py +26 -2
  26. xpk/core/gcloud_context_test.py +96 -0
  27. xpk/core/gcluster_manager.py +0 -3
  28. xpk/core/jobset.py +5 -8
  29. xpk/core/kjob.py +19 -29
  30. xpk/core/kueue_manager.py +383 -0
  31. xpk/core/kueue_manager_test.py +542 -0
  32. xpk/core/monitoring.py +1 -1
  33. xpk/core/nap.py +11 -16
  34. xpk/core/network.py +18 -19
  35. xpk/core/nodepool.py +65 -71
  36. xpk/core/nodepool_test.py +198 -1
  37. xpk/core/pathways.py +9 -5
  38. xpk/core/ray.py +11 -15
  39. xpk/core/resources.py +15 -10
  40. xpk/core/scheduling.py +23 -1
  41. xpk/core/scheduling_test.py +31 -0
  42. xpk/core/system_characteristics.py +335 -229
  43. xpk/core/vertex.py +1 -1
  44. xpk/core/workload.py +7 -8
  45. xpk/main.py +3 -2
  46. xpk/parser/cluster.py +50 -0
  47. xpk/parser/cluster_test.py +66 -0
  48. xpk/parser/common.py +11 -0
  49. xpk/parser/workload.py +62 -25
  50. xpk/parser/workload_test.py +82 -0
  51. xpk/utils/execution_context.py +28 -0
  52. xpk/utils/feature_flags.py +28 -0
  53. xpk/utils/file.py +25 -10
  54. xpk/utils/kueue.py +20 -0
  55. xpk/utils/network.py +4 -0
  56. xpk/utils/templates.py +2 -0
  57. xpk/utils/topology.py +37 -0
  58. xpk/utils/topology_test.py +43 -0
  59. xpk/utils/validation.py +79 -55
  60. xpk/utils/validation_test.py +37 -0
  61. {xpk-0.12.0.dist-info → xpk-0.14.0.dist-info}/METADATA +6 -1
  62. xpk-0.14.0.dist-info/RECORD +112 -0
  63. xpk/core/kueue.py +0 -545
  64. xpk-0.12.0.dist-info/RECORD +0 -100
  65. {xpk-0.12.0.dist-info → xpk-0.14.0.dist-info}/WHEEL +0 -0
  66. {xpk-0.12.0.dist-info → xpk-0.14.0.dist-info}/entry_points.txt +0 -0
  67. {xpk-0.12.0.dist-info → xpk-0.14.0.dist-info}/licenses/LICENSE +0 -0
  68. {xpk-0.12.0.dist-info → xpk-0.14.0.dist-info}/top_level.txt +0 -0
xpk/core/commands.py CHANGED
@@ -18,14 +18,14 @@ import datetime
18
18
  import subprocess
19
19
  import sys
20
20
  import time
21
- from argparse import Namespace
22
21
 
23
22
  from ..utils.objects import chunks
24
23
  from ..utils.file import make_tmp_files, write_tmp_file
25
24
  from ..utils.console import xpk_print
25
+ from ..utils.execution_context import is_dry_run
26
26
 
27
27
 
28
- def run_commands(commands, jobname, per_command_name, batch=10, dry_run=False):
28
+ def run_commands(commands, jobname, per_command_name, batch=10):
29
29
  """Run commands in groups of `batch`.
30
30
 
31
31
  Args:
@@ -33,7 +33,6 @@ def run_commands(commands, jobname, per_command_name, batch=10, dry_run=False):
33
33
  jobname: the name of the job.
34
34
  per_command_name: list of command names.
35
35
  batch: number of commands to run in parallel.
36
- dry_run: enables dry_run if set to true.
37
36
 
38
37
  Returns:
39
38
  0 if successful and 1 otherwise.
@@ -46,7 +45,7 @@ def run_commands(commands, jobname, per_command_name, batch=10, dry_run=False):
46
45
  f'Breaking up a total of {len(commands)} commands into'
47
46
  f' {len(commands_batched)} batches'
48
47
  )
49
- if dry_run:
48
+ if is_dry_run():
50
49
  xpk_print('Pretending all the jobs succeeded')
51
50
  return 0
52
51
 
@@ -78,14 +77,13 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
78
77
  The max return code and a list of all the return codes.
79
78
  """
80
79
 
80
+ files = [open(f, 'w', encoding='utf-8') for f in output_logs]
81
81
  children = []
82
82
  start_time = datetime.datetime.now()
83
- for i, command in enumerate(commands):
83
+ for command, file in zip(commands, files):
84
84
  children.append(
85
85
  # subprocess managed by list pylint: disable=consider-using-with
86
- subprocess.Popen(
87
- command, stdout=output_logs[i], stderr=output_logs[i], shell=True
88
- )
86
+ subprocess.Popen(command, stdout=file, stderr=file, shell=True)
89
87
  )
90
88
 
91
89
  while True:
@@ -99,7 +97,7 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
99
97
  slow_worker_text = per_command_name[slow_worker_index]
100
98
  slow_str = (
101
99
  f', task {slow_worker_text} still working, logfile'
102
- f' {output_logs[slow_worker_index].name}'
100
+ f' {output_logs[slow_worker_index]}'
103
101
  )
104
102
  else:
105
103
  slow_str = ''
@@ -116,7 +114,7 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
116
114
  )
117
115
  xpk_print(
118
116
  f'Failure is {per_command_name[failing_index]}'
119
- f' and logfile {output_logs[failing_index].name}'
117
+ f' and logfile {output_logs[failing_index]}'
120
118
  )
121
119
  for child in children:
122
120
  child.terminate()
@@ -126,18 +124,21 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
126
124
  break
127
125
 
128
126
  time.sleep(1)
127
+
128
+ for file in files:
129
+ file.close()
130
+
129
131
  return max_returncode, returncodes
130
132
 
131
133
 
132
134
  def run_command_with_updates_retry(
133
- command, task, args, verbose=True, num_retry_attempts=5, wait_seconds=10
135
+ command, task, verbose=True, num_retry_attempts=5, wait_seconds=10
134
136
  ) -> int:
135
137
  """Generic run commands function with updates and retry logic.
136
138
 
137
139
  Args:
138
140
  command: command to execute
139
141
  task: user-facing name of the task
140
- args: user provided arguments for running the command.
141
142
  verbose: shows stdout and stderr if set to true. Set to True by default.
142
143
  num_retry_attempts: number of attempts to retry the command.
143
144
  This has a default value in the function arguments.
@@ -157,23 +158,22 @@ def run_command_with_updates_retry(
157
158
  time.sleep(wait_seconds)
158
159
  i += 1
159
160
  xpk_print(f'Try {i}: {task}')
160
- return_code = run_command_with_updates(command, task, args, verbose=verbose)
161
+ return_code = run_command_with_updates(command, task, verbose=verbose)
161
162
  return return_code
162
163
 
163
164
 
164
- def run_command_with_updates(command, task, global_args, verbose=True) -> int:
165
+ def run_command_with_updates(command, task, verbose=True) -> int:
165
166
  """Generic run commands function with updates.
166
167
 
167
168
  Args:
168
169
  command: command to execute
169
170
  task: user-facing name of the task
170
- global_args: user provided arguments for running the command.
171
171
  verbose: shows stdout and stderr if set to true. Set to True by default.
172
172
 
173
173
  Returns:
174
174
  0 if successful and 1 otherwise.
175
175
  """
176
- if global_args.dry_run:
176
+ if is_dry_run():
177
177
  xpk_print(
178
178
  f'Task: `{task}` is implemented by the following command'
179
179
  ' not running since it is a dry run.'
@@ -223,7 +223,6 @@ def run_command_with_updates(command, task, global_args, verbose=True) -> int:
223
223
  def run_command_for_value(
224
224
  command,
225
225
  task,
226
- global_args,
227
226
  dry_run_return_val='0',
228
227
  print_timer=False,
229
228
  hide_error=False,
@@ -236,7 +235,6 @@ def run_command_for_value(
236
235
  Args:
237
236
  command: user provided command to run.
238
237
  task: user provided task name for running the command.
239
- global_args: user provided arguments for running the command.
240
238
  dry_run_return_val: return value of this command for dry run.
241
239
  print_timer: print out the time the command is running.
242
240
  hide_error: hide the error from the command output upon success.
@@ -246,7 +244,7 @@ def run_command_for_value(
246
244
  int: return_code, default is 0
247
245
  str: return_val, default is '0'
248
246
  """
249
- if global_args is not None and global_args.dry_run:
247
+ if is_dry_run():
250
248
  xpk_print(
251
249
  f'Task: `{task}` is implemented by the following command'
252
250
  ' not running since it is a dry run.'
@@ -302,7 +300,6 @@ def run_command_for_value(
302
300
  def run_command_with_full_controls(
303
301
  command: str,
304
302
  task: str,
305
- global_args: Namespace,
306
303
  instructions: str | None = None,
307
304
  ) -> int:
308
305
  """Run command in current shell with system out, in and error handles. Wait
@@ -311,13 +308,12 @@ def run_command_with_full_controls(
311
308
  Args:
312
309
  command: command to execute
313
310
  task: user-facing name of the task
314
- global_args: user provided arguments for running the command.
315
311
  verbose: shows stdout and stderr if set to true. Set to True by default.
316
312
 
317
313
  Returns:
318
314
  0 if successful and 1 otherwise.
319
315
  """
320
- if global_args.dry_run:
316
+ if is_dry_run():
321
317
  xpk_print(
322
318
  f'Task: `{task}` is implemented by the following command'
323
319
  ' not running since it is a dry run.'
@@ -349,8 +345,8 @@ def run_command_with_full_controls(
349
345
  return return_code
350
346
 
351
347
 
352
- def run_kubectl_apply(yml_string: str, task: str, args: Namespace) -> int:
348
+ def run_kubectl_apply(yml_string: str, task: str) -> int:
353
349
  tmp = write_tmp_file(yml_string)
354
- command = f'kubectl apply -f {str(tmp.file.name)}'
355
- err_code = run_command_with_updates(command, task, args)
350
+ command = f'kubectl apply -f {str(tmp)}'
351
+ err_code = run_command_with_updates(command, task)
356
352
  return err_code
xpk/core/config.py CHANGED
@@ -22,7 +22,7 @@ from ..utils import file
22
22
  from ..utils.console import xpk_print
23
23
 
24
24
  # This is the version for XPK PyPI package
25
- __version__ = 'v0.12.0'
25
+ __version__ = 'v0.14.0'
26
26
  XPK_CURRENT_VERSION = __version__
27
27
  XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
28
28
 
xpk/core/docker_image.py CHANGED
@@ -21,6 +21,7 @@ import string
21
21
 
22
22
  from ..utils.console import xpk_exit, xpk_print
23
23
  from ..utils.file import write_tmp_file
24
+ from ..utils.execution_context import is_dry_run
24
25
  from .commands import run_command_with_updates
25
26
 
26
27
  DEFAULT_DOCKER_IMAGE = 'python:3.10'
@@ -48,7 +49,7 @@ def validate_docker_image(docker_image, args) -> int:
48
49
  f'gcloud container images describe {docker_image} --project {project}'
49
50
  )
50
51
  return_code = run_command_with_updates(
51
- command, 'Validate Docker Image', args, verbose=False
52
+ command, 'Validate Docker Image', verbose=False
52
53
  )
53
54
  if return_code != 0:
54
55
  xpk_print(
@@ -75,7 +76,9 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
75
76
  """
76
77
 
77
78
  # Pick a name for the docker image.
78
- docker_image_prefix = os.getenv('USER', 'unknown')
79
+ docker_image_prefix = (
80
+ 'dry-run' if is_dry_run() else os.getenv('USER', 'unknown')
81
+ )
79
82
  docker_name = f'{docker_image_prefix}-runner'
80
83
 
81
84
  script_dir_dockerfile = """FROM {base_docker_image}
@@ -94,14 +97,13 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
94
97
  )
95
98
  tmp = write_tmp_file(docker_file)
96
99
  docker_build_command = (
97
- f'docker buildx build --platform={PLATFORM} -f {str(tmp.file.name)} -t'
100
+ f'docker buildx build --platform={PLATFORM} -f {str(tmp)} -t'
98
101
  f' {docker_name} {args.script_dir}'
99
102
  )
100
103
  xpk_print(f'Building {args.script_dir} into docker image.')
101
104
  return_code = run_command_with_updates(
102
105
  docker_build_command,
103
106
  'Building script_dir into docker image',
104
- args,
105
107
  verbose=verbose,
106
108
  )
107
109
  if return_code != 0:
@@ -114,10 +116,16 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
114
116
 
115
117
  # Pick a randomly generated `tag_length` character docker tag.
116
118
  tag_length = 4
117
- tag_random_prefix = ''.join(
118
- random.choices(string.ascii_lowercase, k=tag_length)
119
+ tag_random_prefix = (
120
+ 'prefix'
121
+ if is_dry_run()
122
+ else ''.join(random.choices(string.ascii_lowercase, k=tag_length))
123
+ )
124
+ tag_datetime = (
125
+ 'current'
126
+ if is_dry_run()
127
+ else datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
119
128
  )
120
- tag_datetime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
121
129
  tag_name = f'{tag_random_prefix}-{tag_datetime}'
122
130
  cloud_docker_image = f'gcr.io/{args.project}/{docker_name}:{tag_name}'
123
131
  xpk_print(f'Adding Docker Image: {cloud_docker_image} to {args.project}')
@@ -125,7 +133,7 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
125
133
  # Tag the docker image.
126
134
  tag_docker_image_command = f'docker tag {docker_name} {cloud_docker_image}'
127
135
  return_code = run_command_with_updates(
128
- tag_docker_image_command, 'Tag Docker Image', args, verbose=verbose
136
+ tag_docker_image_command, 'Tag Docker Image', verbose=verbose
129
137
  )
130
138
  if return_code != 0:
131
139
  xpk_print(
@@ -138,7 +146,7 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
138
146
  # Upload image to Artifact Registry.
139
147
  upload_docker_image_command = f'docker push {cloud_docker_image}'
140
148
  return_code = run_command_with_updates(
141
- upload_docker_image_command, 'Upload Docker Image', args, verbose=verbose
149
+ upload_docker_image_command, 'Upload Docker Image', verbose=verbose
142
150
  )
143
151
  if return_code != 0:
144
152
  xpk_print(
@@ -20,6 +20,7 @@ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
20
20
  from .cluster import setup_k8s_env
21
21
  from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, PARALLELSTORE_TYPE, GCE_PD_TYPE, LUSTRE_TYPE, Storage, get_storages_to_mount
22
22
  from .system_characteristics import AcceleratorType, SystemCharacteristics
23
+ from ..utils.execution_context import is_dry_run
23
24
 
24
25
 
25
26
  def get_main_container_resources(
@@ -272,8 +273,10 @@ def get_volumes(args, system: SystemCharacteristics) -> str:
272
273
  - name: shared-data
273
274
  """
274
275
 
275
- storages: list[Storage] = get_storages_to_mount(
276
- setup_k8s_env(args), args.storage
276
+ storages: list[Storage] = (
277
+ []
278
+ if is_dry_run()
279
+ else get_storages_to_mount(setup_k8s_env(args), args.storage)
277
280
  )
278
281
  for storage in storages:
279
282
  if storage.type in {
@@ -325,8 +328,10 @@ def get_volume_mounts(args, system: SystemCharacteristics) -> str:
325
328
  elif system.accelerator_type == AcceleratorType['GPU']:
326
329
  volume_mount_yaml = ''
327
330
 
328
- storages: list[Storage] = get_storages_to_mount(
329
- setup_k8s_env(args), args.storage
331
+ storages: list[Storage] = (
332
+ []
333
+ if is_dry_run()
334
+ else get_storages_to_mount(setup_k8s_env(args), args.storage)
330
335
  )
331
336
  for storage in storages:
332
337
  if storage.type in {
@@ -18,8 +18,9 @@ import subprocess
18
18
  import sys
19
19
  from dataclasses import dataclass
20
20
 
21
- from ..utils.console import xpk_print
21
+ from ..utils.console import xpk_print, xpk_exit
22
22
  from .commands import run_command_for_value
23
+ from functools import lru_cache
23
24
 
24
25
 
25
26
  def get_project():
@@ -85,9 +86,33 @@ def zone_to_region(zone: str) -> str:
85
86
  The region name.
86
87
  """
87
88
  zone_terms = zone.split('-')
89
+ if len(zone_terms) != 2 and len(zone_terms) != 3:
90
+ raise ValueError(f'Invalid zone name: {zone}')
88
91
  return zone_terms[0] + '-' + zone_terms[1]
89
92
 
90
93
 
94
+ @lru_cache()
95
+ def get_cluster_location(project: str, name: str, zone: str) -> str:
96
+ """Helper function to resolve location for a given cluster"""
97
+ return_code, result = run_command_for_value(
98
+ command=(
99
+ 'gcloud container clusters list '
100
+ f'--project={project} '
101
+ f'--filter=name={name} '
102
+ '--format="value(location)"'
103
+ ),
104
+ task='Find cluster region or zone',
105
+ dry_run_return_val=zone_to_region(zone),
106
+ )
107
+
108
+ if return_code != 0:
109
+ xpk_print('Error: Unable to determine cluster region or zone')
110
+ xpk_exit(return_code)
111
+
112
+ regions = result.strip().splitlines()
113
+ return zone if zone in regions else zone_to_region(zone)
114
+
115
+
91
116
  @dataclass
92
117
  class GkeServerConfig:
93
118
  """Stores the valid gke versions based on gcloud recommendations."""
@@ -139,7 +164,6 @@ def get_gke_server_config(args) -> tuple[int, GkeServerConfig | None]:
139
164
  return_code, cmd_output = run_command_for_value(
140
165
  command,
141
166
  command_description,
142
- args,
143
167
  hide_error=True,
144
168
  )
145
169
  if return_code != 0:
@@ -0,0 +1,96 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import pytest
18
+ from .gcloud_context import get_cluster_location, zone_to_region
19
+
20
+
21
+ def test_zone_to_region_raises_when_zone_is_invalid():
22
+ with pytest.raises(ValueError):
23
+ zone_to_region("us")
24
+
25
+
26
+ def test_zone_to_region_returns_region_when_region_given():
27
+ assert zone_to_region("us-central1") == "us-central1"
28
+
29
+
30
+ def test_zone_to_region_returns_region_when_zone_is_valid():
31
+ assert zone_to_region("us-central1-a") == "us-central1"
32
+
33
+
34
+ def test_get_cluster_location_returns_cluster_region_when_cluster_is_regional(
35
+ mocker,
36
+ ):
37
+ mocker.patch(
38
+ "xpk.core.gcloud_context.run_command_for_value",
39
+ return_value=(0, "us-central1"),
40
+ )
41
+
42
+ result = get_cluster_location(
43
+ project="project1", name="name1", zone="us-central1-a"
44
+ )
45
+
46
+ assert result == "us-central1"
47
+
48
+
49
+ def test_get_cluster_location_returns_cluster_zone_when_both_regional_and_zonal_clusters_exist(
50
+ mocker,
51
+ ):
52
+ mocker.patch(
53
+ "xpk.core.gcloud_context.run_command_for_value",
54
+ return_value=(0, "us-central1\nus-central1-a"),
55
+ )
56
+
57
+ result = get_cluster_location(
58
+ project="project2", name="name2", zone="us-central1-a"
59
+ )
60
+
61
+ assert result == "us-central1-a"
62
+
63
+
64
+ def test_get_cluster_location_returns_given_zone_converted_to_region_when_cluster_is_not_found(
65
+ mocker,
66
+ ):
67
+ mocker.patch(
68
+ "xpk.core.gcloud_context.run_command_for_value", return_value=(0, "")
69
+ )
70
+
71
+ result = get_cluster_location(
72
+ project="project3", name="name3", zone="us-central1-a"
73
+ )
74
+
75
+ assert result == "us-central1"
76
+
77
+
78
+ def test_get_cluster_location_caches_previous_command_result(mocker):
79
+ mock = mocker.patch(
80
+ "xpk.core.gcloud_context.run_command_for_value", return_value=(0, "")
81
+ )
82
+
83
+ get_cluster_location(project="project4", name="name4", zone="us-central1-a")
84
+
85
+ assert mock.call_count == 1
86
+
87
+
88
+ def test_get_cluster_location_invokes_command_for_different_input_args(mocker):
89
+ mock = mocker.patch(
90
+ "xpk.core.gcloud_context.run_command_for_value", return_value=(0, "")
91
+ )
92
+
93
+ get_cluster_location(project="project5", name="name5", zone="us-central1-a")
94
+ get_cluster_location(project="project6", name="name6", zone="us-central1-a")
95
+
96
+ assert mock.call_count == 2
@@ -27,9 +27,6 @@ blueprint_file_name = 'xpk_blueprint.yaml'
27
27
  deployment_module = '/out/xpk-deployment'
28
28
  a3_utils_dir_name = 'a3-mega-xpk'
29
29
  config_map_repo_path = 'src/xpk/blueprints/a3-mega-xpk/config-map.yaml.tftpl'
30
- kueue_config_repo_path = (
31
- 'src/xpk/blueprints/a3-mega-xpk/kueue-xpk-configuration.yaml.tftpl'
32
- )
33
30
 
34
31
 
35
32
  class GclusterManager:
xpk/core/jobset.py CHANGED
@@ -18,7 +18,7 @@ import math
18
18
 
19
19
  from ..utils.console import xpk_exit, xpk_print
20
20
  from ..utils.file import write_tmp_file
21
- from ..core.kueue import (
21
+ from ..core.kueue_manager import (
22
22
  MEMORY_SIZE_PER_VM,
23
23
  MIN_MEMORY_LIMIT_SIZE,
24
24
  )
@@ -110,19 +110,16 @@ spec:
110
110
  """
111
111
 
112
112
 
113
- def update_jobset_resources_if_necessary(args):
113
+ def update_jobset_resources_if_necessary():
114
114
  """Update the jobset manifest to increase the resources for the jobset controller manager.
115
115
 
116
- Args:
117
- args: user provided arguments for running the command.
118
-
119
116
  Returns:
120
117
  0 if successful and 1 otherwise.
121
118
  """
122
119
  # Get total number of nodes
123
120
  cmd_total_node_num = 'kubectl get node --no-headers | wc -l'
124
121
  return_code, out = run_command_for_value(
125
- cmd_total_node_num, 'Count total nodes', args
122
+ cmd_total_node_num, 'Count total nodes'
126
123
  )
127
124
  if return_code != 0:
128
125
  xpk_exit(1)
@@ -134,10 +131,10 @@ def update_jobset_resources_if_necessary(args):
134
131
  memory_limit_size=new_memory_limit,
135
132
  )
136
133
  tmp = write_tmp_file(yml_string)
137
- command = f'kubectl apply -f {str(tmp.file.name)}'
134
+ command = f'kubectl apply -f {str(tmp)}'
138
135
 
139
136
  task = 'Updating jobset Controller Manager resources'
140
- return_code = run_command_with_updates_retry(command, task, args)
137
+ return_code = run_command_with_updates_retry(command, task)
141
138
  if return_code != 0:
142
139
  xpk_print(f'{task} returned ERROR {return_code}')
143
140
  return return_code
xpk/core/kjob.py CHANGED
@@ -23,6 +23,7 @@ from kubernetes.client import ApiClient
23
23
  from kubernetes.client.rest import ApiException
24
24
 
25
25
  from ..utils import templates
26
+ from ..utils.execution_context import is_dry_run
26
27
  from ..utils.console import xpk_exit, xpk_print
27
28
  from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
28
29
  from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
@@ -166,8 +167,8 @@ Kueue_TAS_annotation = "kueue.x-k8s.io/podset-preferred-topology=cloud.google.co
166
167
  default_interface_annotation = "networking.gke.io/default-interface=eth0"
167
168
 
168
169
 
169
- def get_a4_pod_template_annotations(args) -> tuple[str, str]:
170
- sub_networks = get_cluster_subnetworks(args)
170
+ def get_a4_pod_template_annotations() -> tuple[str, str]:
171
+ sub_networks = get_cluster_subnetworks()
171
172
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
172
173
  sub_networks
173
174
  )
@@ -178,8 +179,8 @@ def get_a4_pod_template_annotations(args) -> tuple[str, str]:
178
179
  )
179
180
 
180
181
 
181
- def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
182
- sub_networks = get_cluster_subnetworks(args)
182
+ def get_a3ultra_pod_template_annotations() -> tuple[str, str]:
183
+ sub_networks = get_cluster_subnetworks()
183
184
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
184
185
  sub_networks
185
186
  )
@@ -190,11 +191,9 @@ def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
190
191
  )
191
192
 
192
193
 
193
- def get_a3mega_pod_template_annotations(
194
- args: Namespace,
195
- ) -> tuple[str, str, str]:
194
+ def get_a3mega_pod_template_annotations() -> tuple[str, str, str]:
196
195
  """Adds or updates annotations in the Pod template."""
197
- sub_networks = get_cluster_subnetworks(args)
196
+ sub_networks = get_cluster_subnetworks()
198
197
  tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
199
198
  interfaces_key, interfaces_value = tcpxo_decorator.get_interfaces_entry(
200
199
  sub_networks
@@ -204,16 +203,14 @@ def get_a3mega_pod_template_annotations(
204
203
  return tcpxo, interfaces, default_interface_annotation
205
204
 
206
205
 
207
- def verify_kjob_installed(args: Namespace) -> int:
206
+ def verify_kjob_installed() -> int:
208
207
  """Check if kjob is installed. If not provide user with proper communicate and exit.
209
- Args:
210
- args - user provided arguments.
211
208
  Returns:
212
209
  error code > if kjob not installed, otherwise 0
213
210
  """
214
211
  command = "kubectl-kjob help"
215
212
  task = "Verify kjob installation "
216
- verify_kjob_installed_code, _ = run_command_for_value(command, task, args)
213
+ verify_kjob_installed_code, _ = run_command_for_value(command, task)
217
214
 
218
215
  if verify_kjob_installed_code == 0:
219
216
  xpk_print("kjob found")
@@ -245,9 +242,7 @@ def get_pod_template_interactive_command() -> str:
245
242
  return pod_command
246
243
 
247
244
 
248
- def create_app_profile_instance(
249
- args: Namespace, volume_bundles: list[str]
250
- ) -> int:
245
+ def create_app_profile_instance(volume_bundles: list[str]) -> int:
251
246
  """Create new AppProfile instance on cluster with default settings.
252
247
 
253
248
  Args:
@@ -263,7 +258,6 @@ def create_app_profile_instance(
263
258
  volume_bundles=volume_bundles,
264
259
  ),
265
260
  task="Creating AppProfile",
266
- args=args,
267
261
  )
268
262
 
269
263
 
@@ -331,15 +325,12 @@ def create_job_template_instance(
331
325
  return run_kubectl_apply(
332
326
  yml_string,
333
327
  task="Creating JobTemplate",
334
- args=args,
335
328
  )
336
329
 
337
330
 
338
- def create_pod_template_instance(args: Namespace, service_account: str) -> int:
331
+ def create_pod_template_instance(service_account: str) -> int:
339
332
  """Create new PodTemplate instance on cluster with default settings.
340
333
 
341
- Args:
342
- args - user provided arguments
343
334
  Returns:
344
335
  exit_code > 0 if creating PodTemplate fails, 0 otherwise
345
336
  """
@@ -361,15 +352,16 @@ def create_pod_template_instance(args: Namespace, service_account: str) -> int:
361
352
  service_account=service_account,
362
353
  ),
363
354
  task="Creating PodTemplate",
364
- args=args,
365
355
  )
366
356
 
367
357
 
368
358
  def prepare_kjob(args: Namespace) -> int:
369
359
  system = get_cluster_system_characteristics(args)
370
360
 
371
- k8s_api_client = setup_k8s_env(args)
372
- storages = get_auto_mount_storages(k8s_api_client)
361
+ storages = []
362
+ if not is_dry_run():
363
+ k8s_api_client = setup_k8s_env(args)
364
+ storages = get_auto_mount_storages(k8s_api_client)
373
365
 
374
366
  service_account = ""
375
367
  if len(storages) > 0:
@@ -378,29 +370,27 @@ def prepare_kjob(args: Namespace) -> int:
378
370
  job_err_code = create_job_template_instance(args, system, service_account)
379
371
  if job_err_code > 0:
380
372
  return job_err_code
381
- pod_err_code = create_pod_template_instance(args, service_account)
373
+ pod_err_code = create_pod_template_instance(service_account)
382
374
  if pod_err_code > 0:
383
375
  return pod_err_code
384
376
 
385
377
  volume_bundles = [item.name for item in storages]
386
378
 
387
- return create_app_profile_instance(args, volume_bundles)
379
+ return create_app_profile_instance(volume_bundles)
388
380
 
389
381
 
390
- def apply_kjob_crds(args: Namespace) -> int:
382
+ def apply_kjob_crds() -> int:
391
383
  """Apply kjob CRDs on cluster.
392
384
 
393
385
  This function install kjob CRDs files from kjobctl printcrds.
394
386
  It creates all neccessary kjob CRDs.
395
387
 
396
- Args:
397
- args - user provided arguments
398
388
  Returns:
399
389
  None
400
390
  """
401
391
  command = "kubectl kjob printcrds | kubectl apply --server-side -f -"
402
392
  task = "Create kjob CRDs on cluster"
403
- return_code = run_command_with_updates(command, task, args)
393
+ return_code = run_command_with_updates(command, task)
404
394
  if return_code != 0:
405
395
  xpk_print(f"{task} returned ERROR {return_code}")
406
396
  return return_code