dstack 0.19.23rc1__py3-none-any.whl → 0.19.25rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (60) hide show
  1. dstack/_internal/cli/commands/apply.py +14 -2
  2. dstack/_internal/cli/commands/init.py +47 -2
  3. dstack/_internal/cli/commands/offer.py +68 -60
  4. dstack/_internal/cli/services/configurators/run.py +38 -10
  5. dstack/_internal/cli/services/repos.py +6 -24
  6. dstack/_internal/cli/utils/common.py +7 -0
  7. dstack/_internal/cli/utils/gpu.py +210 -0
  8. dstack/_internal/cli/utils/run.py +33 -0
  9. dstack/_internal/core/backends/aws/compute.py +1 -4
  10. dstack/_internal/core/backends/base/compute.py +0 -4
  11. dstack/_internal/core/backends/gcp/compute.py +1 -4
  12. dstack/_internal/core/backends/nebius/compute.py +1 -4
  13. dstack/_internal/core/models/common.py +1 -1
  14. dstack/_internal/core/models/config.py +3 -1
  15. dstack/_internal/core/models/configurations.py +16 -14
  16. dstack/_internal/core/models/fleets.py +2 -2
  17. dstack/_internal/core/models/instances.py +4 -1
  18. dstack/_internal/core/models/profiles.py +2 -2
  19. dstack/_internal/core/models/repos/remote.py +2 -2
  20. dstack/_internal/core/models/resources.py +4 -4
  21. dstack/_internal/core/models/runs.py +13 -9
  22. dstack/_internal/core/services/configs/__init__.py +4 -6
  23. dstack/_internal/proxy/gateway/services/registry.py +2 -0
  24. dstack/_internal/server/app.py +2 -0
  25. dstack/_internal/server/background/tasks/process_fleets.py +10 -2
  26. dstack/_internal/server/background/tasks/process_running_jobs.py +66 -46
  27. dstack/_internal/server/background/tasks/process_runs.py +16 -15
  28. dstack/_internal/server/background/tasks/process_submitted_jobs.py +251 -52
  29. dstack/_internal/server/migrations/versions/3d7f6c2ec000_add_jobmodel_registered.py +28 -0
  30. dstack/_internal/server/migrations/versions/74a1f55209bd_store_enums_as_strings.py +484 -0
  31. dstack/_internal/server/migrations/versions/e2d08cd1b8d9_add_jobmodel_fleet.py +41 -0
  32. dstack/_internal/server/models.py +24 -13
  33. dstack/_internal/server/routers/gpus.py +29 -0
  34. dstack/_internal/server/schemas/gateways.py +1 -1
  35. dstack/_internal/server/schemas/gpus.py +66 -0
  36. dstack/_internal/server/services/docker.py +1 -1
  37. dstack/_internal/server/services/gpus.py +390 -0
  38. dstack/_internal/server/services/jobs/__init__.py +3 -1
  39. dstack/_internal/server/services/offers.py +48 -31
  40. dstack/_internal/server/services/probes.py +5 -1
  41. dstack/_internal/server/services/proxy/repo.py +1 -0
  42. dstack/_internal/server/services/repos.py +1 -1
  43. dstack/_internal/server/services/runs.py +15 -12
  44. dstack/_internal/server/services/secrets.py +1 -1
  45. dstack/_internal/server/services/services/__init__.py +60 -41
  46. dstack/_internal/server/statics/index.html +1 -1
  47. dstack/_internal/server/statics/logo-notext.svg +116 -0
  48. dstack/_internal/server/statics/{main-03e818b110e1d5705378.css → main-aec4762350e34d6fbff9.css} +1 -1
  49. dstack/_internal/server/statics/{main-cc067b7fd1a8f33f97da.js → main-d151b300fcac3933213d.js} +20 -23
  50. dstack/_internal/server/statics/{main-cc067b7fd1a8f33f97da.js.map → main-d151b300fcac3933213d.js.map} +1 -1
  51. dstack/_internal/server/testing/common.py +7 -2
  52. dstack/api/_public/repos.py +8 -7
  53. dstack/api/server/__init__.py +6 -0
  54. dstack/api/server/_gpus.py +22 -0
  55. dstack/version.py +1 -1
  56. {dstack-0.19.23rc1.dist-info → dstack-0.19.25rc1.dist-info}/METADATA +1 -1
  57. {dstack-0.19.23rc1.dist-info → dstack-0.19.25rc1.dist-info}/RECORD +60 -51
  58. {dstack-0.19.23rc1.dist-info → dstack-0.19.25rc1.dist-info}/WHEEL +0 -0
  59. {dstack-0.19.23rc1.dist-info → dstack-0.19.25rc1.dist-info}/entry_points.txt +0 -0
  60. {dstack-0.19.23rc1.dist-info → dstack-0.19.25rc1.dist-info}/licenses/LICENSE.md +0 -0
@@ -28,6 +28,39 @@ from dstack._internal.utils.common import (
28
28
  from dstack.api import Run
29
29
 
30
30
 
31
+ def print_offers_json(run_plan: RunPlan, run_spec):
32
+ """Print offers information in JSON format."""
33
+ job_plan = run_plan.job_plans[0]
34
+
35
+ output = {
36
+ "project": run_plan.project_name,
37
+ "user": run_plan.user,
38
+ "resources": job_plan.job_spec.requirements.resources.dict(),
39
+ "max_price": (job_plan.job_spec.requirements.max_price),
40
+ "spot": run_spec.configuration.spot_policy,
41
+ "reservation": run_plan.run_spec.configuration.reservation,
42
+ "offers": [],
43
+ "total_offers": job_plan.total_offers,
44
+ }
45
+
46
+ for offer in job_plan.offers:
47
+ output["offers"].append(
48
+ {
49
+ "backend": ("ssh" if offer.backend.value == "remote" else offer.backend.value),
50
+ "region": offer.region,
51
+ "instance_type": offer.instance.name,
52
+ "resources": offer.instance.resources.dict(),
53
+ "spot": offer.instance.resources.spot,
54
+ "price": float(offer.price),
55
+ "availability": offer.availability.value,
56
+ }
57
+ )
58
+
59
+ import json
60
+
61
+ print(json.dumps(output, indent=2))
62
+
63
+
31
64
  def print_run_plan(
32
65
  run_plan: RunPlan, max_offers: Optional[int] = None, include_run_properties: bool = True
33
66
  ):
@@ -383,10 +383,7 @@ class AWSCompute(
383
383
  ) -> bool:
384
384
  if not _offer_supports_placement_group(instance_offer, placement_group):
385
385
  return False
386
- return (
387
- placement_group.configuration.backend == BackendType.AWS
388
- and placement_group.configuration.region == instance_offer.region
389
- )
386
+ return placement_group.configuration.region == instance_offer.region
390
387
 
391
388
  def create_gateway(
392
389
  self,
@@ -263,10 +263,6 @@ class ComputeWithPlacementGroupSupport(ABC):
263
263
  Checks if the instance offer can be provisioned in the placement group.
264
264
 
265
265
  Should return immediately, without performing API calls.
266
-
267
- Can be called with an offer originating from a different backend, because some backends
268
- (BackendType.DSTACK) produce offers on behalf of other backends. Should return `False`
269
- in that case.
270
266
  """
271
267
  pass
272
268
 
@@ -448,10 +448,7 @@ class GCPCompute(
448
448
  placement_group: PlacementGroup,
449
449
  instance_offer: InstanceOffer,
450
450
  ) -> bool:
451
- return (
452
- placement_group.configuration.backend == BackendType.GCP
453
- and placement_group.configuration.region == instance_offer.region
454
- )
451
+ return placement_group.configuration.region == instance_offer.region
455
452
 
456
453
  def create_gateway(
457
454
  self,
@@ -298,10 +298,7 @@ class NebiusCompute(
298
298
  placement_group: PlacementGroup,
299
299
  instance_offer: InstanceOffer,
300
300
  ) -> bool:
301
- if not (
302
- placement_group.configuration.backend == BackendType.NEBIUS
303
- and placement_group.configuration.region == instance_offer.region
304
- ):
301
+ if placement_group.configuration.region != instance_offer.region:
305
302
  return False
306
303
  assert placement_group.provisioning_data is not None
307
304
  backend_data = NebiusPlacementGroupBackendData.load(
@@ -102,7 +102,7 @@ class RegistryAuth(CoreModel):
102
102
  password (str): The password or access token
103
103
  """
104
104
 
105
- class Config:
105
+ class Config(CoreModel.Config):
106
106
  frozen = True
107
107
 
108
108
  username: Annotated[str, Field(description="The username")]
@@ -16,7 +16,9 @@ class RepoConfig(CoreModel):
16
16
  path: str
17
17
  repo_id: str
18
18
  repo_type: RepoType
19
- ssh_key_path: str
19
+ # Deprecated since 0.19.25, not used. Can be removed when most users update their `config.yml`
20
+ # (it's updated each time a project or repo is added)
21
+ ssh_key_path: Annotated[Optional[str], Field(exclude=True)] = None
20
22
 
21
23
 
22
24
  class GlobalConfig(CoreModel):
@@ -20,6 +20,7 @@ from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
20
20
  from dstack._internal.core.models.unix import UnixUser
21
21
  from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
22
22
  from dstack._internal.utils.common import has_duplicates
23
+ from dstack._internal.utils.json_schema import add_extra_schema_types
23
24
  from dstack._internal.utils.json_utils import (
24
25
  pydantic_orjson_dumps_with_indent,
25
26
  )
@@ -561,7 +562,7 @@ class ServiceConfigurationParams(CoreModel):
561
562
  )
562
563
  auth: Annotated[bool, Field(description="Enable the authorization")] = True
563
564
  replicas: Annotated[
564
- Union[conint(ge=1), constr(regex=r"^[0-9]+..[1-9][0-9]*$"), Range[int]],
565
+ Range[int],
565
566
  Field(
566
567
  description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
567
568
  "If it's a range, the `scaling` property is required"
@@ -592,20 +593,13 @@ class ServiceConfigurationParams(CoreModel):
592
593
  return v
593
594
 
594
595
  @validator("replicas")
595
- def convert_replicas(cls, v: Any) -> Range[int]:
596
- if isinstance(v, str) and ".." in v:
597
- min, max = v.replace(" ", "").split("..")
598
- v = Range(min=min or 0, max=max or None)
599
- elif isinstance(v, (int, float)):
600
- v = Range(min=v, max=v)
596
+ def convert_replicas(cls, v: Range[int]) -> Range[int]:
601
597
  if v.max is None:
602
598
  raise ValueError("The maximum number of replicas is required")
599
+ if v.min is None:
600
+ v.min = 0
603
601
  if v.min < 0:
604
602
  raise ValueError("The minimum number of replicas must be greater than or equal to 0")
605
- if v.max < v.min:
606
- raise ValueError(
607
- "The maximum number of replicas must be greater than or equal to the minimum number of replicas"
608
- )
609
603
  return v
610
604
 
611
605
  @validator("gateway")
@@ -622,9 +616,9 @@ class ServiceConfigurationParams(CoreModel):
622
616
  def validate_scaling(cls, values):
623
617
  scaling = values.get("scaling")
624
618
  replicas = values.get("replicas")
625
- if replicas.min != replicas.max and not scaling:
619
+ if replicas and replicas.min != replicas.max and not scaling:
626
620
  raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
627
- if replicas.min == replicas.max and scaling:
621
+ if replicas and replicas.min == replicas.max and scaling:
628
622
  raise ValueError("To use `scaling`, `replicas` must be set to a range.")
629
623
  return values
630
624
 
@@ -655,6 +649,14 @@ class ServiceConfiguration(
655
649
  ):
656
650
  type: Literal["service"] = "service"
657
651
 
652
+ class Config(CoreModel.Config):
653
+ @staticmethod
654
+ def schema_extra(schema: Dict[str, Any]):
655
+ add_extra_schema_types(
656
+ schema["properties"]["replicas"],
657
+ extra_types=[{"type": "integer"}, {"type": "string"}],
658
+ )
659
+
658
660
 
659
661
  AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]
660
662
 
@@ -715,7 +717,7 @@ class DstackConfiguration(CoreModel):
715
717
  Field(discriminator="type"),
716
718
  ]
717
719
 
718
- class Config:
720
+ class Config(CoreModel.Config):
719
721
  json_loads = orjson.loads
720
722
  json_dumps = pydantic_orjson_dumps_with_indent
721
723
 
@@ -234,7 +234,7 @@ class InstanceGroupParams(CoreModel):
234
234
  termination_policy: Annotated[Optional[TerminationPolicy], Field(exclude=True)] = None
235
235
  termination_idle_time: Annotated[Optional[Union[str, int]], Field(exclude=True)] = None
236
236
 
237
- class Config:
237
+ class Config(CoreModel.Config):
238
238
  @staticmethod
239
239
  def schema_extra(schema: Dict[str, Any], model: Type):
240
240
  del schema["properties"]["termination_policy"]
@@ -279,7 +279,7 @@ class FleetSpec(CoreModel):
279
279
  # TODO: make merged_profile a computed field after migrating to pydanticV2
280
280
  merged_profile: Annotated[Profile, Field(exclude=True)] = None
281
281
 
282
- class Config:
282
+ class Config(CoreModel.Config):
283
283
  @staticmethod
284
284
  def schema_extra(schema: Dict[str, Any], model: Type) -> None:
285
285
  prop = schema.get("properties", {})
@@ -122,7 +122,7 @@ class SSHConnectionParams(CoreModel):
122
122
  username: str
123
123
  port: int
124
124
 
125
- class Config:
125
+ class Config(CoreModel.Config):
126
126
  frozen = True
127
127
 
128
128
 
@@ -165,6 +165,9 @@ class InstanceAvailability(Enum):
165
165
  AVAILABLE = "available"
166
166
  NOT_AVAILABLE = "not_available"
167
167
  NO_QUOTA = "no_quota"
168
+ NO_BALANCE = (
169
+ "no_balance" # Introduced in 0.19.24, may be used after a short compatibility period
170
+ )
168
171
  IDLE = "idle"
169
172
  BUSY = "busy"
170
173
 
@@ -339,7 +339,7 @@ class ProfileParams(CoreModel):
339
339
  termination_policy: Annotated[Optional[TerminationPolicy], Field(exclude=True)] = None
340
340
  termination_idle_time: Annotated[Optional[Union[str, int]], Field(exclude=True)] = None
341
341
 
342
- class Config:
342
+ class Config(CoreModel.Config):
343
343
  @staticmethod
344
344
  def schema_extra(schema: Dict[str, Any]) -> None:
345
345
  del schema["properties"]["pool_name"]
@@ -379,7 +379,7 @@ class Profile(ProfileProps, ProfileParams):
379
379
  class ProfilesConfig(CoreModel):
380
380
  profiles: List[Profile]
381
381
 
382
- class Config:
382
+ class Config(CoreModel.Config):
383
383
  json_loads = orjson.loads
384
384
  json_dumps = pydantic_orjson_dumps_with_indent
385
385
 
@@ -32,7 +32,7 @@ class RemoteRepoCreds(CoreModel):
32
32
  # TODO: remove in 0.20. Left for compatibility with CLI <=0.18.44
33
33
  protocol: Annotated[Optional[str], Field(exclude=True)] = None
34
34
 
35
- class Config:
35
+ class Config(CoreModel.Config):
36
36
  @staticmethod
37
37
  def schema_extra(schema: Dict[str, Any]) -> None:
38
38
  del schema["properties"]["protocol"]
@@ -47,7 +47,7 @@ class RemoteRepoInfo(BaseRepoInfo):
47
47
  repo_port: Annotated[Optional[int], Field(exclude=True)] = None
48
48
  repo_user_name: Annotated[Optional[str], Field(exclude=True)] = None
49
49
 
50
- class Config:
50
+ class Config(BaseRepoInfo.Config):
51
51
  @staticmethod
52
52
  def schema_extra(schema: Dict[str, Any]) -> None:
53
53
  del schema["properties"]["repo_host_name"]
@@ -130,7 +130,7 @@ DEFAULT_GPU_COUNT = Range[int](min=1)
130
130
 
131
131
 
132
132
  class CPUSpec(CoreModel):
133
- class Config:
133
+ class Config(CoreModel.Config):
134
134
  @staticmethod
135
135
  def schema_extra(schema: Dict[str, Any]):
136
136
  add_extra_schema_types(
@@ -191,7 +191,7 @@ class CPUSpec(CoreModel):
191
191
 
192
192
 
193
193
  class GPUSpec(CoreModel):
194
- class Config:
194
+ class Config(CoreModel.Config):
195
195
  @staticmethod
196
196
  def schema_extra(schema: Dict[str, Any]):
197
197
  add_extra_schema_types(
@@ -314,7 +314,7 @@ class GPUSpec(CoreModel):
314
314
 
315
315
 
316
316
  class DiskSpec(CoreModel):
317
- class Config:
317
+ class Config(CoreModel.Config):
318
318
  @staticmethod
319
319
  def schema_extra(schema: Dict[str, Any]):
320
320
  add_extra_schema_types(
@@ -340,7 +340,7 @@ DEFAULT_DISK = DiskSpec(size=Range[Memory](min=Memory.parse("100GB"), max=None))
340
340
 
341
341
 
342
342
  class ResourcesSpec(CoreModel):
343
- class Config:
343
+ class Config(CoreModel.Config):
344
344
  @staticmethod
345
345
  def schema_extra(schema: Dict[str, Any]):
346
346
  add_extra_schema_types(
@@ -350,15 +350,17 @@ class JobSubmission(CoreModel):
350
350
  deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers
351
351
  submitted_at: datetime
352
352
  last_processed_at: datetime
353
- finished_at: Optional[datetime]
354
- inactivity_secs: Optional[int]
353
+ finished_at: Optional[datetime] = None
354
+ inactivity_secs: Optional[int] = None
355
355
  status: JobStatus
356
356
  status_message: str = "" # default for backward compatibility
357
- termination_reason: Optional[JobTerminationReason]
358
- termination_reason_message: Optional[str]
359
- exit_status: Optional[int]
360
- job_provisioning_data: Optional[JobProvisioningData]
361
- job_runtime_data: Optional[JobRuntimeData]
357
+ # termination_reason stores JobTerminationReason.
358
+ # str allows adding new enum members without breaking compatibility with old clients.
359
+ termination_reason: Optional[str] = None
360
+ termination_reason_message: Optional[str] = None
361
+ exit_status: Optional[int] = None
362
+ job_provisioning_data: Optional[JobProvisioningData] = None
363
+ job_runtime_data: Optional[JobRuntimeData] = None
362
364
  error: Optional[str] = None
363
365
  probes: list[Probe] = []
364
366
 
@@ -442,7 +444,7 @@ class RunSpec(CoreModel):
442
444
  # TODO: make merged_profile a computed field after migrating to pydanticV2
443
445
  merged_profile: Annotated[Profile, Field(exclude=True)] = None
444
446
 
445
- class Config:
447
+ class Config(CoreModel.Config):
446
448
  @staticmethod
447
449
  def schema_extra(schema: Dict[str, Any], model: Type) -> None:
448
450
  prop = schema.get("properties", {})
@@ -508,7 +510,9 @@ class Run(CoreModel):
508
510
  last_processed_at: datetime
509
511
  status: RunStatus
510
512
  status_message: str = "" # default for backward compatibility
511
- termination_reason: Optional[RunTerminationReason] = None
513
+ # termination_reason stores RunTerminationReason.
514
+ # str allows adding new enum members without breaking compatibility with old clients.
515
+ termination_reason: Optional[str] = None
512
516
  run_spec: RunSpec
513
517
  jobs: List[Job]
514
518
  latest_job_submission: Optional[JobSubmission] = None
@@ -71,19 +71,15 @@ class ConfigManager:
71
71
  def delete_project(self, name: str):
72
72
  self.config.projects = [p for p in self.config.projects if p.name != name]
73
73
 
74
- def save_repo_config(
75
- self, repo_path: PathLike, repo_id: str, repo_type: RepoType, ssh_key_path: PathLike
76
- ):
74
+ def save_repo_config(self, repo_path: PathLike, repo_id: str, repo_type: RepoType):
77
75
  self.config_filepath.parent.mkdir(parents=True, exist_ok=True)
78
76
  with filelock.FileLock(str(self.config_filepath) + ".lock"):
79
77
  self.load()
80
78
  repo_path = os.path.abspath(repo_path)
81
- ssh_key_path = os.path.abspath(ssh_key_path)
82
79
  for repo in self.config.repos:
83
80
  if repo.path == repo_path:
84
81
  repo.repo_id = repo_id
85
82
  repo.repo_type = repo_type
86
- repo.ssh_key_path = ssh_key_path
87
83
  break
88
84
  else:
89
85
  self.config.repos.append(
@@ -91,7 +87,6 @@ class ConfigManager:
91
87
  path=repo_path,
92
88
  repo_id=repo_id,
93
89
  repo_type=repo_type,
94
- ssh_key_path=ssh_key_path,
95
90
  )
96
91
  )
97
92
  self.save()
@@ -110,6 +105,9 @@ class ConfigManager:
110
105
  return repo_config
111
106
  raise DstackError("No repo config found")
112
107
 
108
+ def delete_repo_config(self, repo_id: str):
109
+ self.config.repos = [p for p in self.config.repos if p.repo_id != repo_id]
110
+
113
111
  @property
114
112
  def dstack_ssh_dir(self) -> Path:
115
113
  return self.dstack_dir / "ssh"
@@ -152,6 +152,8 @@ async def register_replica(
152
152
  )
153
153
 
154
154
  if old_service.find_replica(replica_id) is not None:
155
+ # NOTE: as of 0.19.25, the dstack server relies on the exact text of this error.
156
+ # See dstack._internal.server.services.services.register_replica
155
157
  raise ProxyError(f"Replica {replica_id} already exists in service {old_service.fmt()}")
156
158
 
157
159
  service = old_service.with_replicas(old_service.replicas + (replica,))
@@ -29,6 +29,7 @@ from dstack._internal.server.routers import (
29
29
  files,
30
30
  fleets,
31
31
  gateways,
32
+ gpus,
32
33
  instances,
33
34
  logs,
34
35
  metrics,
@@ -204,6 +205,7 @@ def register_routes(app: FastAPI, ui: bool = True):
204
205
  app.include_router(repos.router)
205
206
  app.include_router(runs.root_router)
206
207
  app.include_router(runs.project_router)
208
+ app.include_router(gpus.project_router)
207
209
  app.include_router(metrics.router)
208
210
  app.include_router(logs.router)
209
211
  app.include_router(secrets.router)
@@ -15,6 +15,7 @@ from dstack._internal.server.models import (
15
15
  RunModel,
16
16
  )
17
17
  from dstack._internal.server.services.fleets import (
18
+ get_fleet_spec,
18
19
  is_fleet_empty,
19
20
  is_fleet_in_use,
20
21
  )
@@ -92,11 +93,18 @@ async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel])
92
93
 
93
94
 
94
95
  def _autodelete_fleet(fleet_model: FleetModel) -> bool:
95
- # Currently all empty fleets are autodeleted.
96
- # TODO: If fleets with `nodes: 0..` are supported, their deletion should be skipped.
97
96
  if is_fleet_in_use(fleet_model) or not is_fleet_empty(fleet_model):
98
97
  return False
99
98
 
99
+ fleet_spec = get_fleet_spec(fleet_model)
100
+ if (
101
+ fleet_model.status != FleetStatus.TERMINATING
102
+ and fleet_spec.configuration.nodes is not None
103
+ and (fleet_spec.configuration.nodes.min is None or fleet_spec.configuration.nodes.min == 0)
104
+ ):
105
+ # Empty fleets that allow 0 nodes should not be auto-deleted
106
+ return False
107
+
100
108
  logger.info("Automatic cleanup of an empty fleet %s", fleet_model.name)
101
109
  fleet_model.status = FleetStatus.TERMINATED
102
110
  fleet_model.deleted = True
@@ -32,6 +32,7 @@ from dstack._internal.core.models.runs import (
32
32
  JobSpec,
33
33
  JobStatus,
34
34
  JobTerminationReason,
35
+ ProbeSpec,
35
36
  Run,
36
37
  RunSpec,
37
38
  RunStatus,
@@ -70,6 +71,7 @@ from dstack._internal.server.services.repos import (
70
71
  from dstack._internal.server.services.runner import client
71
72
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
72
73
  from dstack._internal.server.services.runs import (
74
+ is_job_ready,
73
75
  run_model_to_run,
74
76
  )
75
77
  from dstack._internal.server.services.secrets import get_project_secrets_mapping
@@ -140,6 +142,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
140
142
  select(JobModel)
141
143
  .where(JobModel.id == job_model.id)
142
144
  .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
145
+ .options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
143
146
  .execution_options(populate_existing=True)
144
147
  )
145
148
  job_model = res.unique().scalar_one()
@@ -382,52 +385,21 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
382
385
  job_submission.age,
383
386
  )
384
387
 
385
- if (
386
- initial_status != job_model.status
387
- and job_model.status == JobStatus.RUNNING
388
- and job_model.job_num == 0 # gateway connects only to the first node
389
- and run.run_spec.configuration.type == "service"
390
- ):
391
- ssh_head_proxy: Optional[SSHConnectionParams] = None
392
- ssh_head_proxy_private_key: Optional[str] = None
393
- instance = common_utils.get_or_error(job_model.instance)
394
- if instance.remote_connection_info is not None:
395
- rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
396
- if rci.ssh_proxy is not None:
397
- ssh_head_proxy = rci.ssh_proxy
398
- ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
399
- ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
400
- try:
401
- await services.register_replica(
402
- session,
403
- run_model.gateway_id,
404
- run,
405
- job_model,
406
- ssh_head_proxy,
407
- ssh_head_proxy_private_key,
408
- )
409
- except GatewayError as e:
410
- logger.warning(
411
- "%s: failed to register service replica: %s, age=%s",
412
- fmt(job_model),
413
- e,
414
- job_submission.age,
415
- )
416
- job_model.status = JobStatus.TERMINATING
417
- job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR
418
- else:
419
- for probe_num in range(len(job.job_spec.probes)):
420
- session.add(
421
- ProbeModel(
422
- name=f"{job_model.job_name}-{probe_num}",
423
- job=job_model,
424
- probe_num=probe_num,
425
- due=common_utils.get_current_datetime(),
426
- success_streak=0,
427
- active=True,
428
- )
388
+ if initial_status != job_model.status and job_model.status == JobStatus.RUNNING:
389
+ job_model.probes = []
390
+ for probe_num in range(len(job.job_spec.probes)):
391
+ job_model.probes.append(
392
+ ProbeModel(
393
+ name=f"{job_model.job_name}-{probe_num}",
394
+ probe_num=probe_num,
395
+ due=common_utils.get_current_datetime(),
396
+ success_streak=0,
397
+ active=True,
429
398
  )
399
+ )
430
400
 
401
+ if job_model.status == JobStatus.RUNNING:
402
+ await _maybe_register_replica(session, run_model, run, job_model, job.job_spec.probes)
431
403
  if job_model.status == JobStatus.RUNNING:
432
404
  await _check_gpu_utilization(session, job_model, job)
433
405
 
@@ -455,8 +427,7 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
455
427
 
456
428
  if job_model.instance.status == InstanceStatus.TERMINATED:
457
429
  job_model.status = JobStatus.TERMINATING
458
- # TODO use WAITING_INSTANCE_LIMIT_EXCEEDED after 0.19.x
459
- job_model.termination_reason = JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
430
+ job_model.termination_reason = JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED
460
431
  return
461
432
 
462
433
  job_model.job_provisioning_data = job_model.instance.job_provisioning_data
@@ -823,6 +794,55 @@ def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
823
794
  )
824
795
 
825
796
 
797
+ async def _maybe_register_replica(
798
+ session: AsyncSession,
799
+ run_model: RunModel,
800
+ run: Run,
801
+ job_model: JobModel,
802
+ probe_specs: Iterable[ProbeSpec],
803
+ ) -> None:
804
+ """
805
+ Register the replica represented by this job to receive service requests if it is ready.
806
+ """
807
+
808
+ if (
809
+ run.run_spec.configuration.type != "service"
810
+ or job_model.registered
811
+ or job_model.job_num != 0 # only the first job in the replica receives service requests
812
+ or not is_job_ready(job_model.probes, probe_specs)
813
+ ):
814
+ return
815
+
816
+ ssh_head_proxy: Optional[SSHConnectionParams] = None
817
+ ssh_head_proxy_private_key: Optional[str] = None
818
+ instance = common_utils.get_or_error(job_model.instance)
819
+ if instance.remote_connection_info is not None:
820
+ rci: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
821
+ instance.remote_connection_info
822
+ )
823
+ if rci.ssh_proxy is not None:
824
+ ssh_head_proxy = rci.ssh_proxy
825
+ ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
826
+ ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
827
+ try:
828
+ await services.register_replica(
829
+ session,
830
+ run_model.gateway_id,
831
+ run,
832
+ job_model,
833
+ ssh_head_proxy,
834
+ ssh_head_proxy_private_key,
835
+ )
836
+ except GatewayError as e:
837
+ logger.warning(
838
+ "%s: failed to register service replica: %s",
839
+ fmt(job_model),
840
+ e,
841
+ )
842
+ job_model.status = JobStatus.TERMINATING
843
+ job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR
844
+
845
+
826
846
  async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None:
827
847
  policy = job.job_spec.utilization_policy
828
848
  if policy is None: