dstack 0.19.30rc1__py3-none-any.whl → 0.19.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (47) hide show
  1. dstack/_internal/cli/commands/__init__.py +8 -0
  2. dstack/_internal/cli/commands/project.py +27 -20
  3. dstack/_internal/cli/commands/server.py +5 -0
  4. dstack/_internal/cli/services/configurators/fleet.py +20 -6
  5. dstack/_internal/cli/utils/gpu.py +2 -2
  6. dstack/_internal/core/backends/aws/compute.py +13 -5
  7. dstack/_internal/core/backends/aws/resources.py +11 -6
  8. dstack/_internal/core/backends/azure/compute.py +17 -6
  9. dstack/_internal/core/backends/base/compute.py +57 -9
  10. dstack/_internal/core/backends/base/offers.py +1 -0
  11. dstack/_internal/core/backends/cloudrift/compute.py +2 -0
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
  15. dstack/_internal/core/backends/features.py +5 -0
  16. dstack/_internal/core/backends/gcp/compute.py +87 -38
  17. dstack/_internal/core/backends/gcp/configurator.py +1 -1
  18. dstack/_internal/core/backends/gcp/models.py +14 -1
  19. dstack/_internal/core/backends/gcp/resources.py +35 -12
  20. dstack/_internal/core/backends/hotaisle/compute.py +2 -0
  21. dstack/_internal/core/backends/kubernetes/compute.py +466 -213
  22. dstack/_internal/core/backends/kubernetes/models.py +13 -16
  23. dstack/_internal/core/backends/kubernetes/utils.py +145 -8
  24. dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
  25. dstack/_internal/core/backends/local/compute.py +2 -0
  26. dstack/_internal/core/backends/nebius/compute.py +2 -0
  27. dstack/_internal/core/backends/oci/compute.py +7 -1
  28. dstack/_internal/core/backends/oci/resources.py +8 -3
  29. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  30. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  31. dstack/_internal/core/backends/vultr/compute.py +2 -0
  32. dstack/_internal/core/consts.py +2 -0
  33. dstack/_internal/core/services/repos.py +101 -11
  34. dstack/_internal/server/background/tasks/common.py +2 -0
  35. dstack/_internal/server/background/tasks/process_instances.py +2 -2
  36. dstack/_internal/server/background/tasks/process_running_jobs.py +1 -1
  37. dstack/_internal/server/background/tasks/process_submitted_jobs.py +51 -41
  38. dstack/_internal/server/services/offers.py +7 -1
  39. dstack/_internal/server/testing/common.py +2 -0
  40. dstack/_internal/server/utils/provisioning.py +3 -10
  41. dstack/_internal/utils/ssh.py +22 -2
  42. dstack/version.py +2 -2
  43. {dstack-0.19.30rc1.dist-info → dstack-0.19.31.dist-info}/METADATA +17 -13
  44. {dstack-0.19.30rc1.dist-info → dstack-0.19.31.dist-info}/RECORD +47 -47
  45. {dstack-0.19.30rc1.dist-info → dstack-0.19.31.dist-info}/WHEEL +0 -0
  46. {dstack-0.19.30rc1.dist-info → dstack-0.19.31.dist-info}/entry_points.txt +0 -0
  47. {dstack-0.19.30rc1.dist-info → dstack-0.19.31.dist-info}/licenses/LICENSE.md +0 -0
@@ -5,12 +5,14 @@ from pydantic import Field, root_validator
5
5
  from dstack._internal.core.backends.base.models import fill_data
6
6
  from dstack._internal.core.models.common import CoreModel
7
7
 
8
+ DEFAULT_NAMESPACE = "default"
8
9
 
9
- class KubernetesNetworkingConfig(CoreModel):
10
- ssh_host: Annotated[
11
- Optional[str], Field(description="The external IP address of any node")
10
+
11
+ class KubernetesProxyJumpConfig(CoreModel):
12
+ hostname: Annotated[
13
+ Optional[str], Field(description="The external IP address or hostname of any node")
12
14
  ] = None
13
- ssh_port: Annotated[
15
+ port: Annotated[
14
16
  Optional[int], Field(description="Any port accessible outside of the cluster")
15
17
  ] = None
16
18
 
@@ -22,16 +24,15 @@ class KubeconfigConfig(CoreModel):
22
24
 
23
25
  class KubernetesBackendConfig(CoreModel):
24
26
  type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes"
25
- networking: Annotated[
26
- Optional[KubernetesNetworkingConfig], Field(description="The networking configuration")
27
+ proxy_jump: Annotated[
28
+ Optional[KubernetesProxyJumpConfig], Field(description="The SSH proxy jump configuration")
27
29
  ] = None
30
+ namespace: Annotated[
31
+ str, Field(description="The namespace for resources managed by `dstack`")
32
+ ] = DEFAULT_NAMESPACE
28
33
 
29
34
 
30
- class KubernetesBackendConfigWithCreds(CoreModel):
31
- type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes"
32
- networking: Annotated[
33
- Optional[KubernetesNetworkingConfig], Field(description="The networking configuration")
34
- ] = None
35
+ class KubernetesBackendConfigWithCreds(KubernetesBackendConfig):
35
36
  kubeconfig: Annotated[KubeconfigConfig, Field(description="The kubeconfig configuration")]
36
37
 
37
38
 
@@ -53,11 +54,7 @@ class KubeconfigFileConfig(CoreModel):
53
54
  return fill_data(values)
54
55
 
55
56
 
56
- class KubernetesBackendFileConfigWithCreds(CoreModel):
57
- type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes"
58
- networking: Annotated[
59
- Optional[KubernetesNetworkingConfig], Field(description="The networking configuration")
60
- ] = None
57
+ class KubernetesBackendFileConfigWithCreds(KubernetesBackendConfig):
61
58
  kubeconfig: Annotated[KubeconfigFileConfig, Field(description="The kubeconfig configuration")]
62
59
 
63
60
 
@@ -1,20 +1,157 @@
1
- from typing import Dict, List, Optional
1
+ import ast
2
+ from typing import Any, Callable, List, Literal, Optional, TypeVar, Union, get_origin, overload
2
3
 
3
- import kubernetes
4
4
  import yaml
5
+ from kubernetes import client as kubernetes_client
6
+ from kubernetes import config as kubernetes_config
7
+ from typing_extensions import ParamSpec
5
8
 
9
+ T = TypeVar("T")
10
+ P = ParamSpec("P")
6
11
 
7
- def get_api_from_config_data(kubeconfig_data: str) -> kubernetes.client.CoreV1Api:
12
+
13
+ def get_api_from_config_data(kubeconfig_data: str) -> kubernetes_client.CoreV1Api:
8
14
  config_dict = yaml.load(kubeconfig_data, yaml.FullLoader)
9
15
  return get_api_from_config_dict(config_dict)
10
16
 
11
17
 
12
- def get_api_from_config_dict(kubeconfig: Dict) -> kubernetes.client.CoreV1Api:
13
- api_client = kubernetes.config.new_client_from_config_dict(config_dict=kubeconfig)
14
- return kubernetes.client.CoreV1Api(api_client=api_client)
18
+ def get_api_from_config_dict(kubeconfig: dict) -> kubernetes_client.CoreV1Api:
19
+ api_client = kubernetes_config.new_client_from_config_dict(config_dict=kubeconfig)
20
+ return kubernetes_client.CoreV1Api(api_client=api_client)
21
+
22
+
23
+ @overload
24
+ def call_api_method(
25
+ method: Callable[P, Any],
26
+ type_: type[T],
27
+ expected: None = None,
28
+ *args: P.args,
29
+ **kwargs: P.kwargs,
30
+ ) -> T: ...
31
+
32
+
33
+ @overload
34
+ def call_api_method(
35
+ method: Callable[P, Any],
36
+ type_: type[T],
37
+ expected: Union[int, tuple[int, ...], list[int]],
38
+ *args: P.args,
39
+ **kwargs: P.kwargs,
40
+ ) -> Optional[T]: ...
41
+
42
+
43
+ def call_api_method(
44
+ method: Callable[P, Any],
45
+ type_: type[T],
46
+ expected: Optional[Union[int, tuple[int, ...], list[int]]] = None,
47
+ *args: P.args,
48
+ **kwargs: P.kwargs,
49
+ ) -> Optional[T]:
50
+ """
51
+ Returns the result of the API method call, optionally ignoring specified HTTP status codes.
52
+
53
+ Args:
54
+ method: the `CoreV1Api` bound method.
55
+ type_: The expected type of the return value, used for runtime type checking and
56
+ as a type hint for a static type checker (as kubernetes package is not type-annotated).
57
+ NB: For composite types, only "origin" type is checked, e.g., list, not list[Node]
58
+ expected: Expected error statuses, e.g., 404.
59
+ args: positional arguments of the method.
60
+ kwargs: keyword arguments of the method.
61
+ Returns:
62
+ The return value or `None` in case of the expected error.
63
+ """
64
+ if isinstance(expected, int):
65
+ expected = (expected,)
66
+ result: T
67
+ try:
68
+ result = method(*args, **kwargs)
69
+ except kubernetes_client.ApiException as e:
70
+ if expected is None or e.status not in expected:
71
+ raise
72
+ return None
73
+ if not isinstance(result, get_origin(type_) or type_):
74
+ raise TypeError(
75
+ f"{method.__name__} returned {type(result).__name__}, expected {type_.__name__}"
76
+ )
77
+ return result
78
+
79
+
80
+ @overload
81
+ def get_value(
82
+ obj: object, path: str, type_: type[T], *, required: Literal[False] = False
83
+ ) -> Optional[T]: ...
84
+
85
+
86
+ @overload
87
+ def get_value(obj: object, path: str, type_: type[T], *, required: Literal[True]) -> T: ...
88
+
89
+
90
+ def get_value(obj: object, path: str, type_: type[T], *, required: bool = False) -> Optional[T]:
91
+ """
92
+ Returns the value at a given path.
93
+ Supports object attributes, sequence indices, and mapping keys.
94
+
95
+ Args:
96
+ obj: The object to traverse.
97
+ path: The path to the value, regular Python syntax. The leading dot is optional, all the
98
+ following are correct: `.attr`, `attr`, `.[0]`, `[0]`, `.['key']`, `['key']`.
99
+ type_: The expected type of the value, used for runtime type checking and as a type hint
100
+ for a static type checker (as kubernetes package is not type-annotated).
101
+ NB: For composite types, only "origin" type is checked, e.g., list, not list[Node]
102
+ required: If `True`, the value must exist and must not be `None`. If `False` (safe
103
+ navigation mode), the may not exist and may be `None`.
104
+
105
+ Returns:
106
+ The requested value or `None` in case of failed traverse when required=False.
107
+ """
108
+ _path = path.removeprefix(".")
109
+ if _path.startswith("["):
110
+ src = f"obj{_path}"
111
+ else:
112
+ src = f"obj.{_path}"
113
+ module = ast.parse(src)
114
+ assert len(module.body) == 1, ast.dump(module, indent=4)
115
+ root_expr = module.body[0]
116
+ assert isinstance(root_expr, ast.Expr), ast.dump(module, indent=4)
117
+ varname: Optional[str] = None
118
+ expr = root_expr.value
119
+ while True:
120
+ if isinstance(expr, ast.Name):
121
+ varname = expr.id
122
+ break
123
+ if __debug__:
124
+ if isinstance(expr, ast.Subscript):
125
+ if isinstance(expr.slice, ast.UnaryOp):
126
+ # .items[-1]
127
+ assert isinstance(expr.slice.op, ast.USub), ast.dump(expr, indent=4)
128
+ assert isinstance(expr.slice.operand, ast.Constant), ast.dump(expr, indent=4)
129
+ assert isinstance(expr.slice.operand.value, int), ast.dump(expr, indent=4)
130
+ else:
131
+ # .items[0], .labels["name"]
132
+ assert isinstance(expr.slice, ast.Constant), ast.dump(expr, indent=4)
133
+ else:
134
+ assert isinstance(expr, ast.Attribute), ast.dump(expr, indent=4)
135
+ else:
136
+ assert isinstance(expr, (ast.Attribute, ast.Subscript))
137
+ expr = expr.value
138
+ assert varname is not None, ast.dump(module)
139
+ try:
140
+ value = eval(src, {"__builtins__": {}}, {"obj": obj})
141
+ except (AttributeError, KeyError, IndexError, TypeError) as e:
142
+ if required:
143
+ raise type(e)(f"Failed to traverse {path}: {e}") from e
144
+ return None
145
+ if value is None:
146
+ if required:
147
+ raise TypeError(f"Required {path} is None")
148
+ return value
149
+ if not isinstance(value, get_origin(type_) or type_):
150
+ raise TypeError(f"{path} value is {type(value).__name__}, expected {type_.__name__}")
151
+ return value
15
152
 
16
153
 
17
- def get_cluster_public_ip(api_client: kubernetes.client.CoreV1Api) -> Optional[str]:
154
+ def get_cluster_public_ip(api_client: kubernetes_client.CoreV1Api) -> Optional[str]:
18
155
  """
19
156
  Returns public IP of any cluster node.
20
157
  """
@@ -24,7 +161,7 @@ def get_cluster_public_ip(api_client: kubernetes.client.CoreV1Api) -> Optional[s
24
161
  return public_ips[0]
25
162
 
26
163
 
27
- def get_cluster_public_ips(api_client: kubernetes.client.CoreV1Api) -> List[str]:
164
+ def get_cluster_public_ips(api_client: kubernetes_client.CoreV1Api) -> List[str]:
28
165
  """
29
166
  Returns public IPs of all cluster nodes.
30
167
  """
@@ -9,6 +9,7 @@ from dstack._internal.core.backends.base.compute import (
9
9
  Compute,
10
10
  ComputeWithAllOffersCached,
11
11
  ComputeWithCreateInstanceSupport,
12
+ ComputeWithPrivilegedSupport,
12
13
  generate_unique_instance_name,
13
14
  get_shim_commands,
14
15
  )
@@ -31,6 +32,7 @@ MAX_INSTANCE_NAME_LEN = 60
31
32
  class LambdaCompute(
32
33
  ComputeWithAllOffersCached,
33
34
  ComputeWithCreateInstanceSupport,
35
+ ComputeWithPrivilegedSupport,
34
36
  Compute,
35
37
  ):
36
38
  def __init__(self, config: LambdaConfig):
@@ -3,6 +3,7 @@ from typing import List, Optional
3
3
  from dstack._internal.core.backends.base.compute import (
4
4
  Compute,
5
5
  ComputeWithCreateInstanceSupport,
6
+ ComputeWithPrivilegedSupport,
6
7
  ComputeWithVolumeSupport,
7
8
  )
8
9
  from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
@@ -25,6 +26,7 @@ logger = get_logger(__name__)
25
26
 
26
27
  class LocalCompute(
27
28
  ComputeWithCreateInstanceSupport,
29
+ ComputeWithPrivilegedSupport,
28
30
  ComputeWithVolumeSupport,
29
31
  Compute,
30
32
  ):
@@ -16,6 +16,7 @@ from dstack._internal.core.backends.base.compute import (
16
16
  ComputeWithCreateInstanceSupport,
17
17
  ComputeWithMultinodeSupport,
18
18
  ComputeWithPlacementGroupSupport,
19
+ ComputeWithPrivilegedSupport,
19
20
  generate_unique_instance_name,
20
21
  get_user_data,
21
22
  )
@@ -79,6 +80,7 @@ SUPPORTED_PLATFORMS = [
79
80
  class NebiusCompute(
80
81
  ComputeWithAllOffersCached,
81
82
  ComputeWithCreateInstanceSupport,
83
+ ComputeWithPrivilegedSupport,
82
84
  ComputeWithMultinodeSupport,
83
85
  ComputeWithPlacementGroupSupport,
84
86
  Compute,
@@ -9,6 +9,7 @@ from dstack._internal.core.backends.base.compute import (
9
9
  ComputeWithAllOffersCached,
10
10
  ComputeWithCreateInstanceSupport,
11
11
  ComputeWithMultinodeSupport,
12
+ ComputeWithPrivilegedSupport,
12
13
  generate_unique_instance_name,
13
14
  get_user_data,
14
15
  )
@@ -50,6 +51,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("50GB"), max=Memory.pars
50
51
  class OCICompute(
51
52
  ComputeWithAllOffersCached,
52
53
  ComputeWithCreateInstanceSupport,
54
+ ComputeWithPrivilegedSupport,
53
55
  ComputeWithMultinodeSupport,
54
56
  Compute,
55
57
  ):
@@ -118,7 +120,11 @@ class OCICompute(
118
120
  availability_domain = instance_offer.availability_zones[0]
119
121
 
120
122
  listing, package = resources.get_marketplace_listing_and_package(
121
- cuda=len(instance_offer.instance.resources.gpus) > 0,
123
+ gpu_name=(
124
+ instance_offer.instance.resources.gpus[0].name
125
+ if len(instance_offer.instance.resources.gpus) > 0
126
+ else None
127
+ ),
122
128
  client=region.marketplace_client,
123
129
  )
124
130
  resources.accept_marketplace_listing_agreements(
@@ -23,7 +23,9 @@ import oci
23
23
  from oci.object_storage.models import CreatePreauthenticatedRequestDetails
24
24
 
25
25
  from dstack import version
26
+ from dstack._internal.core.backends.base.compute import requires_nvidia_proprietary_kernel_modules
26
27
  from dstack._internal.core.backends.oci.region import OCIRegionClient
28
+ from dstack._internal.core.consts import DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES
27
29
  from dstack._internal.core.errors import BackendError
28
30
  from dstack._internal.core.models.instances import InstanceOffer
29
31
  from dstack._internal.utils.common import batched
@@ -352,11 +354,14 @@ def terminate_instance_if_exists(client: oci.core.ComputeClient, instance_id: st
352
354
 
353
355
 
354
356
  def get_marketplace_listing_and_package(
355
- cuda: bool, client: oci.marketplace.MarketplaceClient
357
+ gpu_name: Optional[str], client: oci.marketplace.MarketplaceClient
356
358
  ) -> Tuple[oci.marketplace.models.Listing, oci.marketplace.models.ImageListingPackage]:
357
359
  listing_name = f"dstack-{version.base_image}"
358
- if cuda:
359
- listing_name = f"dstack-cuda-{version.base_image}"
360
+ if gpu_name is not None:
361
+ if not requires_nvidia_proprietary_kernel_modules(gpu_name):
362
+ listing_name = f"dstack-cuda-{version.base_image}"
363
+ else:
364
+ listing_name = f"dstack-cuda-{DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES}"
360
365
 
361
366
  listing_summaries = list_marketplace_listings(listing_name, client)
362
367
  if len(listing_summaries) != 1:
@@ -8,6 +8,7 @@ from dstack._internal.core.backends.base.compute import (
8
8
  ComputeWithMultinodeSupport,
9
9
  ComputeWithPlacementGroupSupport,
10
10
  ComputeWithPrivateGatewaySupport,
11
+ ComputeWithPrivilegedSupport,
11
12
  ComputeWithReservationSupport,
12
13
  ComputeWithVolumeSupport,
13
14
  )
@@ -31,6 +32,7 @@ class {{ backend_name }}Compute(
31
32
  # TODO: Choose ComputeWith* classes to extend and implement
32
33
  # ComputeWithAllOffersCached,
33
34
  # ComputeWithCreateInstanceSupport,
35
+ # ComputeWithPrivilegedSupport,
34
36
  # ComputeWithMultinodeSupport,
35
37
  # ComputeWithReservationSupport,
36
38
  # ComputeWithPlacementGroupSupport,
@@ -6,6 +6,7 @@ import requests
6
6
  from dstack._internal.core.backends.base.backend import Compute
7
7
  from dstack._internal.core.backends.base.compute import (
8
8
  ComputeWithCreateInstanceSupport,
9
+ ComputeWithPrivilegedSupport,
9
10
  generate_unique_instance_name,
10
11
  get_shim_commands,
11
12
  )
@@ -32,6 +33,7 @@ MAX_INSTANCE_NAME_LEN = 60
32
33
 
33
34
  class TensorDockCompute(
34
35
  ComputeWithCreateInstanceSupport,
36
+ ComputeWithPrivilegedSupport,
35
37
  Compute,
36
38
  ):
37
39
  def __init__(self, config: TensorDockConfig):
@@ -9,6 +9,7 @@ from dstack._internal.core.backends.base.compute import (
9
9
  ComputeWithAllOffersCached,
10
10
  ComputeWithCreateInstanceSupport,
11
11
  ComputeWithMultinodeSupport,
12
+ ComputeWithPrivilegedSupport,
12
13
  generate_unique_instance_name,
13
14
  get_user_data,
14
15
  )
@@ -35,6 +36,7 @@ MAX_INSTANCE_NAME_LEN = 64
35
36
  class VultrCompute(
36
37
  ComputeWithAllOffersCached,
37
38
  ComputeWithCreateInstanceSupport,
39
+ ComputeWithPrivilegedSupport,
38
40
  ComputeWithMultinodeSupport,
39
41
  Compute,
40
42
  ):
@@ -4,3 +4,5 @@ DSTACK_SHIM_HTTP_PORT = 10998
4
4
  DSTACK_RUNNER_HTTP_PORT = 10999
5
5
  # ssh server (runs alongside the runner inside a container) listen port
6
6
  DSTACK_RUNNER_SSH_PORT = 10022
7
+ # legacy AWS, Azure, GCP, and OCI image for older GPUs
8
+ DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES = "0.10"
@@ -36,24 +36,59 @@ def get_repo_creds_and_default_branch(
36
36
 
37
37
  # no auth
38
38
  with suppress(InvalidRepoCredentialsError):
39
- return _get_repo_creds_and_default_branch_https(url)
39
+ creds, default_branch = _get_repo_creds_and_default_branch_https(url)
40
+ logger.debug(
41
+ "Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch
42
+ )
43
+ return creds, default_branch
40
44
 
41
45
  # ssh key provided by the user or pulled from the server
42
46
  if identity_file is not None or private_key is not None:
43
47
  if identity_file is not None:
44
48
  private_key = _read_private_key(identity_file)
45
- return _get_repo_creds_and_default_branch_ssh(url, identity_file, private_key)
49
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
50
+ url, identity_file, private_key
51
+ )
52
+ logger.debug(
53
+ "Git repo %s is private. Using identity file: %s. Default branch: %s",
54
+ repo_url,
55
+ identity_file,
56
+ default_branch,
57
+ )
58
+ return creds, default_branch
46
59
  elif private_key is not None:
47
60
  with NamedTemporaryFile("w+", 0o600) as f:
48
61
  f.write(private_key)
49
62
  f.flush()
50
- return _get_repo_creds_and_default_branch_ssh(url, f.name, private_key)
63
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
64
+ url, f.name, private_key
65
+ )
66
+ masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***"
67
+ logger.debug(
68
+ "Git repo %s is private. Using private key: %s. Default branch: %s",
69
+ repo_url,
70
+ masked_key,
71
+ default_branch,
72
+ )
73
+ return creds, default_branch
51
74
  else:
52
75
  assert False, "should not reach here"
53
76
 
54
77
  # oauth token provided by the user or pulled from the server
55
78
  if oauth_token is not None:
56
- return _get_repo_creds_and_default_branch_https(url, oauth_token)
79
+ creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token)
80
+ masked_token = (
81
+ len(oauth_token[:-4]) * "*" + oauth_token[-4:]
82
+ if len(oauth_token) > 4
83
+ else "***MASKED***"
84
+ )
85
+ logger.debug(
86
+ "Git repo %s is private. Using provided OAuth token: %s. Default branch: %s",
87
+ repo_url,
88
+ masked_token,
89
+ default_branch,
90
+ )
91
+ return creds, default_branch
57
92
 
58
93
  # key from ssh config
59
94
  identities = get_host_config(url.original_host).get("identityfile")
@@ -61,7 +96,16 @@ def get_repo_creds_and_default_branch(
61
96
  _identity_file = identities[0]
62
97
  with suppress(InvalidRepoCredentialsError):
63
98
  _private_key = _read_private_key(_identity_file)
64
- return _get_repo_creds_and_default_branch_ssh(url, _identity_file, _private_key)
99
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
100
+ url, _identity_file, _private_key
101
+ )
102
+ logger.debug(
103
+ "Git repo %s is private. Using SSH config identity file: %s. Default branch: %s",
104
+ repo_url,
105
+ _identity_file,
106
+ default_branch,
107
+ )
108
+ return creds, default_branch
65
109
 
66
110
  # token from gh config
67
111
  if os.path.exists(gh_config_path):
@@ -70,13 +114,35 @@ def get_repo_creds_and_default_branch(
70
114
  _oauth_token = gh_hosts.get(url.host, {}).get("oauth_token")
71
115
  if _oauth_token is not None:
72
116
  with suppress(InvalidRepoCredentialsError):
73
- return _get_repo_creds_and_default_branch_https(url, _oauth_token)
117
+ creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token)
118
+ masked_token = (
119
+ len(_oauth_token[:-4]) * "*" + _oauth_token[-4:]
120
+ if len(_oauth_token) > 4
121
+ else "***MASKED***"
122
+ )
123
+ logger.debug(
124
+ "Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s",
125
+ repo_url,
126
+ masked_token,
127
+ gh_config_path,
128
+ default_branch,
129
+ )
130
+ return creds, default_branch
74
131
 
75
132
  # default user key
76
133
  if os.path.exists(default_ssh_key):
77
134
  with suppress(InvalidRepoCredentialsError):
78
135
  _private_key = _read_private_key(default_ssh_key)
79
- return _get_repo_creds_and_default_branch_ssh(url, default_ssh_key, _private_key)
136
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
137
+ url, default_ssh_key, _private_key
138
+ )
139
+ logger.debug(
140
+ "Git repo %s is private. Using default identity file: %s. Default branch: %s",
141
+ repo_url,
142
+ default_ssh_key,
143
+ default_branch,
144
+ )
145
+ return creds, default_branch
80
146
 
81
147
  raise InvalidRepoCredentialsError(
82
148
  "No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
@@ -87,8 +153,9 @@ def _get_repo_creds_and_default_branch_ssh(
87
153
  url: GitRepoURL, identity_file: PathLike, private_key: str
88
154
  ) -> tuple[RemoteRepoCreds, Optional[str]]:
89
155
  _url = url.as_ssh()
156
+ env = _make_git_env_for_creds_check(identity_file=identity_file)
90
157
  try:
91
- default_branch = _get_repo_default_branch(_url, make_git_env(identity_file=identity_file))
158
+ default_branch = _get_repo_default_branch(_url, env)
92
159
  except GitCommandError as e:
93
160
  message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key"
94
161
  raise InvalidRepoCredentialsError(message) from e
@@ -104,8 +171,9 @@ def _get_repo_creds_and_default_branch_https(
104
171
  url: GitRepoURL, oauth_token: Optional[str] = None
105
172
  ) -> tuple[RemoteRepoCreds, Optional[str]]:
106
173
  _url = url.as_https()
174
+ env = _make_git_env_for_creds_check()
107
175
  try:
108
- default_branch = _get_repo_default_branch(url.as_https(oauth_token), make_git_env())
176
+ default_branch = _get_repo_default_branch(url.as_https(oauth_token), env)
109
177
  except GitCommandError as e:
110
178
  message = f"Cannot access `{_url}`"
111
179
  if oauth_token is not None:
@@ -120,10 +188,32 @@ def _get_repo_creds_and_default_branch_https(
120
188
  return creds, default_branch
121
189
 
122
190
 
191
+ def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]:
192
+ # Our goal is to check if _provided_ creds (if any) are correct, so we need to be sure that
193
+ # only the provided creds are used, without falling back to any additional mechanisms.
194
+ # To do this, we:
195
+ # 1. Disable all configs to ignore any stored creds
196
+ # 2. Disable askpass to avoid asking for creds interactively or fetching stored creds from
197
+ # a non-interactive askpass helper (for example, VS Code sets GIT_ASKPASS to its own helper,
198
+ # which silently provides creds to Git).
199
+ return make_git_env(disable_config=True, disable_askpass=True, identity_file=identity_file)
200
+
201
+
123
202
  def _get_repo_default_branch(url: str, env: dict[str, str]) -> Optional[str]:
203
+ # Git shipped by Apple with XCode is patched to support an additional config scope
204
+ # above "system" called "xcode". There is no option in `git config list` to show this config,
205
+ # but you can list the merged config (`git config list` without options) and then exclude
206
+ # all settings listed in `git config list --{system,global,local,worktree}`.
207
+ # As of time of writing, there are only two settings in the "xcode" config, one of which breaks
208
+ # our "is repo public?" check, namely "credential.helper=osxkeychain".
209
+ # As there is no way to disable "xcode" config (no env variable, no CLI option, etc.),
210
+ # the only way to disable credential helper is to override this specific setting with an empty
211
+ # string via command line argument: `git -c credential.helper= COMMAND [ARGS ...]`.
212
+ # See: https://github.com/git/git/commit/3d4355712b9fe77a96ad4ad877d92dc7ff6e0874
213
+ # See: https://gist.github.com/ChrisTollefson/ab9c0a5d1dd4dd615217345c6936a307
214
+ _git = git.cmd.Git()(c="credential.helper=")
124
215
  # output example: "ref: refs/heads/dev\tHEAD\n545344f77c0df78367085952a97fc3a058eb4c65\tHEAD"
125
- # Disable credential helpers to exclude any default credentials from being used
126
- output: str = git.cmd.Git()(c="credential.helper=").ls_remote("--symref", url, "HEAD", env=env)
216
+ output: str = _git.ls_remote("--symref", url, "HEAD", env=env)
127
217
  for line in output.splitlines():
128
218
  # line format: `<oid> TAB <ref> LF`
129
219
  oid, _, ref = line.partition("\t")
@@ -19,4 +19,6 @@ def get_provisioning_timeout(backend_type: BackendType, instance_type_name: str)
19
19
  return timedelta(minutes=20)
20
20
  if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
21
21
  return timedelta(minutes=55)
22
+ if backend_type == BackendType.GCP and instance_type_name == "a4-highgpu-8g":
23
+ return timedelta(minutes=16)
22
24
  return timedelta(minutes=10)
@@ -307,7 +307,7 @@ async def _add_remote(instance: InstanceModel) -> None:
307
307
  )
308
308
  deploy_timeout = 20 * 60 # 20 minutes
309
309
  result = await asyncio.wait_for(future, timeout=deploy_timeout)
310
- health, host_info, cpu_arch = result
310
+ health, host_info, arch = result
311
311
  except (asyncio.TimeoutError, TimeoutError) as e:
312
312
  raise ProvisioningError(f"Deploy timeout: {e}") from e
313
313
  except Exception as e:
@@ -327,7 +327,7 @@ async def _add_remote(instance: InstanceModel) -> None:
327
327
  instance.status = InstanceStatus.PENDING
328
328
  return
329
329
 
330
- instance_type = host_info_to_instance_type(host_info, cpu_arch)
330
+ instance_type = host_info_to_instance_type(host_info, arch)
331
331
  instance_network = None
332
332
  internal_ip = None
333
333
  try:
@@ -1139,7 +1139,7 @@ def _patch_base_image_for_aws_efa(
1139
1139
  efa_enabled_patterns = [
1140
1140
  # TODO: p6-b200 isn't supported yet in gpuhunt
1141
1141
  r"^p6-b200\.(48xlarge)$",
1142
- r"^p5\.(48xlarge)$",
1142
+ r"^p5\.(4xlarge|48xlarge)$",
1143
1143
  r"^p5e\.(48xlarge)$",
1144
1144
  r"^p5en\.(48xlarge)$",
1145
1145
  r"^p4d\.(24xlarge)$",