dstack 0.19.17__py3-none-any.whl → 0.19.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (86) hide show
  1. dstack/_internal/cli/services/configurators/fleet.py +111 -1
  2. dstack/_internal/cli/services/profile.py +1 -1
  3. dstack/_internal/core/backends/aws/compute.py +237 -18
  4. dstack/_internal/core/backends/base/compute.py +20 -2
  5. dstack/_internal/core/backends/cudo/compute.py +23 -9
  6. dstack/_internal/core/backends/gcp/compute.py +13 -7
  7. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  8. dstack/_internal/core/compatibility/fleets.py +12 -11
  9. dstack/_internal/core/compatibility/gateways.py +9 -8
  10. dstack/_internal/core/compatibility/logs.py +4 -3
  11. dstack/_internal/core/compatibility/runs.py +29 -21
  12. dstack/_internal/core/compatibility/volumes.py +11 -8
  13. dstack/_internal/core/errors.py +4 -0
  14. dstack/_internal/core/models/common.py +45 -2
  15. dstack/_internal/core/models/configurations.py +9 -1
  16. dstack/_internal/core/models/fleets.py +2 -1
  17. dstack/_internal/core/models/profiles.py +8 -5
  18. dstack/_internal/core/models/resources.py +15 -8
  19. dstack/_internal/core/models/runs.py +41 -138
  20. dstack/_internal/core/models/volumes.py +14 -0
  21. dstack/_internal/core/services/diff.py +56 -3
  22. dstack/_internal/core/services/ssh/attach.py +2 -0
  23. dstack/_internal/server/app.py +37 -9
  24. dstack/_internal/server/background/__init__.py +66 -40
  25. dstack/_internal/server/background/tasks/process_fleets.py +19 -3
  26. dstack/_internal/server/background/tasks/process_gateways.py +47 -29
  27. dstack/_internal/server/background/tasks/process_idle_volumes.py +139 -0
  28. dstack/_internal/server/background/tasks/process_instances.py +13 -2
  29. dstack/_internal/server/background/tasks/process_placement_groups.py +4 -2
  30. dstack/_internal/server/background/tasks/process_running_jobs.py +14 -3
  31. dstack/_internal/server/background/tasks/process_runs.py +8 -4
  32. dstack/_internal/server/background/tasks/process_submitted_jobs.py +38 -7
  33. dstack/_internal/server/background/tasks/process_terminating_jobs.py +5 -3
  34. dstack/_internal/server/background/tasks/process_volumes.py +2 -2
  35. dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +6 -6
  36. dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py +40 -0
  37. dstack/_internal/server/models.py +1 -0
  38. dstack/_internal/server/routers/backends.py +23 -16
  39. dstack/_internal/server/routers/files.py +7 -6
  40. dstack/_internal/server/routers/fleets.py +47 -36
  41. dstack/_internal/server/routers/gateways.py +27 -18
  42. dstack/_internal/server/routers/instances.py +18 -13
  43. dstack/_internal/server/routers/logs.py +7 -3
  44. dstack/_internal/server/routers/metrics.py +14 -8
  45. dstack/_internal/server/routers/projects.py +33 -22
  46. dstack/_internal/server/routers/repos.py +7 -6
  47. dstack/_internal/server/routers/runs.py +49 -28
  48. dstack/_internal/server/routers/secrets.py +20 -15
  49. dstack/_internal/server/routers/server.py +7 -4
  50. dstack/_internal/server/routers/users.py +22 -19
  51. dstack/_internal/server/routers/volumes.py +34 -25
  52. dstack/_internal/server/schemas/logs.py +2 -2
  53. dstack/_internal/server/schemas/runs.py +17 -5
  54. dstack/_internal/server/services/fleets.py +358 -75
  55. dstack/_internal/server/services/gateways/__init__.py +17 -6
  56. dstack/_internal/server/services/gateways/client.py +5 -3
  57. dstack/_internal/server/services/instances.py +8 -0
  58. dstack/_internal/server/services/jobs/__init__.py +45 -0
  59. dstack/_internal/server/services/jobs/configurators/base.py +12 -1
  60. dstack/_internal/server/services/locking.py +104 -13
  61. dstack/_internal/server/services/logging.py +4 -2
  62. dstack/_internal/server/services/logs/__init__.py +15 -2
  63. dstack/_internal/server/services/logs/aws.py +2 -4
  64. dstack/_internal/server/services/logs/filelog.py +33 -27
  65. dstack/_internal/server/services/logs/gcp.py +3 -5
  66. dstack/_internal/server/services/proxy/repo.py +4 -1
  67. dstack/_internal/server/services/runs.py +139 -72
  68. dstack/_internal/server/services/services/__init__.py +2 -1
  69. dstack/_internal/server/services/users.py +3 -1
  70. dstack/_internal/server/services/volumes.py +15 -2
  71. dstack/_internal/server/settings.py +25 -6
  72. dstack/_internal/server/statics/index.html +1 -1
  73. dstack/_internal/server/statics/{main-d151637af20f70b2e796.js → main-64f8273740c4b52c18f5.js} +71 -67
  74. dstack/_internal/server/statics/{main-d151637af20f70b2e796.js.map → main-64f8273740c4b52c18f5.js.map} +1 -1
  75. dstack/_internal/server/statics/{main-d48635d8fe670d53961c.css → main-d58fc0460cb0eae7cb5c.css} +1 -1
  76. dstack/_internal/server/testing/common.py +48 -8
  77. dstack/_internal/server/utils/routers.py +31 -8
  78. dstack/_internal/utils/json_utils.py +54 -0
  79. dstack/api/_public/runs.py +13 -2
  80. dstack/api/server/_runs.py +12 -2
  81. dstack/version.py +1 -1
  82. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/METADATA +17 -14
  83. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/RECORD +86 -83
  84. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/WHEEL +0 -0
  85. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/entry_points.txt +0 -0
  86. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/licenses/LICENSE.md +0 -0
@@ -9,12 +9,24 @@ from dstack._internal.core.models.runs import ApplyRunPlanInput, RunSpec
9
9
 
10
10
 
11
11
  class ListRunsRequest(CoreModel):
12
- project_name: Optional[str]
13
- repo_id: Optional[str]
14
- username: Optional[str]
12
+ project_name: Optional[str] = None
13
+ repo_id: Optional[str] = None
14
+ username: Optional[str] = None
15
15
  only_active: bool = False
16
- prev_submitted_at: Optional[datetime]
17
- prev_run_id: Optional[UUID]
16
+ include_jobs: bool = Field(
17
+ True,
18
+ description=("Whether to include `jobs` in the response"),
19
+ )
20
+ job_submissions_limit: Optional[int] = Field(
21
+ None,
22
+ ge=0,
23
+ description=(
24
+ "Limit number of job submissions returned per job to avoid large responses."
25
+ "Drops older job submissions. No effect with `include_jobs: false`"
26
+ ),
27
+ )
28
+ prev_submitted_at: Optional[datetime] = None
29
+ prev_run_id: Optional[UUID] = None
18
30
  limit: int = Field(100, ge=0, le=100)
19
31
  ascending: bool = False
20
32
 
@@ -1,6 +1,8 @@
1
1
  import uuid
2
+ from collections.abc import Callable
2
3
  from datetime import datetime, timezone
3
- from typing import List, Literal, Optional, Tuple, Union, cast
4
+ from functools import wraps
5
+ from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast
4
6
 
5
7
  from sqlalchemy import and_, func, or_, select
6
8
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,10 +15,12 @@ from dstack._internal.core.errors import (
13
15
  ResourceExistsError,
14
16
  ServerClientError,
15
17
  )
18
+ from dstack._internal.core.models.common import ApplyAction, CoreModel
16
19
  from dstack._internal.core.models.envs import Env
17
20
  from dstack._internal.core.models.fleets import (
18
21
  ApplyFleetPlanInput,
19
22
  Fleet,
23
+ FleetConfiguration,
20
24
  FleetPlan,
21
25
  FleetSpec,
22
26
  FleetStatus,
@@ -40,6 +44,7 @@ from dstack._internal.core.models.resources import ResourcesSpec
40
44
  from dstack._internal.core.models.runs import Requirements, get_policy_map
41
45
  from dstack._internal.core.models.users import GlobalRole
42
46
  from dstack._internal.core.services import validate_dstack_resource_name
47
+ from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models
43
48
  from dstack._internal.server.db import get_db
44
49
  from dstack._internal.server.models import (
45
50
  FleetModel,
@@ -49,7 +54,10 @@ from dstack._internal.server.models import (
49
54
  )
50
55
  from dstack._internal.server.services import instances as instances_services
51
56
  from dstack._internal.server.services import offers as offers_services
52
- from dstack._internal.server.services.instances import list_active_remote_instances
57
+ from dstack._internal.server.services.instances import (
58
+ get_instance_remote_connection_info,
59
+ list_active_remote_instances,
60
+ )
53
61
  from dstack._internal.server.services.locking import (
54
62
  get_locker,
55
63
  string_to_lock_id,
@@ -178,8 +186,9 @@ async def list_project_fleet_models(
178
186
  async def get_fleet(
179
187
  session: AsyncSession,
180
188
  project: ProjectModel,
181
- name: Optional[str],
182
- fleet_id: Optional[uuid.UUID],
189
+ name: Optional[str] = None,
190
+ fleet_id: Optional[uuid.UUID] = None,
191
+ include_sensitive: bool = False,
183
192
  ) -> Optional[Fleet]:
184
193
  if fleet_id is not None:
185
194
  fleet_model = await get_project_fleet_model_by_id(
@@ -193,7 +202,7 @@ async def get_fleet(
193
202
  raise ServerClientError("name or id must be specified")
194
203
  if fleet_model is None:
195
204
  return None
196
- return fleet_model_to_fleet(fleet_model)
205
+ return fleet_model_to_fleet(fleet_model, include_sensitive=include_sensitive)
197
206
 
198
207
 
199
208
  async def get_project_fleet_model_by_id(
@@ -236,23 +245,32 @@ async def get_plan(
236
245
  spec: FleetSpec,
237
246
  ) -> FleetPlan:
238
247
  # Spec must be copied by parsing to calculate merged_profile
239
- effective_spec = FleetSpec.parse_obj(spec.dict())
248
+ effective_spec = copy_model(spec)
240
249
  effective_spec = await apply_plugin_policies(
241
250
  user=user.name,
242
251
  project=project.name,
243
252
  spec=effective_spec,
244
253
  )
245
- effective_spec = FleetSpec.parse_obj(effective_spec.dict())
246
- _validate_fleet_spec_and_set_defaults(spec)
254
+ # Spec must be copied by parsing to calculate merged_profile
255
+ effective_spec = copy_model(effective_spec)
256
+ _validate_fleet_spec_and_set_defaults(effective_spec)
257
+
258
+ action = ApplyAction.CREATE
247
259
  current_fleet: Optional[Fleet] = None
248
260
  current_fleet_id: Optional[uuid.UUID] = None
261
+
249
262
  if effective_spec.configuration.name is not None:
250
- current_fleet_model = await get_project_fleet_model_by_name(
251
- session=session, project=project, name=effective_spec.configuration.name
263
+ current_fleet = await get_fleet(
264
+ session=session,
265
+ project=project,
266
+ name=effective_spec.configuration.name,
267
+ include_sensitive=True,
252
268
  )
253
- if current_fleet_model is not None:
254
- current_fleet = fleet_model_to_fleet(current_fleet_model)
255
- current_fleet_id = current_fleet_model.id
269
+ if current_fleet is not None:
270
+ _set_fleet_spec_defaults(current_fleet.spec)
271
+ if _can_update_fleet_spec(current_fleet.spec, effective_spec):
272
+ action = ApplyAction.UPDATE
273
+ current_fleet_id = current_fleet.id
256
274
  await _check_ssh_hosts_not_yet_added(session, effective_spec, current_fleet_id)
257
275
 
258
276
  offers = []
@@ -265,7 +283,10 @@ async def get_plan(
265
283
  blocks=effective_spec.configuration.blocks,
266
284
  )
267
285
  offers = [offer for _, offer in offers_with_backends]
286
+
268
287
  _remove_fleet_spec_sensitive_info(effective_spec)
288
+ if current_fleet is not None:
289
+ _remove_fleet_spec_sensitive_info(current_fleet.spec)
269
290
  plan = FleetPlan(
270
291
  project_name=project.name,
271
292
  user=user.name,
@@ -275,6 +296,7 @@ async def get_plan(
275
296
  offers=offers[:50],
276
297
  total_offers=len(offers),
277
298
  max_offer_price=max((offer.price for offer in offers), default=None),
299
+ action=action,
278
300
  )
279
301
  return plan
280
302
 
@@ -327,11 +349,77 @@ async def apply_plan(
327
349
  plan: ApplyFleetPlanInput,
328
350
  force: bool,
329
351
  ) -> Fleet:
330
- return await create_fleet(
352
+ spec = await apply_plugin_policies(
353
+ user=user.name,
354
+ project=project.name,
355
+ spec=plan.spec,
356
+ )
357
+ # Spec must be copied by parsing to calculate merged_profile
358
+ spec = copy_model(spec)
359
+ _validate_fleet_spec_and_set_defaults(spec)
360
+
361
+ if spec.configuration.ssh_config is not None:
362
+ _check_can_manage_ssh_fleets(user=user, project=project)
363
+
364
+ configuration = spec.configuration
365
+ if configuration.name is None:
366
+ return await _create_fleet(
367
+ session=session,
368
+ project=project,
369
+ user=user,
370
+ spec=spec,
371
+ )
372
+
373
+ fleet_model = await get_project_fleet_model_by_name(
374
+ session=session,
375
+ project=project,
376
+ name=configuration.name,
377
+ )
378
+ if fleet_model is None:
379
+ return await _create_fleet(
380
+ session=session,
381
+ project=project,
382
+ user=user,
383
+ spec=spec,
384
+ )
385
+
386
+ instances_ids = sorted(i.id for i in fleet_model.instances if not i.deleted)
387
+ await session.commit()
388
+ async with (
389
+ get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, [fleet_model.id]),
390
+ get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids),
391
+ ):
392
+ # Refetch after lock
393
+ # TODO: Lock instances with FOR UPDATE?
394
+ res = await session.execute(
395
+ select(FleetModel)
396
+ .where(
397
+ FleetModel.project_id == project.id,
398
+ FleetModel.id == fleet_model.id,
399
+ FleetModel.deleted == False,
400
+ )
401
+ .options(selectinload(FleetModel.instances))
402
+ .options(selectinload(FleetModel.runs))
403
+ .execution_options(populate_existing=True)
404
+ .order_by(FleetModel.id) # take locks in order
405
+ .with_for_update(key_share=True)
406
+ )
407
+ fleet_model = res.scalars().unique().one_or_none()
408
+ if fleet_model is not None:
409
+ return await _update_fleet(
410
+ session=session,
411
+ project=project,
412
+ spec=spec,
413
+ current_resource=plan.current_resource,
414
+ force=force,
415
+ fleet_model=fleet_model,
416
+ )
417
+
418
+ return await _create_fleet(
331
419
  session=session,
332
420
  project=project,
333
421
  user=user,
334
- spec=plan.spec,
422
+ spec=spec,
335
423
  )
336
424
 
337
425
 
@@ -341,73 +429,19 @@ async def create_fleet(
341
429
  user: UserModel,
342
430
  spec: FleetSpec,
343
431
  ) -> Fleet:
344
- # Spec must be copied by parsing to calculate merged_profile
345
432
  spec = await apply_plugin_policies(
346
433
  user=user.name,
347
434
  project=project.name,
348
435
  spec=spec,
349
436
  )
350
- spec = FleetSpec.parse_obj(spec.dict())
437
+ # Spec must be copied by parsing to calculate merged_profile
438
+ spec = copy_model(spec)
351
439
  _validate_fleet_spec_and_set_defaults(spec)
352
440
 
353
441
  if spec.configuration.ssh_config is not None:
354
442
  _check_can_manage_ssh_fleets(user=user, project=project)
355
443
 
356
- lock_namespace = f"fleet_names_{project.name}"
357
- if get_db().dialect_name == "sqlite":
358
- # Start new transaction to see committed changes after lock
359
- await session.commit()
360
- elif get_db().dialect_name == "postgresql":
361
- await session.execute(
362
- select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
363
- )
364
-
365
- lock, _ = get_locker().get_lockset(lock_namespace)
366
- async with lock:
367
- if spec.configuration.name is not None:
368
- fleet_model = await get_project_fleet_model_by_name(
369
- session=session,
370
- project=project,
371
- name=spec.configuration.name,
372
- )
373
- if fleet_model is not None:
374
- raise ResourceExistsError()
375
- else:
376
- spec.configuration.name = await generate_fleet_name(session=session, project=project)
377
-
378
- fleet_model = FleetModel(
379
- id=uuid.uuid4(),
380
- name=spec.configuration.name,
381
- project=project,
382
- status=FleetStatus.ACTIVE,
383
- spec=spec.json(),
384
- instances=[],
385
- )
386
- session.add(fleet_model)
387
- if spec.configuration.ssh_config is not None:
388
- for i, host in enumerate(spec.configuration.ssh_config.hosts):
389
- instances_model = await create_fleet_ssh_instance_model(
390
- project=project,
391
- spec=spec,
392
- ssh_params=spec.configuration.ssh_config,
393
- env=spec.configuration.env,
394
- instance_num=i,
395
- host=host,
396
- )
397
- fleet_model.instances.append(instances_model)
398
- else:
399
- for i in range(_get_fleet_nodes_to_provision(spec)):
400
- instance_model = await create_fleet_instance_model(
401
- session=session,
402
- project=project,
403
- user=user,
404
- spec=spec,
405
- reservation=spec.configuration.reservation,
406
- instance_num=i,
407
- )
408
- fleet_model.instances.append(instance_model)
409
- await session.commit()
410
- return fleet_model_to_fleet(fleet_model)
444
+ return await _create_fleet(session=session, project=project, user=user, spec=spec)
411
445
 
412
446
 
413
447
  async def create_fleet_instance_model(
@@ -516,11 +550,12 @@ async def delete_fleets(
516
550
  await session.commit()
517
551
  logger.info("Deleting fleets: %s", [v.name for v in fleet_models])
518
552
  async with (
519
- get_locker().lock_ctx(FleetModel.__tablename__, fleets_ids),
520
- get_locker().lock_ctx(InstanceModel.__tablename__, instances_ids),
553
+ get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids),
554
+ get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids),
521
555
  ):
522
556
  # Refetch after lock
523
- # TODO lock instances with FOR UPDATE?
557
+ # TODO: Lock instances with FOR UPDATE?
558
+ # TODO: Do not lock fleet when deleting only instances
524
559
  res = await session.execute(
525
560
  select(FleetModel)
526
561
  .where(
@@ -599,6 +634,235 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool:
599
634
  return len(active_instances) == 0
600
635
 
601
636
 
637
+ async def _create_fleet(
638
+ session: AsyncSession,
639
+ project: ProjectModel,
640
+ user: UserModel,
641
+ spec: FleetSpec,
642
+ ) -> Fleet:
643
+ lock_namespace = f"fleet_names_{project.name}"
644
+ if get_db().dialect_name == "sqlite":
645
+ # Start new transaction to see committed changes after lock
646
+ await session.commit()
647
+ elif get_db().dialect_name == "postgresql":
648
+ await session.execute(
649
+ select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
650
+ )
651
+
652
+ lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
653
+ async with lock:
654
+ if spec.configuration.name is not None:
655
+ fleet_model = await get_project_fleet_model_by_name(
656
+ session=session,
657
+ project=project,
658
+ name=spec.configuration.name,
659
+ )
660
+ if fleet_model is not None:
661
+ raise ResourceExistsError()
662
+ else:
663
+ spec.configuration.name = await generate_fleet_name(session=session, project=project)
664
+
665
+ fleet_model = FleetModel(
666
+ id=uuid.uuid4(),
667
+ name=spec.configuration.name,
668
+ project=project,
669
+ status=FleetStatus.ACTIVE,
670
+ spec=spec.json(),
671
+ instances=[],
672
+ )
673
+ session.add(fleet_model)
674
+ if spec.configuration.ssh_config is not None:
675
+ for i, host in enumerate(spec.configuration.ssh_config.hosts):
676
+ instances_model = await create_fleet_ssh_instance_model(
677
+ project=project,
678
+ spec=spec,
679
+ ssh_params=spec.configuration.ssh_config,
680
+ env=spec.configuration.env,
681
+ instance_num=i,
682
+ host=host,
683
+ )
684
+ fleet_model.instances.append(instances_model)
685
+ else:
686
+ for i in range(_get_fleet_nodes_to_provision(spec)):
687
+ instance_model = await create_fleet_instance_model(
688
+ session=session,
689
+ project=project,
690
+ user=user,
691
+ spec=spec,
692
+ reservation=spec.configuration.reservation,
693
+ instance_num=i,
694
+ )
695
+ fleet_model.instances.append(instance_model)
696
+ await session.commit()
697
+ return fleet_model_to_fleet(fleet_model)
698
+
699
+
700
+ async def _update_fleet(
701
+ session: AsyncSession,
702
+ project: ProjectModel,
703
+ spec: FleetSpec,
704
+ current_resource: Optional[Fleet],
705
+ force: bool,
706
+ fleet_model: FleetModel,
707
+ ) -> Fleet:
708
+ fleet = fleet_model_to_fleet(fleet_model)
709
+ _set_fleet_spec_defaults(fleet.spec)
710
+ fleet_sensitive = fleet_model_to_fleet(fleet_model, include_sensitive=True)
711
+ _set_fleet_spec_defaults(fleet_sensitive.spec)
712
+
713
+ if not force:
714
+ if current_resource is not None:
715
+ _set_fleet_spec_defaults(current_resource.spec)
716
+ if (
717
+ current_resource is None
718
+ or current_resource.id != fleet.id
719
+ or current_resource.spec != fleet.spec
720
+ ):
721
+ raise ServerClientError(
722
+ "Failed to apply plan. Resource has been changed. Try again or use force apply."
723
+ )
724
+
725
+ _check_can_update_fleet_spec(fleet_sensitive.spec, spec)
726
+
727
+ spec_json = spec.json()
728
+ fleet_model.spec = spec_json
729
+
730
+ if (
731
+ fleet_sensitive.spec.configuration.ssh_config is not None
732
+ and spec.configuration.ssh_config is not None
733
+ ):
734
+ added_hosts, removed_hosts, changed_hosts = _calculate_ssh_hosts_changes(
735
+ current=fleet_sensitive.spec.configuration.ssh_config.hosts,
736
+ new=spec.configuration.ssh_config.hosts,
737
+ )
738
+ # `_check_can_update_fleet_spec` ensures hosts are not changed
739
+ assert not changed_hosts, changed_hosts
740
+ active_instance_nums: set[int] = set()
741
+ removed_instance_nums: list[int] = []
742
+ if removed_hosts or added_hosts:
743
+ for instance_model in fleet_model.instances:
744
+ if instance_model.deleted:
745
+ continue
746
+ active_instance_nums.add(instance_model.instance_num)
747
+ rci = get_instance_remote_connection_info(instance_model)
748
+ if rci is None:
749
+ logger.error(
750
+ "Cloud instance %s in SSH fleet %s",
751
+ instance_model.id,
752
+ fleet_model.id,
753
+ )
754
+ continue
755
+ if rci.host in removed_hosts:
756
+ removed_instance_nums.append(instance_model.instance_num)
757
+ if added_hosts:
758
+ await _check_ssh_hosts_not_yet_added(session, spec, fleet.id)
759
+ for host in added_hosts.values():
760
+ instance_num = _get_next_instance_num(active_instance_nums)
761
+ instance_model = await create_fleet_ssh_instance_model(
762
+ project=project,
763
+ spec=spec,
764
+ ssh_params=spec.configuration.ssh_config,
765
+ env=spec.configuration.env,
766
+ instance_num=instance_num,
767
+ host=host,
768
+ )
769
+ fleet_model.instances.append(instance_model)
770
+ active_instance_nums.add(instance_num)
771
+ if removed_instance_nums:
772
+ _terminate_fleet_instances(fleet_model, removed_instance_nums)
773
+
774
+ await session.commit()
775
+ return fleet_model_to_fleet(fleet_model)
776
+
777
+
778
+ def _can_update_fleet_spec(current_fleet_spec: FleetSpec, new_fleet_spec: FleetSpec) -> bool:
779
+ try:
780
+ _check_can_update_fleet_spec(current_fleet_spec, new_fleet_spec)
781
+ except ServerClientError as e:
782
+ logger.debug("Run cannot be updated: %s", repr(e))
783
+ return False
784
+ return True
785
+
786
+
787
+ M = TypeVar("M", bound=CoreModel)
788
+
789
+
790
+ def _check_can_update(*updatable_fields: str):
791
+ def decorator(fn: Callable[[M, M, ModelDiff], None]) -> Callable[[M, M], None]:
792
+ @wraps(fn)
793
+ def inner(current: M, new: M):
794
+ diff = _check_can_update_inner(current, new, updatable_fields)
795
+ fn(current, new, diff)
796
+
797
+ return inner
798
+
799
+ return decorator
800
+
801
+
802
+ def _check_can_update_inner(current: M, new: M, updatable_fields: tuple[str, ...]) -> ModelDiff:
803
+ diff = diff_models(current, new)
804
+ changed_fields = diff.keys()
805
+ if not (changed_fields <= set(updatable_fields)):
806
+ raise ServerClientError(
807
+ f"Failed to update fields {list(changed_fields)}."
808
+ f" Can only update {list(updatable_fields)}."
809
+ )
810
+ return diff
811
+
812
+
813
+ @_check_can_update("configuration", "configuration_path")
814
+ def _check_can_update_fleet_spec(current: FleetSpec, new: FleetSpec, diff: ModelDiff):
815
+ if "configuration" in diff:
816
+ _check_can_update_fleet_configuration(current.configuration, new.configuration)
817
+
818
+
819
+ @_check_can_update("ssh_config")
820
+ def _check_can_update_fleet_configuration(
821
+ current: FleetConfiguration, new: FleetConfiguration, diff: ModelDiff
822
+ ):
823
+ if "ssh_config" in diff:
824
+ current_ssh_config = current.ssh_config
825
+ new_ssh_config = new.ssh_config
826
+ if current_ssh_config is None:
827
+ if new_ssh_config is not None:
828
+ raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update")
829
+ elif new_ssh_config is None:
830
+ raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update")
831
+ else:
832
+ _check_can_update_ssh_config(current_ssh_config, new_ssh_config)
833
+
834
+
835
+ @_check_can_update("hosts")
836
+ def _check_can_update_ssh_config(current: SSHParams, new: SSHParams, diff: ModelDiff):
837
+ if "hosts" in diff:
838
+ _, _, changed_hosts = _calculate_ssh_hosts_changes(current.hosts, new.hosts)
839
+ if changed_hosts:
840
+ raise ServerClientError(
841
+ f"Hosts configuration changed, cannot update: {list(changed_hosts)}"
842
+ )
843
+
844
+
845
+ def _calculate_ssh_hosts_changes(
846
+ current: list[Union[SSHHostParams, str]], new: list[Union[SSHHostParams, str]]
847
+ ) -> tuple[dict[str, Union[SSHHostParams, str]], set[str], set[str]]:
848
+ current_hosts = {h if isinstance(h, str) else h.hostname: h for h in current}
849
+ new_hosts = {h if isinstance(h, str) else h.hostname: h for h in new}
850
+ added_hosts = {h: new_hosts[h] for h in new_hosts.keys() - current_hosts}
851
+ removed_hosts = current_hosts.keys() - new_hosts
852
+ changed_hosts: set[str] = set()
853
+ for host in current_hosts.keys() & new_hosts:
854
+ current_host = current_hosts[host]
855
+ new_host = new_hosts[host]
856
+ if isinstance(current_host, str) or isinstance(new_host, str):
857
+ if current_host != new_host:
858
+ changed_hosts.add(host)
859
+ elif diff_models(
860
+ current_host, new_host, reset={"identity_file": True, "proxy_jump": {"identity_file"}}
861
+ ):
862
+ changed_hosts.add(host)
863
+ return added_hosts, removed_hosts, changed_hosts
864
+
865
+
602
866
  def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel):
603
867
  if user.global_role == GlobalRole.ADMIN:
604
868
  return
@@ -653,6 +917,8 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec):
653
917
  validate_dstack_resource_name(spec.configuration.name)
654
918
  if spec.configuration.ssh_config is None and spec.configuration.nodes is None:
655
919
  raise ServerClientError("No ssh_config or nodes specified")
920
+ if spec.configuration.ssh_config is not None and spec.configuration.nodes is not None:
921
+ raise ServerClientError("ssh_config and nodes are mutually exclusive")
656
922
  if spec.configuration.ssh_config is not None:
657
923
  _validate_all_ssh_params_specified(spec.configuration.ssh_config)
658
924
  if spec.configuration.ssh_config.ssh_key is not None:
@@ -661,6 +927,10 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec):
661
927
  if isinstance(host, SSHHostParams) and host.ssh_key is not None:
662
928
  _validate_ssh_key(host.ssh_key)
663
929
  _validate_internal_ips(spec.configuration.ssh_config)
930
+ _set_fleet_spec_defaults(spec)
931
+
932
+
933
+ def _set_fleet_spec_defaults(spec: FleetSpec):
664
934
  if spec.configuration.resources is not None:
665
935
  set_resources_defaults(spec.configuration.resources)
666
936
 
@@ -733,3 +1003,16 @@ def _get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
733
1003
  reservation=fleet_spec.configuration.reservation,
734
1004
  )
735
1005
  return requirements
1006
+
1007
+
1008
+ def _get_next_instance_num(instance_nums: set[int]) -> int:
1009
+ if not instance_nums:
1010
+ return 0
1011
+ min_instance_num = min(instance_nums)
1012
+ if min_instance_num > 0:
1013
+ return 0
1014
+ instance_num = min_instance_num + 1
1015
+ while True:
1016
+ if instance_num not in instance_nums:
1017
+ return instance_num
1018
+ instance_num += 1
@@ -2,6 +2,7 @@ import asyncio
2
2
  import datetime
3
3
  import uuid
4
4
  from datetime import timedelta, timezone
5
+ from functools import partial
5
6
  from typing import List, Optional, Sequence
6
7
 
7
8
  import httpx
@@ -162,7 +163,7 @@ async def create_gateway(
162
163
  select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
163
164
  )
164
165
 
165
- lock, _ = get_locker().get_lockset(lock_namespace)
166
+ lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
166
167
  async with lock:
167
168
  if configuration.name is None:
168
169
  configuration.name = await generate_gateway_name(session=session, project=project)
@@ -186,6 +187,7 @@ async def create_gateway(
186
187
  return gateway_model_to_gateway(gateway)
187
188
 
188
189
 
190
+ # NOTE: dstack Sky imports and uses this function
189
191
  async def connect_to_gateway_with_retry(
190
192
  gateway_compute: GatewayComputeModel,
191
193
  ) -> Optional[GatewayConnection]:
@@ -229,7 +231,9 @@ async def delete_gateways(
229
231
  gateways_ids = sorted([g.id for g in gateway_models])
230
232
  await session.commit()
231
233
  logger.info("Deleting gateways: %s", [g.name for g in gateway_models])
232
- async with get_locker().lock_ctx(GatewayModel.__tablename__, gateways_ids):
234
+ async with get_locker(get_db().dialect_name).lock_ctx(
235
+ GatewayModel.__tablename__, gateways_ids
236
+ ):
233
237
  # Refetch after lock
234
238
  res = await session.execute(
235
239
  select(GatewayModel)
@@ -378,6 +382,8 @@ async def get_or_add_gateway_connection(
378
382
  async def init_gateways(session: AsyncSession):
379
383
  res = await session.execute(
380
384
  select(GatewayComputeModel).where(
385
+ # FIXME: should not include computes related to gateways in the `provisioning` status.
386
+ # Causes warnings and delays when restarting the server during gateway provisioning.
381
387
  GatewayComputeModel.active == True,
382
388
  GatewayComputeModel.deleted == False,
383
389
  )
@@ -419,7 +425,8 @@ async def init_gateways(session: AsyncSession):
419
425
 
420
426
  for gateway_compute, error in await gather_map_async(
421
427
  await gateway_connections_pool.all(),
422
- configure_gateway,
428
+ # Need several attempts to handle short gateway downtime after update
429
+ partial(configure_gateway, attempts=7),
423
430
  return_exceptions=True,
424
431
  ):
425
432
  if isinstance(error, Exception):
@@ -459,7 +466,11 @@ def _recently_updated(gateway_compute_model: GatewayComputeModel) -> bool:
459
466
  ) > get_current_datetime() - timedelta(seconds=60)
460
467
 
461
468
 
462
- async def configure_gateway(connection: GatewayConnection) -> None:
469
+ # NOTE: dstack Sky imports and uses this function
470
+ async def configure_gateway(
471
+ connection: GatewayConnection,
472
+ attempts: int = GATEWAY_CONFIGURE_ATTEMPTS,
473
+ ) -> None:
463
474
  """
464
475
  Try submitting gateway config several times in case gateway's HTTP server is not
465
476
  running yet
@@ -467,7 +478,7 @@ async def configure_gateway(connection: GatewayConnection) -> None:
467
478
 
468
479
  logger.debug("Configuring gateway %s", connection.ip_address)
469
480
 
470
- for attempt in range(GATEWAY_CONFIGURE_ATTEMPTS - 1):
481
+ for attempt in range(attempts - 1):
471
482
  try:
472
483
  async with connection.client() as client:
473
484
  await client.submit_gateway_config()
@@ -476,7 +487,7 @@ async def configure_gateway(connection: GatewayConnection) -> None:
476
487
  logger.debug(
477
488
  "Failed attempt %s/%s at configuring gateway %s: %r",
478
489
  attempt + 1,
479
- GATEWAY_CONFIGURE_ATTEMPTS,
490
+ attempts,
480
491
  connection.ip_address,
481
492
  e,
482
493
  )