dstack 0.19.9__py3-none-any.whl → 0.19.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (53) hide show
  1. dstack/_internal/cli/commands/config.py +1 -1
  2. dstack/_internal/cli/commands/metrics.py +25 -10
  3. dstack/_internal/cli/commands/offer.py +2 -0
  4. dstack/_internal/cli/commands/project.py +161 -0
  5. dstack/_internal/cli/commands/ps.py +9 -2
  6. dstack/_internal/cli/main.py +2 -0
  7. dstack/_internal/cli/services/configurators/run.py +1 -1
  8. dstack/_internal/cli/utils/updates.py +13 -1
  9. dstack/_internal/core/backends/aws/compute.py +21 -9
  10. dstack/_internal/core/backends/azure/compute.py +8 -3
  11. dstack/_internal/core/backends/base/compute.py +9 -4
  12. dstack/_internal/core/backends/gcp/compute.py +43 -20
  13. dstack/_internal/core/backends/gcp/resources.py +18 -2
  14. dstack/_internal/core/backends/local/compute.py +4 -2
  15. dstack/_internal/core/models/configurations.py +21 -4
  16. dstack/_internal/core/models/runs.py +2 -1
  17. dstack/_internal/proxy/gateway/resources/nginx/00-log-format.conf +11 -1
  18. dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +12 -6
  19. dstack/_internal/proxy/gateway/services/stats.py +17 -3
  20. dstack/_internal/server/background/tasks/process_metrics.py +23 -21
  21. dstack/_internal/server/background/tasks/process_submitted_jobs.py +24 -15
  22. dstack/_internal/server/migrations/versions/bca2fdf130bf_add_runmodel_priority.py +34 -0
  23. dstack/_internal/server/models.py +1 -0
  24. dstack/_internal/server/routers/repos.py +13 -4
  25. dstack/_internal/server/services/fleets.py +2 -2
  26. dstack/_internal/server/services/gateways/__init__.py +1 -1
  27. dstack/_internal/server/services/instances.py +6 -2
  28. dstack/_internal/server/services/jobs/__init__.py +4 -4
  29. dstack/_internal/server/services/jobs/configurators/base.py +18 -4
  30. dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +3 -1
  31. dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +3 -1
  32. dstack/_internal/server/services/plugins.py +64 -32
  33. dstack/_internal/server/services/runs.py +33 -20
  34. dstack/_internal/server/services/volumes.py +1 -1
  35. dstack/_internal/server/settings.py +1 -0
  36. dstack/_internal/server/statics/index.html +1 -1
  37. dstack/_internal/server/statics/{main-b4f65323f5df007e1664.js → main-5b9786c955b42bf93581.js} +8 -8
  38. dstack/_internal/server/statics/{main-b4f65323f5df007e1664.js.map → main-5b9786c955b42bf93581.js.map} +1 -1
  39. dstack/_internal/server/testing/common.py +2 -0
  40. dstack/_internal/server/utils/routers.py +3 -6
  41. dstack/_internal/settings.py +4 -0
  42. dstack/api/_public/runs.py +6 -3
  43. dstack/api/server/_runs.py +2 -0
  44. dstack/plugins/builtin/__init__.py +0 -0
  45. dstack/plugins/builtin/rest_plugin/__init__.py +18 -0
  46. dstack/plugins/builtin/rest_plugin/_models.py +48 -0
  47. dstack/plugins/builtin/rest_plugin/_plugin.py +127 -0
  48. dstack/version.py +2 -2
  49. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/METADATA +10 -6
  50. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/RECORD +53 -47
  51. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/WHEEL +0 -0
  52. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/entry_points.txt +0 -0
  53. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/licenses/LICENSE.md +0 -0
@@ -23,6 +23,10 @@ ValidPort = conint(gt=0, le=65536)
23
23
  MAX_INT64 = 2**63 - 1
24
24
  SERVICE_HTTPS_DEFAULT = True
25
25
  STRIP_PREFIX_DEFAULT = True
26
+ RUN_PRIOTIRY_MIN = 0
27
+ RUN_PRIOTIRY_MAX = 100
28
+ RUN_PRIORITY_DEFAULT = 0
29
+ DEFAULT_REPO_DIR = "/workflow"
26
30
 
27
31
 
28
32
  class RunConfigurationType(str, Enum):
@@ -77,7 +81,8 @@ class ScalingSpec(CoreModel):
77
81
  Field(
78
82
  description="The target value of the metric. "
79
83
  "The number of replicas is calculated based on this number and automatically adjusts "
80
- "(scales up or down) as this metric changes"
84
+ "(scales up or down) as this metric changes",
85
+ gt=0,
81
86
  ),
82
87
  ]
83
88
  scale_up_delay: Annotated[
@@ -177,7 +182,7 @@ class BaseRunConfiguration(CoreModel):
177
182
  Field(
178
183
  description=(
179
184
  "The path to the working directory inside the container."
180
- " It's specified relative to the repository directory (`/workflow`) and should be inside it."
185
+ f" It's specified relative to the repository directory (`{DEFAULT_REPO_DIR}`) and should be inside it."
181
186
  ' Defaults to `"."` '
182
187
  )
183
188
  ),
@@ -221,14 +226,26 @@ class BaseRunConfiguration(CoreModel):
221
226
  )
222
227
  ),
223
228
  ] = None
224
- # deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
225
- setup: CommandsList = []
226
229
  resources: Annotated[
227
230
  ResourcesSpec, Field(description="The resources requirements to run the configuration")
228
231
  ] = ResourcesSpec()
232
+ priority: Annotated[
233
+ Optional[int],
234
+ Field(
235
+ ge=RUN_PRIOTIRY_MIN,
236
+ le=RUN_PRIOTIRY_MAX,
237
+ description=(
238
+ f"The priority of the run, an integer between `{RUN_PRIOTIRY_MIN}` and `{RUN_PRIOTIRY_MAX}`."
239
+ " `dstack` tries to provision runs with higher priority first."
240
+ f" Defaults to `{RUN_PRIORITY_DEFAULT}`"
241
+ ),
242
+ ),
243
+ ] = None
229
244
  volumes: Annotated[
230
245
  List[Union[MountPoint, str]], Field(description="The volumes mount points")
231
246
  ] = []
247
+ # deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
248
+ setup: CommandsList = []
232
249
 
233
250
  @validator("python", pre=True, always=True)
234
251
  def convert_python(cls, v, values) -> Optional[PythonVersion]:
@@ -8,6 +8,7 @@ from typing_extensions import Annotated
8
8
  from dstack._internal.core.models.backends.base import BackendType
9
9
  from dstack._internal.core.models.common import ApplyAction, CoreModel, NetworkMode, RegistryAuth
10
10
  from dstack._internal.core.models.configurations import (
11
+ DEFAULT_REPO_DIR,
11
12
  AnyRunConfiguration,
12
13
  RunConfiguration,
13
14
  )
@@ -338,7 +339,7 @@ class RunSpec(CoreModel):
338
339
  Field(
339
340
  description=(
340
341
  "The path to the working directory inside the container."
341
- " It's specified relative to the repository directory (`/workflow`) and should be inside it."
342
+ f" It's specified relative to the repository directory (`{DEFAULT_REPO_DIR}`) and should be inside it."
342
343
  ' Defaults to `"."`.'
343
344
  )
344
345
  ),
@@ -1 +1,11 @@
1
- log_format dstack_stat '$time_iso8601 $host $status $request_time';
1
+ log_format dstack_stat '$time_iso8601 $host $status $request_time $dstack_replica_hit';
2
+
3
+
4
+ # A hack to avoid this Nginx reload error when no services are registered:
5
+ # nginx: [emerg] unknown "dstack_replica_hit" variable
6
+ server {
7
+ listen unix:/tmp/dstack-dummy-nginx.sock;
8
+ server_name placeholder.local;
9
+ deny all;
10
+ set $dstack_replica_hit 0;
11
+ }
@@ -14,6 +14,7 @@ upstream {{ domain }}.upstream {
14
14
  server {
15
15
  server_name {{ domain }};
16
16
  limit_req_status 429;
17
+ set $dstack_replica_hit 0;
17
18
  access_log {{ access_log_path }} dstack_stat;
18
19
  client_max_body_size {{ client_max_body_size }};
19
20
 
@@ -23,11 +24,7 @@ server {
23
24
  auth_request /_dstack_auth;
24
25
  {% endif %}
25
26
 
26
- {% if replicas %}
27
27
  try_files /nonexistent @$http_upgrade;
28
- {% else %}
29
- return 503;
30
- {% endif %}
31
28
 
32
29
  {% if location.limit_req %}
33
30
  limit_req zone={{ location.limit_req.zone }}{% if location.limit_req.burst %} burst={{ location.limit_req.burst }} nodelay{% endif %};
@@ -35,8 +32,9 @@ server {
35
32
  }
36
33
  {% endfor %}
37
34
 
38
- {% if replicas %}
39
35
  location @websocket {
36
+ set $dstack_replica_hit 1;
37
+ {% if replicas %}
40
38
  proxy_pass http://{{ domain }}.upstream;
41
39
  proxy_set_header X-Real-IP $remote_addr;
42
40
  proxy_set_header Host $host;
@@ -44,19 +42,27 @@ server {
44
42
  proxy_set_header Upgrade $http_upgrade;
45
43
  proxy_set_header Connection "Upgrade";
46
44
  proxy_read_timeout 300s;
45
+ {% else %}
46
+ return 503;
47
+ {% endif %}
47
48
  }
48
49
  location @ {
50
+ set $dstack_replica_hit 1;
51
+ {% if replicas %}
49
52
  proxy_pass http://{{ domain }}.upstream;
50
53
  proxy_set_header X-Real-IP $remote_addr;
51
54
  proxy_set_header Host $host;
52
55
  proxy_read_timeout 300s;
56
+ {% else %}
57
+ return 503;
58
+ {% endif %}
53
59
  }
54
- {% endif %}
55
60
 
56
61
  {% if auth %}
57
62
  location = /_dstack_auth {
58
63
  internal;
59
64
  if ($remote_addr = 127.0.0.1) {
65
+ # for requests from the gateway app, e.g. from the OpenAI-compatible API
60
66
  return 200;
61
67
  }
62
68
  proxy_pass http://localhost:{{ proxy_port }}/api/auth/{{ project_name }};
@@ -11,10 +11,10 @@ from pydantic import BaseModel
11
11
 
12
12
  from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
13
13
  from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, ServiceStats, Stat
14
+ from dstack._internal.proxy.lib.errors import UnexpectedProxyError
14
15
  from dstack._internal.utils.common import run_async
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
- IGNORE_STATUSES = {403, 404}
18
18
  WINDOWS = (30, 60, 300)
19
19
  TTL = WINDOWS[-1]
20
20
  EMPTY_STATS = {window: Stat(requests=0, request_time=0.0) for window in WINDOWS}
@@ -35,6 +35,7 @@ class LogEntry(BaseModel):
35
35
  host: str
36
36
  status: int
37
37
  request_time: float
38
+ is_replica_hit: bool
38
39
 
39
40
 
40
41
  class StatsCollector:
@@ -87,7 +88,8 @@ class StatsCollector:
87
88
  now = datetime.datetime.now(tz=datetime.timezone.utc)
88
89
 
89
90
  for entry in self._read_access_log(now - datetime.timedelta(seconds=TTL)):
90
- if entry.status in IGNORE_STATUSES:
91
+ # only include requests that hit or should hit a service replica
92
+ if not entry.is_replica_hit:
91
93
  continue
92
94
 
93
95
  frame_timestamp = int(entry.timestamp.timestamp())
@@ -119,7 +121,10 @@ class StatsCollector:
119
121
  line = self._file.readline()
120
122
  if not line:
121
123
  break
122
- timestamp_str, host, status, request_time = line.split()
124
+ cells = line.split()
125
+ if len(cells) == 4: # compatibility with pre-0.19.11 logs
126
+ cells.append("0" if cells[2] in ["403", "404"] else "1")
127
+ timestamp_str, host, status, request_time, dstack_replica_hit = cells
123
128
  timestamp = datetime.datetime.fromisoformat(timestamp_str)
124
129
  if timestamp < after:
125
130
  continue
@@ -128,6 +133,7 @@ class StatsCollector:
128
133
  host=host,
129
134
  status=int(status),
130
135
  request_time=float(request_time),
136
+ is_replica_hit=_parse_nginx_bool(dstack_replica_hit),
131
137
  )
132
138
  if os.fstat(self._file.fileno()).st_ino != st_ino:
133
139
  # file was rotated
@@ -154,3 +160,11 @@ async def get_service_stats(
154
160
  )
155
161
  for service in services
156
162
  ]
163
+
164
+
165
+ def _parse_nginx_bool(v: str) -> bool:
166
+ if v == "0":
167
+ return False
168
+ if v == "1":
169
+ return True
170
+ raise UnexpectedProxyError(f"Cannot parse boolean value: expected '0' or '1', got {v!r}")
@@ -2,7 +2,7 @@ import asyncio
2
2
  import json
3
3
  from typing import Dict, List, Optional
4
4
 
5
- from sqlalchemy import delete, select
5
+ from sqlalchemy import Delete, delete, select
6
6
  from sqlalchemy.orm import joinedload
7
7
 
8
8
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
@@ -49,27 +49,29 @@ async def delete_metrics():
49
49
  finished_timestamp_micro_cutoff = (
50
50
  now_timestamp_micro - settings.SERVER_METRICS_FINISHED_TTL_SECONDS * 1_000_000
51
51
  )
52
+ await asyncio.gather(
53
+ _execute_delete_statement(
54
+ delete(JobMetricsPoint).where(
55
+ JobMetricsPoint.job_id.in_(
56
+ select(JobModel.id).where(JobModel.status.in_([JobStatus.RUNNING]))
57
+ ),
58
+ JobMetricsPoint.timestamp_micro < running_timestamp_micro_cutoff,
59
+ )
60
+ ),
61
+ _execute_delete_statement(
62
+ delete(JobMetricsPoint).where(
63
+ JobMetricsPoint.job_id.in_(
64
+ select(JobModel.id).where(JobModel.status.in_(JobStatus.finished_statuses()))
65
+ ),
66
+ JobMetricsPoint.timestamp_micro < finished_timestamp_micro_cutoff,
67
+ )
68
+ ),
69
+ )
70
+
71
+
72
+ async def _execute_delete_statement(stmt: Delete) -> None:
52
73
  async with get_session_ctx() as session:
53
- await asyncio.gather(
54
- session.execute(
55
- delete(JobMetricsPoint).where(
56
- JobMetricsPoint.job_id.in_(
57
- select(JobModel.id).where(JobModel.status.in_([JobStatus.RUNNING]))
58
- ),
59
- JobMetricsPoint.timestamp_micro < running_timestamp_micro_cutoff,
60
- )
61
- ),
62
- session.execute(
63
- delete(JobMetricsPoint).where(
64
- JobMetricsPoint.job_id.in_(
65
- select(JobModel.id).where(
66
- JobModel.status.in_(JobStatus.finished_statuses())
67
- )
68
- ),
69
- JobMetricsPoint.timestamp_micro < finished_timestamp_micro_cutoff,
70
- )
71
- ),
72
- )
74
+ await session.execute(stmt)
73
75
  await session.commit()
74
76
 
75
77
 
@@ -93,11 +93,20 @@ async def _process_next_submitted_job():
93
93
  async with lock:
94
94
  res = await session.execute(
95
95
  select(JobModel)
96
+ .join(JobModel.run)
96
97
  .where(
97
98
  JobModel.status == JobStatus.SUBMITTED,
98
99
  JobModel.id.not_in(lockset),
99
100
  )
100
- .order_by(JobModel.last_processed_at.asc())
101
+ # Jobs are process in FIFO sorted by priority globally,
102
+ # thus runs from different project can "overtake" each other by using higher priorities.
103
+ # That's not a big problem as long as projects do not compete for the same compute resources.
104
+ # Jobs with lower priorities from other projects will be processed without major lag
105
+ # as long as new higher priority runs are not constantly submitted.
106
+ # TODO: Consider processing jobs from different projects fairly/round-robin
107
+ # Fully fair processing can be tricky to implement via the current DB queue as
108
+ # there can be many projects and we are limited by the max DB connections.
109
+ .order_by(RunModel.priority.desc(), JobModel.last_processed_at.asc())
101
110
  .limit(1)
102
111
  .with_for_update(skip_locked=True)
103
112
  )
@@ -360,16 +369,16 @@ async def _assign_job_to_pool_instance(
360
369
  (instance, common_utils.get_or_error(get_instance_offer(instance)))
361
370
  for instance in nonshared_instances
362
371
  ]
363
- if not multinode:
364
- shared_instances_with_offers = get_shared_pool_instances_with_offers(
365
- pool_instances=pool_instances,
366
- profile=profile,
367
- requirements=job.job_spec.requirements,
368
- idle_only=True,
369
- fleet_model=fleet_model,
370
- volumes=volumes,
371
- )
372
- instances_with_offers.extend(shared_instances_with_offers)
372
+ shared_instances_with_offers = get_shared_pool_instances_with_offers(
373
+ pool_instances=pool_instances,
374
+ profile=profile,
375
+ requirements=job.job_spec.requirements,
376
+ idle_only=True,
377
+ fleet_model=fleet_model,
378
+ multinode=multinode,
379
+ volumes=volumes,
380
+ )
381
+ instances_with_offers.extend(shared_instances_with_offers)
373
382
 
374
383
  if len(instances_with_offers) == 0:
375
384
  return None
@@ -572,7 +581,7 @@ def _create_instance_model_for_job(
572
581
 
573
582
 
574
583
  def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
575
- if offer.total_blocks == 1:
584
+ if offer.blocks == offer.total_blocks:
576
585
  if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
577
586
  network_mode = NetworkMode.BRIDGE
578
587
  else:
@@ -650,7 +659,7 @@ async def _attach_volumes(
650
659
  backend=backend,
651
660
  volume_model=volume_model,
652
661
  instance=instance,
653
- instance_id=job_provisioning_data.instance_id,
662
+ jpd=job_provisioning_data,
654
663
  )
655
664
  job_runtime_data.volume_names.append(volume.name)
656
665
  break # attach next mount point
@@ -676,7 +685,7 @@ async def _attach_volume(
676
685
  backend: Backend,
677
686
  volume_model: VolumeModel,
678
687
  instance: InstanceModel,
679
- instance_id: str,
688
+ jpd: JobProvisioningData,
680
689
  ):
681
690
  compute = backend.compute()
682
691
  assert isinstance(compute, ComputeWithVolumeSupport)
@@ -688,7 +697,7 @@ async def _attach_volume(
688
697
  attachment_data = await common_utils.run_async(
689
698
  compute.attach_volume,
690
699
  volume=volume,
691
- instance_id=instance_id,
700
+ provisioning_data=jpd,
692
701
  )
693
702
  volume_attachment_model = VolumeAttachmentModel(
694
703
  volume=volume_model,
@@ -0,0 +1,34 @@
1
+ """Add RunModel.priority
2
+
3
+ Revision ID: bca2fdf130bf
4
+ Revises: 20166748b60c
5
+ Create Date: 2025-05-14 15:24:21.269775
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "bca2fdf130bf"
14
+ down_revision = "20166748b60c"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ with op.batch_alter_table("runs", schema=None) as batch_op:
22
+ batch_op.add_column(sa.Column("priority", sa.Integer(), nullable=True))
23
+ batch_op.execute("UPDATE runs SET priority = 0")
24
+ with op.batch_alter_table("runs", schema=None) as batch_op:
25
+ batch_op.alter_column("priority", nullable=False)
26
+ # ### end Alembic commands ###
27
+
28
+
29
+ def downgrade() -> None:
30
+ # ### commands auto generated by Alembic - please adjust! ###
31
+ with op.batch_alter_table("runs", schema=None) as batch_op:
32
+ batch_op.drop_column("priority")
33
+
34
+ # ### end Alembic commands ###
@@ -348,6 +348,7 @@ class RunModel(BaseModel):
348
348
  resubmission_attempt: Mapped[int] = mapped_column(Integer, default=0)
349
349
  run_spec: Mapped[str] = mapped_column(Text)
350
350
  service_spec: Mapped[Optional[str]] = mapped_column(Text)
351
+ priority: Mapped[int] = mapped_column(Integer, default=0)
351
352
 
352
353
  jobs: Mapped[List["JobModel"]] = relationship(
353
354
  back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]"
@@ -14,10 +14,12 @@ from dstack._internal.server.schemas.repos import (
14
14
  )
15
15
  from dstack._internal.server.security.permissions import ProjectMember
16
16
  from dstack._internal.server.services import repos
17
+ from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT
17
18
  from dstack._internal.server.utils.routers import (
18
19
  get_base_api_additional_responses,
19
- request_size_exceeded,
20
+ get_request_size,
20
21
  )
22
+ from dstack._internal.utils.common import sizeof_fmt
21
23
 
22
24
  router = APIRouter(
23
25
  prefix="/api/project/{project_name}/repos",
@@ -94,10 +96,17 @@ async def upload_code(
94
96
  session: AsyncSession = Depends(get_session),
95
97
  user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
96
98
  ):
97
- if request_size_exceeded(request, limit=2 * 2**20):
99
+ request_size = get_request_size(request)
100
+ if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT:
101
+ diff_size_fmt = sizeof_fmt(request_size)
102
+ limit_fmt = sizeof_fmt(SERVER_CODE_UPLOAD_LIMIT)
103
+ if diff_size_fmt == limit_fmt:
104
+ diff_size_fmt = f"{request_size}B"
105
+ limit_fmt = f"{SERVER_CODE_UPLOAD_LIMIT}B"
98
106
  raise ServerClientError(
99
- "Repo diff size exceeds the limit of 2MB. "
100
- "Use .gitignore to exclude large files from the repo."
107
+ f"Repo diff size is {diff_size_fmt}, which exceeds the limit of {limit_fmt}."
108
+ " Use .gitignore to exclude large files from the repo."
109
+ " This limit can be modified by setting the DSTACK_SERVER_CODE_UPLOAD_LIMIT environment variable."
101
110
  )
102
111
  _, project = user_project
103
112
  await repos.upload_code(
@@ -237,7 +237,7 @@ async def get_plan(
237
237
  ) -> FleetPlan:
238
238
  # Spec must be copied by parsing to calculate merged_profile
239
239
  effective_spec = FleetSpec.parse_obj(spec.dict())
240
- effective_spec = apply_plugin_policies(
240
+ effective_spec = await apply_plugin_policies(
241
241
  user=user.name,
242
242
  project=project.name,
243
243
  spec=effective_spec,
@@ -342,7 +342,7 @@ async def create_fleet(
342
342
  spec: FleetSpec,
343
343
  ) -> Fleet:
344
344
  # Spec must be copied by parsing to calculate merged_profile
345
- spec = apply_plugin_policies(
345
+ spec = await apply_plugin_policies(
346
346
  user=user.name,
347
347
  project=project.name,
348
348
  spec=spec,
@@ -140,7 +140,7 @@ async def create_gateway(
140
140
  project: ProjectModel,
141
141
  configuration: GatewayConfiguration,
142
142
  ) -> Gateway:
143
- spec = apply_plugin_policies(
143
+ spec = await apply_plugin_policies(
144
144
  user=user.name,
145
145
  project=project.name,
146
146
  # Create pseudo spec until the gateway API is updated to accept spec
@@ -235,6 +235,7 @@ def get_shared_pool_instances_with_offers(
235
235
  *,
236
236
  idle_only: bool = False,
237
237
  fleet_model: Optional[FleetModel] = None,
238
+ multinode: bool = False,
238
239
  volumes: Optional[List[List[Volume]]] = None,
239
240
  ) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
240
241
  instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = []
@@ -243,19 +244,22 @@ def get_shared_pool_instances_with_offers(
243
244
  pool_instances=pool_instances,
244
245
  profile=profile,
245
246
  fleet_model=fleet_model,
246
- multinode=False,
247
+ multinode=multinode,
247
248
  volumes=volumes,
248
249
  shared=True,
249
250
  )
250
251
  for instance in filtered_instances:
251
252
  if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
252
253
  continue
254
+ if multinode and instance.busy_blocks > 0:
255
+ continue
253
256
  offer = get_instance_offer(instance)
254
257
  if offer is None:
255
258
  continue
256
259
  total_blocks = common_utils.get_or_error(instance.total_blocks)
257
260
  idle_blocks = total_blocks - instance.busy_blocks
258
- for blocks in range(1, total_blocks + 1):
261
+ min_blocks = total_blocks if multinode else 1
262
+ for blocks in range(min_blocks, total_blocks + 1):
259
263
  shared_offer = generate_shared_offer(offer, blocks, total_blocks)
260
264
  catalog_item = offer_to_catalog_item(shared_offer)
261
265
  if gpuhunt.matches(catalog_item, query_filter):
@@ -470,20 +470,20 @@ async def _detach_volume_from_job_instance(
470
470
  await run_async(
471
471
  compute.detach_volume,
472
472
  volume=volume,
473
- instance_id=jpd.instance_id,
473
+ provisioning_data=jpd,
474
474
  force=False,
475
475
  )
476
476
  # For some backends, the volume may be detached immediately
477
477
  detached = await run_async(
478
478
  compute.is_volume_detached,
479
479
  volume=volume,
480
- instance_id=jpd.instance_id,
480
+ provisioning_data=jpd,
481
481
  )
482
482
  else:
483
483
  detached = await run_async(
484
484
  compute.is_volume_detached,
485
485
  volume=volume,
486
- instance_id=jpd.instance_id,
486
+ provisioning_data=jpd,
487
487
  )
488
488
  if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration):
489
489
  logger.info(
@@ -494,7 +494,7 @@ async def _detach_volume_from_job_instance(
494
494
  await run_async(
495
495
  compute.detach_volume,
496
496
  volume=volume,
497
- instance_id=jpd.instance_id,
497
+ provisioning_data=jpd,
498
498
  force=True,
499
499
  )
500
500
  # Let the next iteration check if force detach worked
@@ -6,10 +6,11 @@ from typing import Dict, List, Optional, Union
6
6
 
7
7
  from cachetools import TTLCache, cached
8
8
 
9
- import dstack.version as version
9
+ from dstack._internal import settings
10
10
  from dstack._internal.core.errors import DockerRegistryError, ServerClientError
11
11
  from dstack._internal.core.models.common import RegistryAuth
12
12
  from dstack._internal.core.models.configurations import (
13
+ DEFAULT_REPO_DIR,
13
14
  PortMapping,
14
15
  PythonVersion,
15
16
  RunConfigurationType,
@@ -53,14 +54,14 @@ def get_default_image(python_version: str, nvcc: bool = False) -> str:
53
54
  suffix = ""
54
55
  if nvcc:
55
56
  suffix = "-devel"
56
- return f"dstackai/base:py{python_version}-{version.base_image}-cuda-12.1{suffix}"
57
+ return f"{settings.DSTACK_BASE_IMAGE}:py{python_version}-{settings.DSTACK_BASE_IMAGE_VERSION}-cuda-12.1{suffix}"
57
58
 
58
59
 
59
60
  class JobConfigurator(ABC):
60
61
  TYPE: RunConfigurationType
61
62
 
62
63
  _image_config: Optional[ImageConfig] = None
63
- # JobSSHKey should be shared for all jobs in a replica for inter-node communitation.
64
+ # JobSSHKey should be shared for all jobs in a replica for inter-node communication.
64
65
  _job_ssh_key: Optional[JobSSHKey] = None
65
66
 
66
67
  def __init__(self, run_spec: RunSpec):
@@ -149,7 +150,8 @@ class JobConfigurator(ABC):
149
150
  commands = self.run_spec.configuration.commands
150
151
  elif shell_commands := self._shell_commands():
151
152
  entrypoint = [self._shell(), "-i", "-c"]
152
- commands = [_join_shell_commands(shell_commands)]
153
+ dstack_image_commands = self._dstack_image_commands()
154
+ commands = [_join_shell_commands(dstack_image_commands + shell_commands)]
153
155
  else: # custom docker image without commands
154
156
  image_config = await self._get_image_config()
155
157
  entrypoint = image_config.entrypoint or []
@@ -164,6 +166,18 @@ class JobConfigurator(ABC):
164
166
 
165
167
  return result
166
168
 
169
+ def _dstack_image_commands(self) -> List[str]:
170
+ if (
171
+ self.run_spec.configuration.image is not None
172
+ or self.run_spec.configuration.entrypoint is not None
173
+ ):
174
+ return []
175
+ return [
176
+ f"uv venv --prompt workflow --seed {DEFAULT_REPO_DIR}/.venv > /dev/null 2>&1",
177
+ f"echo 'source {DEFAULT_REPO_DIR}/.venv/bin/activate' >> ~/.bashrc",
178
+ f"source {DEFAULT_REPO_DIR}/.venv/bin/activate",
179
+ ]
180
+
167
181
  def _app_specs(self) -> List[AppSpec]:
168
182
  specs = []
169
183
  for i, pm in enumerate(filter_reserved_ports(self._ports())):
@@ -1,5 +1,7 @@
1
1
  from typing import List
2
2
 
3
+ from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR
4
+
3
5
 
4
6
  class CursorDesktop:
5
7
  def __init__(
@@ -37,6 +39,6 @@ class CursorDesktop:
37
39
  return [
38
40
  "echo To open in Cursor, use link below:",
39
41
  "echo ''",
40
- f"echo ' cursor://vscode-remote/ssh-remote+{self.run_name}/workflow'", # TODO use $REPO_DIR
42
+ f"echo ' cursor://vscode-remote/ssh-remote+{self.run_name}{DEFAULT_REPO_DIR}'", # TODO use $REPO_DIR
41
43
  "echo ''",
42
44
  ]
@@ -1,5 +1,7 @@
1
1
  from typing import List
2
2
 
3
+ from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR
4
+
3
5
 
4
6
  class VSCodeDesktop:
5
7
  def __init__(
@@ -37,6 +39,6 @@ class VSCodeDesktop:
37
39
  return [
38
40
  "echo To open in VS Code Desktop, use link below:",
39
41
  "echo ''",
40
- f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}/workflow'", # TODO use $REPO_DIR
42
+ f"echo ' vscode://vscode-remote/ssh-remote+{self.run_name}{DEFAULT_REPO_DIR}'", # TODO use $REPO_DIR
41
43
  "echo ''",
42
44
  ]