skypilot-nightly 1.0.0.dev20251014__py3-none-any.whl → 1.0.0.dev20251016__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (51) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +29 -15
  3. sky/backends/cloud_vm_ray_backend.py +30 -13
  4. sky/dashboard/out/404.html +1 -1
  5. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  6. sky/dashboard/out/clusters/[cluster].html +1 -1
  7. sky/dashboard/out/clusters.html +1 -1
  8. sky/dashboard/out/config.html +1 -1
  9. sky/dashboard/out/index.html +1 -1
  10. sky/dashboard/out/infra/[context].html +1 -1
  11. sky/dashboard/out/infra.html +1 -1
  12. sky/dashboard/out/jobs/[job].html +1 -1
  13. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  14. sky/dashboard/out/jobs.html +1 -1
  15. sky/dashboard/out/users.html +1 -1
  16. sky/dashboard/out/volumes.html +1 -1
  17. sky/dashboard/out/workspace/new.html +1 -1
  18. sky/dashboard/out/workspaces/[name].html +1 -1
  19. sky/dashboard/out/workspaces.html +1 -1
  20. sky/exceptions.py +13 -1
  21. sky/jobs/constants.py +1 -1
  22. sky/jobs/scheduler.py +2 -4
  23. sky/jobs/server/core.py +2 -1
  24. sky/jobs/server/server.py +5 -3
  25. sky/jobs/state.py +12 -6
  26. sky/jobs/utils.py +8 -2
  27. sky/provision/common.py +2 -0
  28. sky/provision/instance_setup.py +10 -2
  29. sky/provision/kubernetes/instance.py +34 -10
  30. sky/provision/kubernetes/utils.py +9 -0
  31. sky/schemas/generated/jobsv1_pb2.py +52 -52
  32. sky/schemas/generated/jobsv1_pb2.pyi +4 -2
  33. sky/serve/server/server.py +1 -0
  34. sky/server/requests/executor.py +51 -15
  35. sky/server/requests/preconditions.py +2 -4
  36. sky/server/requests/requests.py +14 -23
  37. sky/server/requests/threads.py +106 -0
  38. sky/server/rest.py +36 -18
  39. sky/server/server.py +24 -0
  40. sky/skylet/constants.py +1 -1
  41. sky/skylet/services.py +3 -1
  42. sky/utils/asyncio_utils.py +18 -0
  43. sky/utils/context_utils.py +2 -0
  44. {skypilot_nightly-1.0.0.dev20251014.dist-info → skypilot_nightly-1.0.0.dev20251016.dist-info}/METADATA +37 -36
  45. {skypilot_nightly-1.0.0.dev20251014.dist-info → skypilot_nightly-1.0.0.dev20251016.dist-info}/RECORD +51 -49
  46. /sky/dashboard/out/_next/static/{9Fek73R28lDp1A5J4N7g7 → pbgtEUoCUdmJyLHjgln5A}/_buildManifest.js +0 -0
  47. /sky/dashboard/out/_next/static/{9Fek73R28lDp1A5J4N7g7 → pbgtEUoCUdmJyLHjgln5A}/_ssgManifest.js +0 -0
  48. {skypilot_nightly-1.0.0.dev20251014.dist-info → skypilot_nightly-1.0.0.dev20251016.dist-info}/WHEEL +0 -0
  49. {skypilot_nightly-1.0.0.dev20251014.dist-info → skypilot_nightly-1.0.0.dev20251016.dist-info}/entry_points.txt +0 -0
  50. {skypilot_nightly-1.0.0.dev20251014.dist-info → skypilot_nightly-1.0.0.dev20251016.dist-info}/licenses/LICENSE +0 -0
  51. {skypilot_nightly-1.0.0.dev20251014.dist-info → skypilot_nightly-1.0.0.dev20251016.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -2148,8 +2148,12 @@ class ManagedJobCodeGen:
2148
2148
  return cls._build(code)
2149
2149
 
2150
2150
  @classmethod
2151
- def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag',
2152
- workspace: str, entrypoint: str) -> str:
2151
+ def set_pending(cls,
2152
+ job_id: int,
2153
+ managed_job_dag: 'dag_lib.Dag',
2154
+ workspace: str,
2155
+ entrypoint: str,
2156
+ user_hash: Optional[str] = None) -> str:
2153
2157
  dag_name = managed_job_dag.name
2154
2158
  pool = managed_job_dag.pool
2155
2159
  # Add the managed job to queue table.
@@ -2166,6 +2170,8 @@ class ManagedJobCodeGen:
2166
2170
  pool_hash = serve_state.get_service_hash({pool!r})
2167
2171
  set_job_info_kwargs['pool'] = {pool!r}
2168
2172
  set_job_info_kwargs['pool_hash'] = pool_hash
2173
+ if managed_job_version >= 11:
2174
+ set_job_info_kwargs['user_hash'] = {user_hash!r}
2169
2175
  managed_job_state.set_job_info(
2170
2176
  {job_id}, {dag_name!r}, **set_job_info_kwargs)
2171
2177
  """)
sky/provision/common.py CHANGED
@@ -97,6 +97,8 @@ class InstanceInfo:
97
97
  external_ip: Optional[str]
98
98
  tags: Dict[str, str]
99
99
  ssh_port: int = 22
100
+ # The internal service address of the instance on Kubernetes.
101
+ internal_svc: Optional[str] = None
100
102
 
101
103
  def get_feasible_ip(self) -> str:
102
104
  """Get the most feasible IPs of the instance. This function returns
@@ -434,8 +434,16 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
434
434
  # use the external IP of the head node.
435
435
  use_external_ip = cluster_info.custom_ray_options.pop(
436
436
  'use_external_ip', False)
437
- head_ip = (head_instance.internal_ip
438
- if not use_external_ip else head_instance.external_ip)
437
+
438
+ if use_external_ip:
439
+ head_ip = head_instance.external_ip
440
+ else:
441
+ # For Kubernetes, use the internal service address of the head node.
442
+ # Keep this consistent with the logic in kubernetes-ray.yml.j2
443
+ if head_instance.internal_svc:
444
+ head_ip = head_instance.internal_svc
445
+ else:
446
+ head_ip = head_instance.internal_ip
439
447
 
440
448
  ray_cmd = ray_worker_start_command(custom_resource,
441
449
  cluster_info.custom_ray_options,
@@ -959,12 +959,19 @@ def _create_pods(region: str, cluster_name: str, cluster_name_on_cloud: str,
959
959
 
960
960
  def _create_resource_thread(i: int):
961
961
  pod_spec_copy = copy.deepcopy(pod_spec)
962
- if head_pod_name is None and i == 0:
963
- # First pod should be head if no head exists
964
- pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS)
965
- head_selector = _head_service_selector(cluster_name_on_cloud)
966
- pod_spec_copy['metadata']['labels'].update(head_selector)
967
- pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head'
962
+ # 0 is for head pod, while 1+ is for worker pods.
963
+ if i == 0:
964
+ if head_pod_name is None:
965
+ # First pod should be head if no head exists
966
+ pod_spec_copy['metadata']['labels'].update(
967
+ constants.HEAD_NODE_TAGS)
968
+ head_selector = _head_service_selector(cluster_name_on_cloud)
969
+ pod_spec_copy['metadata']['labels'].update(head_selector)
970
+ pod_spec_copy['metadata'][
971
+ 'name'] = f'{cluster_name_on_cloud}-head'
972
+ else:
973
+ # If head pod already exists, we skip creating it.
974
+ return
968
975
  else:
969
976
  # Worker pods
970
977
  pod_spec_copy['metadata']['labels'].update(
@@ -1105,9 +1112,16 @@ def _create_pods(region: str, cluster_name: str, cluster_name_on_cloud: str,
1105
1112
  'and then up the cluster again.')
1106
1113
  raise exceptions.InconsistentHighAvailabilityError(message)
1107
1114
 
1108
- # Create pods in parallel
1109
- created_resources = subprocess_utils.run_in_parallel(
1110
- _create_resource_thread, list(range(to_start_count)), _NUM_THREADS)
1115
+ created_resources = []
1116
+ if to_start_count > 0:
1117
+ # Create pods in parallel.
1118
+ # Use `config.count` instead of `to_start_count` to keep the index of
1119
+ # the Pods consistent especially for the case where some Pods are down
1120
+ # due to node failure or manual termination, etc. and then launch
1121
+ # again to create the Pods back.
1122
+ # The existing Pods will be skipped in _create_resource_thread.
1123
+ created_resources = subprocess_utils.run_in_parallel(
1124
+ _create_resource_thread, list(range(config.count)), _NUM_THREADS)
1111
1125
 
1112
1126
  if to_create_deployment:
1113
1127
  deployments = copy.deepcopy(created_resources)
@@ -1350,6 +1364,9 @@ def get_cluster_info(
1350
1364
  external_ip=None,
1351
1365
  ssh_port=port,
1352
1366
  tags=pod.metadata.labels,
1367
+ # TODO(hailong): `cluster.local` may need to be configurable
1368
+ # Service name is same as the pod name for now.
1369
+ internal_svc=f'{pod_name}.{namespace}.svc.cluster.local',
1353
1370
  )
1354
1371
  ]
1355
1372
  if _is_head(pod):
@@ -1388,6 +1405,13 @@ def get_cluster_info(
1388
1405
  logger.debug(
1389
1406
  f'Using ssh user {ssh_user} for cluster {cluster_name_on_cloud}')
1390
1407
 
1408
+ # cpu_request may be a string like `100m`, need to parse and convert
1409
+ num_cpus = kubernetes_utils.parse_cpu_or_gpu_resource_to_float(cpu_request)
1410
+ # 'num-cpus' for ray must be an integer, but we should not set it to 0 if
1411
+ # cpus is <1.
1412
+ # Keep consistent with the logic in clouds/kubernetes.py
1413
+ str_cpus = str(max(int(num_cpus), 1))
1414
+
1391
1415
  return common.ClusterInfo(
1392
1416
  instances=pods,
1393
1417
  head_instance_id=head_pod_name,
@@ -1397,7 +1421,7 @@ def get_cluster_info(
1397
1421
  # problems for other pods.
1398
1422
  custom_ray_options={
1399
1423
  'object-store-memory': 500000000,
1400
- 'num-cpus': cpu_request,
1424
+ 'num-cpus': str_cpus,
1401
1425
  },
1402
1426
  provider_name='kubernetes',
1403
1427
  provider_config=provider_config)
@@ -2241,6 +2241,15 @@ def get_kube_config_context_namespace(
2241
2241
  return DEFAULT_NAMESPACE
2242
2242
 
2243
2243
 
2244
+ def parse_cpu_or_gpu_resource_to_float(resource_str: str) -> float:
2245
+ if not resource_str:
2246
+ return 0.0
2247
+ if resource_str[-1] == 'm':
2248
+ return float(resource_str[:-1]) / 1000
2249
+ else:
2250
+ return float(resource_str)
2251
+
2252
+
2244
2253
  def parse_cpu_or_gpu_resource(resource_qty_str: str) -> Union[int, float]:
2245
2254
  resource_str = str(resource_qty_str)
2246
2255
  if resource_str[-1] == 'm':
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"sky/schemas/generated/jobsv1.proto\x12\x07jobs.v1\"\x85\x01\n\rAddJobRequest\x12\x15\n\x08job_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08username\x18\x02 \x01(\t\x12\x15\n\rrun_timestamp\x18\x03 \x01(\t\x12\x15\n\rresources_str\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\tB\x0b\n\t_job_name\"1\n\x0e\x41\x64\x64JobResponse\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x0f\n\x07log_dir\x18\x02 \x01(\t\"\xb3\x01\n\x0fQueueJobRequest\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x14\n\x07\x63odegen\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0bscript_path\x18\x03 \x01(\t\x12\x16\n\x0eremote_log_dir\x18\x04 \x01(\t\x12\x31\n\x0bmanaged_job\x18\x05 \x01(\x0b\x32\x17.jobs.v1.ManagedJobInfoH\x01\x88\x01\x01\x42\n\n\x08_codegenB\x0e\n\x0c_managed_job\"\x89\x01\n\x0eManagedJobInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\x04pool\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x11\n\tworkspace\x18\x03 \x01(\t\x12\x12\n\nentrypoint\x18\x04 \x01(\t\x12&\n\x05tasks\x18\x05 \x03(\x0b\x32\x17.jobs.v1.ManagedJobTaskB\x07\n\x05_pool\"]\n\x0eManagedJobTask\x12\x0f\n\x07task_id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rresources_str\x18\x03 \x01(\t\x12\x15\n\rmetadata_json\x18\x04 \x01(\t\"\x12\n\x10QueueJobResponse\"\x15\n\x13UpdateStatusRequest\"\x16\n\x14UpdateStatusResponse\"L\n\x12GetJobQueueRequest\x12\x16\n\tuser_hash\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08\x61ll_jobs\x18\x02 \x01(\x08\x42\x0c\n\n_user_hash\"\xa3\x02\n\x07JobInfo\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x10\n\x08job_name\x18\x02 \x01(\t\x12\x10\n\x08username\x18\x03 \x01(\t\x12\x14\n\x0csubmitted_at\x18\x04 \x01(\x01\x12\"\n\x06status\x18\x05 \x01(\x0e\x32\x12.jobs.v1.JobStatus\x12\x15\n\rrun_timestamp\x18\x06 \x01(\t\x12\x15\n\x08start_at\x18\x07 \x01(\x01H\x00\x88\x01\x01\x12\x13\n\x06\x65nd_at\x18\x08 \x01(\x01H\x01\x88\x01\x01\x12\x11\n\tresources\x18\t \x01(\t\x12\x10\n\x03pid\x18\n \x01(\x03H\x02\x88\x01\x01\x12\x10\n\x08log_path\x18\x0b \x01(\t\x12\x10\n\x08metadata\x18\x0c \x01(\tB\x0b\n\t_start_atB\t\n\x07_end_atB\x06\n\x04_pid\"5\n\x13GetJobQueueResponse\x12\x1e\n\x04jobs\x18\x01 \x03(\x0b\x32\x10.jobs.v1.JobInfo\"^\n\x11\x43\x61ncelJobsRequest\x12\x0f\n\x07job_ids\x18\x01 \x03(\x03\x12\x12\n\ncancel_all\x18\x02 \x01(\x08\x12\x16\n\tuser_hash\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x0c\n\n_user_hash\"/\n\x12\x43\x61ncelJobsResponse\x12\x19\n\x11\x63\x61ncelled_job_ids\x18\x01 \x03(\x03\"\x1e\n\x1c\x46\x61ilAllInProgressJobsRequest\"\x1f\n\x1d\x46\x61ilAllInProgressJobsResponse\"\x7f\n\x0fTailLogsRequest\x12\x13\n\x06job_id\x18\x01 \x01(\x03H\x00\x88\x01\x01\x12\x1b\n\x0emanaged_job_id\x18\x02 \x01(\x03H\x01\x88\x01\x01\x12\x0e\n\x06\x66ollow\x18\x03 \x01(\x08\x12\x0c\n\x04tail\x18\x04 \x01(\x05\x42\t\n\x07_job_idB\x11\n\x0f_managed_job_id\"7\n\x10TailLogsResponse\x12\x10\n\x08log_line\x18\x01 \x01(\t\x12\x11\n\texit_code\x18\x02 \x01(\x05\"&\n\x13GetJobStatusRequest\x12\x0f\n\x07job_ids\x18\x01 \x03(\x03\"\xa4\x01\n\x14GetJobStatusResponse\x12\x44\n\x0cjob_statuses\x18\x01 \x03(\x0b\x32..jobs.v1.GetJobStatusResponse.JobStatusesEntry\x1a\x46\n\x10JobStatusesEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12!\n\x05value\x18\x02 \x01(\x0e\x32\x12.jobs.v1.JobStatus:\x02\x38\x01\"A\n\x1fGetJobSubmittedTimestampRequest\x12\x13\n\x06job_id\x18\x01 \x01(\x03H\x00\x88\x01\x01\x42\t\n\x07_job_id\"5\n GetJobSubmittedTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x02\"=\n\x1bGetJobEndedTimestampRequest\x12\x13\n\x06job_id\x18\x01 \x01(\x03H\x00\x88\x01\x01\x42\t\n\x07_job_id\"1\n\x1cGetJobEndedTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x02\"+\n\x18GetLogDirsForJobsRequest\x12\x0f\n\x07job_ids\x18\x01 \x03(\x03\"\x98\x01\n\x19GetLogDirsForJobsResponse\x12H\n\x0cjob_log_dirs\x18\x01 \x03(\x0b\x32\x32.jobs.v1.GetLogDirsForJobsResponse.JobLogDirsEntry\x1a\x31\n\x0fJobLogDirsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01*\x8d\x02\n\tJobStatus\x12\x1a\n\x16JOB_STATUS_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOB_STATUS_INIT\x10\x01\x12\x16\n\x12JOB_STATUS_PENDING\x10\x02\x12\x19\n\x15JOB_STATUS_SETTING_UP\x10\x03\x12\x16\n\x12JOB_STATUS_RUNNING\x10\x04\x12\x1c\n\x18JOB_STATUS_FAILED_DRIVER\x10\x05\x12\x18\n\x14JOB_STATUS_SUCCEEDED\x10\x06\x12\x15\n\x11JOB_STATUS_FAILED\x10\x07\x12\x1b\n\x17JOB_STATUS_FAILED_SETUP\x10\x08\x12\x18\n\x14JOB_STATUS_CANCELLED\x10\t2\x91\x07\n\x0bJobsService\x12\x39\n\x06\x41\x64\x64Job\x12\x16.jobs.v1.AddJobRequest\x1a\x17.jobs.v1.AddJobResponse\x12?\n\x08QueueJob\x12\x18.jobs.v1.QueueJobRequest\x1a\x19.jobs.v1.QueueJobResponse\x12K\n\x0cUpdateStatus\x12\x1c.jobs.v1.UpdateStatusRequest\x1a\x1d.jobs.v1.UpdateStatusResponse\x12H\n\x0bGetJobQueue\x12\x1b.jobs.v1.GetJobQueueRequest\x1a\x1c.jobs.v1.GetJobQueueResponse\x12\x45\n\nCancelJobs\x12\x1a.jobs.v1.CancelJobsRequest\x1a\x1b.jobs.v1.CancelJobsResponse\x12\x66\n\x15\x46\x61ilAllInProgressJobs\x12%.jobs.v1.FailAllInProgressJobsRequest\x1a&.jobs.v1.FailAllInProgressJobsResponse\x12\x41\n\x08TailLogs\x12\x18.jobs.v1.TailLogsRequest\x1a\x19.jobs.v1.TailLogsResponse0\x01\x12K\n\x0cGetJobStatus\x12\x1c.jobs.v1.GetJobStatusRequest\x1a\x1d.jobs.v1.GetJobStatusResponse\x12o\n\x18GetJobSubmittedTimestamp\x12(.jobs.v1.GetJobSubmittedTimestampRequest\x1a).jobs.v1.GetJobSubmittedTimestampResponse\x12\x63\n\x14GetJobEndedTimestamp\x12$.jobs.v1.GetJobEndedTimestampRequest\x1a%.jobs.v1.GetJobEndedTimestampResponse\x12Z\n\x11GetLogDirsForJobs\x12!.jobs.v1.GetLogDirsForJobsRequest\x1a\".jobs.v1.GetLogDirsForJobsResponseb\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"sky/schemas/generated/jobsv1.proto\x12\x07jobs.v1\"\x85\x01\n\rAddJobRequest\x12\x15\n\x08job_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08username\x18\x02 \x01(\t\x12\x15\n\rrun_timestamp\x18\x03 \x01(\t\x12\x15\n\rresources_str\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\tB\x0b\n\t_job_name\"1\n\x0e\x41\x64\x64JobResponse\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x0f\n\x07log_dir\x18\x02 \x01(\t\"\xb3\x01\n\x0fQueueJobRequest\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x14\n\x07\x63odegen\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0bscript_path\x18\x03 \x01(\t\x12\x16\n\x0eremote_log_dir\x18\x04 \x01(\t\x12\x31\n\x0bmanaged_job\x18\x05 \x01(\x0b\x32\x17.jobs.v1.ManagedJobInfoH\x01\x88\x01\x01\x42\n\n\x08_codegenB\x0e\n\x0c_managed_job\"\xab\x01\n\x0eManagedJobInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\x04pool\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x11\n\tworkspace\x18\x03 \x01(\t\x12\x12\n\nentrypoint\x18\x04 \x01(\t\x12&\n\x05tasks\x18\x05 \x03(\x0b\x32\x17.jobs.v1.ManagedJobTask\x12\x14\n\x07user_id\x18\x06 \x01(\tH\x01\x88\x01\x01\x42\x07\n\x05_poolB\n\n\x08_user_id\"]\n\x0eManagedJobTask\x12\x0f\n\x07task_id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rresources_str\x18\x03 \x01(\t\x12\x15\n\rmetadata_json\x18\x04 \x01(\t\"\x12\n\x10QueueJobResponse\"\x15\n\x13UpdateStatusRequest\"\x16\n\x14UpdateStatusResponse\"L\n\x12GetJobQueueRequest\x12\x16\n\tuser_hash\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08\x61ll_jobs\x18\x02 \x01(\x08\x42\x0c\n\n_user_hash\"\xa3\x02\n\x07JobInfo\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x10\n\x08job_name\x18\x02 \x01(\t\x12\x10\n\x08username\x18\x03 \x01(\t\x12\x14\n\x0csubmitted_at\x18\x04 \x01(\x01\x12\"\n\x06status\x18\x05 \x01(\x0e\x32\x12.jobs.v1.JobStatus\x12\x15\n\rrun_timestamp\x18\x06 \x01(\t\x12\x15\n\x08start_at\x18\x07 \x01(\x01H\x00\x88\x01\x01\x12\x13\n\x06\x65nd_at\x18\x08 \x01(\x01H\x01\x88\x01\x01\x12\x11\n\tresources\x18\t \x01(\t\x12\x10\n\x03pid\x18\n \x01(\x03H\x02\x88\x01\x01\x12\x10\n\x08log_path\x18\x0b \x01(\t\x12\x10\n\x08metadata\x18\x0c \x01(\tB\x0b\n\t_start_atB\t\n\x07_end_atB\x06\n\x04_pid\"5\n\x13GetJobQueueResponse\x12\x1e\n\x04jobs\x18\x01 \x03(\x0b\x32\x10.jobs.v1.JobInfo\"^\n\x11\x43\x61ncelJobsRequest\x12\x0f\n\x07job_ids\x18\x01 \x03(\x03\x12\x12\n\ncancel_all\x18\x02 \x01(\x08\x12\x16\n\tuser_hash\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x0c\n\n_user_hash\"/\n\x12\x43\x61ncelJobsResponse\x12\x19\n\x11\x63\x61ncelled_job_ids\x18\x01 \x03(\x03\"\x1e\n\x1c\x46\x61ilAllInProgressJobsRequest\"\x1f\n\x1d\x46\x61ilAllInProgressJobsResponse\"\x7f\n\x0fTailLogsRequest\x12\x13\n\x06job_id\x18\x01 \x01(\x03H\x00\x88\x01\x01\x12\x1b\n\x0emanaged_job_id\x18\x02 \x01(\x03H\x01\x88\x01\x01\x12\x0e\n\x06\x66ollow\x18\x03 \x01(\x08\x12\x0c\n\x04tail\x18\x04 \x01(\x05\x42\t\n\x07_job_idB\x11\n\x0f_managed_job_id\"7\n\x10TailLogsResponse\x12\x10\n\x08log_line\x18\x01 \x01(\t\x12\x11\n\texit_code\x18\x02 \x01(\x05\"&\n\x13GetJobStatusRequest\x12\x0f\n\x07job_ids\x18\x01 \x03(\x03\"\xa4\x01\n\x14GetJobStatusResponse\x12\x44\n\x0cjob_statuses\x18\x01 \x03(\x0b\x32..jobs.v1.GetJobStatusResponse.JobStatusesEntry\x1a\x46\n\x10JobStatusesEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12!\n\x05value\x18\x02 \x01(\x0e\x32\x12.jobs.v1.JobStatus:\x02\x38\x01\"A\n\x1fGetJobSubmittedTimestampRequest\x12\x13\n\x06job_id\x18\x01 \x01(\x03H\x00\x88\x01\x01\x42\t\n\x07_job_id\"5\n GetJobSubmittedTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x02\"=\n\x1bGetJobEndedTimestampRequest\x12\x13\n\x06job_id\x18\x01 \x01(\x03H\x00\x88\x01\x01\x42\t\n\x07_job_id\"1\n\x1cGetJobEndedTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x02\"+\n\x18GetLogDirsForJobsRequest\x12\x0f\n\x07job_ids\x18\x01 \x03(\x03\"\x98\x01\n\x19GetLogDirsForJobsResponse\x12H\n\x0cjob_log_dirs\x18\x01 \x03(\x0b\x32\x32.jobs.v1.GetLogDirsForJobsResponse.JobLogDirsEntry\x1a\x31\n\x0fJobLogDirsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01*\x8d\x02\n\tJobStatus\x12\x1a\n\x16JOB_STATUS_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOB_STATUS_INIT\x10\x01\x12\x16\n\x12JOB_STATUS_PENDING\x10\x02\x12\x19\n\x15JOB_STATUS_SETTING_UP\x10\x03\x12\x16\n\x12JOB_STATUS_RUNNING\x10\x04\x12\x1c\n\x18JOB_STATUS_FAILED_DRIVER\x10\x05\x12\x18\n\x14JOB_STATUS_SUCCEEDED\x10\x06\x12\x15\n\x11JOB_STATUS_FAILED\x10\x07\x12\x1b\n\x17JOB_STATUS_FAILED_SETUP\x10\x08\x12\x18\n\x14JOB_STATUS_CANCELLED\x10\t2\x91\x07\n\x0bJobsService\x12\x39\n\x06\x41\x64\x64Job\x12\x16.jobs.v1.AddJobRequest\x1a\x17.jobs.v1.AddJobResponse\x12?\n\x08QueueJob\x12\x18.jobs.v1.QueueJobRequest\x1a\x19.jobs.v1.QueueJobResponse\x12K\n\x0cUpdateStatus\x12\x1c.jobs.v1.UpdateStatusRequest\x1a\x1d.jobs.v1.UpdateStatusResponse\x12H\n\x0bGetJobQueue\x12\x1b.jobs.v1.GetJobQueueRequest\x1a\x1c.jobs.v1.GetJobQueueResponse\x12\x45\n\nCancelJobs\x12\x1a.jobs.v1.CancelJobsRequest\x1a\x1b.jobs.v1.CancelJobsResponse\x12\x66\n\x15\x46\x61ilAllInProgressJobs\x12%.jobs.v1.FailAllInProgressJobsRequest\x1a&.jobs.v1.FailAllInProgressJobsResponse\x12\x41\n\x08TailLogs\x12\x18.jobs.v1.TailLogsRequest\x1a\x19.jobs.v1.TailLogsResponse0\x01\x12K\n\x0cGetJobStatus\x12\x1c.jobs.v1.GetJobStatusRequest\x1a\x1d.jobs.v1.GetJobStatusResponse\x12o\n\x18GetJobSubmittedTimestamp\x12(.jobs.v1.GetJobSubmittedTimestampRequest\x1a).jobs.v1.GetJobSubmittedTimestampResponse\x12\x63\n\x14GetJobEndedTimestamp\x12$.jobs.v1.GetJobEndedTimestampRequest\x1a%.jobs.v1.GetJobEndedTimestampResponse\x12Z\n\x11GetLogDirsForJobs\x12!.jobs.v1.GetLogDirsForJobsRequest\x1a\".jobs.v1.GetLogDirsForJobsResponseb\x06proto3')
18
18
 
19
19
  _globals = globals()
20
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -25,8 +25,8 @@ if not _descriptor._USE_C_DESCRIPTORS:
25
25
  _globals['_GETJOBSTATUSRESPONSE_JOBSTATUSESENTRY']._serialized_options = b'8\001'
26
26
  _globals['_GETLOGDIRSFORJOBSRESPONSE_JOBLOGDIRSENTRY']._loaded_options = None
27
27
  _globals['_GETLOGDIRSFORJOBSRESPONSE_JOBLOGDIRSENTRY']._serialized_options = b'8\001'
28
- _globals['_JOBSTATUS']._serialized_start=2185
29
- _globals['_JOBSTATUS']._serialized_end=2454
28
+ _globals['_JOBSTATUS']._serialized_start=2219
29
+ _globals['_JOBSTATUS']._serialized_end=2488
30
30
  _globals['_ADDJOBREQUEST']._serialized_start=48
31
31
  _globals['_ADDJOBREQUEST']._serialized_end=181
32
32
  _globals['_ADDJOBRESPONSE']._serialized_start=183
@@ -34,53 +34,53 @@ if not _descriptor._USE_C_DESCRIPTORS:
34
34
  _globals['_QUEUEJOBREQUEST']._serialized_start=235
35
35
  _globals['_QUEUEJOBREQUEST']._serialized_end=414
36
36
  _globals['_MANAGEDJOBINFO']._serialized_start=417
37
- _globals['_MANAGEDJOBINFO']._serialized_end=554
38
- _globals['_MANAGEDJOBTASK']._serialized_start=556
39
- _globals['_MANAGEDJOBTASK']._serialized_end=649
40
- _globals['_QUEUEJOBRESPONSE']._serialized_start=651
41
- _globals['_QUEUEJOBRESPONSE']._serialized_end=669
42
- _globals['_UPDATESTATUSREQUEST']._serialized_start=671
43
- _globals['_UPDATESTATUSREQUEST']._serialized_end=692
44
- _globals['_UPDATESTATUSRESPONSE']._serialized_start=694
45
- _globals['_UPDATESTATUSRESPONSE']._serialized_end=716
46
- _globals['_GETJOBQUEUEREQUEST']._serialized_start=718
47
- _globals['_GETJOBQUEUEREQUEST']._serialized_end=794
48
- _globals['_JOBINFO']._serialized_start=797
49
- _globals['_JOBINFO']._serialized_end=1088
50
- _globals['_GETJOBQUEUERESPONSE']._serialized_start=1090
51
- _globals['_GETJOBQUEUERESPONSE']._serialized_end=1143
52
- _globals['_CANCELJOBSREQUEST']._serialized_start=1145
53
- _globals['_CANCELJOBSREQUEST']._serialized_end=1239
54
- _globals['_CANCELJOBSRESPONSE']._serialized_start=1241
55
- _globals['_CANCELJOBSRESPONSE']._serialized_end=1288
56
- _globals['_FAILALLINPROGRESSJOBSREQUEST']._serialized_start=1290
57
- _globals['_FAILALLINPROGRESSJOBSREQUEST']._serialized_end=1320
58
- _globals['_FAILALLINPROGRESSJOBSRESPONSE']._serialized_start=1322
59
- _globals['_FAILALLINPROGRESSJOBSRESPONSE']._serialized_end=1353
60
- _globals['_TAILLOGSREQUEST']._serialized_start=1355
61
- _globals['_TAILLOGSREQUEST']._serialized_end=1482
62
- _globals['_TAILLOGSRESPONSE']._serialized_start=1484
63
- _globals['_TAILLOGSRESPONSE']._serialized_end=1539
64
- _globals['_GETJOBSTATUSREQUEST']._serialized_start=1541
65
- _globals['_GETJOBSTATUSREQUEST']._serialized_end=1579
66
- _globals['_GETJOBSTATUSRESPONSE']._serialized_start=1582
67
- _globals['_GETJOBSTATUSRESPONSE']._serialized_end=1746
68
- _globals['_GETJOBSTATUSRESPONSE_JOBSTATUSESENTRY']._serialized_start=1676
69
- _globals['_GETJOBSTATUSRESPONSE_JOBSTATUSESENTRY']._serialized_end=1746
70
- _globals['_GETJOBSUBMITTEDTIMESTAMPREQUEST']._serialized_start=1748
71
- _globals['_GETJOBSUBMITTEDTIMESTAMPREQUEST']._serialized_end=1813
72
- _globals['_GETJOBSUBMITTEDTIMESTAMPRESPONSE']._serialized_start=1815
73
- _globals['_GETJOBSUBMITTEDTIMESTAMPRESPONSE']._serialized_end=1868
74
- _globals['_GETJOBENDEDTIMESTAMPREQUEST']._serialized_start=1870
75
- _globals['_GETJOBENDEDTIMESTAMPREQUEST']._serialized_end=1931
76
- _globals['_GETJOBENDEDTIMESTAMPRESPONSE']._serialized_start=1933
77
- _globals['_GETJOBENDEDTIMESTAMPRESPONSE']._serialized_end=1982
78
- _globals['_GETLOGDIRSFORJOBSREQUEST']._serialized_start=1984
79
- _globals['_GETLOGDIRSFORJOBSREQUEST']._serialized_end=2027
80
- _globals['_GETLOGDIRSFORJOBSRESPONSE']._serialized_start=2030
81
- _globals['_GETLOGDIRSFORJOBSRESPONSE']._serialized_end=2182
82
- _globals['_GETLOGDIRSFORJOBSRESPONSE_JOBLOGDIRSENTRY']._serialized_start=2133
83
- _globals['_GETLOGDIRSFORJOBSRESPONSE_JOBLOGDIRSENTRY']._serialized_end=2182
84
- _globals['_JOBSSERVICE']._serialized_start=2457
85
- _globals['_JOBSSERVICE']._serialized_end=3370
37
+ _globals['_MANAGEDJOBINFO']._serialized_end=588
38
+ _globals['_MANAGEDJOBTASK']._serialized_start=590
39
+ _globals['_MANAGEDJOBTASK']._serialized_end=683
40
+ _globals['_QUEUEJOBRESPONSE']._serialized_start=685
41
+ _globals['_QUEUEJOBRESPONSE']._serialized_end=703
42
+ _globals['_UPDATESTATUSREQUEST']._serialized_start=705
43
+ _globals['_UPDATESTATUSREQUEST']._serialized_end=726
44
+ _globals['_UPDATESTATUSRESPONSE']._serialized_start=728
45
+ _globals['_UPDATESTATUSRESPONSE']._serialized_end=750
46
+ _globals['_GETJOBQUEUEREQUEST']._serialized_start=752
47
+ _globals['_GETJOBQUEUEREQUEST']._serialized_end=828
48
+ _globals['_JOBINFO']._serialized_start=831
49
+ _globals['_JOBINFO']._serialized_end=1122
50
+ _globals['_GETJOBQUEUERESPONSE']._serialized_start=1124
51
+ _globals['_GETJOBQUEUERESPONSE']._serialized_end=1177
52
+ _globals['_CANCELJOBSREQUEST']._serialized_start=1179
53
+ _globals['_CANCELJOBSREQUEST']._serialized_end=1273
54
+ _globals['_CANCELJOBSRESPONSE']._serialized_start=1275
55
+ _globals['_CANCELJOBSRESPONSE']._serialized_end=1322
56
+ _globals['_FAILALLINPROGRESSJOBSREQUEST']._serialized_start=1324
57
+ _globals['_FAILALLINPROGRESSJOBSREQUEST']._serialized_end=1354
58
+ _globals['_FAILALLINPROGRESSJOBSRESPONSE']._serialized_start=1356
59
+ _globals['_FAILALLINPROGRESSJOBSRESPONSE']._serialized_end=1387
60
+ _globals['_TAILLOGSREQUEST']._serialized_start=1389
61
+ _globals['_TAILLOGSREQUEST']._serialized_end=1516
62
+ _globals['_TAILLOGSRESPONSE']._serialized_start=1518
63
+ _globals['_TAILLOGSRESPONSE']._serialized_end=1573
64
+ _globals['_GETJOBSTATUSREQUEST']._serialized_start=1575
65
+ _globals['_GETJOBSTATUSREQUEST']._serialized_end=1613
66
+ _globals['_GETJOBSTATUSRESPONSE']._serialized_start=1616
67
+ _globals['_GETJOBSTATUSRESPONSE']._serialized_end=1780
68
+ _globals['_GETJOBSTATUSRESPONSE_JOBSTATUSESENTRY']._serialized_start=1710
69
+ _globals['_GETJOBSTATUSRESPONSE_JOBSTATUSESENTRY']._serialized_end=1780
70
+ _globals['_GETJOBSUBMITTEDTIMESTAMPREQUEST']._serialized_start=1782
71
+ _globals['_GETJOBSUBMITTEDTIMESTAMPREQUEST']._serialized_end=1847
72
+ _globals['_GETJOBSUBMITTEDTIMESTAMPRESPONSE']._serialized_start=1849
73
+ _globals['_GETJOBSUBMITTEDTIMESTAMPRESPONSE']._serialized_end=1902
74
+ _globals['_GETJOBENDEDTIMESTAMPREQUEST']._serialized_start=1904
75
+ _globals['_GETJOBENDEDTIMESTAMPREQUEST']._serialized_end=1965
76
+ _globals['_GETJOBENDEDTIMESTAMPRESPONSE']._serialized_start=1967
77
+ _globals['_GETJOBENDEDTIMESTAMPRESPONSE']._serialized_end=2016
78
+ _globals['_GETLOGDIRSFORJOBSREQUEST']._serialized_start=2018
79
+ _globals['_GETLOGDIRSFORJOBSREQUEST']._serialized_end=2061
80
+ _globals['_GETLOGDIRSFORJOBSRESPONSE']._serialized_start=2064
81
+ _globals['_GETLOGDIRSFORJOBSRESPONSE']._serialized_end=2216
82
+ _globals['_GETLOGDIRSFORJOBSRESPONSE_JOBLOGDIRSENTRY']._serialized_start=2167
83
+ _globals['_GETLOGDIRSFORJOBSRESPONSE_JOBLOGDIRSENTRY']._serialized_end=2216
84
+ _globals['_JOBSSERVICE']._serialized_start=2491
85
+ _globals['_JOBSSERVICE']._serialized_end=3404
86
86
  # @@protoc_insertion_point(module_scope)
@@ -66,18 +66,20 @@ class QueueJobRequest(_message.Message):
66
66
  def __init__(self, job_id: _Optional[int] = ..., codegen: _Optional[str] = ..., script_path: _Optional[str] = ..., remote_log_dir: _Optional[str] = ..., managed_job: _Optional[_Union[ManagedJobInfo, _Mapping]] = ...) -> None: ...
67
67
 
68
68
  class ManagedJobInfo(_message.Message):
69
- __slots__ = ("name", "pool", "workspace", "entrypoint", "tasks")
69
+ __slots__ = ("name", "pool", "workspace", "entrypoint", "tasks", "user_id")
70
70
  NAME_FIELD_NUMBER: _ClassVar[int]
71
71
  POOL_FIELD_NUMBER: _ClassVar[int]
72
72
  WORKSPACE_FIELD_NUMBER: _ClassVar[int]
73
73
  ENTRYPOINT_FIELD_NUMBER: _ClassVar[int]
74
74
  TASKS_FIELD_NUMBER: _ClassVar[int]
75
+ USER_ID_FIELD_NUMBER: _ClassVar[int]
75
76
  name: str
76
77
  pool: str
77
78
  workspace: str
78
79
  entrypoint: str
79
80
  tasks: _containers.RepeatedCompositeFieldContainer[ManagedJobTask]
80
- def __init__(self, name: _Optional[str] = ..., pool: _Optional[str] = ..., workspace: _Optional[str] = ..., entrypoint: _Optional[str] = ..., tasks: _Optional[_Iterable[_Union[ManagedJobTask, _Mapping]]] = ...) -> None: ...
81
+ user_id: str
82
+ def __init__(self, name: _Optional[str] = ..., pool: _Optional[str] = ..., workspace: _Optional[str] = ..., entrypoint: _Optional[str] = ..., tasks: _Optional[_Iterable[_Union[ManagedJobTask, _Mapping]]] = ..., user_id: _Optional[str] = ...) -> None: ...
81
83
 
82
84
  class ManagedJobTask(_message.Message):
83
85
  __slots__ = ("task_id", "name", "resources_str", "metadata_json")
@@ -98,6 +98,7 @@ async def tail_logs(
98
98
  request: fastapi.Request, log_body: payloads.ServeLogsBody,
99
99
  background_tasks: fastapi.BackgroundTasks
100
100
  ) -> fastapi.responses.StreamingResponse:
101
+ executor.check_request_thread_executor_available()
101
102
  request_task = executor.prepare_request(
102
103
  request_id=request.state.request_id,
103
104
  request_name='serve.logs',
@@ -48,6 +48,7 @@ from sky.server.requests import payloads
48
48
  from sky.server.requests import preconditions
49
49
  from sky.server.requests import process
50
50
  from sky.server.requests import requests as api_requests
51
+ from sky.server.requests import threads
51
52
  from sky.server.requests.queues import local_queue
52
53
  from sky.server.requests.queues import mp_queue
53
54
  from sky.skylet import constants
@@ -81,23 +82,28 @@ logger = sky_logging.init_logger(__name__)
81
82
  # platforms, including macOS.
82
83
  multiprocessing.set_start_method('spawn', force=True)
83
84
 
84
- # Max threads that is equivalent to the number of thread workers in the
85
- # default thread pool executor of event loop.
86
- _REQUEST_THREADS_LIMIT = min(32, (os.cpu_count() or 0) + 4)
85
+ # An upper limit of max threads for request execution per server process that
86
+ # unlikely to be reached to allow higher concurrency while still prevent the
87
+ # server process become overloaded.
88
+ _REQUEST_THREADS_LIMIT = 128
87
89
 
88
90
  _REQUEST_THREAD_EXECUTOR_LOCK = threading.Lock()
89
- # A dedicated thread pool executor for synced requests execution in coroutine
90
- _REQUEST_THREAD_EXECUTOR: Optional[concurrent.futures.ThreadPoolExecutor] = None
91
+ # A dedicated thread pool executor for synced requests execution in coroutine to
92
+ # avoid:
93
+ # 1. blocking the event loop;
94
+ # 2. exhausting the default thread pool executor of event loop;
95
+ _REQUEST_THREAD_EXECUTOR: Optional[threads.OnDemandThreadExecutor] = None
91
96
 
92
97
 
93
- def get_request_thread_executor() -> concurrent.futures.ThreadPoolExecutor:
98
+ def get_request_thread_executor() -> threads.OnDemandThreadExecutor:
94
99
  """Lazy init and return the request thread executor for current process."""
95
100
  global _REQUEST_THREAD_EXECUTOR
96
101
  if _REQUEST_THREAD_EXECUTOR is not None:
97
102
  return _REQUEST_THREAD_EXECUTOR
98
103
  with _REQUEST_THREAD_EXECUTOR_LOCK:
99
104
  if _REQUEST_THREAD_EXECUTOR is None:
100
- _REQUEST_THREAD_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
105
+ _REQUEST_THREAD_EXECUTOR = threads.OnDemandThreadExecutor(
106
+ name='request_thread_executor',
101
107
  max_workers=_REQUEST_THREADS_LIMIT)
102
108
  return _REQUEST_THREAD_EXECUTOR
103
109
 
@@ -561,6 +567,21 @@ class CoroutineTask:
561
567
  pass
562
568
 
563
569
 
570
+ def check_request_thread_executor_available() -> None:
571
+ """Check if the request thread executor is available.
572
+
573
+ This is a best effort check to hint the client to retry other server
574
+ processes when there is no avaiable thread worker in current one. But
575
+ a request may pass this check and still cannot get worker on execution
576
+ time due to race condition. In this case, the client will see a failed
577
+ request instead of retry.
578
+
579
+ TODO(aylei): this can be refined with a refactor of our coroutine
580
+ execution flow.
581
+ """
582
+ get_request_thread_executor().check_available()
583
+
584
+
564
585
  def execute_request_in_coroutine(
565
586
  request: api_requests.Request) -> CoroutineTask:
566
587
  """Execute a request in current event loop.
@@ -575,6 +596,18 @@ def execute_request_in_coroutine(
575
596
  return CoroutineTask(task)
576
597
 
577
598
 
599
+ def _execute_with_config_override(func: Callable,
600
+ request_body: payloads.RequestBody,
601
+ request_id: str, request_name: str,
602
+ **kwargs) -> Any:
603
+ """Execute a function with env and config override inside a thread."""
604
+ # Override the environment and config within this thread's context,
605
+ # which gets copied when we call to_thread.
606
+ with override_request_env_and_config(request_body, request_id,
607
+ request_name):
608
+ return func(**kwargs)
609
+
610
+
578
611
  async def _execute_request_coroutine(request: api_requests.Request):
579
612
  """Execute a request in current event loop.
580
613
 
@@ -592,14 +625,17 @@ async def _execute_request_coroutine(request: api_requests.Request):
592
625
  request_task.status = api_requests.RequestStatus.RUNNING
593
626
  # Redirect stdout and stderr to the request log path.
594
627
  original_output = ctx.redirect_log(request.log_path)
595
- # Override environment variables that backs env_options.Options
596
- # TODO(aylei): compared to process executor, running task in coroutine has
597
- # two issues to fix:
598
- # 1. skypilot config is not contextual
599
- # 2. envs that read directly from os.environ are not contextual
600
- ctx.override_envs(request_body.env_vars)
601
- fut: asyncio.Future = context_utils.to_thread_with_executor(
602
- get_request_thread_executor(), func, **request_body.to_kwargs())
628
+ try:
629
+ fut: asyncio.Future = context_utils.to_thread_with_executor(
630
+ get_request_thread_executor(), _execute_with_config_override, func,
631
+ request_body, request.request_id, request.name,
632
+ **request_body.to_kwargs())
633
+ except Exception as e: # pylint: disable=broad-except
634
+ ctx.redirect_log(original_output)
635
+ api_requests.set_request_failed(request.request_id, e)
636
+ logger.error(f'Failed to run request {request.request_id} due to '
637
+ f'{common_utils.format_exception(e)}')
638
+ return
603
639
 
604
640
  async def poll_task(request_id: str) -> bool:
605
641
  req_status = await api_requests.get_request_status_async(request_id)
@@ -112,10 +112,8 @@ class Precondition(abc.ABC):
112
112
  return True
113
113
  if status_msg is not None and status_msg != last_status_msg:
114
114
  # Update the status message if it has changed.
115
- async with api_requests.update_request_async(
116
- self.request_id) as req:
117
- assert req is not None, self.request_id
118
- req.status_msg = status_msg
115
+ await api_requests.update_status_msg_async(
116
+ self.request_id, status_msg)
119
117
  last_status_msg = status_msg
120
118
  except (Exception, SystemExit, KeyboardInterrupt) as e: # pylint: disable=broad-except
121
119
  api_requests.set_request_failed(self.request_id, e)
@@ -14,8 +14,8 @@ import sqlite3
14
14
  import threading
15
15
  import time
16
16
  import traceback
17
- from typing import (Any, AsyncContextManager, Callable, Dict, Generator, List,
18
- NamedTuple, Optional, Tuple)
17
+ from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
18
+ Tuple)
19
19
 
20
20
  import anyio
21
21
  import colorama
@@ -32,6 +32,7 @@ from sky.server import daemons
32
32
  from sky.server.requests import payloads
33
33
  from sky.server.requests.serializers import decoders
34
34
  from sky.server.requests.serializers import encoders
35
+ from sky.utils import asyncio_utils
35
36
  from sky.utils import common_utils
36
37
  from sky.utils import ux_utils
37
38
  from sky.utils.db import db_utils
@@ -578,27 +579,14 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
578
579
 
579
580
  @init_db
580
581
  @metrics_lib.time_me
581
- def update_request_async(
582
- request_id: str) -> AsyncContextManager[Optional[Request]]:
583
- """Async version of update_request.
584
-
585
- Returns an async context manager that yields the request record and
586
- persists any in-place updates upon exit.
587
- """
588
-
589
- @contextlib.asynccontextmanager
590
- async def _cm():
591
- # Acquire the lock to avoid race conditions between multiple request
592
- # operations, e.g. execute and cancel.
593
- async with filelock.AsyncFileLock(request_lock_path(request_id)):
594
- request = await _get_request_no_lock_async(request_id)
595
- try:
596
- yield request
597
- finally:
598
- if request is not None:
599
- await _add_or_update_request_no_lock_async(request)
600
-
601
- return _cm()
582
+ @asyncio_utils.shield
583
+ async def update_status_msg_async(request_id: str, status_msg: str) -> None:
584
+ """Update the status message of a request"""
585
+ async with filelock.AsyncFileLock(request_lock_path(request_id)):
586
+ request = await _get_request_no_lock_async(request_id)
587
+ if request is not None:
588
+ request.status_msg = status_msg
589
+ await _add_or_update_request_no_lock_async(request)
602
590
 
603
591
 
604
592
  _get_request_sql = (f'SELECT {", ".join(REQUEST_COLUMNS)} FROM {REQUEST_TABLE} '
@@ -651,8 +639,10 @@ def get_request(request_id: str) -> Optional[Request]:
651
639
 
652
640
  @init_db_async
653
641
  @metrics_lib.time_me_async
642
+ @asyncio_utils.shield
654
643
  async def get_request_async(request_id: str) -> Optional[Request]:
655
644
  """Async version of get_request."""
645
+ # TODO(aylei): figure out how to remove FileLock here to avoid the overhead
656
646
  async with filelock.AsyncFileLock(request_lock_path(request_id)):
657
647
  return await _get_request_no_lock_async(request_id)
658
648
 
@@ -704,6 +694,7 @@ def create_if_not_exists(request: Request) -> bool:
704
694
 
705
695
  @init_db_async
706
696
  @metrics_lib.time_me_async
697
+ @asyncio_utils.shield
707
698
  async def create_if_not_exists_async(request: Request) -> bool:
708
699
  """Async version of create_if_not_exists."""
709
700
  async with filelock.AsyncFileLock(request_lock_path(request.request_id)):
@@ -0,0 +1,106 @@
1
+ """Request execution threads management."""
2
+
3
+ import concurrent.futures
4
+ import threading
5
+ from typing import Callable, Set
6
+
7
+ from sky import exceptions
8
+ from sky import sky_logging
9
+ from sky.utils import atomic
10
+
11
+ logger = sky_logging.init_logger(__name__)
12
+
13
+
14
+ class OnDemandThreadExecutor(concurrent.futures.Executor):
15
+ """An executor that creates a new thread for each task and destroys it
16
+ after the task is completed.
17
+
18
+ Note(dev):
19
+ We raise an error instead of queuing the request if the limit is reached, so
20
+ that:
21
+ 1. the request might be handled by other processes that have idle workers
22
+ upon retry;
23
+ 2. if not, then users can be clearly hinted that they need to scale the API
24
+ server to support higher concurrency.
25
+ So this executor is only suitable for carefully selected cases where the
26
+ error can be properly handled by caller. To make this executor general, we
27
+ need to support configuring the queuing behavior (exception or queueing).
28
+ """
29
+
30
+ def __init__(self, name: str, max_workers: int):
31
+ self.name: str = name
32
+ self.max_workers: int = max_workers
33
+ self.running: atomic.AtomicInt = atomic.AtomicInt(0)
34
+ self._shutdown: bool = False
35
+ self._shutdown_lock: threading.Lock = threading.Lock()
36
+ self._threads: Set[threading.Thread] = set()
37
+ self._threads_lock: threading.Lock = threading.Lock()
38
+
39
+ def _cleanup_thread(self, thread: threading.Thread):
40
+ with self._threads_lock:
41
+ self._threads.discard(thread)
42
+
43
+ def _task_wrapper(self, fn: Callable, fut: concurrent.futures.Future, /,
44
+ *args, **kwargs):
45
+ try:
46
+ result = fn(*args, **kwargs)
47
+ fut.set_result(result)
48
+ except Exception as e: # pylint: disable=broad-except
49
+ logger.debug(f'Executor [{self.name}] error executing {fn}: {e}')
50
+ fut.set_exception(e)
51
+ finally:
52
+ self.running.decrement()
53
+ self._cleanup_thread(threading.current_thread())
54
+
55
+ def check_available(self, borrow: bool = False) -> int:
56
+ """Check if there are available workers.
57
+
58
+ Args:
59
+ borrow: If True, the caller borrow a worker from the executor.
60
+ The caller is responsible for returning the worker to the
61
+ executor after the task is completed.
62
+ """
63
+ count = self.running.increment()
64
+ if count > self.max_workers:
65
+ self.running.decrement()
66
+ raise exceptions.ConcurrentWorkerExhaustedError(
67
+ f'Maximum concurrent workers {self.max_workers} of threads '
68
+ f'executor [{self.name}] reached')
69
+ if not borrow:
70
+ self.running.decrement()
71
+ return count
72
+
73
+ def submit(self, fn, /, *args, **kwargs):
74
+ with self._shutdown_lock:
75
+ if self._shutdown:
76
+ raise RuntimeError(
77
+ 'Cannot submit task after executor is shutdown')
78
+ count = self.check_available(borrow=True)
79
+ fut: concurrent.futures.Future = concurrent.futures.Future()
80
+ # Name is assigned for debugging purpose, duplication is fine
81
+ thread = threading.Thread(target=self._task_wrapper,
82
+ name=f'{self.name}-{count}',
83
+ args=(fn, fut, *args),
84
+ kwargs=kwargs,
85
+ daemon=True)
86
+ with self._threads_lock:
87
+ self._threads.add(thread)
88
+ try:
89
+ thread.start()
90
+ except Exception as e:
91
+ self.running.decrement()
92
+ self._cleanup_thread(thread)
93
+ fut.set_exception(e)
94
+ raise
95
+ assert thread.ident is not None, 'Thread should be started'
96
+ return fut
97
+
98
+ def shutdown(self, wait=True):
99
+ with self._shutdown_lock:
100
+ self._shutdown = True
101
+ if not wait:
102
+ return
103
+ with self._threads_lock:
104
+ threads = list(self._threads)
105
+ for t in threads:
106
+ t.join()