dstack 0.18.41__py3-none-any.whl → 0.18.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. dstack/_internal/cli/commands/__init__.py +2 -1
  2. dstack/_internal/cli/commands/apply.py +4 -2
  3. dstack/_internal/cli/commands/attach.py +21 -1
  4. dstack/_internal/cli/commands/completion.py +20 -0
  5. dstack/_internal/cli/commands/delete.py +3 -1
  6. dstack/_internal/cli/commands/fleet.py +2 -1
  7. dstack/_internal/cli/commands/gateway.py +7 -2
  8. dstack/_internal/cli/commands/logs.py +3 -2
  9. dstack/_internal/cli/commands/stats.py +2 -1
  10. dstack/_internal/cli/commands/stop.py +2 -1
  11. dstack/_internal/cli/commands/volume.py +2 -1
  12. dstack/_internal/cli/main.py +6 -0
  13. dstack/_internal/cli/services/completion.py +86 -0
  14. dstack/_internal/cli/services/configurators/run.py +10 -17
  15. dstack/_internal/cli/utils/fleet.py +5 -1
  16. dstack/_internal/cli/utils/volume.py +9 -0
  17. dstack/_internal/core/backends/aws/compute.py +24 -11
  18. dstack/_internal/core/backends/aws/resources.py +3 -3
  19. dstack/_internal/core/backends/azure/compute.py +14 -8
  20. dstack/_internal/core/backends/azure/resources.py +2 -0
  21. dstack/_internal/core/backends/base/compute.py +102 -2
  22. dstack/_internal/core/backends/base/offers.py +7 -1
  23. dstack/_internal/core/backends/cudo/compute.py +8 -4
  24. dstack/_internal/core/backends/datacrunch/compute.py +10 -4
  25. dstack/_internal/core/backends/gcp/auth.py +19 -13
  26. dstack/_internal/core/backends/gcp/compute.py +27 -20
  27. dstack/_internal/core/backends/gcp/resources.py +3 -10
  28. dstack/_internal/core/backends/kubernetes/compute.py +4 -3
  29. dstack/_internal/core/backends/lambdalabs/compute.py +9 -3
  30. dstack/_internal/core/backends/nebius/compute.py +2 -2
  31. dstack/_internal/core/backends/oci/compute.py +10 -4
  32. dstack/_internal/core/backends/runpod/compute.py +11 -4
  33. dstack/_internal/core/backends/tensordock/compute.py +14 -3
  34. dstack/_internal/core/backends/vastai/compute.py +12 -2
  35. dstack/_internal/core/backends/vultr/api_client.py +3 -3
  36. dstack/_internal/core/backends/vultr/compute.py +9 -3
  37. dstack/_internal/core/models/backends/aws.py +2 -0
  38. dstack/_internal/core/models/backends/base.py +1 -0
  39. dstack/_internal/core/models/configurations.py +0 -1
  40. dstack/_internal/core/models/runs.py +3 -3
  41. dstack/_internal/core/models/volumes.py +23 -0
  42. dstack/_internal/core/services/__init__.py +5 -1
  43. dstack/_internal/core/services/configs/__init__.py +3 -0
  44. dstack/_internal/server/background/tasks/common.py +22 -0
  45. dstack/_internal/server/background/tasks/process_instances.py +13 -21
  46. dstack/_internal/server/background/tasks/process_running_jobs.py +13 -16
  47. dstack/_internal/server/background/tasks/process_submitted_jobs.py +12 -7
  48. dstack/_internal/server/background/tasks/process_terminating_jobs.py +7 -2
  49. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  50. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  51. dstack/_internal/server/models.py +17 -19
  52. dstack/_internal/server/routers/logs.py +3 -0
  53. dstack/_internal/server/services/backends/configurators/aws.py +31 -1
  54. dstack/_internal/server/services/backends/configurators/gcp.py +8 -15
  55. dstack/_internal/server/services/config.py +11 -1
  56. dstack/_internal/server/services/fleets.py +5 -1
  57. dstack/_internal/server/services/jobs/__init__.py +14 -11
  58. dstack/_internal/server/services/jobs/configurators/dev.py +1 -3
  59. dstack/_internal/server/services/jobs/configurators/task.py +1 -3
  60. dstack/_internal/server/services/logs/__init__.py +78 -0
  61. dstack/_internal/server/services/{logs.py → logs/aws.py} +12 -207
  62. dstack/_internal/server/services/logs/base.py +47 -0
  63. dstack/_internal/server/services/logs/filelog.py +110 -0
  64. dstack/_internal/server/services/logs/gcp.py +165 -0
  65. dstack/_internal/server/services/offers.py +7 -7
  66. dstack/_internal/server/services/pools.py +19 -20
  67. dstack/_internal/server/services/proxy/routers/service_proxy.py +14 -7
  68. dstack/_internal/server/services/runner/client.py +8 -5
  69. dstack/_internal/server/services/volumes.py +68 -9
  70. dstack/_internal/server/settings.py +3 -0
  71. dstack/_internal/server/statics/index.html +1 -1
  72. dstack/_internal/server/statics/{main-ad5150a441de98cd8987.css → main-7510e71dfa9749a4e70e.css} +1 -1
  73. dstack/_internal/server/statics/{main-2ac66bfcbd2e39830b88.js → main-fe8fd9db55df8d10e648.js} +66 -66
  74. dstack/_internal/server/statics/{main-2ac66bfcbd2e39830b88.js.map → main-fe8fd9db55df8d10e648.js.map} +1 -1
  75. dstack/_internal/server/testing/common.py +46 -17
  76. dstack/api/_public/runs.py +1 -1
  77. dstack/version.py +2 -2
  78. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/METADATA +4 -3
  79. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/RECORD +97 -86
  80. tests/_internal/core/backends/base/__init__.py +0 -0
  81. tests/_internal/core/backends/base/test_compute.py +56 -0
  82. tests/_internal/server/background/tasks/test_process_running_jobs.py +2 -1
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +5 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +11 -6
  85. tests/_internal/server/conftest.py +4 -5
  86. tests/_internal/server/routers/test_backends.py +1 -0
  87. tests/_internal/server/routers/test_logs.py +1 -1
  88. tests/_internal/server/routers/test_runs.py +2 -2
  89. tests/_internal/server/routers/test_volumes.py +9 -2
  90. tests/_internal/server/services/runner/test_client.py +22 -3
  91. tests/_internal/server/services/test_logs.py +3 -3
  92. tests/_internal/server/services/test_offers.py +167 -0
  93. tests/_internal/server/services/test_pools.py +105 -1
  94. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/LICENSE.md +0 -0
  95. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/WHEEL +0 -0
  96. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/entry_points.txt +0 -0
  97. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,26 @@
1
- import atexit
2
- import base64
3
1
  import itertools
4
2
  import operator
5
- from abc import ABC, abstractmethod
6
3
  from contextlib import contextmanager
7
4
  from datetime import datetime, timedelta, timezone
8
- from pathlib import Path
9
- from typing import Iterator, List, Optional, Set, Tuple, TypedDict, Union
5
+ from typing import Iterator, List, Optional, Set, Tuple, TypedDict
10
6
  from uuid import UUID
11
7
 
12
- from dstack._internal.core.errors import DstackError
13
8
  from dstack._internal.core.models.logs import (
14
9
  JobSubmissionLogs,
15
10
  LogEvent,
16
11
  LogEventSource,
17
12
  LogProducer,
18
13
  )
19
- from dstack._internal.server import settings
20
14
  from dstack._internal.server.models import ProjectModel
21
15
  from dstack._internal.server.schemas.logs import PollLogsRequest
22
16
  from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
23
- from dstack._internal.utils.common import run_async
17
+ from dstack._internal.server.services.logs.base import (
18
+ LogStorage,
19
+ LogStorageError,
20
+ b64encode_raw_message,
21
+ datetime_to_unix_time_ms,
22
+ unix_time_ms_to_datetime,
23
+ )
24
24
  from dstack._internal.utils.logging import get_logger
25
25
 
26
26
  BOTO_AVAILABLE = True
@@ -33,30 +33,6 @@ except ImportError:
33
33
  logger = get_logger(__name__)
34
34
 
35
35
 
36
- class LogStorageError(DstackError):
37
- pass
38
-
39
-
40
- class LogStorage(ABC):
41
- @abstractmethod
42
- def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
43
- pass
44
-
45
- @abstractmethod
46
- def write_logs(
47
- self,
48
- project: ProjectModel,
49
- run_name: str,
50
- job_submission_id: UUID,
51
- runner_logs: List[RunnerLogEvent],
52
- job_logs: List[RunnerLogEvent],
53
- ) -> None:
54
- pass
55
-
56
- def close(self) -> None:
57
- pass
58
-
59
-
60
36
  class _CloudWatchLogEvent(TypedDict):
61
37
  timestamp: int # unix time in milliseconds
62
38
  message: str
@@ -119,7 +95,7 @@ class CloudWatchLogStorage(LogStorage):
119
95
  cw_events_iter = iter(cw_events)
120
96
  logs = [
121
97
  LogEvent(
122
- timestamp=_unix_time_ms_to_datetime(cw_event["timestamp"]),
98
+ timestamp=unix_time_ms_to_datetime(cw_event["timestamp"]),
123
99
  log_source=LogEventSource.STDOUT,
124
100
  message=cw_event["message"],
125
101
  )
@@ -138,11 +114,11 @@ class CloudWatchLogStorage(LogStorage):
138
114
  if request.start_time:
139
115
  # XXX: Since callers use start_time/end_time for pagination, one millisecond is added
140
116
  # to avoid an infinite loop because startTime boundary is inclusive.
141
- parameters["startTime"] = _datetime_to_unix_time_ms(request.start_time) + 1
117
+ parameters["startTime"] = datetime_to_unix_time_ms(request.start_time) + 1
142
118
  if request.end_time:
143
119
  # No need to substract one millisecond in this case, though, seems that endTime is
144
120
  # exclusive, that is, time interval boundaries are [startTime, entTime)
145
- parameters["endTime"] = _datetime_to_unix_time_ms(request.end_time)
121
+ parameters["endTime"] = datetime_to_unix_time_ms(request.end_time)
146
122
  response = self._client.get_log_events(**parameters)
147
123
  events: List[_CloudWatchLogEvent] = response["events"]
148
124
  if start_from_head or events:
@@ -294,7 +270,7 @@ class CloudWatchLogStorage(LogStorage):
294
270
  ) -> _CloudWatchLogEvent:
295
271
  return {
296
272
  "timestamp": runner_log_event.timestamp,
297
- "message": _b64encode_raw_message(runner_log_event.message),
273
+ "message": b64encode_raw_message(runner_log_event.message),
298
274
  }
299
275
 
300
276
  @contextmanager
@@ -339,174 +315,3 @@ class CloudWatchLogStorage(LogStorage):
339
315
  producer: LogProducer,
340
316
  ) -> str:
341
317
  return f"{project_name}/{run_name}/{job_submission_id}/{producer.value}"
342
-
343
-
344
- class FileLogStorage(LogStorage):
345
- root: Path
346
-
347
- def __init__(self, root: Union[Path, str, None] = None) -> None:
348
- if root is None:
349
- self.root = settings.SERVER_DIR_PATH
350
- else:
351
- self.root = Path(root)
352
-
353
- def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
354
- # TODO Respect request.limit to support pagination
355
- log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
356
- log_file_path = self._get_log_file_path(
357
- project_name=project.name,
358
- run_name=request.run_name,
359
- job_submission_id=request.job_submission_id,
360
- producer=log_producer,
361
- )
362
- logs = []
363
- try:
364
- with open(log_file_path) as f:
365
- for line in f:
366
- log_event = LogEvent.__response__.parse_raw(line)
367
- if request.start_time and log_event.timestamp <= request.start_time:
368
- continue
369
- if request.end_time is None or log_event.timestamp < request.end_time:
370
- logs.append(log_event)
371
- else:
372
- break
373
- except IOError:
374
- pass
375
- if request.descending:
376
- logs = list(reversed(logs))
377
- return JobSubmissionLogs(logs=logs)
378
-
379
- def write_logs(
380
- self,
381
- project: ProjectModel,
382
- run_name: str,
383
- job_submission_id: UUID,
384
- runner_logs: List[RunnerLogEvent],
385
- job_logs: List[RunnerLogEvent],
386
- ):
387
- if len(runner_logs) > 0:
388
- runner_log_file_path = self._get_log_file_path(
389
- project.name, run_name, job_submission_id, LogProducer.RUNNER
390
- )
391
- self._write_logs(
392
- log_file_path=runner_log_file_path,
393
- log_events=runner_logs,
394
- )
395
- if len(job_logs) > 0:
396
- job_log_file_path = self._get_log_file_path(
397
- project.name, run_name, job_submission_id, LogProducer.JOB
398
- )
399
- self._write_logs(
400
- log_file_path=job_log_file_path,
401
- log_events=job_logs,
402
- )
403
-
404
- def _write_logs(self, log_file_path: Path, log_events: List[RunnerLogEvent]) -> None:
405
- log_events_parsed = [self._runner_log_event_to_log_event(event) for event in log_events]
406
- log_file_path.parent.mkdir(exist_ok=True, parents=True)
407
- with open(log_file_path, "a") as f:
408
- f.writelines(log.json() + "\n" for log in log_events_parsed)
409
-
410
- def _get_log_file_path(
411
- self,
412
- project_name: str,
413
- run_name: str,
414
- job_submission_id: UUID,
415
- producer: LogProducer,
416
- ) -> Path:
417
- return (
418
- self.root
419
- / "projects"
420
- / project_name
421
- / "logs"
422
- / run_name
423
- / str(job_submission_id)
424
- / f"{producer.value}.log"
425
- )
426
-
427
- def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> LogEvent:
428
- return LogEvent(
429
- timestamp=_unix_time_ms_to_datetime(runner_log_event.timestamp),
430
- log_source=LogEventSource.STDOUT,
431
- message=_b64encode_raw_message(runner_log_event.message),
432
- )
433
-
434
-
435
- def _unix_time_ms_to_datetime(unix_time_ms: int) -> datetime:
436
- return datetime.fromtimestamp(unix_time_ms / 1000, tz=timezone.utc)
437
-
438
-
439
- def _datetime_to_unix_time_ms(dt: datetime) -> int:
440
- return int(dt.timestamp() * 1000)
441
-
442
-
443
- def _b64encode_raw_message(message: bytes) -> str:
444
- return base64.b64encode(message).decode()
445
-
446
-
447
- _default_log_storage: Optional[LogStorage] = None
448
-
449
-
450
- def get_default_log_storage() -> LogStorage:
451
- global _default_log_storage
452
- if _default_log_storage is not None:
453
- return _default_log_storage
454
- if settings.SERVER_CLOUDWATCH_LOG_GROUP:
455
- if BOTO_AVAILABLE:
456
- try:
457
- _default_log_storage = CloudWatchLogStorage(
458
- group=settings.SERVER_CLOUDWATCH_LOG_GROUP,
459
- region=settings.SERVER_CLOUDWATCH_LOG_REGION,
460
- )
461
- except LogStorageError as e:
462
- logger.error("Failed to initialize CloudWatch Logs storage: %s", e)
463
- else:
464
- logger.debug("Using CloudWatch Logs storage")
465
- else:
466
- logger.error("Cannot use CloudWatch Logs storage, boto3 is not installed")
467
- if _default_log_storage is None:
468
- logger.debug("Using file-based storage")
469
- _default_log_storage = FileLogStorage()
470
- atexit.register(_default_log_storage.close)
471
- return _default_log_storage
472
-
473
-
474
- def poll_logs(project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
475
- return get_default_log_storage().poll_logs(project=project, request=request)
476
-
477
-
478
- def write_logs(
479
- project: ProjectModel,
480
- run_name: str,
481
- job_submission_id: UUID,
482
- runner_logs: List[RunnerLogEvent],
483
- job_logs: List[RunnerLogEvent],
484
- ) -> None:
485
- return get_default_log_storage().write_logs(
486
- project=project,
487
- run_name=run_name,
488
- job_submission_id=job_submission_id,
489
- runner_logs=runner_logs,
490
- job_logs=job_logs,
491
- )
492
-
493
-
494
- async def poll_logs_async(project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
495
- return await run_async(get_default_log_storage().poll_logs, project=project, request=request)
496
-
497
-
498
- async def write_logs_async(
499
- project: ProjectModel,
500
- run_name: str,
501
- job_submission_id: UUID,
502
- runner_logs: List[RunnerLogEvent],
503
- job_logs: List[RunnerLogEvent],
504
- ) -> None:
505
- return await run_async(
506
- get_default_log_storage().write_logs,
507
- project=project,
508
- run_name=run_name,
509
- job_submission_id=job_submission_id,
510
- runner_logs=runner_logs,
511
- job_logs=job_logs,
512
- )
@@ -0,0 +1,47 @@
1
+ import base64
2
+ from abc import ABC, abstractmethod
3
+ from datetime import datetime, timezone
4
+ from typing import List
5
+ from uuid import UUID
6
+
7
+ from dstack._internal.core.errors import DstackError
8
+ from dstack._internal.core.models.logs import JobSubmissionLogs
9
+ from dstack._internal.server.models import ProjectModel
10
+ from dstack._internal.server.schemas.logs import PollLogsRequest
11
+ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
12
+
13
+
14
+ class LogStorageError(DstackError):
15
+ pass
16
+
17
+
18
+ class LogStorage(ABC):
19
+ @abstractmethod
20
+ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def write_logs(
25
+ self,
26
+ project: ProjectModel,
27
+ run_name: str,
28
+ job_submission_id: UUID,
29
+ runner_logs: List[RunnerLogEvent],
30
+ job_logs: List[RunnerLogEvent],
31
+ ) -> None:
32
+ pass
33
+
34
+ def close(self) -> None:
35
+ pass
36
+
37
+
38
+ def unix_time_ms_to_datetime(unix_time_ms: int) -> datetime:
39
+ return datetime.fromtimestamp(unix_time_ms / 1000, tz=timezone.utc)
40
+
41
+
42
+ def datetime_to_unix_time_ms(dt: datetime) -> int:
43
+ return int(dt.timestamp() * 1000)
44
+
45
+
46
+ def b64encode_raw_message(message: bytes) -> str:
47
+ return base64.b64encode(message).decode()
@@ -0,0 +1,110 @@
1
+ from pathlib import Path
2
+ from typing import List, Union
3
+ from uuid import UUID
4
+
5
+ from dstack._internal.core.models.logs import (
6
+ JobSubmissionLogs,
7
+ LogEvent,
8
+ LogEventSource,
9
+ LogProducer,
10
+ )
11
+ from dstack._internal.server import settings
12
+ from dstack._internal.server.models import ProjectModel
13
+ from dstack._internal.server.schemas.logs import PollLogsRequest
14
+ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
15
+ from dstack._internal.server.services.logs.base import (
16
+ LogStorage,
17
+ b64encode_raw_message,
18
+ unix_time_ms_to_datetime,
19
+ )
20
+
21
+
22
+ class FileLogStorage(LogStorage):
23
+ root: Path
24
+
25
+ def __init__(self, root: Union[Path, str, None] = None) -> None:
26
+ if root is None:
27
+ self.root = settings.SERVER_DIR_PATH
28
+ else:
29
+ self.root = Path(root)
30
+
31
+ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
32
+ # TODO Respect request.limit to support pagination
33
+ log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
34
+ log_file_path = self._get_log_file_path(
35
+ project_name=project.name,
36
+ run_name=request.run_name,
37
+ job_submission_id=request.job_submission_id,
38
+ producer=log_producer,
39
+ )
40
+ logs = []
41
+ try:
42
+ with open(log_file_path) as f:
43
+ for line in f:
44
+ log_event = LogEvent.__response__.parse_raw(line)
45
+ if request.start_time and log_event.timestamp <= request.start_time:
46
+ continue
47
+ if request.end_time is None or log_event.timestamp < request.end_time:
48
+ logs.append(log_event)
49
+ else:
50
+ break
51
+ except IOError:
52
+ pass
53
+ if request.descending:
54
+ logs = list(reversed(logs))
55
+ return JobSubmissionLogs(logs=logs)
56
+
57
+ def write_logs(
58
+ self,
59
+ project: ProjectModel,
60
+ run_name: str,
61
+ job_submission_id: UUID,
62
+ runner_logs: List[RunnerLogEvent],
63
+ job_logs: List[RunnerLogEvent],
64
+ ):
65
+ if len(runner_logs) > 0:
66
+ runner_log_file_path = self._get_log_file_path(
67
+ project.name, run_name, job_submission_id, LogProducer.RUNNER
68
+ )
69
+ self._write_logs(
70
+ log_file_path=runner_log_file_path,
71
+ log_events=runner_logs,
72
+ )
73
+ if len(job_logs) > 0:
74
+ job_log_file_path = self._get_log_file_path(
75
+ project.name, run_name, job_submission_id, LogProducer.JOB
76
+ )
77
+ self._write_logs(
78
+ log_file_path=job_log_file_path,
79
+ log_events=job_logs,
80
+ )
81
+
82
+ def _write_logs(self, log_file_path: Path, log_events: List[RunnerLogEvent]) -> None:
83
+ log_events_parsed = [self._runner_log_event_to_log_event(event) for event in log_events]
84
+ log_file_path.parent.mkdir(exist_ok=True, parents=True)
85
+ with open(log_file_path, "a") as f:
86
+ f.writelines(log.json() + "\n" for log in log_events_parsed)
87
+
88
+ def _get_log_file_path(
89
+ self,
90
+ project_name: str,
91
+ run_name: str,
92
+ job_submission_id: UUID,
93
+ producer: LogProducer,
94
+ ) -> Path:
95
+ return (
96
+ self.root
97
+ / "projects"
98
+ / project_name
99
+ / "logs"
100
+ / run_name
101
+ / str(job_submission_id)
102
+ / f"{producer.value}.log"
103
+ )
104
+
105
+ def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> LogEvent:
106
+ return LogEvent(
107
+ timestamp=unix_time_ms_to_datetime(runner_log_event.timestamp),
108
+ log_source=LogEventSource.STDOUT,
109
+ message=b64encode_raw_message(runner_log_event.message),
110
+ )
@@ -0,0 +1,165 @@
1
+ import time
2
+ from typing import Iterable, List
3
+ from uuid import UUID
4
+
5
+ from dstack._internal.core.errors import ServerClientError
6
+ from dstack._internal.core.models.logs import (
7
+ JobSubmissionLogs,
8
+ LogEvent,
9
+ LogEventSource,
10
+ LogProducer,
11
+ )
12
+ from dstack._internal.server.models import ProjectModel
13
+ from dstack._internal.server.schemas.logs import PollLogsRequest
14
+ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
15
+ from dstack._internal.server.services.logs.base import (
16
+ LogStorage,
17
+ LogStorageError,
18
+ b64encode_raw_message,
19
+ unix_time_ms_to_datetime,
20
+ )
21
+ from dstack._internal.utils.common import batched
22
+ from dstack._internal.utils.logging import get_logger
23
+
24
+ GCP_LOGGING_AVAILABLE = True
25
+ try:
26
+ import google.api_core.exceptions
27
+ import google.auth.exceptions
28
+ from google.cloud import logging
29
+ except ImportError:
30
+ GCP_LOGGING_AVAILABLE = False
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class GCPLogStorage(LogStorage):
37
+ # Max expected message size from runner is 32KB.
38
+ # Max expected LogEntry size is 32KB + metadata < 50KB < 256KB limit.
39
+ # With MAX_BATCH_SIZE = 100, max write request size < 5MB < 10 MB limit.
40
+ # See: https://cloud.google.com/logging/quotas.
41
+ MAX_RUNNER_MESSAGE_SIZE = 32 * 1024
42
+ MAX_BATCH_SIZE = 100
43
+
44
+ # Use the same log name for all run logs so that it's easy to manage all dstack-related logs.
45
+ LOG_NAME = "dstack-run-logs"
46
+ # Logs from different jobs belong to different "streams".
47
+ # GCP Logging has no built-in concepts of streams, so we implement them with labels.
48
+ # It should be fast to filter by labels since labels are indexed by default
49
+ # (https://cloud.google.com/logging/docs/analyze/custom-index).
50
+
51
+ def __init__(self, project_id: str):
52
+ try:
53
+ self.client = logging.Client(project=project_id)
54
+ self.logger = self.client.logger(name=self.LOG_NAME)
55
+ self.logger.list_entries(max_results=1)
56
+ # Python client doesn't seem to support dry_run,
57
+ # so emit an empty log to check permissions.
58
+ self.logger.log_empty()
59
+ except google.auth.exceptions.DefaultCredentialsError:
60
+ raise LogStorageError("Default credentials not found")
61
+ except google.api_core.exceptions.NotFound:
62
+ raise LogStorageError(f"Project {project_id} not found")
63
+ except google.api_core.exceptions.PermissionDenied:
64
+ raise LogStorageError("Insufficient permissions")
65
+
66
+ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
67
+ producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
68
+ stream_name = self._get_stream_name(
69
+ project_name=project.name,
70
+ run_name=request.run_name,
71
+ job_submission_id=request.job_submission_id,
72
+ producer=producer,
73
+ )
74
+ log_filters = [f'labels.stream = "{stream_name}"']
75
+ if request.start_time:
76
+ log_filters.append(f'timestamp > "{request.start_time.isoformat()}"')
77
+ if request.end_time:
78
+ log_filters.append(f'timestamp < "{request.end_time.isoformat()}"')
79
+ log_filter = " AND ".join(log_filters)
80
+
81
+ order_by = logging.DESCENDING if request.descending else logging.ASCENDING
82
+ try:
83
+ entries: Iterable[logging.LogEntry] = self.logger.list_entries(
84
+ filter_=log_filter,
85
+ order_by=order_by,
86
+ max_results=request.limit,
87
+ # Specify max possible page_size (<=1000) to reduce number of API calls.
88
+ page_size=request.limit,
89
+ )
90
+ logs = [
91
+ LogEvent(
92
+ timestamp=entry.timestamp,
93
+ message=entry.payload["message"],
94
+ log_source=LogEventSource.STDOUT,
95
+ )
96
+ for entry in entries
97
+ ]
98
+ except google.api_core.exceptions.ResourceExhausted as e:
99
+ logger.warning("GCP Logging exception: %s", repr(e))
100
+ # GCP Logging has severely low quota of 60 reads/min for entries.list
101
+ raise ServerClientError(
102
+ "GCP Logging read request limit exceeded."
103
+ " It's recommended to increase default entries.list request quota from 60 per minute."
104
+ )
105
+ # We intentionally make reading logs slow to prevent hitting GCP quota.
106
+ # This doesn't help with many concurrent clients but
107
+ # should help with one client reading all logs sequentially.
108
+ time.sleep(1)
109
+ return JobSubmissionLogs(logs=logs)
110
+
111
+ def write_logs(
112
+ self,
113
+ project: ProjectModel,
114
+ run_name: str,
115
+ job_submission_id: UUID,
116
+ runner_logs: List[RunnerLogEvent],
117
+ job_logs: List[RunnerLogEvent],
118
+ ):
119
+ producers_with_logs = [(LogProducer.RUNNER, runner_logs), (LogProducer.JOB, job_logs)]
120
+ for producer, producer_logs in producers_with_logs:
121
+ stream_name = self._get_stream_name(
122
+ project_name=project.name,
123
+ run_name=run_name,
124
+ job_submission_id=job_submission_id,
125
+ producer=producer,
126
+ )
127
+ self._write_logs_to_stream(
128
+ stream_name=stream_name,
129
+ logs=producer_logs,
130
+ )
131
+
132
+ def close(self):
133
+ self.client.close()
134
+
135
+ def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]):
136
+ with self.logger.batch() as batcher:
137
+ for batch in batched(logs, self.MAX_BATCH_SIZE):
138
+ for log in batch:
139
+ message = b64encode_raw_message(log.message)
140
+ timestamp = unix_time_ms_to_datetime(log.timestamp)
141
+ # as message is base64-encoded, length in bytes = length in code points
142
+ if len(message) > self.MAX_RUNNER_MESSAGE_SIZE:
143
+ logger.error(
144
+ "Stream %s: skipping event at %s, message exceeds max size: %d > %d",
145
+ stream_name,
146
+ timestamp.isoformat(),
147
+ len(message),
148
+ self.MAX_RUNNER_MESSAGE_SIZE,
149
+ )
150
+ continue
151
+ batcher.log_struct(
152
+ {
153
+ "message": message,
154
+ },
155
+ labels={
156
+ "stream": stream_name,
157
+ },
158
+ timestamp=timestamp,
159
+ )
160
+ batcher.commit()
161
+
162
+ def _get_stream_name(
163
+ self, project_name: str, run_name: str, job_submission_id: UUID, producer: LogProducer
164
+ ) -> str:
165
+ return f"{project_name}-{run_name}-{job_submission_id}-{producer.value}"
@@ -50,35 +50,35 @@ async def get_offers_by_requirements(
50
50
  if volumes:
51
51
  mount_point_volumes = volumes[0]
52
52
  volumes_backend_types = [v.configuration.backend for v in mount_point_volumes]
53
- if not backend_types:
53
+ if backend_types is None:
54
54
  backend_types = volumes_backend_types
55
55
  backend_types = [b for b in backend_types if b in volumes_backend_types]
56
56
  volumes_regions = [v.configuration.region for v in mount_point_volumes]
57
- if not regions:
57
+ if regions is None:
58
58
  regions = volumes_regions
59
59
  regions = [r for r in regions if r in volumes_regions]
60
60
 
61
61
  if multinode:
62
- if not backend_types:
62
+ if backend_types is None:
63
63
  backend_types = BACKENDS_WITH_MULTINODE_SUPPORT
64
64
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_MULTINODE_SUPPORT]
65
65
 
66
66
  if privileged or instance_mounts:
67
- if not backend_types:
67
+ if backend_types is None:
68
68
  backend_types = BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
69
69
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT]
70
70
 
71
71
  if profile.reservation is not None:
72
- if not backend_types:
72
+ if backend_types is None:
73
73
  backend_types = BACKENDS_WITH_RESERVATION_SUPPORT
74
74
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_RESERVATION_SUPPORT]
75
75
 
76
76
  # For multi-node, restrict backend and region.
77
77
  # The default behavior is to provision all nodes in the same backend and region.
78
78
  if master_job_provisioning_data is not None:
79
- if not backend_types:
79
+ if backend_types is None:
80
80
  backend_types = [master_job_provisioning_data.get_base_backend()]
81
- if not regions:
81
+ if regions is None:
82
82
  regions = [master_job_provisioning_data.region]
83
83
  backend_types = [
84
84
  b for b in backend_types if b == master_job_provisioning_data.get_base_backend()