modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ class IOContext:
27
27
  input_ids: list[str]
28
28
  retry_counts: list[int]
29
29
  function_call_ids: list[str]
30
+ attempt_tokens: list[str]
30
31
  function_inputs: list[modal_proto.api_pb2.FunctionInput]
31
32
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction
32
33
  _cancel_issued: bool
@@ -37,6 +38,7 @@ class IOContext:
37
38
  input_ids: list[str],
38
39
  retry_counts: list[int],
39
40
  function_call_ids: list[str],
41
+ attempt_tokens: list[str],
40
42
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
41
43
  function_inputs: list[modal_proto.api_pb2.FunctionInput],
42
44
  is_batched: bool,
@@ -50,14 +52,29 @@ class IOContext:
50
52
  cls,
51
53
  client: modal.client._Client,
52
54
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
53
- inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
55
+ inputs: list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]],
54
56
  is_batched: bool,
55
57
  ) -> IOContext: ...
56
58
  def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
57
59
  def cancel(self): ...
58
60
  def _args_and_kwargs(self) -> tuple[tuple[typing.Any, ...], dict[str, list[typing.Any]]]: ...
59
- def call_finalized_function(self) -> typing.Any: ...
60
- def validate_output_data(self, data: typing.Any) -> list[typing.Any]: ...
61
+ def _generator_output_format(self) -> int: ...
62
+ def _prepare_batch_output(self, data: typing.Any) -> list[typing.Any]: ...
63
+ def call_function_sync(self) -> list[typing.Any]: ...
64
+ async def call_function_async(self) -> list[typing.Any]: ...
65
+ def call_generator_sync(self) -> typing.Generator[typing.Any, None, None]: ...
66
+ def call_generator_async(self) -> collections.abc.AsyncGenerator[typing.Any, None]: ...
67
+ async def output_items_cancellation(self, started_at: float): ...
68
+ def _determine_output_format(self, input_format: int) -> int: ...
69
+ async def output_items_exception(
70
+ self, started_at: float, task_id: str, exc: BaseException
71
+ ) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
72
+ def output_items_generator_done(
73
+ self, started_at: float, items_total: int
74
+ ) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
75
+ async def output_items(
76
+ self, started_at: float, data: list[typing.Any]
77
+ ) -> list[modal_proto.api_pb2.FunctionPutOutputsItem]: ...
61
78
 
62
79
  class InputSlots:
63
80
  """A semaphore that allows dynamically adjusting the concurrency."""
@@ -131,14 +148,19 @@ class _ContainerIOManager:
131
148
  def stop_heartbeat(self): ...
132
149
  def dynamic_concurrency_manager(self) -> typing.AsyncContextManager[None]: ...
133
150
  async def _dynamic_concurrency_loop(self): ...
134
- def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
135
- async def format_blob_data(self, data: bytes) -> dict[str, typing.Any]: ...
136
- def get_data_in(self, function_call_id: str) -> collections.abc.AsyncIterator[typing.Any]:
151
+ def get_data_in(
152
+ self, function_call_id: str, attempt_token: typing.Optional[str]
153
+ ) -> collections.abc.AsyncIterator[typing.Any]:
137
154
  """Read from the `data_in` stream of a function call."""
138
155
  ...
139
156
 
140
157
  async def put_data_out(
141
- self, function_call_id: str, start_index: int, data_format: int, serialized_messages: list[typing.Any]
158
+ self,
159
+ function_call_id: str,
160
+ attempt_token: str,
161
+ start_index: int,
162
+ data_format: int,
163
+ serialized_messages: list[typing.Any],
142
164
  ) -> None:
143
165
  """Put data onto the `data_out` stream of a function call.
144
166
 
@@ -149,7 +171,7 @@ class _ContainerIOManager:
149
171
  ...
150
172
 
151
173
  def generator_output_sender(
152
- self, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
174
+ self, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
153
175
  ) -> typing.AsyncContextManager[None]:
154
176
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
155
177
  ...
@@ -166,22 +188,17 @@ class _ContainerIOManager:
166
188
  def get_max_inputs_to_fetch(self): ...
167
189
  def _generate_inputs(
168
190
  self, batch_max_size: int, batch_wait_ms: int
169
- ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
191
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
170
192
  def run_inputs_outputs(
171
193
  self,
172
194
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
173
195
  batch_max_size: int = 0,
174
196
  batch_wait_ms: int = 0,
175
197
  ) -> collections.abc.AsyncIterator[IOContext]: ...
176
- async def _push_outputs(
177
- self,
178
- io_context: IOContext,
179
- started_at: float,
180
- data_format: int,
181
- results: list[modal_proto.api_pb2.GenericResult],
182
- ) -> None: ...
183
- def serialize_exception(self, exc: BaseException) -> bytes: ...
184
- def serialize_traceback(self, exc: BaseException) -> tuple[typing.Optional[bytes], typing.Optional[bytes]]: ...
198
+ async def _send_outputs(self, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
199
+ """Send pre-built output items with retry and chunking."""
200
+ ...
201
+
185
202
  def handle_user_exception(self) -> typing.AsyncContextManager[None]:
186
203
  """Sets the task as failed in a way where it's not retried.
187
204
 
@@ -195,9 +212,7 @@ class _ContainerIOManager:
195
212
  ...
196
213
 
197
214
  def exit_context(self, started_at, input_ids: list[str]): ...
198
- async def push_outputs(
199
- self, io_context: IOContext, started_at: float, data: typing.Any, data_format: int
200
- ) -> None: ...
215
+ async def push_outputs(self, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
201
216
  async def memory_restore(self) -> None: ...
202
217
  async def memory_snapshot(self) -> None:
203
218
  """Message server indicating that function is ready to be checkpointed."""
@@ -323,20 +338,16 @@ class ContainerIOManager:
323
338
 
324
339
  _dynamic_concurrency_loop: ___dynamic_concurrency_loop_spec[typing_extensions.Self]
325
340
 
326
- def serialize_data_format(self, obj: typing.Any, data_format: int) -> bytes: ...
327
-
328
- class __format_blob_data_spec(typing_extensions.Protocol[SUPERSELF]):
329
- def __call__(self, /, data: bytes) -> dict[str, typing.Any]: ...
330
- async def aio(self, /, data: bytes) -> dict[str, typing.Any]: ...
331
-
332
- format_blob_data: __format_blob_data_spec[typing_extensions.Self]
333
-
334
341
  class __get_data_in_spec(typing_extensions.Protocol[SUPERSELF]):
335
- def __call__(self, /, function_call_id: str) -> typing.Iterator[typing.Any]:
342
+ def __call__(
343
+ self, /, function_call_id: str, attempt_token: typing.Optional[str]
344
+ ) -> typing.Iterator[typing.Any]:
336
345
  """Read from the `data_in` stream of a function call."""
337
346
  ...
338
347
 
339
- def aio(self, /, function_call_id: str) -> collections.abc.AsyncIterator[typing.Any]:
348
+ def aio(
349
+ self, /, function_call_id: str, attempt_token: typing.Optional[str]
350
+ ) -> collections.abc.AsyncIterator[typing.Any]:
340
351
  """Read from the `data_in` stream of a function call."""
341
352
  ...
342
353
 
@@ -344,7 +355,13 @@ class ContainerIOManager:
344
355
 
345
356
  class __put_data_out_spec(typing_extensions.Protocol[SUPERSELF]):
346
357
  def __call__(
347
- self, /, function_call_id: str, start_index: int, data_format: int, serialized_messages: list[typing.Any]
358
+ self,
359
+ /,
360
+ function_call_id: str,
361
+ attempt_token: str,
362
+ start_index: int,
363
+ data_format: int,
364
+ serialized_messages: list[typing.Any],
348
365
  ) -> None:
349
366
  """Put data onto the `data_out` stream of a function call.
350
367
 
@@ -355,7 +372,13 @@ class ContainerIOManager:
355
372
  ...
356
373
 
357
374
  async def aio(
358
- self, /, function_call_id: str, start_index: int, data_format: int, serialized_messages: list[typing.Any]
375
+ self,
376
+ /,
377
+ function_call_id: str,
378
+ attempt_token: str,
379
+ start_index: int,
380
+ data_format: int,
381
+ serialized_messages: list[typing.Any],
359
382
  ) -> None:
360
383
  """Put data onto the `data_out` stream of a function call.
361
384
 
@@ -369,13 +392,13 @@ class ContainerIOManager:
369
392
 
370
393
  class __generator_output_sender_spec(typing_extensions.Protocol[SUPERSELF]):
371
394
  def __call__(
372
- self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
395
+ self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
373
396
  ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
374
397
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
375
398
  ...
376
399
 
377
400
  def aio(
378
- self, /, function_call_id: str, data_format: int, message_rx: asyncio.queues.Queue
401
+ self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
379
402
  ) -> typing.AsyncContextManager[None]:
380
403
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
381
404
  ...
@@ -410,10 +433,10 @@ class ContainerIOManager:
410
433
  class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
411
434
  def __call__(
412
435
  self, /, batch_max_size: int, batch_wait_ms: int
413
- ) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
436
+ ) -> typing.Iterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
414
437
  def aio(
415
438
  self, /, batch_max_size: int, batch_wait_ms: int
416
- ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
439
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
417
440
 
418
441
  _generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
419
442
 
@@ -435,28 +458,16 @@ class ContainerIOManager:
435
458
 
436
459
  run_inputs_outputs: __run_inputs_outputs_spec[typing_extensions.Self]
437
460
 
438
- class ___push_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
439
- def __call__(
440
- self,
441
- /,
442
- io_context: IOContext,
443
- started_at: float,
444
- data_format: int,
445
- results: list[modal_proto.api_pb2.GenericResult],
446
- ) -> None: ...
447
- async def aio(
448
- self,
449
- /,
450
- io_context: IOContext,
451
- started_at: float,
452
- data_format: int,
453
- results: list[modal_proto.api_pb2.GenericResult],
454
- ) -> None: ...
461
+ class ___send_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
462
+ def __call__(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
463
+ """Send pre-built output items with retry and chunking."""
464
+ ...
455
465
 
456
- _push_outputs: ___push_outputs_spec[typing_extensions.Self]
466
+ async def aio(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
467
+ """Send pre-built output items with retry and chunking."""
468
+ ...
457
469
 
458
- def serialize_exception(self, exc: BaseException) -> bytes: ...
459
- def serialize_traceback(self, exc: BaseException) -> tuple[typing.Optional[bytes], typing.Optional[bytes]]: ...
470
+ _send_outputs: ___send_outputs_spec[typing_extensions.Self]
460
471
 
461
472
  class __handle_user_exception_spec(typing_extensions.Protocol[SUPERSELF]):
462
473
  def __call__(self, /) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
@@ -493,10 +504,8 @@ class ContainerIOManager:
493
504
  def exit_context(self, started_at, input_ids: list[str]): ...
494
505
 
495
506
  class __push_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
496
- def __call__(self, /, io_context: IOContext, started_at: float, data: typing.Any, data_format: int) -> None: ...
497
- async def aio(
498
- self, /, io_context: IOContext, started_at: float, data: typing.Any, data_format: int
499
- ) -> None: ...
507
+ def __call__(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
508
+ async def aio(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
500
509
 
501
510
  push_outputs: __push_outputs_spec[typing_extensions.Self]
502
511
 
@@ -72,22 +72,38 @@ def current_function_call_id() -> Optional[str]:
72
72
  return None
73
73
 
74
74
 
75
- def _set_current_context_ids(input_ids: list[str], function_call_ids: list[str]) -> Callable[[], None]:
76
- assert len(input_ids) == len(function_call_ids) and len(input_ids) > 0
75
+ def current_attempt_token() -> Optional[str]:
76
+ # This ContextVar isn't useful to expose to users.
77
+ try:
78
+ return _current_attempt_token.get()
79
+ except LookupError:
80
+ return None
81
+
82
+
83
+ def _set_current_context_ids(
84
+ input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
85
+ ) -> Callable[[], None]:
86
+ assert len(input_ids) == len(function_call_ids) == len(attempt_tokens) and input_ids
87
+
77
88
  input_id = input_ids[0]
78
89
  function_call_id = function_call_ids[0]
90
+ attempt_token = attempt_tokens[0]
91
+
79
92
  input_token = _current_input_id.set(input_id)
80
93
  function_call_token = _current_function_call_id.set(function_call_id)
94
+ attempt_token_token = _current_attempt_token.set(attempt_token)
81
95
 
82
96
  def _reset_current_context_ids():
83
97
  _current_input_id.reset(input_token)
84
98
  _current_function_call_id.reset(function_call_token)
99
+ _current_attempt_token.reset(attempt_token_token)
85
100
 
86
101
  return _reset_current_context_ids
87
102
 
88
103
 
89
104
  _current_input_id: ContextVar = ContextVar("_current_input_id")
90
105
  _current_function_call_id: ContextVar = ContextVar("_current_function_call_id")
106
+ _current_attempt_token: ContextVar = ContextVar("_current_attempt_token")
91
107
 
92
108
  _is_currently_importing = False # we set this to True while a container is importing user code
93
109
 
@@ -68,11 +68,14 @@ def current_function_call_id() -> typing.Optional[str]:
68
68
  """
69
69
  ...
70
70
 
71
+ def current_attempt_token() -> typing.Optional[str]: ...
71
72
  def _set_current_context_ids(
72
- input_ids: list[str], function_call_ids: list[str]
73
+ input_ids: list[str], function_call_ids: list[str], attempt_tokens: list[str]
73
74
  ) -> collections.abc.Callable[[], None]: ...
74
75
  def _import_context(): ...
75
76
 
76
77
  _current_input_id: contextvars.ContextVar
77
78
 
78
79
  _current_function_call_id: contextvars.ContextVar
80
+
81
+ _current_attempt_token: contextvars.ContextVar
@@ -1,25 +1,34 @@
1
1
  # Copyright Modal Labs 2022
2
2
  #
3
3
  # This module provides a simple interface for creating GPU memory snapshots,
4
- # provising a convenient interface to `cuda-checkpoint` [1]. This is intended
4
+ # providing a convenient interface to `cuda-checkpoint` [1]. This is intended
5
5
  # to be used in conjunction with memory snapshots.
6
6
  #
7
7
  # [1] https://github.com/NVIDIA/cuda-checkpoint
8
8
 
9
9
  import subprocess
10
10
  import time
11
- from concurrent.futures import ThreadPoolExecutor
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  from dataclasses import dataclass
13
13
  from enum import Enum
14
14
  from pathlib import Path
15
+ from typing import List, Optional
15
16
 
16
17
  from modal.config import config, logger
17
18
 
18
19
  CUDA_CHECKPOINT_PATH: str = config.get("cuda_checkpoint_path")
19
20
 
21
+ # Maximum total duration for an entire toggle operation.
22
+ CUDA_CHECKPOINT_TOGGLE_TIMEOUT: float = 5 * 60.0
23
+
24
+ # Maximum total duration for each individual `cuda-checkpoint` invocation.
25
+ CUDA_CHECKPOINT_TIMEOUT: float = 90
26
+
20
27
 
21
28
  class CudaCheckpointState(Enum):
22
- """State representation from the CUDA API: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc96cdda177a2b8c296144567cbea4f23"""
29
+ """State representation from the CUDA API [1].
30
+
31
+ [1] https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html"""
23
32
 
24
33
  RUNNING = "running"
25
34
  LOCKED = "locked"
@@ -28,6 +37,8 @@ class CudaCheckpointState(Enum):
28
37
 
29
38
 
30
39
  class CudaCheckpointException(Exception):
40
+ """Exception raised for CUDA checkpoint operations."""
41
+
31
42
  pass
32
43
 
33
44
 
@@ -39,24 +50,44 @@ class CudaCheckpointProcess:
39
50
  pid: int
40
51
  state: CudaCheckpointState
41
52
 
42
- def toggle(self, target_state: CudaCheckpointState, timeout_secs: float = 5 * 60.0):
53
+ def toggle(self, target_state: CudaCheckpointState, skip_first_refresh: bool = False) -> None:
43
54
  """Toggle CUDA checkpoint state for current process, moving GPU memory to the
44
- CPU and back depending on the current process state when called."""
55
+ CPU and back depending on the current process state when called.
56
+ """
45
57
  logger.debug(f"PID: {self.pid} Toggling CUDA checkpoint state to {target_state.value}")
46
58
 
47
59
  start_time = time.monotonic()
48
-
49
- while self._should_continue_toggle(target_state, start_time, timeout_secs):
50
- self._execute_toggle_command()
51
- time.sleep(0.1)
60
+ retry_count = 0
61
+ max_retries = 3
62
+
63
+ attempts = 0
64
+ while self._should_continue_toggle(
65
+ target_state, start_time, refresh=not (skip_first_refresh and attempts == 0)
66
+ ):
67
+ attempts += 1
68
+ try:
69
+ self._execute_toggle_command()
70
+ # Use exponential backoff for retries
71
+ sleep_time = min(0.1 * (2**retry_count), 1.0)
72
+ time.sleep(sleep_time)
73
+ retry_count = 0
74
+ except CudaCheckpointException as e:
75
+ retry_count += 1
76
+ if retry_count >= max_retries:
77
+ raise CudaCheckpointException(
78
+ f"PID: {self.pid} Failed to toggle state after {max_retries} retries: {e}"
79
+ )
80
+ logger.debug(f"PID: {self.pid} Retry {retry_count}/{max_retries} after error: {e}")
81
+ time.sleep(0.5 * retry_count)
52
82
 
53
83
  logger.debug(f"PID: {self.pid} Target state {target_state.value} reached")
54
84
 
55
85
  def _should_continue_toggle(
56
- self, target_state: CudaCheckpointState, start_time: float, timeout_secs: float
86
+ self, target_state: CudaCheckpointState, start_time: float, refresh: bool = True
57
87
  ) -> bool:
58
88
  """Check if toggle operation should continue based on current state and timeout."""
59
- self.refresh_state()
89
+ if refresh:
90
+ self.refresh_state()
60
91
 
61
92
  if self.state == target_state:
62
93
  return False
@@ -65,7 +96,7 @@ class CudaCheckpointProcess:
65
96
  raise CudaCheckpointException(f"PID: {self.pid} CUDA process state is {self.state}")
66
97
 
67
98
  elapsed = time.monotonic() - start_time
68
- if elapsed >= timeout_secs:
99
+ if elapsed >= CUDA_CHECKPOINT_TOGGLE_TIMEOUT:
69
100
  raise CudaCheckpointException(
70
101
  f"PID: {self.pid} Timeout after {elapsed:.2f}s waiting for state {target_state.value}. "
71
102
  f"Current state: {self.state}"
@@ -73,19 +104,25 @@ class CudaCheckpointProcess:
73
104
 
74
105
  return True
75
106
 
76
- def _execute_toggle_command(self):
107
+ def _execute_toggle_command(self) -> None:
77
108
  """Execute the cuda-checkpoint toggle command."""
78
109
  try:
79
- subprocess.run(
110
+ _ = subprocess.run(
80
111
  [CUDA_CHECKPOINT_PATH, "--toggle", "--pid", str(self.pid)],
81
112
  check=True,
82
113
  capture_output=True,
83
114
  text=True,
115
+ timeout=CUDA_CHECKPOINT_TIMEOUT,
84
116
  )
85
117
  logger.debug(f"PID: {self.pid} Successfully toggled CUDA checkpoint state")
86
118
  except subprocess.CalledProcessError as e:
87
- logger.debug(f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}")
88
- raise CudaCheckpointException(e.stderr)
119
+ error_msg = f"PID: {self.pid} Failed to toggle CUDA checkpoint state: {e.stderr}"
120
+ logger.debug(error_msg)
121
+ raise CudaCheckpointException(error_msg)
122
+ except subprocess.TimeoutExpired:
123
+ error_msg = f"PID: {self.pid} Toggle command timed out"
124
+ logger.debug(error_msg)
125
+ raise CudaCheckpointException(error_msg)
89
126
 
90
127
  def refresh_state(self) -> None:
91
128
  """Refreshes the current CUDA checkpoint state for this process."""
@@ -95,15 +132,20 @@ class CudaCheckpointProcess:
95
132
  check=True,
96
133
  capture_output=True,
97
134
  text=True,
98
- timeout=5,
135
+ timeout=CUDA_CHECKPOINT_TIMEOUT,
99
136
  )
100
137
 
101
138
  state_str = result.stdout.strip().lower()
102
139
  self.state = CudaCheckpointState(state_str)
103
140
 
104
141
  except subprocess.CalledProcessError as e:
105
- logger.debug(f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}")
106
- raise CudaCheckpointException(e.stderr)
142
+ error_msg = f"PID: {self.pid} Failed to get CUDA checkpoint state: {e.stderr}"
143
+ logger.debug(error_msg)
144
+ raise CudaCheckpointException(error_msg)
145
+ except subprocess.TimeoutExpired:
146
+ error_msg = f"PID: {self.pid} Get state command timed out"
147
+ logger.debug(error_msg)
148
+ raise CudaCheckpointException(error_msg)
107
149
 
108
150
 
109
151
  class CudaCheckpointSession:
@@ -111,12 +153,17 @@ class CudaCheckpointSession:
111
153
 
112
154
  def __init__(self):
113
155
  self.cuda_processes = self._get_cuda_pids()
114
- logger.debug(f"PIDs with CUDA sessions: {[c.pid for c in self.cuda_processes]}")
156
+ if self.cuda_processes:
157
+ logger.debug(
158
+ f"Found {len(self.cuda_processes)} PID(s) with CUDA sessions: {[c.pid for c in self.cuda_processes]}"
159
+ )
160
+ else:
161
+ logger.debug("No CUDA sessions found.")
115
162
 
116
- def _get_cuda_pids(self) -> list[CudaCheckpointProcess]:
163
+ def _get_cuda_pids(self) -> List[CudaCheckpointProcess]:
117
164
  """Iterates over all PIDs and identifies the ones that have running
118
165
  CUDA sessions."""
119
- cuda_pids: list[CudaCheckpointProcess] = []
166
+ cuda_pids: List[CudaCheckpointProcess] = []
120
167
 
121
168
  # Get all active process IDs from /proc directory
122
169
  proc_dir = Path("/proc")
@@ -125,75 +172,135 @@ class CudaCheckpointSession:
125
172
  "OS does not have /proc path rendering it incompatible with GPU memory snapshots."
126
173
  )
127
174
 
128
- for entry in proc_dir.iterdir():
129
- if not entry.name.isdigit():
130
- continue
175
+ # Get all numeric directories (PIDs) from /proc
176
+ pid_dirs = [entry for entry in proc_dir.iterdir() if entry.name.isdigit()]
131
177
 
132
- pid = int(entry.name)
133
- try:
134
- # Call cuda-checkpoint to check if this PID has a CUDA session
135
- result = subprocess.run(
136
- [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
137
- capture_output=True,
138
- text=True,
139
- timeout=10,
140
- )
141
-
142
- # If the command succeeds (return code 0), this PID has a CUDA session
143
- if result.returncode == 0:
144
- state_str = result.stdout.strip().lower()
145
- state = CudaCheckpointState(state_str)
146
-
147
- cuda_checkpoint_process = CudaCheckpointProcess(pid=pid, state=state)
148
- cuda_pids.append(cuda_checkpoint_process)
149
-
150
- # Command failed, which is expected for PIDs without CUDA sessions
151
- except subprocess.CalledProcessError:
152
- continue
178
+ # Use ThreadPoolExecutor to check PIDs in parallel for better performance
179
+ with ThreadPoolExecutor(max_workers=min(50, len(pid_dirs))) as executor:
180
+ future_to_pid = {
181
+ executor.submit(self._check_cuda_session, int(entry.name)): int(entry.name) for entry in pid_dirs
182
+ }
153
183
 
154
- # Raise other exceptions
155
- except subprocess.TimeoutExpired:
156
- raise CudaCheckpointException(f"Failed to get CUDA state for PID {pid}")
157
- except Exception as e:
158
- raise CudaCheckpointException(e)
184
+ for future in as_completed(future_to_pid):
185
+ pid = future_to_pid[future]
186
+ try:
187
+ cuda_process = future.result()
188
+ if cuda_process:
189
+ cuda_pids.append(cuda_process)
190
+ except Exception as e:
191
+ logger.debug(f"Error checking PID {pid}: {e}")
159
192
 
160
193
  # Sort PIDs for ordered checkpointing
161
194
  cuda_pids.sort(key=lambda x: x.pid)
162
195
  return cuda_pids
163
196
 
197
+ def _check_cuda_session(self, pid: int) -> Optional[CudaCheckpointProcess]:
198
+ """Check if a specific PID has a CUDA session."""
199
+ try:
200
+ result = subprocess.run(
201
+ [CUDA_CHECKPOINT_PATH, "--get-state", "--pid", str(pid)],
202
+ capture_output=True,
203
+ text=True,
204
+ # This should be quick since no checkpoint has taken place yet
205
+ timeout=5,
206
+ )
207
+
208
+ # If the command succeeds (return code 0), this PID has a CUDA session
209
+ if result.returncode == 0:
210
+ state_str = result.stdout.strip().lower()
211
+ state = CudaCheckpointState(state_str)
212
+ return CudaCheckpointProcess(pid=pid, state=state)
213
+
214
+ except subprocess.CalledProcessError:
215
+ # Command failed, which is expected for PIDs without CUDA sessions
216
+ pass
217
+ except subprocess.TimeoutExpired:
218
+ logger.debug(f"Timeout checking CUDA state for PID {pid}")
219
+ except Exception as e:
220
+ logger.debug(f"Error checking PID {pid}: {e}")
221
+
222
+ return None
223
+
164
224
  def checkpoint(self) -> None:
225
+ """Checkpoint all CUDA processes, moving GPU memory to CPU."""
226
+ if not self.cuda_processes:
227
+ logger.debug("No CUDA processes to checkpoint.")
228
+ return
229
+
165
230
  # Validate all states first
166
231
  for proc in self.cuda_processes:
232
+ proc.refresh_state() # Refresh state before validation
167
233
  if proc.state != CudaCheckpointState.RUNNING:
168
- raise CudaCheckpointException(f"CUDA session not in {CudaCheckpointState.RUNNING} state.")
234
+ raise CudaCheckpointException(
235
+ f"PID {proc.pid}: CUDA session not in {CudaCheckpointState.RUNNING.value} state. "
236
+ f"Current state: {proc.state.value}"
237
+ )
169
238
 
170
239
  # Moving state from GPU to CPU can take several seconds per CUDA session.
171
240
  # Make a parallel call per CUDA session.
172
241
  start = time.perf_counter()
173
242
 
174
- def checkpoint_impl(proc: CudaCheckpointProcess):
243
+ def checkpoint_impl(proc: CudaCheckpointProcess) -> None:
175
244
  proc.toggle(CudaCheckpointState.CHECKPOINTED)
176
245
 
177
246
  with ThreadPoolExecutor() as executor:
178
- list(executor.map(checkpoint_impl, self.cuda_processes))
247
+ futures = [executor.submit(checkpoint_impl, proc) for proc in self.cuda_processes]
248
+
249
+ # Wait for all futures and collect any exceptions
250
+ exceptions = []
251
+ for future in as_completed(futures):
252
+ try:
253
+ future.result()
254
+ except Exception as e:
255
+ exceptions.append(e)
256
+
257
+ if exceptions:
258
+ raise CudaCheckpointException(
259
+ f"Failed to checkpoint {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
260
+ )
179
261
 
180
262
  elapsed = time.perf_counter() - start
181
- logger.debug(f"Checkpointing CUDA sessions took => {elapsed:.3f}s")
263
+ logger.debug(f"Checkpointing {len(self.cuda_processes)} CUDA sessions took => {elapsed:.3f}s")
182
264
 
183
265
  def restore(self) -> None:
184
- # Validate all states first
185
- for proc in self.cuda_processes:
186
- if proc.state != CudaCheckpointState.CHECKPOINTED:
187
- raise CudaCheckpointException(f"CUDA session not in {CudaCheckpointState.CHECKPOINTED} state.")
266
+ """Restore all CUDA processes, moving memory back from CPU to GPU."""
267
+ if not self.cuda_processes:
268
+ logger.debug("No CUDA sessions to restore.")
269
+ return
188
270
 
189
271
  # See checkpoint() for rationale about parallelism.
190
272
  start = time.perf_counter()
191
273
 
192
- def restore_process(proc: CudaCheckpointProcess):
193
- proc.toggle(CudaCheckpointState.RUNNING)
274
+ def restore_process(proc: CudaCheckpointProcess) -> None:
275
+ proc.toggle(CudaCheckpointState.RUNNING, skip_first_refresh=True)
194
276
 
195
277
  with ThreadPoolExecutor() as executor:
196
- list(executor.map(restore_process, self.cuda_processes))
278
+ futures = [executor.submit(restore_process, proc) for proc in self.cuda_processes]
279
+
280
+ # Wait for all futures and collect any exceptions
281
+ exceptions = []
282
+ for future in as_completed(futures):
283
+ try:
284
+ future.result()
285
+ except Exception as e:
286
+ exceptions.append(e)
287
+
288
+ if exceptions:
289
+ raise CudaCheckpointException(
290
+ f"Failed to restore {len(exceptions)} processes: {'; '.join(str(e) for e in exceptions)}"
291
+ )
197
292
 
198
293
  elapsed = time.perf_counter() - start
199
- logger.debug(f"Restoring CUDA sessions took => {elapsed:.3f}s")
294
+ logger.debug(f"Restoring {len(self.cuda_processes)} CUDA session(s) took => {elapsed:.3f}s")
295
+
296
+ def get_process_count(self) -> int:
297
+ """Get the number of CUDA processes managed by this session."""
298
+ return len(self.cuda_processes)
299
+
300
+ def get_process_states(self) -> List[tuple[int, CudaCheckpointState]]:
301
+ """Get current states of all managed processes."""
302
+ states = []
303
+ for proc in self.cuda_processes:
304
+ proc.refresh_state()
305
+ states.append((proc.pid, proc.state))
306
+ return states