modal 1.1.0__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- {modal-1.1.0 → modal-1.1.1}/PKG-INFO +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/__main__.py +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/_clustered_functions.py +3 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_clustered_functions.pyi +3 -2
- {modal-1.1.0 → modal-1.1.1}/modal/_functions.py +78 -26
- {modal-1.1.0 → modal-1.1.1}/modal/_object.py +9 -1
- {modal-1.1.0 → modal-1.1.1}/modal/_output.py +14 -25
- modal-1.1.1/modal/_runtime/gpu_memory_snapshot.py +303 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/async_utils.py +6 -4
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/auth_token_manager.py +1 -1
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/blob_utils.py +16 -21
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/function_utils.py +16 -4
- modal-1.1.1/modal/_utils/time_utils.py +19 -0
- {modal-1.1.0 → modal-1.1.1}/modal/app.py +0 -4
- {modal-1.1.0 → modal-1.1.1}/modal/app.pyi +0 -4
- {modal-1.1.0 → modal-1.1.1}/modal/cli/_traceback.py +3 -2
- {modal-1.1.0 → modal-1.1.1}/modal/cli/app.py +4 -4
- {modal-1.1.0 → modal-1.1.1}/modal/cli/cluster.py +4 -4
- {modal-1.1.0 → modal-1.1.1}/modal/cli/config.py +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/cli/container.py +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/cli/dict.py +4 -4
- {modal-1.1.0 → modal-1.1.1}/modal/cli/entry_point.py +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/cli/import_refs.py +3 -3
- {modal-1.1.0 → modal-1.1.1}/modal/cli/network_file_system.py +8 -9
- {modal-1.1.0 → modal-1.1.1}/modal/cli/profile.py +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/cli/queues.py +5 -5
- {modal-1.1.0 → modal-1.1.1}/modal/cli/secret.py +5 -5
- {modal-1.1.0 → modal-1.1.1}/modal/cli/utils.py +3 -4
- {modal-1.1.0 → modal-1.1.1}/modal/cli/volume.py +8 -9
- {modal-1.1.0 → modal-1.1.1}/modal/client.py +8 -1
- {modal-1.1.0 → modal-1.1.1}/modal/client.pyi +9 -2
- {modal-1.1.0 → modal-1.1.1}/modal/container_process.py +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal/dict.py +47 -3
- {modal-1.1.0 → modal-1.1.1}/modal/dict.pyi +55 -0
- {modal-1.1.0 → modal-1.1.1}/modal/exception.py +4 -0
- {modal-1.1.0 → modal-1.1.1}/modal/experimental/__init__.py +1 -1
- {modal-1.1.0 → modal-1.1.1}/modal/experimental/flash.py +18 -2
- {modal-1.1.0 → modal-1.1.1}/modal/experimental/flash.pyi +19 -0
- {modal-1.1.0 → modal-1.1.1}/modal/functions.pyi +0 -1
- {modal-1.1.0 → modal-1.1.1}/modal/image.py +26 -10
- {modal-1.1.0 → modal-1.1.1}/modal/image.pyi +12 -4
- {modal-1.1.0 → modal-1.1.1}/modal/mount.py +1 -1
- {modal-1.1.0 → modal-1.1.1}/modal/object.pyi +4 -0
- {modal-1.1.0 → modal-1.1.1}/modal/parallel_map.py +432 -4
- {modal-1.1.0 → modal-1.1.1}/modal/parallel_map.pyi +28 -0
- {modal-1.1.0 → modal-1.1.1}/modal/queue.py +46 -3
- {modal-1.1.0 → modal-1.1.1}/modal/queue.pyi +53 -0
- {modal-1.1.0 → modal-1.1.1}/modal/sandbox.py +105 -25
- {modal-1.1.0 → modal-1.1.1}/modal/sandbox.pyi +108 -18
- {modal-1.1.0 → modal-1.1.1}/modal/secret.py +48 -5
- {modal-1.1.0 → modal-1.1.1}/modal/secret.pyi +55 -0
- {modal-1.1.0 → modal-1.1.1}/modal/token_flow.py +3 -3
- {modal-1.1.0 → modal-1.1.1}/modal/volume.py +49 -18
- {modal-1.1.0 → modal-1.1.1}/modal/volume.pyi +50 -8
- {modal-1.1.0 → modal-1.1.1}/modal.egg-info/PKG-INFO +2 -2
- {modal-1.1.0 → modal-1.1.1}/modal.egg-info/SOURCES.txt +8 -8
- {modal-1.1.0 → modal-1.1.1}/modal.egg-info/requires.txt +1 -1
- {modal-1.1.0 → modal-1.1.1}/modal_proto/api.proto +140 -14
- {modal-1.1.0 → modal-1.1.1}/modal_proto/api_grpc.py +80 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/api_pb2.py +927 -756
- {modal-1.1.0 → modal-1.1.1}/modal_proto/api_pb2.pyi +488 -34
- {modal-1.1.0 → modal-1.1.1}/modal_proto/api_pb2_grpc.py +166 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/api_pb2_grpc.pyi +52 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/modal_api_grpc.py +5 -0
- {modal-1.1.0 → modal-1.1.1}/modal_version/__init__.py +1 -1
- {modal-1.1.0 → modal-1.1.1}/pyproject.toml +2 -2
- modal-1.1.0/modal/_runtime/gpu_memory_snapshot.py +0 -199
- modal-1.1.0/modal/_utils/time_utils.py +0 -15
- {modal-1.1.0 → modal-1.1.1}/LICENSE +0 -0
- {modal-1.1.0 → modal-1.1.1}/README.md +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_container_entrypoint.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_ipython.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_location.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_partial_function.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_pty.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_resolver.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_resources.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/asgi.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/container_io_manager.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/container_io_manager.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/execution_context.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/execution_context.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/telemetry.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_runtime/user_code_imports.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_serialization.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_traceback.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_tunnel.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_tunnel.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_type_manager.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/app_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/bytes_io_segment_payload.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/deprecation.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/docker_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/git_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/grpc_testing.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/grpc_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/hash_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/http_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/jwt_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/logger.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/mount_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/name_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/package_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/pattern_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/rand_pb_testing.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_utils/shell_utils.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_vendor/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_vendor/a2wsgi_wsgi.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_vendor/cloudpickle.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_vendor/tblib.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/_watcher.py +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/2023.12.312.txt +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/2023.12.txt +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/2024.04.txt +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/2024.10.txt +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/2025.06.txt +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/PREVIEW.txt +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/README.md +0 -0
- {modal-1.1.0/modal/requirements → modal-1.1.1/modal/builder}/base-images.json +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/call_graph.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/_download.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/environment.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/launch.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/programs/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/programs/run_jupyter.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/programs/vscode.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/run.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cli/token.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cloud_bucket_mount.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cloud_bucket_mount.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cls.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/cls.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/config.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/container_process.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/environments.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/environments.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/experimental/ipython.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/file_io.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/file_io.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/file_pattern_matcher.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/functions.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/gpu.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/io_streams.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/io_streams.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/mount.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/network_file_system.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/network_file_system.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/object.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/output.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/partial_function.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/partial_function.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/proxy.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/proxy.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/py.typed +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/retries.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/runner.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/runner.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/running_app.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/schedule.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/scheduler_placement.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/serving.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/serving.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/snapshot.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/snapshot.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/stream_type.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal/token_flow.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal.egg-info/dependency_links.txt +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal.egg-info/entry_points.txt +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal.egg-info/top_level.txt +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_docs/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_docs/gen_cli_docs.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_docs/gen_reference_docs.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_docs/mdmd/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_docs/mdmd/mdmd.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_docs/mdmd/signatures.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/__init__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/modal_options_grpc.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/options.proto +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/options_grpc.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/options_pb2.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/options_pb2.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/options_pb2_grpc.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/options_pb2_grpc.pyi +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_proto/py.typed +0 -0
- {modal-1.1.0 → modal-1.1.1}/modal_version/__main__.py +0 -0
- {modal-1.1.0 → modal-1.1.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: modal
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.1
|
|
4
4
|
Summary: Python client library for Modal
|
|
5
5
|
Author-email: Modal Labs <support@modal.com>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
|
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: aiohttp
|
|
20
20
|
Requires-Dist: certifi
|
|
21
|
-
Requires-Dist: click~=8.1
|
|
21
|
+
Requires-Dist: click~=8.1
|
|
22
22
|
Requires-Dist: grpclib<0.4.9,>=0.4.7
|
|
23
23
|
Requires-Dist: protobuf!=4.24.0,<7.0,>=3.19
|
|
24
24
|
Requires-Dist: rich>=12.0.0
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright Modal Labs 2022
|
|
2
2
|
import sys
|
|
3
3
|
|
|
4
|
+
from ._output import make_console
|
|
4
5
|
from ._traceback import reduce_traceback_to_user_code
|
|
5
6
|
from .cli._traceback import highlight_modal_warnings, setup_rich_traceback
|
|
6
7
|
from .cli.entry_point import entrypoint_cli
|
|
@@ -35,7 +36,6 @@ def main():
|
|
|
35
36
|
raise
|
|
36
37
|
|
|
37
38
|
from grpclib import GRPCError, Status
|
|
38
|
-
from rich.console import Console
|
|
39
39
|
from rich.panel import Panel
|
|
40
40
|
from rich.text import Text
|
|
41
41
|
|
|
@@ -68,7 +68,7 @@ def main():
|
|
|
68
68
|
if notes := getattr(exc, "__notes__", []):
|
|
69
69
|
content = f"{content}\n\nNote: {' '.join(notes)}"
|
|
70
70
|
|
|
71
|
-
console =
|
|
71
|
+
console = make_console(stderr=True)
|
|
72
72
|
panel = Panel(Text(content), title=title, title_align="left", border_style="red")
|
|
73
73
|
console.print(panel, highlight=False)
|
|
74
74
|
sys.exit(1)
|
|
@@ -15,6 +15,7 @@ from modal_proto import api_pb2
|
|
|
15
15
|
class ClusterInfo:
|
|
16
16
|
rank: int
|
|
17
17
|
container_ips: list[str]
|
|
18
|
+
container_ipv4_ips: list[str]
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
cluster_info: Optional[ClusterInfo] = None
|
|
@@ -69,11 +70,13 @@ async def _initialize_clustered_function(client: _Client, task_id: str, world_si
|
|
|
69
70
|
cluster_info = ClusterInfo(
|
|
70
71
|
rank=resp.cluster_rank,
|
|
71
72
|
container_ips=resp.container_ips,
|
|
73
|
+
container_ipv4_ips=resp.container_ipv4_ips,
|
|
72
74
|
)
|
|
73
75
|
else:
|
|
74
76
|
cluster_info = ClusterInfo(
|
|
75
77
|
rank=0,
|
|
76
78
|
container_ips=[container_ip],
|
|
79
|
+
container_ipv4_ips=[], # No IPv4 IPs for single-node
|
|
77
80
|
)
|
|
78
81
|
|
|
79
82
|
|
|
@@ -3,12 +3,13 @@ import typing
|
|
|
3
3
|
import typing_extensions
|
|
4
4
|
|
|
5
5
|
class ClusterInfo:
|
|
6
|
-
"""ClusterInfo(rank: int, container_ips: list[str])"""
|
|
6
|
+
"""ClusterInfo(rank: int, container_ips: list[str], container_ipv4_ips: list[str])"""
|
|
7
7
|
|
|
8
8
|
rank: int
|
|
9
9
|
container_ips: list[str]
|
|
10
|
+
container_ipv4_ips: list[str]
|
|
10
11
|
|
|
11
|
-
def __init__(self, rank: int, container_ips: list[str]) -> None:
|
|
12
|
+
def __init__(self, rank: int, container_ips: list[str], container_ipv4_ips: list[str]) -> None:
|
|
12
13
|
"""Initialize self. See help(type(self)) for accurate signature."""
|
|
13
14
|
...
|
|
14
15
|
|
|
@@ -75,6 +75,7 @@ from .parallel_map import (
|
|
|
75
75
|
_for_each_sync,
|
|
76
76
|
_map_async,
|
|
77
77
|
_map_invocation,
|
|
78
|
+
_map_invocation_inputplane,
|
|
78
79
|
_map_sync,
|
|
79
80
|
_spawn_map_async,
|
|
80
81
|
_spawn_map_sync,
|
|
@@ -399,7 +400,8 @@ class _InputPlaneInvocation:
|
|
|
399
400
|
parent_input_id=current_input_id() or "",
|
|
400
401
|
input=input_item,
|
|
401
402
|
)
|
|
402
|
-
|
|
403
|
+
|
|
404
|
+
metadata = await client.get_input_plane_metadata(input_plane_region)
|
|
403
405
|
response = await retry_transient_errors(stub.AttemptStart, request, metadata=metadata)
|
|
404
406
|
attempt_token = response.attempt_token
|
|
405
407
|
|
|
@@ -415,7 +417,7 @@ class _InputPlaneInvocation:
|
|
|
415
417
|
timeout_secs=OUTPUTS_TIMEOUT,
|
|
416
418
|
requested_at=time.time(),
|
|
417
419
|
)
|
|
418
|
-
metadata = await self.
|
|
420
|
+
metadata = await self.client.get_input_plane_metadata(self.input_plane_region)
|
|
419
421
|
await_response: api_pb2.AttemptAwaitResponse = await retry_transient_errors(
|
|
420
422
|
self.stub.AttemptAwait,
|
|
421
423
|
await_request,
|
|
@@ -451,6 +453,33 @@ class _InputPlaneInvocation:
|
|
|
451
453
|
await_response.output.result, await_response.output.data_format, control_plane_stub, self.client
|
|
452
454
|
)
|
|
453
455
|
|
|
456
|
+
async def run_generator(self):
|
|
457
|
+
items_received = 0
|
|
458
|
+
# populated when self.run_function() completes
|
|
459
|
+
items_total: Union[int, None] = None
|
|
460
|
+
async with aclosing(
|
|
461
|
+
async_merge(
|
|
462
|
+
_stream_function_call_data(
|
|
463
|
+
self.client,
|
|
464
|
+
self.stub,
|
|
465
|
+
"",
|
|
466
|
+
variant="data_out",
|
|
467
|
+
attempt_token=self.attempt_token,
|
|
468
|
+
),
|
|
469
|
+
callable_to_agen(self.run_function),
|
|
470
|
+
)
|
|
471
|
+
) as streamer:
|
|
472
|
+
async for item in streamer:
|
|
473
|
+
if isinstance(item, api_pb2.GeneratorDone):
|
|
474
|
+
items_total = item.items_total
|
|
475
|
+
else:
|
|
476
|
+
yield item
|
|
477
|
+
items_received += 1
|
|
478
|
+
# The comparison avoids infinite loops if a non-deterministic generator is retried
|
|
479
|
+
# and produces less data in the second run than what was already sent.
|
|
480
|
+
if items_total is not None and items_received >= items_total:
|
|
481
|
+
break
|
|
482
|
+
|
|
454
483
|
@staticmethod
|
|
455
484
|
async def _get_metadata(input_plane_region: str, client: _Client) -> list[tuple[str, str]]:
|
|
456
485
|
if not input_plane_region:
|
|
@@ -600,7 +629,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
|
600
629
|
experimental_options: Optional[dict[str, str]] = None,
|
|
601
630
|
_experimental_proxy_ip: Optional[str] = None,
|
|
602
631
|
_experimental_custom_scaling_factor: Optional[float] = None,
|
|
603
|
-
_experimental_enable_gpu_snapshot: bool = False,
|
|
604
632
|
) -> "_Function":
|
|
605
633
|
"""mdmd:hidden"""
|
|
606
634
|
# Needed to avoid circular imports
|
|
@@ -901,7 +929,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
|
901
929
|
_experimental_concurrent_cancellations=True,
|
|
902
930
|
_experimental_proxy_ip=_experimental_proxy_ip,
|
|
903
931
|
_experimental_custom_scaling=_experimental_custom_scaling_factor is not None,
|
|
904
|
-
_experimental_enable_gpu_snapshot=_experimental_enable_gpu_snapshot,
|
|
905
932
|
# --- These are deprecated in favor of autoscaler_settings
|
|
906
933
|
warm_pool_size=min_containers or 0,
|
|
907
934
|
concurrency_limit=max_containers or 0,
|
|
@@ -938,7 +965,6 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
|
938
965
|
_experimental_group_size=function_definition._experimental_group_size,
|
|
939
966
|
_experimental_buffer_containers=function_definition._experimental_buffer_containers,
|
|
940
967
|
_experimental_custom_scaling=function_definition._experimental_custom_scaling,
|
|
941
|
-
_experimental_enable_gpu_snapshot=_experimental_enable_gpu_snapshot,
|
|
942
968
|
_experimental_proxy_ip=function_definition._experimental_proxy_ip,
|
|
943
969
|
snapshot_debug=function_definition.snapshot_debug,
|
|
944
970
|
runtime_perf_record=function_definition.runtime_perf_record,
|
|
@@ -1487,20 +1513,35 @@ Use the `Function.get_web_url()` method instead.
|
|
|
1487
1513
|
else:
|
|
1488
1514
|
count_update_callback = None
|
|
1489
1515
|
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1516
|
+
if self._input_plane_url:
|
|
1517
|
+
async with aclosing(
|
|
1518
|
+
_map_invocation_inputplane(
|
|
1519
|
+
self,
|
|
1520
|
+
input_queue,
|
|
1521
|
+
self.client,
|
|
1522
|
+
order_outputs,
|
|
1523
|
+
return_exceptions,
|
|
1524
|
+
wrap_returned_exceptions,
|
|
1525
|
+
count_update_callback,
|
|
1526
|
+
)
|
|
1527
|
+
) as stream:
|
|
1528
|
+
async for item in stream:
|
|
1529
|
+
yield item
|
|
1530
|
+
else:
|
|
1531
|
+
async with aclosing(
|
|
1532
|
+
_map_invocation(
|
|
1533
|
+
self,
|
|
1534
|
+
input_queue,
|
|
1535
|
+
self.client,
|
|
1536
|
+
order_outputs,
|
|
1537
|
+
return_exceptions,
|
|
1538
|
+
wrap_returned_exceptions,
|
|
1539
|
+
count_update_callback,
|
|
1540
|
+
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
|
|
1541
|
+
)
|
|
1542
|
+
) as stream:
|
|
1543
|
+
async for item in stream:
|
|
1544
|
+
yield item
|
|
1504
1545
|
|
|
1505
1546
|
async def _call_function(self, args, kwargs) -> ReturnType:
|
|
1506
1547
|
invocation: Union[_Invocation, _InputPlaneInvocation]
|
|
@@ -1544,13 +1585,24 @@ Use the `Function.get_web_url()` method instead.
|
|
|
1544
1585
|
@live_method_gen
|
|
1545
1586
|
@synchronizer.no_input_translation
|
|
1546
1587
|
async def _call_generator(self, args, kwargs):
|
|
1547
|
-
invocation
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1588
|
+
invocation: Union[_Invocation, _InputPlaneInvocation]
|
|
1589
|
+
if self._input_plane_url:
|
|
1590
|
+
invocation = await _InputPlaneInvocation.create(
|
|
1591
|
+
self,
|
|
1592
|
+
args,
|
|
1593
|
+
kwargs,
|
|
1594
|
+
client=self.client,
|
|
1595
|
+
input_plane_url=self._input_plane_url,
|
|
1596
|
+
input_plane_region=self._input_plane_region,
|
|
1597
|
+
)
|
|
1598
|
+
else:
|
|
1599
|
+
invocation = await _Invocation.create(
|
|
1600
|
+
self,
|
|
1601
|
+
args,
|
|
1602
|
+
kwargs,
|
|
1603
|
+
client=self.client,
|
|
1604
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
|
|
1605
|
+
)
|
|
1554
1606
|
async for res in invocation.run_generator():
|
|
1555
1607
|
yield res
|
|
1556
1608
|
|
|
@@ -48,6 +48,10 @@ class _Object:
|
|
|
48
48
|
_is_hydrated: bool
|
|
49
49
|
_is_rehydrated: bool
|
|
50
50
|
|
|
51
|
+
# Not all object subclasses have a meaningful "name" concept
|
|
52
|
+
# So whether they expose this is a matter of having a name property
|
|
53
|
+
_name: Optional[str]
|
|
54
|
+
|
|
51
55
|
@classmethod
|
|
52
56
|
def __init_subclass__(cls, type_prefix: Optional[str] = None):
|
|
53
57
|
super().__init_subclass__()
|
|
@@ -68,6 +72,7 @@ class _Object:
|
|
|
68
72
|
hydrate_lazily: bool = False,
|
|
69
73
|
deps: Optional[Callable[..., Sequence["_Object"]]] = None,
|
|
70
74
|
deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
|
|
75
|
+
name: Optional[str] = None,
|
|
71
76
|
):
|
|
72
77
|
self._local_uuid = str(uuid.uuid4())
|
|
73
78
|
self._load = load
|
|
@@ -83,6 +88,8 @@ class _Object:
|
|
|
83
88
|
self._is_hydrated = False
|
|
84
89
|
self._is_rehydrated = False
|
|
85
90
|
|
|
91
|
+
self._name = name
|
|
92
|
+
|
|
86
93
|
self._initialize_from_empty()
|
|
87
94
|
|
|
88
95
|
def _unhydrate(self):
|
|
@@ -163,10 +170,11 @@ class _Object:
|
|
|
163
170
|
hydrate_lazily: bool = False,
|
|
164
171
|
deps: Optional[Callable[..., Sequence["_Object"]]] = None,
|
|
165
172
|
deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
|
|
173
|
+
name: Optional[str] = None,
|
|
166
174
|
):
|
|
167
175
|
# TODO(erikbern): flip the order of the two first arguments
|
|
168
176
|
obj = _Object.__new__(cls)
|
|
169
|
-
obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key)
|
|
177
|
+
obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key, name)
|
|
170
178
|
return obj
|
|
171
179
|
|
|
172
180
|
@staticmethod
|
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|
|
4
4
|
import asyncio
|
|
5
5
|
import contextlib
|
|
6
6
|
import functools
|
|
7
|
-
import io
|
|
8
7
|
import platform
|
|
9
8
|
import re
|
|
10
9
|
import socket
|
|
@@ -32,7 +31,7 @@ from rich.progress import (
|
|
|
32
31
|
from rich.spinner import Spinner
|
|
33
32
|
from rich.text import Text
|
|
34
33
|
|
|
35
|
-
from modal._utils.time_utils import
|
|
34
|
+
from modal._utils.time_utils import timestamp_to_localized_str
|
|
36
35
|
from modal_proto import api_pb2
|
|
37
36
|
|
|
38
37
|
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
|
|
@@ -46,6 +45,16 @@ else:
|
|
|
46
45
|
default_spinner = "dots"
|
|
47
46
|
|
|
48
47
|
|
|
48
|
+
def make_console(*, stderr: bool = False, highlight: bool = True) -> Console:
|
|
49
|
+
"""Create a rich Console tuned for Modal CLI output."""
|
|
50
|
+
return Console(
|
|
51
|
+
stderr=stderr,
|
|
52
|
+
highlight=highlight,
|
|
53
|
+
# CLI does not work with auto-detected Jupyter HTML display_data.
|
|
54
|
+
force_jupyter=False,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
49
58
|
class FunctionQueuingColumn(ProgressColumn):
|
|
50
59
|
"""Renders time elapsed, including task.completed as additional elapsed time."""
|
|
51
60
|
|
|
@@ -63,25 +72,6 @@ class FunctionQueuingColumn(ProgressColumn):
|
|
|
63
72
|
return Text(str(delta), style="progress.elapsed")
|
|
64
73
|
|
|
65
74
|
|
|
66
|
-
def download_progress_bar() -> Progress:
|
|
67
|
-
"""
|
|
68
|
-
Returns a progress bar suitable for showing file download progress.
|
|
69
|
-
Requires passing a `path: str` data field for rendering.
|
|
70
|
-
"""
|
|
71
|
-
return Progress(
|
|
72
|
-
TextColumn("[bold white]{task.fields[path]}", justify="right"),
|
|
73
|
-
BarColumn(bar_width=None),
|
|
74
|
-
"[progress.percentage]{task.percentage:>3.1f}%",
|
|
75
|
-
"•",
|
|
76
|
-
DownloadColumn(),
|
|
77
|
-
"•",
|
|
78
|
-
TransferSpeedColumn(),
|
|
79
|
-
"•",
|
|
80
|
-
TimeRemainingColumn(),
|
|
81
|
-
transient=True,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
|
|
85
75
|
class LineBufferedOutput:
|
|
86
76
|
"""Output stream that buffers lines and passes them to a callback."""
|
|
87
77
|
|
|
@@ -101,7 +91,7 @@ class LineBufferedOutput:
|
|
|
101
91
|
|
|
102
92
|
if self._show_timestamps:
|
|
103
93
|
for i in range(0, len(chunks) - 1, 2):
|
|
104
|
-
chunks[i] = f"{
|
|
94
|
+
chunks[i] = f"{timestamp_to_localized_str(log.timestamp)} {chunks[i]}"
|
|
105
95
|
|
|
106
96
|
completed_lines = "".join(chunks[:-1])
|
|
107
97
|
remainder = chunks[-1]
|
|
@@ -147,12 +137,11 @@ class OutputManager:
|
|
|
147
137
|
def __init__(
|
|
148
138
|
self,
|
|
149
139
|
*,
|
|
150
|
-
stdout: io.TextIOWrapper | None = None,
|
|
151
140
|
status_spinner_text: str = "Running app...",
|
|
152
141
|
show_timestamps: bool = False,
|
|
153
142
|
):
|
|
154
|
-
self._stdout =
|
|
155
|
-
self._console =
|
|
143
|
+
self._stdout = sys.stdout
|
|
144
|
+
self._console = make_console(highlight=False)
|
|
156
145
|
self._task_states = {}
|
|
157
146
|
self._task_progress_items = {}
|
|
158
147
|
self._current_render_group = None
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
# Copyright Modal Labs 2022
|
|
2
|
+
#
|
|
3
|
+
# This module provides a simple interface for creating GPU memory snapshots,
|
|
4
|
+
# providing a convenient interface to `cuda-checkpoint` [1]. This is intended
|
|
5
|
+
# to be used in conjunction with memory snapshots.
|
|
6
|
+
#
|
|
7
|
+
# [1] https://github.com/NVIDIA/cuda-checkpoint
|
|
8
|
+
|
|
9
|
+
import subprocess
|
|
10
|
+
import time
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import List, Optional
|
|
16
|
+
|
|
17
|
+
from modal.config import config, logger
|
|
18
|
+
|
|
19
|
+
CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CudaCheckpointState(Enum):
|
|
23
|
+
"""State representation from the CUDA API [1].
|
|
24
|
+
|
|
25
|
+
[1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
|
|
26
|
+
|
|
27
|
+
RUNNING = "running"
|
|
28
|
+
LOCKED = "locked"
|
|
29
|
+
CHECKPOINTED = "checkpointed"
|
|
30
|
+
FAILED = "failed"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CudaCheckpointException(Exception):
|
|
34
|
+
"""Exception raised for CUDA checkpoint operations."""
|
|
35
|
+
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class CudaCheckpointProcess:
|
|
41
|
+
"""Contains a reference to a PID with active CUDA session. This also provides
|
|
42
|
+
methods for checkpointing and restoring GPU memory."""
|
|
43
|
+
|
|
44
|
+
pid: int
|
|
45
|
+
state: CudaCheckpointState
|
|
46
|
+
|
|
47
|
+
def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0) -> None:
|
|
48
|
+
"""Toggle CUDA checkpoint state for current process, moving GPU memory to the
|
|
49
|
+
CPU and back depending on the current process state when called.
|
|
50
|
+
"""
|
|
51
|
+
logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
|
|
52
|
+
|
|
53
|
+
start_time = time.monotonic()
|
|
54
|
+
retry_count = 0
|
|
55
|
+
max_retries = 3
|
|
56
|
+
|
|
57
|
+
while self._should_continue_toggle(target_state, start_time, timeout_secs):
|
|
58
|
+
try:
|
|
59
|
+
self._execute_toggle_command()
|
|
60
|
+
# Use exponential backoff for retries
|
|
61
|
+
sleep_time = min(0.1 * (2**retry_count), 1.0)
|
|
62
|
+
time.sleep(sleep_time)
|
|
63
|
+
retry_count = 0
|
|
64
|
+
except CudaCheckpointException as e:
|
|
65
|
+
retry_count += 1
|
|
66
|
+
if retry_count >= max_retries:
|
|
67
|
+
raise CudaCheckpointException(
|
|
68
|
+
f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
|
|
69
|
+
)
|
|
70
|
+
logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
|
|
71
|
+
time.sleep(0.5 * retry_count)
|
|
72
|
+
|
|
73
|
+
logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
|
|
74
|
+
|
|
75
|
+
def _should_continue_toggle(
|
|
76
|
+
self, target_state: CudaCheckpointState, start_time: float, timeout_secs: float
|
|
77
|
+
) -> bool:
|
|
78
|
+
"""Check if toggle operation should continue based on current state and timeout."""
|
|
79
|
+
self.refresh_state()
|
|
80
|
+
|
|
81
|
+
if self.state == target_state:
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
if self.state == CudaCheckpointState.FAILED:
|
|
85
|
+
raise CudaCheckpointException(f"PID: {self.pid} CUDA process state is {self.state}")
|
|
86
|
+
|
|
87
|
+
elapsed = time.monotonic() - start_time
|
|
88
|
+
if elapsed >= timeout_secs:
|
|
89
|
+
raise CudaCheckpointException(
|
|
90
|
+
f"PID: {self.pid} Timeout after {elapsed:.2f}s waiting for state {target_state.value}. "
|
|
91
|
+
f"Current state: {self.state}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return True
|
|
95
|
+
|
|
96
|
+
def _execute_toggle_command(self) -> None:
|
|
97
|
+
"""Execute the cuda-checkpoint toggle command."""
|
|
98
|
+
try:
|
|
99
|
+
_ = subprocess.run(
|
|
100
|
+
[CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
|
|
101
|
+
check=True,
|
|
102
|
+
capture_output=True,
|
|
103
|
+
text=True,
|
|
104
|
+
timeout=30,
|
|
105
|
+
)
|
|
106
|
+
logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
|
|
107
|
+
except subprocess.CalledProcessError as e:
|
|
108
|
+
error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
|
|
109
|
+
logger.debug(error_msg)
|
|
110
|
+
raise CudaCheckpointException(error_msg)
|
|
111
|
+
except subprocess.TimeoutExpired:
|
|
112
|
+
error_msg = f"PID: {self.pid} Toggle command timed out"
|
|
113
|
+
logger.debug(error_msg)
|
|
114
|
+
raise CudaCheckpointException(error_msg)
|
|
115
|
+
|
|
116
|
+
def refresh_state(self) -> None:
|
|
117
|
+
"""Refreshes the current CUDA checkpoint state for this process."""
|
|
118
|
+
try:
|
|
119
|
+
result = subprocess.run(
|
|
120
|
+
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(self.pid)],
|
|
121
|
+
check=True,
|
|
122
|
+
capture_output=True,
|
|
123
|
+
text=True,
|
|
124
|
+
timeout=10,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
state_str = result.stdout.strip().lower()
|
|
128
|
+
self.state = CudaCheckpointState(state_str)
|
|
129
|
+
|
|
130
|
+
except subprocess.CalledProcessError as e:
|
|
131
|
+
error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
|
|
132
|
+
logger.debug(error_msg)
|
|
133
|
+
raise CudaCheckpointException(error_msg)
|
|
134
|
+
except subprocess.TimeoutExpired:
|
|
135
|
+
error_msg = f"PID: {self.pid} Get state command timed out"
|
|
136
|
+
logger.debug(error_msg)
|
|
137
|
+
raise CudaCheckpointException(error_msg)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class CudaCheckpointSession:
|
|
141
|
+
"""Manages the checkpointing state of processes with active CUDA sessions."""
|
|
142
|
+
|
|
143
|
+
def __init__(self):
|
|
144
|
+
self.cuda_processes = self._get_cuda_pids()
|
|
145
|
+
if self.cuda_processes:
|
|
146
|
+
logger.debug(
|
|
147
|
+
f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
logger.debug("No CUDA sessions found.")
|
|
151
|
+
|
|
152
|
+
def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
|
|
153
|
+
"""Iterates over all PIDs and identifies the ones that have running
|
|
154
|
+
CUDA sessions."""
|
|
155
|
+
cuda_pids: List[CudaCheckpointProcess] = []
|
|
156
|
+
|
|
157
|
+
# Get all active process IDs from /proc directory
|
|
158
|
+
proc_dir = Path("/proc")
|
|
159
|
+
if not proc_dir.exists():
|
|
160
|
+
raise CudaCheckpointException(
|
|
161
|
+
"OS does not have /proc path rendering it incompatible with GPU memory snapshots."
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Get all numeric directories (PIDs) from /proc
|
|
165
|
+
pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
|
|
166
|
+
|
|
167
|
+
# Use ThreadPoolExecutor to check PIDs in parallel for better performance
|
|
168
|
+
with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
|
|
169
|
+
future_to_pid = {
|
|
170
|
+
executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
for future in as_completed(future_to_pid):
|
|
174
|
+
pid = future_to_pid[future]
|
|
175
|
+
try:
|
|
176
|
+
cuda_process = future.result()
|
|
177
|
+
if cuda_process:
|
|
178
|
+
cuda_pids.append(cuda_process)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
181
|
+
|
|
182
|
+
# Sort PIDs for ordered checkpointing
|
|
183
|
+
cuda_pids.sort(key=lambda x: x.pid)
|
|
184
|
+
return cuda_pids
|
|
185
|
+
|
|
186
|
+
def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
|
|
187
|
+
"""Check if a specific PID has a CUDA session."""
|
|
188
|
+
try:
|
|
189
|
+
result = subprocess.run(
|
|
190
|
+
[CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
|
|
191
|
+
capture_output=True,
|
|
192
|
+
text=True,
|
|
193
|
+
timeout=5,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# If the command succeeds (return code 0), this PID has a CUDA session
|
|
197
|
+
if result.returncode == 0:
|
|
198
|
+
state_str = result.stdout.strip().lower()
|
|
199
|
+
state = CudaCheckpointState(state_str)
|
|
200
|
+
return CudaCheckpointProcess(pid=pid, state=state)
|
|
201
|
+
|
|
202
|
+
except subprocess.CalledProcessError:
|
|
203
|
+
# Command failed, which is expected for PIDs without CUDA sessions
|
|
204
|
+
pass
|
|
205
|
+
except subprocess.TimeoutExpired:
|
|
206
|
+
logger.debug(f"Timeout checking CUDA state for PID {pid}")
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.debug(f"Error checking PID {pid}: {e}")
|
|
209
|
+
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
def checkpoint(self) -> None:
|
|
213
|
+
"""Checkpoint all CUDA processes, moving GPU memory to CPU."""
|
|
214
|
+
if not self.cuda_processes:
|
|
215
|
+
logger.debug("No CUDA processes to checkpoint.")
|
|
216
|
+
return
|
|
217
|
+
|
|
218
|
+
# Validate all states first
|
|
219
|
+
for proc in self.cuda_processes:
|
|
220
|
+
proc.refresh_state() # Refresh state before validation
|
|
221
|
+
if proc.state != CudaCheckpointState.RUNNING:
|
|
222
|
+
raise CudaCheckpointException(
|
|
223
|
+
f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
|
|
224
|
+
f"Current state: {proc.state.value}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Moving state from GPU to CPU can take several seconds per CUDA session.
|
|
228
|
+
# Make a parallel call per CUDA session.
|
|
229
|
+
start = time.perf_counter()
|
|
230
|
+
|
|
231
|
+
def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
|
|
232
|
+
proc.toggle(CudaCheckpointState.CHECKPOINTED)
|
|
233
|
+
|
|
234
|
+
with ThreadPoolExecutor() as executor:
|
|
235
|
+
futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
|
|
236
|
+
|
|
237
|
+
# Wait for all futures and collect any exceptions
|
|
238
|
+
exceptions = []
|
|
239
|
+
for future in as_completed(futures):
|
|
240
|
+
try:
|
|
241
|
+
future.result()
|
|
242
|
+
except Exception as e:
|
|
243
|
+
exceptions.append(e)
|
|
244
|
+
|
|
245
|
+
if exceptions:
|
|
246
|
+
raise CudaCheckpointException(
|
|
247
|
+
f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
elapsed = time.perf_counter() - start
|
|
251
|
+
logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
|
|
252
|
+
|
|
253
|
+
def restore(self) -> None:
|
|
254
|
+
"""Restore all CUDA processes, moving memory back from CPU to GPU."""
|
|
255
|
+
if not self.cuda_processes:
|
|
256
|
+
logger.debug("No CUDA sessions to restore.")
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
# Validate all states first
|
|
260
|
+
for proc in self.cuda_processes:
|
|
261
|
+
proc.refresh_state() # Refresh state before validation
|
|
262
|
+
if proc.state != CudaCheckpointState.CHECKPOINTED:
|
|
263
|
+
raise CudaCheckpointException(
|
|
264
|
+
f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.CHECKPOINTED.value} state. "
|
|
265
|
+
f"Current state: {proc.state.value}"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# See checkpoint() for rationale about parallelism.
|
|
269
|
+
start = time.perf_counter()
|
|
270
|
+
|
|
271
|
+
def restore_process(proc: CudaCheckpointProcess) -> None:
|
|
272
|
+
proc.toggle(CudaCheckpointState.RUNNING)
|
|
273
|
+
|
|
274
|
+
with ThreadPoolExecutor() as executor:
|
|
275
|
+
futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
|
|
276
|
+
|
|
277
|
+
# Wait for all futures and collect any exceptions
|
|
278
|
+
exceptions = []
|
|
279
|
+
for future in as_completed(futures):
|
|
280
|
+
try:
|
|
281
|
+
future.result()
|
|
282
|
+
except Exception as e:
|
|
283
|
+
exceptions.append(e)
|
|
284
|
+
|
|
285
|
+
if exceptions:
|
|
286
|
+
raise CudaCheckpointException(
|
|
287
|
+
f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
elapsed = time.perf_counter() - start
|
|
291
|
+
logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
|
|
292
|
+
|
|
293
|
+
def get_process_count(self) -> int:
|
|
294
|
+
"""Get the number of CUDA processes managed by this session."""
|
|
295
|
+
return len(self.cuda_processes)
|
|
296
|
+
|
|
297
|
+
def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
|
|
298
|
+
"""Get current states of all managed processes."""
|
|
299
|
+
states = []
|
|
300
|
+
for proc in self.cuda_processes:
|
|
301
|
+
proc.refresh_state()
|
|
302
|
+
states.append((proc.pid, proc.state))
|
|
303
|
+
return states
|