xpk 0.14.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. integration/README.md +19 -0
  2. integration/gcluster_a3mega_test.py +11 -0
  3. integration/gcluster_a3ultra_test.py +11 -0
  4. integration/gcluster_a4_test.py +11 -0
  5. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  6. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  7. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  8. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  9. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  10. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  11. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  12. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  13. xpk/blueprints/a4/storage_crd.yaml +52 -0
  14. xpk/commands/cluster.py +89 -32
  15. xpk/commands/cluster_gcluster.py +25 -5
  16. xpk/commands/cluster_gcluster_test.py +16 -3
  17. xpk/commands/cluster_test.py +353 -7
  18. xpk/commands/config.py +3 -5
  19. xpk/commands/inspector.py +5 -3
  20. xpk/commands/kind.py +3 -1
  21. xpk/commands/managed_ml_diagnostics.py +249 -0
  22. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  23. xpk/commands/storage.py +8 -10
  24. xpk/commands/workload.py +143 -142
  25. xpk/commands/workload_test.py +160 -118
  26. xpk/core/blueprint/blueprint_generator.py +73 -33
  27. xpk/core/blueprint/blueprint_test.py +9 -0
  28. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  29. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  30. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  31. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  32. xpk/core/capacity.py +48 -8
  33. xpk/core/capacity_test.py +32 -1
  34. xpk/core/cluster.py +55 -104
  35. xpk/core/cluster_test.py +170 -0
  36. xpk/core/commands.py +4 -10
  37. xpk/core/config.py +88 -7
  38. xpk/core/config_test.py +67 -11
  39. xpk/core/docker_container.py +3 -1
  40. xpk/core/docker_image.py +10 -6
  41. xpk/core/docker_resources.py +1 -10
  42. xpk/core/gcloud_context.py +18 -12
  43. xpk/core/gcloud_context_test.py +111 -1
  44. xpk/core/kjob.py +17 -19
  45. xpk/core/kueue_manager.py +205 -51
  46. xpk/core/kueue_manager_test.py +158 -4
  47. xpk/core/nap.py +13 -14
  48. xpk/core/nodepool.py +37 -43
  49. xpk/core/nodepool_test.py +42 -19
  50. xpk/core/pathways.py +23 -0
  51. xpk/core/pathways_test.py +57 -0
  52. xpk/core/resources.py +84 -27
  53. xpk/core/scheduling.py +144 -133
  54. xpk/core/scheduling_test.py +298 -6
  55. xpk/core/system_characteristics.py +256 -19
  56. xpk/core/system_characteristics_test.py +128 -5
  57. xpk/core/telemetry.py +263 -0
  58. xpk/core/telemetry_test.py +211 -0
  59. xpk/core/vertex.py +4 -3
  60. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  61. xpk/main.py +33 -13
  62. xpk/parser/cluster.py +40 -67
  63. xpk/parser/cluster_test.py +83 -3
  64. xpk/parser/common.py +84 -0
  65. xpk/parser/storage.py +10 -0
  66. xpk/parser/storage_test.py +47 -0
  67. xpk/parser/workload.py +14 -29
  68. xpk/parser/workload_test.py +3 -49
  69. xpk/telemetry_uploader.py +29 -0
  70. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  71. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  72. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  73. xpk/utils/console.py +41 -10
  74. xpk/utils/console_test.py +106 -0
  75. xpk/utils/feature_flags.py +10 -1
  76. xpk/utils/file.py +4 -1
  77. xpk/utils/topology.py +4 -0
  78. xpk/utils/user_agent.py +35 -0
  79. xpk/utils/user_agent_test.py +44 -0
  80. xpk/utils/user_input.py +48 -0
  81. xpk/utils/user_input_test.py +92 -0
  82. xpk/utils/validation.py +2 -13
  83. xpk/utils/versions.py +31 -0
  84. xpk-0.16.0.dist-info/METADATA +127 -0
  85. xpk-0.16.0.dist-info/RECORD +168 -0
  86. xpk-0.14.4.dist-info/METADATA +0 -1645
  87. xpk-0.14.4.dist-info/RECORD +0 -139
  88. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  89. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  90. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  91. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
xpk/core/config.py CHANGED
@@ -17,12 +17,32 @@ limitations under the License.
17
17
  import os
18
18
 
19
19
  import ruamel.yaml
20
-
20
+ from abc import ABC, abstractmethod
21
21
  from ..utils import file
22
22
  from ..utils.console import xpk_print
23
+ from setuptools_scm import get_version as setuptools_get_version
24
+ from importlib.metadata import version, PackageNotFoundError
25
+
26
+
27
+ def _get_version() -> str:
28
+ xpk_version_override = os.getenv('XPK_VERSION_OVERRIDE', '')
29
+ if xpk_version_override != '':
30
+ return xpk_version_override
31
+
32
+ try:
33
+ return setuptools_get_version()
34
+ except LookupError:
35
+ pass
36
+
37
+ try:
38
+ return version('xpk')
39
+ except PackageNotFoundError:
40
+ pass
23
41
 
24
- # This is the version for XPK PyPI package
25
- __version__ = 'v0.14.4'
42
+ raise LookupError('unable to determine version number')
43
+
44
+
45
+ __version__ = _get_version()
26
46
  XPK_CURRENT_VERSION = __version__
27
47
  XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
28
48
 
@@ -30,6 +50,8 @@ CONFIGS_KEY = 'configs'
30
50
  CFG_BUCKET_KEY = 'cluster-state-gcs-bucket'
31
51
  CLUSTER_NAME_KEY = 'cluster-name'
32
52
  PROJECT_KEY = 'project-id'
53
+ CLIENT_ID_KEY = 'client-id'
54
+ SEND_TELEMETRY_KEY = 'send-telemetry'
33
55
  ZONE_KEY = 'zone'
34
56
  KJOB_BATCH_IMAGE = 'batch-image'
35
57
  KJOB_BATCH_WORKING_DIRECTORY = 'batch-working-directory'
@@ -39,12 +61,13 @@ KJOB_SHELL_WORKING_DIRECTORY = 'shell-working-directory'
39
61
  CONFIGS_KEY = 'configs'
40
62
  GKE_ENDPOINT_KEY = 'gke-endpoint'
41
63
  DEPENDENCIES_KEY = 'deps-verified-version'
42
- XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
43
64
 
44
65
  DEFAULT_KEYS = [
45
66
  CFG_BUCKET_KEY,
46
67
  CLUSTER_NAME_KEY,
47
68
  PROJECT_KEY,
69
+ CLIENT_ID_KEY,
70
+ SEND_TELEMETRY_KEY,
48
71
  ZONE_KEY,
49
72
  GKE_ENDPOINT_KEY,
50
73
  DEPENDENCIES_KEY,
@@ -60,8 +83,28 @@ VERTEX_TENSORBOARD_FEATURE_FLAG = XPK_CURRENT_VERSION >= '0.4.0'
60
83
  yaml = ruamel.yaml.YAML()
61
84
 
62
85
 
63
- class XpkConfig:
64
- """XpkConfig is a class for setting and getting values from .yaml config file."""
86
+ class Config(ABC):
87
+ """Stores and manipulates XPK configuration."""
88
+
89
+ @abstractmethod
90
+ def set(self, key: str, value: str | None) -> None:
91
+ """Sets the config value"""
92
+ pass
93
+
94
+ @abstractmethod
95
+ def get(self, key: str) -> str | None:
96
+ """Reads the config value"""
97
+ pass
98
+
99
+ @abstractmethod
100
+ def get_all(
101
+ self,
102
+ ) -> dict[str, str] | None:
103
+ pass
104
+
105
+
106
+ class FileSystemConfig(Config):
107
+ """XPK Configuration manipulation class leveraging the file system."""
65
108
 
66
109
  def __init__(self, custom_config_file: str = XPK_CONFIG_FILE) -> None:
67
110
  self._config = custom_config_file
@@ -82,7 +125,7 @@ class XpkConfig:
82
125
  with open(self._config, encoding='utf-8', mode='w') as stream:
83
126
  yaml.dump(config_yaml, stream)
84
127
 
85
- def set(self, key: str, value: str) -> None:
128
+ def set(self, key: str, value: str | None) -> None:
86
129
  if key not in self._allowed_keys:
87
130
  xpk_print(f'Key {key} is not an allowed xpk config key.')
88
131
  return
@@ -114,3 +157,41 @@ class XpkConfig:
114
157
  return None
115
158
  val: dict[str, str] = config_yaml[CONFIGS_KEY]
116
159
  return val
160
+
161
+
162
+ class InMemoryXpkConfig(Config):
163
+ """XPK Configuration manipulation class in memory."""
164
+
165
+ def __init__(self) -> None:
166
+ self._config: dict[str, str] = {}
167
+ self._allowed_keys = DEFAULT_KEYS
168
+
169
+ def set(self, key: str, value: str | None) -> None:
170
+ if key not in self._allowed_keys:
171
+ return
172
+ if value is None:
173
+ self._config.pop(key, None)
174
+ else:
175
+ self._config[key] = value
176
+
177
+ def get(self, key: str) -> str | None:
178
+ if key not in self._allowed_keys:
179
+ return None
180
+ return self._config.get(key)
181
+
182
+ def get_all(
183
+ self,
184
+ ) -> dict[str, str] | None:
185
+ return None if len(self._config) <= 0 else self._config
186
+
187
+
188
+ _xpk_config: Config = InMemoryXpkConfig()
189
+
190
+
191
+ def set_config(config: Config):
192
+ global _xpk_config
193
+ _xpk_config = config
194
+
195
+
196
+ def get_config() -> Config:
197
+ return _xpk_config
xpk/core/config_test.py CHANGED
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from xpk.core.config import XpkConfig, CFG_BUCKET_KEY, CLUSTER_NAME_KEY, PROJECT_KEY, ZONE_KEY
17
+ from xpk.core.config import FileSystemConfig, InMemoryXpkConfig, CFG_BUCKET_KEY, CLUSTER_NAME_KEY, PROJECT_KEY, ZONE_KEY, _get_version
18
+ from unittest.mock import patch
19
+ from importlib.metadata import PackageNotFoundError
18
20
 
19
21
  import os
20
22
  import pytest
@@ -31,15 +33,60 @@ def _():
31
33
  os.remove(config_tmp_path)
32
34
 
33
35
 
34
- def test_config(_):
35
- cfg = XpkConfig(config_tmp_path)
36
+ @patch('os.getenv', return_value='10.0.0')
37
+ def test_get_version_returns_overriden_value_when_it_is_overriden(_):
38
+ assert _get_version() == '10.0.0'
39
+
40
+
41
+ @patch('os.getenv', return_value='')
42
+ @patch('xpk.core.config.setuptools_get_version', return_value='10.0.0')
43
+ def test_get_version_returns_value_from_setuptools_scm_when_there_is_no_override(
44
+ *_,
45
+ ):
46
+ assert _get_version() == '10.0.0'
47
+
48
+
49
+ @patch('os.getenv', return_value='')
50
+ @patch(
51
+ 'xpk.core.config.setuptools_get_version',
52
+ side_effect=LookupError('unable to find git version'),
53
+ )
54
+ @patch('xpk.core.config.version', return_value='10.0.0')
55
+ def test_get_version_returns_value_from_pip_when_there_is_no_setuptools_could_be_resolved(
56
+ *_,
57
+ ):
58
+ assert _get_version() == '10.0.0'
59
+
60
+
61
+ @patch('os.getenv', return_value='')
62
+ @patch(
63
+ 'xpk.core.config.setuptools_get_version',
64
+ side_effect=LookupError('unable to find git version'),
65
+ )
66
+ @patch(
67
+ 'xpk.core.config.version',
68
+ side_effect=PackageNotFoundError('unable to locate package'),
69
+ )
70
+ def test_get_version_returns_none_when_no_version_could_be_resolved(*_):
71
+ with pytest.raises(LookupError):
72
+ _get_version()
73
+
74
+
75
+ @pytest.mark.parametrize(
76
+ argnames='cfg',
77
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
78
+ )
79
+ def test_config(_, cfg):
36
80
  cfg.set('project-id', 'foo')
37
81
  project_id = cfg.get('project-id')
38
82
  assert project_id == 'foo'
39
83
 
40
84
 
41
- def test_config_get_all(_):
42
- cfg = XpkConfig(config_tmp_path)
85
+ @pytest.mark.parametrize(
86
+ argnames='cfg',
87
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
88
+ )
89
+ def test_config_get_all(_, cfg):
43
90
  cfg.set(PROJECT_KEY, 'foo')
44
91
  cfg.set(CLUSTER_NAME_KEY, 'bar')
45
92
  cfg.set(ZONE_KEY, 'europe-west1-a')
@@ -52,20 +99,29 @@ def test_config_get_all(_):
52
99
  assert cfg_all[CFG_BUCKET_KEY] == 'cfg-bucket'
53
100
 
54
101
 
55
- def test_config_get_empty(_):
56
- cfg = XpkConfig(config_tmp_path)
102
+ @pytest.mark.parametrize(
103
+ argnames='cfg',
104
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
105
+ )
106
+ def test_config_get_empty(_, cfg):
57
107
  val = cfg.get(PROJECT_KEY)
58
108
  assert val is None
59
109
 
60
110
 
61
- def test_config_get_all_empty(_):
62
- cfg = XpkConfig(config_tmp_path)
111
+ @pytest.mark.parametrize(
112
+ argnames='cfg',
113
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
114
+ )
115
+ def test_config_get_all_empty(_, cfg):
63
116
  val = cfg.get_all()
64
117
  assert not val
65
118
 
66
119
 
67
- def test_config_set_incorrect(_):
68
- cfg = XpkConfig(config_tmp_path)
120
+ @pytest.mark.parametrize(
121
+ argnames='cfg',
122
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
123
+ )
124
+ def test_config_set_incorrect(cfg, _):
69
125
  cfg.set('foo', 'bar')
70
126
  cfg_all = cfg.get_all()
71
127
  assert not cfg_all
@@ -182,7 +182,9 @@ def get_user_workload_container(args, system: SystemCharacteristics):
182
182
  debugging_dashboard_id: id of the GKE dashboard
183
183
  """
184
184
 
185
- setup_docker_image_code, docker_image = setup_docker_image(args)
185
+ setup_docker_image_code, docker_image = setup_docker_image(
186
+ args, system.docker_platform
187
+ )
186
188
  if setup_docker_image_code != 0:
187
189
  xpk_exit(setup_docker_image_code)
188
190
 
xpk/core/docker_image.py CHANGED
@@ -19,6 +19,7 @@ import os
19
19
  import random
20
20
  import string
21
21
 
22
+ from .system_characteristics import DockerPlatform
22
23
  from ..utils.console import xpk_exit, xpk_print
23
24
  from ..utils.file import write_tmp_file
24
25
  from ..utils.execution_context import is_dry_run
@@ -26,7 +27,6 @@ from .commands import run_command_with_updates
26
27
 
27
28
  DEFAULT_DOCKER_IMAGE = 'python:3.10'
28
29
  DEFAULT_SCRIPT_DIR = os.getcwd()
29
- PLATFORM = 'linux/amd64'
30
30
 
31
31
 
32
32
  def validate_docker_image(docker_image, args) -> int:
@@ -63,7 +63,9 @@ def validate_docker_image(docker_image, args) -> int:
63
63
  return 0
64
64
 
65
65
 
66
- def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
66
+ def build_docker_image_from_base_image(
67
+ args, docker_platform: DockerPlatform, verbose=True
68
+ ) -> tuple[int, str]:
67
69
  """Adds script dir to the base docker image and uploads the image.
68
70
 
69
71
  Args:
@@ -97,8 +99,8 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
97
99
  )
98
100
  tmp = write_tmp_file(docker_file)
99
101
  docker_build_command = (
100
- f'docker buildx build --platform={PLATFORM} -f {str(tmp)} -t'
101
- f' {docker_name} {args.script_dir}'
102
+ f'docker buildx build --platform={docker_platform.value} -f'
103
+ f' {str(tmp)} -t {docker_name} {args.script_dir}'
102
104
  )
103
105
  xpk_print(f'Building {args.script_dir} into docker image.')
104
106
  return_code = run_command_with_updates(
@@ -158,7 +160,9 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
158
160
  return return_code, cloud_docker_image
159
161
 
160
162
 
161
- def setup_docker_image(args) -> tuple[int, str]:
163
+ def setup_docker_image(
164
+ args, docker_platform: DockerPlatform
165
+ ) -> tuple[int, str]:
162
166
  """Does steps to verify docker args, check image, and build image (if asked).
163
167
 
164
168
  Args:
@@ -177,7 +181,7 @@ def setup_docker_image(args) -> tuple[int, str]:
177
181
  if validate_docker_image_code != 0:
178
182
  xpk_exit(validate_docker_image_code)
179
183
  build_docker_image_code, docker_image = build_docker_image_from_base_image(
180
- args
184
+ args, docker_platform
181
185
  )
182
186
  if build_docker_image_code != 0:
183
187
  xpk_exit(build_docker_image_code)
@@ -16,7 +16,6 @@ limitations under the License.
16
16
 
17
17
  import os
18
18
  import re
19
- from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
20
19
  from .cluster import setup_k8s_env
21
20
  from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, PARALLELSTORE_TYPE, GCE_PD_TYPE, LUSTRE_TYPE, Storage, get_storages_to_mount
22
21
  from .system_characteristics import AcceleratorType, SystemCharacteristics
@@ -109,14 +108,6 @@ def get_gpu_env(args, system) -> str:
109
108
  value: "{args.command}"
110
109
  {custom_envs}"""
111
110
 
112
- gpu_direct_name = 'fastrak'
113
- if args.device_type == H100_DEVICE_TYPE:
114
- gpu_direct_name = 'tcpx'
115
- elif args.device_type == H100_MEGA_DEVICE_TYPE:
116
- gpu_direct_name = 'tcpxo'
117
- elif args.device_type == H200_DEVICE_TYPE:
118
- gpu_direct_name = 'rdma'
119
-
120
111
  gpu_env_dic = {
121
112
  'JAX_COORDINATOR_PORT': '6002',
122
113
  'JAX_COORDINATOR_ADDRESS': (
@@ -129,7 +120,7 @@ def get_gpu_env(args, system) -> str:
129
120
  return gpu_env_yaml.format(
130
121
  args=args,
131
122
  chips_per_vm=system.chips_per_vm,
132
- gpu_direct_name=gpu_direct_name,
123
+ gpu_direct_name=system.gpu_config.gpu_direct_name,
133
124
  custom_envs=format_env_dict(args.env, system),
134
125
  )
135
126
 
@@ -19,6 +19,7 @@ import sys
19
19
  from dataclasses import dataclass
20
20
 
21
21
  from ..utils.console import xpk_print, xpk_exit
22
+ from ..utils.versions import ReleaseChannel
22
23
  from .commands import run_command_for_value
23
24
  from functools import lru_cache
24
25
 
@@ -117,15 +118,18 @@ def get_cluster_location(project: str, name: str, zone: str) -> str:
117
118
  class GkeServerConfig:
118
119
  """Stores the valid gke versions based on gcloud recommendations."""
119
120
 
120
- default_rapid_gke_version: str
121
+ default_gke_version: str
121
122
  valid_versions: set[str]
122
123
 
123
124
 
124
- def get_gke_server_config(args) -> tuple[int, GkeServerConfig | None]:
125
+ def get_gke_server_config(
126
+ args, release_channel: ReleaseChannel
127
+ ) -> tuple[int, GkeServerConfig | None]:
125
128
  """Determine the GKE versions supported by gcloud currently.
126
129
 
127
130
  Args:
128
131
  args: user provided arguments for running the command.
132
+ release_channel: the release channel to use.
129
133
 
130
134
  Returns:
131
135
  Tuple of
@@ -136,22 +140,24 @@ def get_gke_server_config(args) -> tuple[int, GkeServerConfig | None]:
136
140
  'gcloud container get-server-config'
137
141
  f' --project={args.project} --region={zone_to_region(args.zone)}'
138
142
  )
139
- default_rapid_gke_version_cmd = (
143
+ default_gke_version_cmd = (
140
144
  base_command
141
- + ' --flatten="channels" --filter="channels.channel=RAPID"'
145
+ + ' --flatten="channels"'
146
+ f' --filter="channels.channel={release_channel.value}"'
142
147
  ' --format="value(channels.defaultVersion)"'
143
148
  )
144
149
  valid_versions_cmd = (
145
150
  base_command
146
- + ' --flatten="channels" --filter="channels.channel=RAPID"'
151
+ + ' --flatten="channels"'
152
+ f' --filter="channels.channel={release_channel.value}"'
147
153
  ' --format="value(channels.validVersions)"'
148
154
  )
149
155
  base_command_description = 'Determine server supported GKE versions for '
150
156
 
151
157
  server_config_commands_and_descriptions = [
152
158
  (
153
- default_rapid_gke_version_cmd,
154
- base_command_description + 'default rapid gke version',
159
+ default_gke_version_cmd,
160
+ base_command_description + 'default gke version',
155
161
  ),
156
162
  (
157
163
  valid_versions_cmd,
@@ -172,8 +178,8 @@ def get_gke_server_config(args) -> tuple[int, GkeServerConfig | None]:
172
178
  command_outputs.append(cmd_output)
173
179
 
174
180
  return 0, GkeServerConfig(
175
- default_rapid_gke_version=command_outputs[0].strip(),
176
- valid_versions=set(command_outputs[1].split(';')),
181
+ default_gke_version=command_outputs[0].strip(),
182
+ valid_versions=set([s.strip() for s in command_outputs[1].split(';')]),
177
183
  )
178
184
 
179
185
 
@@ -196,7 +202,7 @@ def get_gke_control_plane_version(
196
202
  if args.gke_version is not None:
197
203
  master_gke_version = args.gke_version
198
204
  else:
199
- master_gke_version = gke_server_config.default_rapid_gke_version
205
+ master_gke_version = gke_server_config.default_gke_version
200
206
 
201
207
  is_valid_version = master_gke_version in gke_server_config.valid_versions
202
208
 
@@ -204,7 +210,7 @@ def get_gke_control_plane_version(
204
210
  xpk_print(
205
211
  f'Planned GKE Version: {master_gke_version}\n Valid Versions:'
206
212
  f'\n{gke_server_config.valid_versions}\nRecommended / Default GKE'
207
- f' Version: {gke_server_config.default_rapid_gke_version}'
213
+ f' Version: {gke_server_config.default_gke_version}'
208
214
  )
209
215
  xpk_print(
210
216
  f'Error: Planned GKE Version {master_gke_version} is not valid.'
@@ -213,7 +219,7 @@ def get_gke_control_plane_version(
213
219
  xpk_print(
214
220
  'Please select a gke version from the above list using --gke-version=x'
215
221
  ' argument or rely on the default gke version:'
216
- f' {gke_server_config.default_rapid_gke_version}'
222
+ f' {gke_server_config.default_gke_version}'
217
223
  )
218
224
  return 1, None
219
225
 
@@ -15,7 +15,20 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import pytest
18
- from .gcloud_context import get_cluster_location, zone_to_region
18
+ from unittest.mock import MagicMock
19
+ from .gcloud_context import (
20
+ get_cluster_location,
21
+ get_gke_control_plane_version,
22
+ get_gke_server_config,
23
+ GkeServerConfig,
24
+ zone_to_region,
25
+ )
26
+ from ..utils.versions import ReleaseChannel
27
+
28
+
29
+ @pytest.fixture(autouse=True)
30
+ def xpk_print(mocker):
31
+ return mocker.patch("xpk.core.gcloud_context.xpk_print")
19
32
 
20
33
 
21
34
  def test_zone_to_region_raises_when_zone_is_invalid():
@@ -94,3 +107,100 @@ def test_get_cluster_location_invokes_command_for_different_input_args(mocker):
94
107
  get_cluster_location(project="project6", name="name6", zone="us-central1-a")
95
108
 
96
109
  assert mock.call_count == 2
110
+
111
+
112
+ def test_get_gke_server_config_success(mocker):
113
+ mock_run_command = mocker.patch(
114
+ "xpk.core.gcloud_context.run_command_for_value",
115
+ side_effect=[
116
+ (0, "1.2.3"),
117
+ (0, "1.2.3;1.2.4;1.3.0"),
118
+ ],
119
+ )
120
+ args = mocker.Mock(project="test-project", zone="us-central1")
121
+
122
+ return_code, config = get_gke_server_config(args, ReleaseChannel.STABLE)
123
+
124
+ assert return_code == 0
125
+ assert isinstance(config, GkeServerConfig)
126
+ assert config.default_gke_version == "1.2.3"
127
+ assert config.valid_versions == {"1.2.3", "1.2.4", "1.3.0"}
128
+ assert mock_run_command.call_count == 2
129
+
130
+
131
+ def test_get_gke_server_config_fails_on_default_version_command(mocker):
132
+ mocker.patch(
133
+ "xpk.core.gcloud_context.run_command_for_value",
134
+ return_value=(1, "error"),
135
+ )
136
+ args = mocker.Mock(project="test-project", zone="us-central1")
137
+
138
+ return_code, config = get_gke_server_config(args, ReleaseChannel.STABLE)
139
+
140
+ assert return_code == 1
141
+ assert config is None
142
+
143
+
144
+ def test_get_gke_server_config_fails_on_valid_versions_command(mocker):
145
+ mocker.patch(
146
+ "xpk.core.gcloud_context.run_command_for_value",
147
+ side_effect=[(0, "1.2.3"), (1, "error")],
148
+ )
149
+ args = mocker.Mock(project="test-project", zone="us-central1")
150
+
151
+ return_code, config = get_gke_server_config(args, ReleaseChannel.STABLE)
152
+
153
+ assert return_code == 1
154
+ assert config is None
155
+
156
+
157
+ def test_get_gke_control_plane_version_uses_default_when_not_specified(mocker):
158
+ args = mocker.Mock(gke_version=None)
159
+ gke_server_config = GkeServerConfig(
160
+ default_gke_version="1.2.3", valid_versions={"1.2.3", "1.2.4"}
161
+ )
162
+
163
+ return_code, version = get_gke_control_plane_version(args, gke_server_config)
164
+
165
+ assert return_code == 0
166
+ assert version == "1.2.3"
167
+
168
+
169
+ def test_get_gke_control_plane_version_uses_user_version_when_valid(mocker):
170
+ args = mocker.Mock(gke_version="1.2.4")
171
+ gke_server_config = GkeServerConfig(
172
+ default_gke_version="1.2.3", valid_versions={"1.2.3", "1.2.4"}
173
+ )
174
+
175
+ return_code, version = get_gke_control_plane_version(args, gke_server_config)
176
+
177
+ assert return_code == 0
178
+ assert version == "1.2.4"
179
+
180
+
181
+ def test_get_gke_control_plane_version_fails_for_invalid_user_version(
182
+ mocker, xpk_print: MagicMock
183
+ ):
184
+ args = mocker.Mock(gke_version="1.2.5")
185
+ gke_server_config = GkeServerConfig(
186
+ default_gke_version="1.2.3", valid_versions={"1.2.3", "1.2.4"}
187
+ )
188
+
189
+ return_code, version = get_gke_control_plane_version(args, gke_server_config)
190
+
191
+ assert return_code == 1
192
+ assert version is None
193
+ assert "Planned GKE Version: 1.2.5" in xpk_print.mock_calls[0].args[0]
194
+ assert (
195
+ "Recommended / Default GKE Version: 1.2.3"
196
+ in xpk_print.mock_calls[0].args[0]
197
+ )
198
+ assert (
199
+ "Error: Planned GKE Version 1.2.5 is not valid."
200
+ in xpk_print.mock_calls[1].args[0]
201
+ )
202
+ assert (
203
+ "Please select a gke version from the above list using --gke-version=x"
204
+ " argument or rely on the default gke version: 1.2.3"
205
+ in xpk_print.mock_calls[2].args[0]
206
+ )
xpk/core/kjob.py CHANGED
@@ -25,7 +25,6 @@ from kubernetes.client.rest import ApiException
25
25
  from ..utils import templates
26
26
  from ..utils.execution_context import is_dry_run
27
27
  from ..utils.console import xpk_exit, xpk_print
28
- from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
29
28
  from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
30
29
  from .commands import (
31
30
  run_command_for_value,
@@ -38,7 +37,7 @@ from .config import (
38
37
  KJOB_SHELL_IMAGE,
39
38
  KJOB_SHELL_INTERACTIVE_COMMAND,
40
39
  KJOB_SHELL_WORKING_DIRECTORY,
41
- XpkConfig,
40
+ get_config,
42
41
  )
43
42
  from .network import get_cluster_subnetworks
44
43
  from .system_characteristics import AcceleratorType, SystemCharacteristics
@@ -52,7 +51,6 @@ from .storage import (
52
51
  )
53
52
  from .workload_decorators import (
54
53
  rdma_decorator,
55
- tcpx_decorator,
56
54
  tcpxo_decorator,
57
55
  )
58
56
  from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
@@ -234,8 +232,7 @@ def get_pod_template_interactive_command() -> str:
234
232
  Returns:
235
233
  str - PodTemplate's interactive command
236
234
  """
237
- config = XpkConfig()
238
- pod_command = config.get(KJOB_SHELL_INTERACTIVE_COMMAND)
235
+ pod_command = get_config().get(KJOB_SHELL_INTERACTIVE_COMMAND)
239
236
  if pod_command is None or len(pod_command) == 0:
240
237
  pod_command = PodTemplateDefaults.INTERACTIVE_COMMAND.value
241
238
 
@@ -261,14 +258,17 @@ def create_app_profile_instance(volume_bundles: list[str]) -> int:
261
258
  )
262
259
 
263
260
 
264
- def decorate_job_template_with_gpu(yml_string: str, gpu_type: str) -> str:
261
+ def decorate_job_template_with_gpu(
262
+ yml_string: str, system: SystemCharacteristics
263
+ ) -> str:
265
264
  job_spec = yaml.safe_load(yml_string)["template"]
266
- if gpu_type == H100_DEVICE_TYPE:
267
- job_spec = tcpx_decorator.decorate_kjob_template(job_spec)
268
- if gpu_type == H100_MEGA_DEVICE_TYPE:
269
- job_spec = tcpxo_decorator.decorate_kjob_template(job_spec)
270
- if gpu_type == H200_DEVICE_TYPE:
271
- job_spec = rdma_decorator.decorate_kjob_template(job_spec)
265
+ kjob_decorator = (
266
+ system.gpu_config.kjob_decorator_fn
267
+ if system.gpu_config and system.gpu_config.kjob_decorator_fn
268
+ else None
269
+ )
270
+ if kjob_decorator:
271
+ job_spec = kjob_decorator(job_spec)
272
272
  job_template_dict = yaml.safe_load(yml_string)
273
273
  job_template_dict["template"] = job_spec
274
274
  yaml_result: str = yaml.dump(job_template_dict, sort_keys=False)
@@ -287,11 +287,10 @@ def create_job_template_instance(
287
287
  Returns:
288
288
  exit_code > 0 if creating JobTemplate fails, 0 otherwise
289
289
  """
290
- config = XpkConfig()
291
- job_image = config.get(KJOB_BATCH_IMAGE)
290
+ job_image = get_config().get(KJOB_BATCH_IMAGE)
292
291
  if job_image is None or len(job_image) == 0:
293
292
  job_image = JobTemplateDefaults.IMAGE.value
294
- working_directory = config.get(KJOB_BATCH_WORKING_DIRECTORY)
293
+ working_directory = get_config().get(KJOB_BATCH_WORKING_DIRECTORY)
295
294
  if working_directory is None or len(working_directory) == 0:
296
295
  working_directory = JobTemplateDefaults.WORKING_DIRECTORY.value
297
296
  resources = (
@@ -318,7 +317,7 @@ def create_job_template_instance(
318
317
  service_account=service_account,
319
318
  )
320
319
  if system is not None and system.accelerator_type == AcceleratorType.GPU:
321
- yml_string = decorate_job_template_with_gpu(yml_string, system.device_type)
320
+ yml_string = decorate_job_template_with_gpu(yml_string, system)
322
321
 
323
322
  return run_kubectl_apply(
324
323
  yml_string,
@@ -332,11 +331,10 @@ def create_pod_template_instance(service_account: str) -> int:
332
331
  Returns:
333
332
  exit_code > 0 if creating PodTemplate fails, 0 otherwise
334
333
  """
335
- config = XpkConfig()
336
- pod_image = config.get(KJOB_SHELL_IMAGE)
334
+ pod_image = get_config().get(KJOB_SHELL_IMAGE)
337
335
  if pod_image is None or len(pod_image) == 0:
338
336
  pod_image = PodTemplateDefaults.IMAGE.value
339
- working_directory = config.get(KJOB_SHELL_WORKING_DIRECTORY)
337
+ working_directory = get_config().get(KJOB_SHELL_WORKING_DIRECTORY)
340
338
  if working_directory is None or len(working_directory) == 0:
341
339
  working_directory = PodTemplateDefaults.WORKING_DIRECTORY.value
342
340