dstack 0.18.40rc1__py3-none-any.whl → 0.18.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +21 -9
  11. dstack/_internal/core/backends/aws/compute.py +92 -52
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +30 -23
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +20 -0
  28. dstack/_internal/core/models/volumes.py +3 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +72 -33
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
  42. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  43. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  44. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  45. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  46. dstack/_internal/server/models.py +10 -4
  47. dstack/_internal/server/routers/runs.py +1 -0
  48. dstack/_internal/server/schemas/runner.py +1 -0
  49. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  50. dstack/_internal/server/services/config.py +9 -0
  51. dstack/_internal/server/services/fleets.py +27 -2
  52. dstack/_internal/server/services/gateways/client.py +9 -1
  53. dstack/_internal/server/services/jobs/__init__.py +215 -43
  54. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  55. dstack/_internal/server/services/offers.py +91 -5
  56. dstack/_internal/server/services/pools.py +95 -11
  57. dstack/_internal/server/services/proxy/repo.py +17 -3
  58. dstack/_internal/server/services/runner/client.py +1 -1
  59. dstack/_internal/server/services/runner/ssh.py +33 -5
  60. dstack/_internal/server/services/runs.py +48 -179
  61. dstack/_internal/server/services/services/__init__.py +9 -1
  62. dstack/_internal/server/statics/index.html +1 -1
  63. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  64. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  65. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  66. dstack/_internal/server/testing/common.py +117 -52
  67. dstack/_internal/utils/common.py +22 -8
  68. dstack/_internal/utils/env.py +14 -0
  69. dstack/_internal/utils/ssh.py +1 -1
  70. dstack/api/server/_fleets.py +25 -1
  71. dstack/api/server/_runs.py +23 -2
  72. dstack/api/server/_volumes.py +12 -1
  73. dstack/version.py +1 -1
  74. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
  75. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
  76. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  77. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  78. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  79. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  80. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  81. tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
  82. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
  85. tests/_internal/server/routers/test_fleets.py +15 -2
  86. tests/_internal/server/routers/test_pools.py +6 -0
  87. tests/_internal/server/routers/test_runs.py +27 -0
  88. tests/_internal/server/services/jobs/__init__.py +0 -0
  89. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  90. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  91. tests/_internal/server/services/test_pools.py +4 -0
  92. tests/_internal/server/services/test_runs.py +5 -41
  93. tests/_internal/utils/test_common.py +21 -0
  94. tests/_internal/utils/test_env.py +38 -0
  95. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
  96. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
  97. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
  98. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,32 @@
1
+ """Add JobModel.inactivity_secs
2
+
3
+ Revision ID: 1e76fb0dde87
4
+ Revises: 63c3f19cb184
5
+ Create Date: 2025-02-11 23:37:58.823710
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "1e76fb0dde87"
14
+ down_revision = "63c3f19cb184"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
22
+ batch_op.add_column(sa.Column("inactivity_secs", sa.Integer(), nullable=True))
23
+
24
+ # ### end Alembic commands ###
25
+
26
+
27
+ def downgrade() -> None:
28
+ # ### commands auto generated by Alembic - please adjust! ###
29
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
30
+ batch_op.drop_column("inactivity_secs")
31
+
32
+ # ### end Alembic commands ###
@@ -0,0 +1,43 @@
1
+ """Add InstanceModel blocks fields
2
+
3
+ Revision ID: 51d45659d574
4
+ Revises: da574e93fee0
5
+ Create Date: 2025-02-04 11:10:41.626273
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "51d45659d574"
14
+ down_revision = "da574e93fee0"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ with op.batch_alter_table("instances", schema=None) as batch_op:
21
+ batch_op.add_column(sa.Column("total_blocks", sa.Integer(), nullable=True))
22
+ batch_op.add_column(sa.Column("busy_blocks", sa.Integer(), nullable=True))
23
+
24
+ op.execute("""
25
+ UPDATE instances
26
+ SET total_blocks = 1
27
+ """)
28
+ op.execute("""
29
+ UPDATE instances
30
+ SET busy_blocks = CASE
31
+ WHEN job_id IS NOT NULL THEN 1
32
+ ELSE 0
33
+ END
34
+ """)
35
+
36
+ with op.batch_alter_table("instances", schema=None) as batch_op:
37
+ batch_op.alter_column("busy_blocks", existing_type=sa.INTEGER(), nullable=False)
38
+
39
+
40
+ def downgrade() -> None:
41
+ with op.batch_alter_table("instances", schema=None) as batch_op:
42
+ batch_op.drop_column("busy_blocks")
43
+ batch_op.drop_column("total_blocks")
@@ -0,0 +1,83 @@
1
+ """Add JobTerminationReason.INACTIVITY_DURATION_EXCEEDED
2
+
3
+ Revision ID: 63c3f19cb184
4
+ Revises: 1338b788b612
5
+ Create Date: 2025-02-11 22:30:47.289393
6
+
7
+ """
8
+
9
+ from alembic import op
10
+ from alembic_postgresql_enum import TableReference
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "63c3f19cb184"
14
+ down_revision = "1338b788b612"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ op.sync_enum_values(
22
+ enum_schema="public",
23
+ enum_name="jobterminationreason",
24
+ new_values=[
25
+ "FAILED_TO_START_DUE_TO_NO_CAPACITY",
26
+ "INTERRUPTED_BY_NO_CAPACITY",
27
+ "WAITING_INSTANCE_LIMIT_EXCEEDED",
28
+ "WAITING_RUNNER_LIMIT_EXCEEDED",
29
+ "TERMINATED_BY_USER",
30
+ "VOLUME_ERROR",
31
+ "GATEWAY_ERROR",
32
+ "SCALED_DOWN",
33
+ "DONE_BY_RUNNER",
34
+ "ABORTED_BY_USER",
35
+ "TERMINATED_BY_SERVER",
36
+ "INACTIVITY_DURATION_EXCEEDED",
37
+ "CONTAINER_EXITED_WITH_ERROR",
38
+ "PORTS_BINDING_FAILED",
39
+ "CREATING_CONTAINER_ERROR",
40
+ "EXECUTOR_ERROR",
41
+ "MAX_DURATION_EXCEEDED",
42
+ ],
43
+ affected_columns=[
44
+ TableReference(
45
+ table_schema="public", table_name="jobs", column_name="termination_reason"
46
+ )
47
+ ],
48
+ enum_values_to_rename=[],
49
+ )
50
+ # ### end Alembic commands ###
51
+
52
+
53
+ def downgrade() -> None:
54
+ # ### commands auto generated by Alembic - please adjust! ###
55
+ op.sync_enum_values(
56
+ enum_schema="public",
57
+ enum_name="jobterminationreason",
58
+ new_values=[
59
+ "FAILED_TO_START_DUE_TO_NO_CAPACITY",
60
+ "INTERRUPTED_BY_NO_CAPACITY",
61
+ "WAITING_INSTANCE_LIMIT_EXCEEDED",
62
+ "WAITING_RUNNER_LIMIT_EXCEEDED",
63
+ "TERMINATED_BY_USER",
64
+ "VOLUME_ERROR",
65
+ "GATEWAY_ERROR",
66
+ "SCALED_DOWN",
67
+ "DONE_BY_RUNNER",
68
+ "ABORTED_BY_USER",
69
+ "TERMINATED_BY_SERVER",
70
+ "CONTAINER_EXITED_WITH_ERROR",
71
+ "PORTS_BINDING_FAILED",
72
+ "CREATING_CONTAINER_ERROR",
73
+ "EXECUTOR_ERROR",
74
+ "MAX_DURATION_EXCEEDED",
75
+ ],
76
+ affected_columns=[
77
+ TableReference(
78
+ table_schema="public", table_name="jobs", column_name="termination_reason"
79
+ )
80
+ ],
81
+ enum_values_to_rename=[],
82
+ )
83
+ # ### end Alembic commands ###
@@ -351,13 +351,17 @@ class JobModel(BaseModel):
351
351
  job_spec_data: Mapped[str] = mapped_column(Text)
352
352
  job_provisioning_data: Mapped[Optional[str]] = mapped_column(Text)
353
353
  runner_timestamp: Mapped[Optional[int]] = mapped_column(BigInteger)
354
+ inactivity_secs: Mapped[Optional[int]] = mapped_column(Integer) # 0 - active, None - N/A
354
355
  # `removed` is used to ensure that the instance is killed after the job is finished
355
356
  remove_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
356
357
  volumes_detached_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
357
358
  # `instance_assigned` means instance assignment was done.
358
359
  # if `instance_assigned` is True and `instance` is None, no instance was assiged.
359
360
  instance_assigned: Mapped[bool] = mapped_column(Boolean, default=False)
360
- instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="job")
361
+ instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(
362
+ ForeignKey("instances.id", ondelete="CASCADE")
363
+ )
364
+ instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs")
361
365
  used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False))
362
366
  replica_num: Mapped[int] = mapped_column(Integer)
363
367
  job_runtime_data: Mapped[Optional[str]] = mapped_column(Text)
@@ -543,9 +547,11 @@ class InstanceModel(BaseModel):
543
547
 
544
548
  remote_connection_info: Mapped[Optional[str]] = mapped_column(Text)
545
549
 
546
- # current job
547
- job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id"))
548
- job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="joined")
550
+ # NULL means `auto` (only during provisioning, when ready it's not NULL)
551
+ total_blocks: Mapped[Optional[int]] = mapped_column(Integer)
552
+ busy_blocks: Mapped[int] = mapped_column(Integer, default=0)
553
+
554
+ jobs: Mapped[list["JobModel"]] = relationship(back_populates="instance", lazy="joined")
549
555
  last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
550
556
 
551
557
  # volumes attached to the instance
@@ -47,6 +47,7 @@ async def list_runs(
47
47
  """
48
48
  Returns all runs visible to user sorted by descending `submitted_at`.
49
49
  `project_name`, `repo_id`, `username`, and `only_active` can be specified as filters.
50
+ Setting `only_active` to `true` excludes finished runs and deleted runs.
50
51
  Specifying `repo_id` without `project_name` returns no runs.
51
52
 
52
53
  The results are paginated. To get the next page, pass `submitted_at` and `id` of
@@ -34,6 +34,7 @@ class PullResponse(CoreModel):
34
34
  job_logs: List[LogEvent]
35
35
  runner_logs: List[LogEvent]
36
36
  last_updated: int
37
+ no_connections_secs: Optional[int] = None # Optional for compatibility with old runners
37
38
 
38
39
 
39
40
  class SubmitBody(CoreModel):
@@ -2,6 +2,7 @@ import json
2
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
3
  from typing import List, Optional, Tuple
4
4
 
5
+ import azure.core.exceptions
5
6
  from azure.core.credentials import TokenCredential
6
7
  from azure.mgmt import network as network_mgmt
7
8
  from azure.mgmt import resource as resource_mgmt
@@ -154,16 +155,17 @@ class AzureConfigurator(Configurator):
154
155
  if is_core_model_instance(config.creds, AzureClientCreds):
155
156
  self._set_client_creds_tenant_id(config.creds, config.tenant_id)
156
157
  credential, _ = auth.authenticate(config.creds)
157
- resource_group = self._create_resource_group(
158
- credential=credential,
159
- subscription_id=config.subscription_id,
160
- location=MAIN_LOCATION,
161
- project_name=project.name,
162
- )
158
+ if config.resource_group is None:
159
+ config.resource_group = self._create_resource_group(
160
+ credential=credential,
161
+ subscription_id=config.subscription_id,
162
+ location=MAIN_LOCATION,
163
+ project_name=project.name,
164
+ )
163
165
  self._create_network_resources(
164
166
  credential=credential,
165
167
  subscription_id=config.subscription_id,
166
- resource_group=resource_group,
168
+ resource_group=config.resource_group,
167
169
  locations=config.locations,
168
170
  create_default_network=config.vpc_ids is None,
169
171
  )
@@ -172,7 +174,6 @@ class AzureConfigurator(Configurator):
172
174
  type=self.TYPE.value,
173
175
  config=AzureStoredConfig(
174
176
  **AzureConfigInfo.__response__.parse_obj(config).dict(),
175
- resource_group=resource_group,
176
177
  ).json(),
177
178
  auth=DecryptedString(plaintext=AzureCreds.parse_obj(config.creds).__root__.json()),
178
179
  )
@@ -322,6 +323,7 @@ class AzureConfigurator(Configurator):
322
323
  self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential
323
324
  ):
324
325
  self._check_tags_config(config)
326
+ self._check_resource_group(config=config, credential=credential)
325
327
  self._check_vpc_config(config=config, credential=credential)
326
328
 
327
329
  def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial):
@@ -336,6 +338,18 @@ class AzureConfigurator(Configurator):
336
338
  except BackendError as e:
337
339
  raise ServerClientError(e.args[0])
338
340
 
341
+ def _check_resource_group(
342
+ self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential
343
+ ):
344
+ if config.resource_group is None:
345
+ return
346
+ resource_manager = ResourceManager(
347
+ credential=credential,
348
+ subscription_id=config.subscription_id,
349
+ )
350
+ if not resource_manager.resource_group_exists(config.resource_group):
351
+ raise ServerClientError(f"Resource group {config.resource_group} not found")
352
+
339
353
  def _check_vpc_config(
340
354
  self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential
341
355
  ):
@@ -406,6 +420,18 @@ class ResourceManager:
406
420
  )
407
421
  return resource_group.name
408
422
 
423
+ def resource_group_exists(
424
+ self,
425
+ name: str,
426
+ ) -> bool:
427
+ try:
428
+ self.resource_client.resource_groups.get(
429
+ resource_group_name=name,
430
+ )
431
+ except azure.core.exceptions.ResourceNotFoundError:
432
+ return False
433
+ return True
434
+
409
435
 
410
436
  class NetworkManager:
411
437
  def __init__(self, credential: TokenCredential, subscription_id: str):
@@ -124,6 +124,15 @@ class AzureConfig(CoreModel):
124
124
  type: Annotated[Literal["azure"], Field(description="The type of the backend")] = "azure"
125
125
  tenant_id: Annotated[str, Field(description="The tenant ID")]
126
126
  subscription_id: Annotated[str, Field(description="The subscription ID")]
127
+ resource_group: Annotated[
128
+ Optional[str],
129
+ Field(
130
+ description=(
131
+ "The resource group for resources created by `dstack`."
132
+ " If not specified, `dstack` will create a new resource group"
133
+ )
134
+ ),
135
+ ] = None
127
136
  regions: Annotated[
128
137
  Optional[List[str]],
129
138
  Field(description="The list of Azure regions (locations). Omit to use all regions"),
@@ -2,7 +2,7 @@ import random
2
2
  import string
3
3
  import uuid
4
4
  from datetime import datetime, timezone
5
- from typing import List, Optional, Tuple, Union, cast
5
+ from typing import List, Literal, Optional, Tuple, Union, cast
6
6
 
7
7
  from sqlalchemy import and_, func, or_, select
8
8
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -31,6 +31,7 @@ from dstack._internal.core.models.instances import (
31
31
  InstanceOfferWithAvailability,
32
32
  InstanceStatus,
33
33
  RemoteConnectionInfo,
34
+ SSHConnectionParams,
34
35
  SSHKey,
35
36
  )
36
37
  from dstack._internal.core.models.pools import Instance
@@ -256,6 +257,7 @@ async def get_plan(
256
257
  project=project,
257
258
  profile=spec.merged_profile,
258
259
  requirements=_get_fleet_requirements(spec),
260
+ blocks=spec.configuration.blocks,
259
261
  )
260
262
  offers = [offer for _, offer in offers_with_backends]
261
263
  _remove_fleet_spec_sensitive_info(spec)
@@ -277,6 +279,7 @@ async def get_create_instance_offers(
277
279
  requirements: Requirements,
278
280
  exclude_not_available=False,
279
281
  fleet_model: Optional[FleetModel] = None,
282
+ blocks: Union[int, Literal["auto"]] = 1,
280
283
  ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
281
284
  multinode = False
282
285
  master_job_provisioning_data = None
@@ -296,6 +299,7 @@ async def get_create_instance_offers(
296
299
  exclude_not_available=exclude_not_available,
297
300
  multinode=multinode,
298
301
  master_job_provisioning_data=master_job_provisioning_data,
302
+ blocks=blocks,
299
303
  )
300
304
  offers = [
301
305
  (backend, offer)
@@ -406,6 +410,7 @@ async def create_fleet_instance_model(
406
410
  instance_num=instance_num,
407
411
  placement_group_name=placement_group_name,
408
412
  reservation=reservation,
413
+ blocks=spec.configuration.blocks,
409
414
  )
410
415
  return instance_model
411
416
 
@@ -424,18 +429,33 @@ async def create_fleet_ssh_instance_model(
424
429
  ssh_user = ssh_params.user
425
430
  ssh_key = ssh_params.ssh_key
426
431
  port = ssh_params.port
432
+ proxy_jump = ssh_params.proxy_jump
427
433
  internal_ip = None
434
+ blocks = 1
428
435
  else:
429
436
  hostname = host.hostname
430
437
  ssh_user = host.user or ssh_params.user
431
438
  ssh_key = host.ssh_key or ssh_params.ssh_key
432
439
  port = host.port or ssh_params.port
440
+ proxy_jump = host.proxy_jump or ssh_params.proxy_jump
433
441
  internal_ip = host.internal_ip
442
+ blocks = host.blocks
434
443
 
435
444
  if ssh_user is None or ssh_key is None:
436
445
  # This should not be reachable but checked by fleet spec validation
437
446
  raise ServerClientError("ssh key or user not specified")
438
447
 
448
+ if proxy_jump is not None:
449
+ ssh_proxy = SSHConnectionParams(
450
+ hostname=proxy_jump.hostname,
451
+ port=proxy_jump.port or 22,
452
+ username=proxy_jump.user,
453
+ )
454
+ ssh_proxy_keys = [proxy_jump.ssh_key]
455
+ else:
456
+ ssh_proxy = None
457
+ ssh_proxy_keys = None
458
+
439
459
  instance_model = await pools_services.create_ssh_instance_model(
440
460
  project=project,
441
461
  pool=pool,
@@ -445,10 +465,13 @@ async def create_fleet_ssh_instance_model(
445
465
  host=hostname,
446
466
  ssh_user=ssh_user,
447
467
  ssh_keys=[ssh_key],
468
+ ssh_proxy=ssh_proxy,
469
+ ssh_proxy_keys=ssh_proxy_keys,
448
470
  env=env,
449
471
  internal_ip=internal_ip,
450
472
  instance_network=ssh_params.network,
451
473
  port=port or 22,
474
+ blocks=blocks,
452
475
  )
453
476
  return instance_model
454
477
 
@@ -544,7 +567,7 @@ async def generate_fleet_name(session: AsyncSession, project: ProjectModel) -> s
544
567
 
545
568
 
546
569
  def is_fleet_in_use(fleet_model: FleetModel, instance_nums: Optional[List[int]] = None) -> bool:
547
- instances_in_use = [i for i in fleet_model.instances if i.job_id is not None and not i.deleted]
570
+ instances_in_use = [i for i in fleet_model.instances if i.jobs and not i.deleted]
548
571
  selected_instance_in_use = instances_in_use
549
572
  if instance_nums is not None:
550
573
  selected_instance_in_use = [i for i in instances_in_use if i.instance_num in instance_nums]
@@ -606,6 +629,8 @@ async def create_instance(
606
629
  instance_configuration=None,
607
630
  termination_policy=termination_policy,
608
631
  termination_idle_time=termination_idle_time,
632
+ total_blocks=1,
633
+ busy_blocks=0,
609
634
  )
610
635
  logger.info(
611
636
  "Added a new instance %s",
@@ -74,10 +74,18 @@ class GatewayClient:
74
74
  resp.raise_for_status()
75
75
  self.is_server_ready = True
76
76
 
77
- async def register_replica(self, run: Run, job_submission: JobSubmission):
77
+ async def register_replica(
78
+ self,
79
+ run: Run,
80
+ job_submission: JobSubmission,
81
+ ssh_head_proxy: Optional[SSHConnectionParams],
82
+ ssh_head_proxy_private_key: Optional[str],
83
+ ):
78
84
  payload = {
79
85
  "job_id": job_submission.id.hex,
80
86
  "app_port": run.run_spec.configuration.port.container_port,
87
+ "ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None,
88
+ "ssh_head_proxy_private_key": ssh_head_proxy_private_key,
81
89
  }
82
90
  jpd = job_submission.job_provisioning_data
83
91
  if not jpd.dockerized: