modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/_ipython.py CHANGED
@@ -1,21 +1,11 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import sys
3
- import warnings
4
-
5
- ipy_outstream = None
6
- try:
7
- with warnings.catch_warnings():
8
- warnings.simplefilter("ignore")
9
- import ipykernel.iostream
10
-
11
- ipy_outstream = ipykernel.iostream.OutStream
12
- except ImportError:
13
- pass
14
3
 
15
4
 
16
5
  def is_notebook(stdout=None):
17
- if ipy_outstream is None:
6
+ ipykernel_iostream = sys.modules.get("ipykernel.iostream")
7
+ if ipykernel_iostream is None:
18
8
  return False
19
9
  if stdout is None:
20
10
  stdout = sys.stdout
21
- return isinstance(stdout, ipy_outstream)
11
+ return isinstance(stdout, ipykernel_iostream.OutStream)
modal/_location.py CHANGED
@@ -1,33 +1,40 @@
1
1
  # Copyright Modal Labs 2022
2
2
  from enum import Enum
3
3
 
4
- from modal_proto import api_pb2
4
+ import modal_proto.api_pb2
5
5
 
6
6
  from .exception import InvalidError
7
7
 
8
8
 
9
9
  class CloudProvider(Enum):
10
- AWS = api_pb2.CLOUD_PROVIDER_AWS
11
- GCP = api_pb2.CLOUD_PROVIDER_GCP
12
- AUTO = api_pb2.CLOUD_PROVIDER_AUTO
13
- OCI = api_pb2.CLOUD_PROVIDER_OCI
10
+ AWS = modal_proto.api_pb2.CLOUD_PROVIDER_AWS
11
+ GCP = modal_proto.api_pb2.CLOUD_PROVIDER_GCP
12
+ AUTO = modal_proto.api_pb2.CLOUD_PROVIDER_AUTO
13
+ OCI = modal_proto.api_pb2.CLOUD_PROVIDER_OCI
14
14
 
15
15
 
16
- def parse_cloud_provider(value: str) -> "api_pb2.CloudProvider.V":
16
+ def parse_cloud_provider(value: str) -> "modal_proto.api_pb2.CloudProvider.V":
17
17
  try:
18
18
  cloud_provider = CloudProvider[value.upper()]
19
19
  except KeyError:
20
+ # provider's int identifier may be directly specified
21
+ try:
22
+ return int(value) # type: ignore
23
+ except ValueError:
24
+ pass
25
+
20
26
  raise InvalidError(
21
- f"Invalid cloud provider: {value}. Value must be one of {[x.name.lower() for x in CloudProvider]} (case-insensitive)."
27
+ f"Invalid cloud provider: {value}. "
28
+ f"Value must be one of {[x.name.lower() for x in CloudProvider]} (case-insensitive)."
22
29
  )
23
30
 
24
31
  return cloud_provider.value
25
32
 
26
33
 
27
- def display_location(cloud_provider: "api_pb2.CloudProvider.V") -> str:
28
- if cloud_provider == api_pb2.CLOUD_PROVIDER_GCP:
34
+ def display_location(cloud_provider: "modal_proto.api_pb2.CloudProvider.V") -> str:
35
+ if cloud_provider == modal_proto.api_pb2.CLOUD_PROVIDER_GCP:
29
36
  return "GCP (us-east1)"
30
- elif cloud_provider == api_pb2.CLOUD_PROVIDER_AWS:
37
+ elif cloud_provider == modal_proto.api_pb2.CLOUD_PROVIDER_AWS:
31
38
  return "AWS (us-east-1)"
32
39
  else:
33
40
  return ""
modal/_output.py CHANGED
@@ -7,9 +7,11 @@ import functools
7
7
  import io
8
8
  import platform
9
9
  import re
10
+ import socket
10
11
  import sys
12
+ from collections.abc import Generator
11
13
  from datetime import timedelta
12
- from typing import Callable, Dict, Optional, Tuple
14
+ from typing import Callable, ClassVar
13
15
 
14
16
  from grpclib.exceptions import GRPCError, StreamTerminatedError
15
17
  from rich.console import Console, Group, RenderableType
@@ -32,7 +34,7 @@ from rich.text import Text
32
34
 
33
35
  from modal_proto import api_pb2
34
36
 
35
- from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
37
+ from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
36
38
  from ._utils.shell_utils import stream_from_stdin
37
39
  from .client import _Client
38
40
  from .config import logger
@@ -60,25 +62,6 @@ class FunctionQueuingColumn(ProgressColumn):
60
62
  return Text(str(delta), style="progress.elapsed")
61
63
 
62
64
 
63
- def step_progress(text: str = "") -> Spinner:
64
- """Returns the element to be rendered when a step is in progress."""
65
- return Spinner(default_spinner, text, style="blue")
66
-
67
-
68
- def step_progress_update(spinner: Spinner, message: str):
69
- spinner.update(text=message)
70
-
71
-
72
- def step_completed(message: str, is_substep: bool = False) -> RenderableType:
73
- """Returns the element to be rendered when a step is completed."""
74
-
75
- STEP_COMPLETED = "[green]✓[/green]"
76
- SUBSTEP_COMPLETED = "🔨"
77
-
78
- symbol = SUBSTEP_COMPLETED if is_substep else STEP_COMPLETED
79
- return f"{symbol} {message}"
80
-
81
-
82
65
  def download_progress_bar() -> Progress:
83
66
  """
84
67
  Returns a progress bar suitable for showing file download progress.
@@ -139,24 +122,28 @@ class LineBufferedOutput(io.StringIO):
139
122
 
140
123
 
141
124
  class OutputManager:
142
- _visible_progress: bool
125
+ _instance: ClassVar[OutputManager | None] = None
126
+
143
127
  _console: Console
144
- _task_states: Dict[str, int]
145
- _task_progress_items: Dict[Tuple[str, int], TaskID]
146
- _current_render_group: Optional[Group]
147
- _function_progress: Optional[Progress]
148
- _function_queueing_progress: Optional[Progress]
149
- _snapshot_progress: Optional[Progress]
150
- _line_buffers: Dict[int, LineBufferedOutput]
128
+ _task_states: dict[str, int]
129
+ _task_progress_items: dict[tuple[str, int], TaskID]
130
+ _current_render_group: Group | None
131
+ _function_progress: Progress | None
132
+ _function_queueing_progress: Progress | None
133
+ _snapshot_progress: Progress | None
134
+ _line_buffers: dict[int, LineBufferedOutput]
151
135
  _status_spinner: Spinner
152
- _app_page_url: Optional[str]
136
+ _app_page_url: str | None
153
137
  _show_image_logs: bool
154
- _status_spinner_live: Optional[Live]
155
-
156
- def __init__(self, stdout: io.TextIOWrapper, show_progress: bool, status_spinner_text: str = "Running app..."):
157
- self.stdout = stdout or sys.stdout
158
-
159
- self._visible_progress = show_progress
138
+ _status_spinner_live: Live | None
139
+
140
+ def __init__(
141
+ self,
142
+ *,
143
+ stdout: io.TextIOWrapper | None = None,
144
+ status_spinner_text: str = "Running app...",
145
+ ):
146
+ self._stdout = stdout or sys.stdout
160
147
  self._console = Console(file=stdout, highlight=False)
161
148
  self._task_states = {}
162
149
  self._task_progress_items = {}
@@ -165,18 +152,47 @@ class OutputManager:
165
152
  self._function_queueing_progress = None
166
153
  self._snapshot_progress = None
167
154
  self._line_buffers = {}
168
- self._status_spinner = step_progress(status_spinner_text)
155
+ self._status_spinner = OutputManager.step_progress(status_spinner_text)
169
156
  self._app_page_url = None
170
157
  self._show_image_logs = False
158
+ self._status_spinner_live = None
171
159
 
172
- def print_if_visible(self, renderable) -> None:
173
- if self._visible_progress:
174
- self._console.print(renderable)
160
+ @classmethod
161
+ def disable(cls):
162
+ cls._instance.flush_lines()
163
+ if cls._instance._status_spinner_live:
164
+ cls._instance._status_spinner_live.stop()
165
+ cls._instance = None
175
166
 
176
- def ctx_if_visible(self, context_mgr):
177
- if self._visible_progress:
178
- return context_mgr
179
- return contextlib.nullcontext()
167
+ @classmethod
168
+ def get(cls) -> OutputManager | None:
169
+ return cls._instance
170
+
171
+ @classmethod
172
+ @contextlib.contextmanager
173
+ def enable_output(cls, show_progress: bool = True) -> Generator[None]:
174
+ if show_progress:
175
+ cls._instance = OutputManager()
176
+ try:
177
+ yield
178
+ finally:
179
+ cls._instance = None
180
+
181
+ @staticmethod
182
+ def step_progress(text: str = "") -> Spinner:
183
+ """Returns the element to be rendered when a step is in progress."""
184
+ return Spinner(default_spinner, text, style="blue")
185
+
186
+ @staticmethod
187
+ def step_completed(message: str) -> RenderableType:
188
+ return f"[green]✓[/green] {message}"
189
+
190
+ @staticmethod
191
+ def substep_completed(message: str) -> RenderableType:
192
+ return f"🔨 {message}"
193
+
194
+ def print(self, renderable) -> None:
195
+ self._console.print(renderable)
180
196
 
181
197
  def make_live(self, renderable: RenderableType) -> Live:
182
198
  """Creates a customized `rich.Live` instance with the given renderable. The renderable
@@ -237,7 +253,7 @@ class OutputManager:
237
253
  self._current_render_group.renderables.append(self._function_queueing_progress)
238
254
  return self._function_queueing_progress
239
255
 
240
- def function_progress_callback(self, tag: str, total: Optional[int]) -> Callable[[int, int], None]:
256
+ def function_progress_callback(self, tag: str, total: int | None) -> Callable[[int, int], None]:
241
257
  """Adds a task to the current function_progress instance, and returns a callback
242
258
  to update task progress with new completed and total counts."""
243
259
 
@@ -292,7 +308,7 @@ class OutputManager:
292
308
  message = f"[blue]{message}[/blue] [grey70]View app at [underline]{self._app_page_url}[/underline][/grey70]"
293
309
 
294
310
  # Set the new message
295
- step_progress_update(self._status_spinner, message)
311
+ self._status_spinner.update(text=message)
296
312
 
297
313
  def update_snapshot_progress(self, image_id: str, task_progress: api_pb2.TaskProgress) -> None:
298
314
  # TODO(erikbern): move this to sit on the resolver object, mostly
@@ -315,7 +331,7 @@ class OutputManager:
315
331
  pass
316
332
 
317
333
  def update_queueing_progress(
318
- self, *, function_id: str, completed: int, total: Optional[int], description: Optional[str]
334
+ self, *, function_id: str, completed: int, total: int | None, description: str | None
319
335
  ) -> None:
320
336
  """Handle queueing updates, ignoring completion updates for functions that have no queue progress bar."""
321
337
  task_key = (function_id, api_pb2.FUNCTION_QUEUED)
@@ -335,33 +351,11 @@ class OutputManager:
335
351
  self._task_progress_items[task_key] = progress_task_id
336
352
 
337
353
  async def put_log_content(self, log: api_pb2.TaskLogs):
338
- if self._visible_progress:
339
- stream = self._line_buffers.get(log.file_descriptor)
340
- if stream is None:
341
- stream = LineBufferedOutput(functools.partial(self._print_log, log.file_descriptor))
342
- self._line_buffers[log.file_descriptor] = stream
343
- stream.write(log.data)
344
- elif hasattr(self.stdout, "buffer"):
345
- # If we're not showing progress, there's no need to buffer lines,
346
- # because the progress spinner can't interfere with output.
347
-
348
- data = log.data.encode("utf-8")
349
- written = 0
350
- n_retries = 0
351
- while written < len(data):
352
- try:
353
- written += self.stdout.buffer.write(data[written:])
354
- self.stdout.flush()
355
- except BlockingIOError:
356
- if n_retries >= 5:
357
- raise
358
- n_retries += 1
359
- await asyncio.sleep(0.1)
360
- else:
361
- # `stdout` isn't always buffered (e.g. %%capture in Jupyter notebooks redirects it to
362
- # io.StringIO).
363
- self.stdout.write(log.data)
364
- self.stdout.flush()
354
+ stream = self._line_buffers.get(log.file_descriptor)
355
+ if stream is None:
356
+ stream = LineBufferedOutput(functools.partial(self._print_log, log.file_descriptor))
357
+ self._line_buffers[log.file_descriptor] = stream
358
+ stream.write(log.data)
365
359
 
366
360
  def flush_lines(self):
367
361
  for stream in self._line_buffers.values():
@@ -370,12 +364,123 @@ class OutputManager:
370
364
  @contextlib.contextmanager
371
365
  def show_status_spinner(self):
372
366
  self._status_spinner_live = self.make_live(self._status_spinner)
373
- with self.ctx_if_visible(self._status_spinner_live):
367
+ with self._status_spinner_live:
374
368
  yield
375
369
 
376
- def hide_status_spinner(self):
377
- if self._status_spinner_live:
378
- self._status_spinner_live.stop()
370
+
371
+ class ProgressHandler:
372
+ live: Live
373
+ _type: str
374
+ _spinner: Spinner
375
+ _overall_progress: Progress
376
+ _download_progress: Progress
377
+ _overall_progress_task_id: TaskID
378
+ _total_tasks: int
379
+ _completed_tasks: int
380
+
381
+ def __init__(self, type: str, console: Console):
382
+ self._type = type
383
+
384
+ if self._type == "download":
385
+ title = "Downloading file(s) to local..."
386
+ elif self._type == "upload":
387
+ title = "Uploading file(s) to volume..."
388
+ else:
389
+ raise NotImplementedError(f"Progress handler of type: `{type}` not yet implemented")
390
+
391
+ self._spinner = OutputManager.step_progress(title)
392
+
393
+ self._overall_progress = Progress(
394
+ TextColumn(f"[bold white]{title}", justify="right"),
395
+ TimeElapsedColumn(),
396
+ BarColumn(bar_width=None),
397
+ TextColumn("[bold white]{task.description}"),
398
+ transient=True,
399
+ console=console,
400
+ )
401
+ self._download_progress = Progress(
402
+ TextColumn("[bold white]{task.fields[path]}", justify="right"),
403
+ BarColumn(bar_width=None),
404
+ "[progress.percentage]{task.percentage:>3.1f}%",
405
+ "•",
406
+ DownloadColumn(),
407
+ "•",
408
+ TransferSpeedColumn(),
409
+ "•",
410
+ TimeRemainingColumn(),
411
+ transient=True,
412
+ console=console,
413
+ )
414
+
415
+ self.live = Live(
416
+ Group(self._spinner, self._overall_progress, self._download_progress), transient=True, refresh_per_second=4
417
+ )
418
+
419
+ self._overall_progress_task_id = self._overall_progress.add_task(".", start=True)
420
+ self._total_tasks = 0
421
+ self._completed_tasks = 0
422
+
423
+ def _add_sub_task(self, name: str, size: float) -> TaskID:
424
+ task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size)
425
+ self._total_tasks += 1
426
+ self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks)
427
+ return task_id
428
+
429
+ def _reset_sub_task(self, task_id: TaskID):
430
+ self._download_progress.reset(task_id)
431
+
432
+ def _complete_progress(self):
433
+ # TODO: we could probably implement some callback progression from the server
434
+ # to get progress reports for the post processing too
435
+ # so we don't have to just spin here
436
+ self._overall_progress.remove_task(self._overall_progress_task_id)
437
+ self._spinner.update(text="Post processing...")
438
+
439
+ def _complete_sub_task(self, task_id: TaskID):
440
+ self._completed_tasks += 1
441
+ self._download_progress.remove_task(task_id)
442
+ self._overall_progress.update(
443
+ self._overall_progress_task_id,
444
+ advance=1,
445
+ description=f"({self._completed_tasks} out of {self._total_tasks} files completed)",
446
+ )
447
+
448
+ def _advance_sub_task(self, task_id: TaskID, advance: float):
449
+ self._download_progress.update(task_id, advance=advance)
450
+
451
+ def progress(
452
+ self,
453
+ task_id: TaskID | None = None,
454
+ advance: float | None = None,
455
+ name: str | None = None,
456
+ size: float | None = None,
457
+ reset: bool | None = False,
458
+ complete: bool | None = False,
459
+ ) -> TaskID | None:
460
+ try:
461
+ if task_id is not None:
462
+ if reset:
463
+ return self._reset_sub_task(task_id)
464
+ elif complete:
465
+ return self._complete_sub_task(task_id)
466
+ elif advance is not None:
467
+ return self._advance_sub_task(task_id, advance)
468
+ elif name is not None and size is not None:
469
+ return self._add_sub_task(name, size)
470
+ elif complete:
471
+ return self._complete_progress()
472
+ except Exception as exc:
473
+ # Liberal exception handling to avoid crashing downloads and uploads.
474
+ logger.error(f"failed progress update: {exc}")
475
+ raise NotImplementedError(
476
+ "Unknown action to take with args: "
477
+ + f"name={name} "
478
+ + f"size={size} "
479
+ + f"task_id={task_id} "
480
+ + f"advance={advance} "
481
+ + f"reset={reset} "
482
+ + f"complete={complete} "
483
+ )
379
484
 
380
485
 
381
486
  async def stream_pty_shell_input(client: _Client, exec_id: str, finish_event: asyncio.Event):
@@ -396,10 +501,42 @@ async def stream_pty_shell_input(client: _Client, exec_id: str, finish_event: as
396
501
  await finish_event.wait()
397
502
 
398
503
 
399
- async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputManager):
504
+ async def put_pty_content(log: api_pb2.TaskLogs, stdout):
505
+ if hasattr(stdout, "buffer"):
506
+ # If we're not showing progress, there's no need to buffer lines,
507
+ # because the progress spinner can't interfere with output.
508
+
509
+ data = log.data.encode("utf-8")
510
+ written = 0
511
+ n_retries = 0
512
+ while written < len(data):
513
+ try:
514
+ written += stdout.buffer.write(data[written:])
515
+ stdout.flush()
516
+ except BlockingIOError:
517
+ if n_retries >= 5:
518
+ raise
519
+ n_retries += 1
520
+ await asyncio.sleep(0.1)
521
+ else:
522
+ # `stdout` isn't always buffered (e.g. %%capture in Jupyter notebooks redirects it to
523
+ # io.StringIO).
524
+ stdout.write(log.data)
525
+ stdout.flush()
526
+
527
+
528
+ async def get_app_logs_loop(
529
+ client: _Client,
530
+ output_mgr: OutputManager,
531
+ app_id: str | None = None,
532
+ task_id: str | None = None,
533
+ app_logs_url: str | None = None,
534
+ ):
400
535
  last_log_batch_entry_id = ""
401
- pty_shell_finish_event: Optional[asyncio.Event] = None
402
- pty_shell_task_id: Optional[str] = None
536
+
537
+ pty_shell_stdout = None
538
+ pty_shell_finish_event: asyncio.Event | None = None
539
+ pty_shell_task_id: str | None = None
403
540
 
404
541
  async def stop_pty_shell():
405
542
  nonlocal pty_shell_finish_event
@@ -428,18 +565,23 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
428
565
  else: # Ensure forward-compatible with new types.
429
566
  logger.debug(f"Received unrecognized progress type: {log.task_progress.progress_type}")
430
567
  elif log.data:
431
- await output_mgr.put_log_content(log)
568
+ if pty_shell_finish_event:
569
+ await put_pty_content(log, pty_shell_stdout)
570
+ else:
571
+ await output_mgr.put_log_content(log)
432
572
 
433
573
  async def _get_logs():
434
- nonlocal last_log_batch_entry_id, pty_shell_finish_event, pty_shell_task_id
574
+ nonlocal last_log_batch_entry_id
575
+ nonlocal pty_shell_stdout, pty_shell_finish_event, pty_shell_task_id
435
576
 
436
577
  request = api_pb2.AppGetLogsRequest(
437
- app_id=app_id,
578
+ app_id=app_id or "",
579
+ task_id=task_id or "",
438
580
  timeout=55,
439
581
  last_entry_id=last_log_batch_entry_id,
440
582
  )
441
583
  log_batch: api_pb2.TaskLogsBatch
442
- async for log_batch in unary_stream(client.stub.AppGetLogs, request):
584
+ async for log_batch in client.stub.AppGetLogs.unary_stream(request):
443
585
  if log_batch.entry_id:
444
586
  # log_batch entry_id is empty for fd="server" messages from AppGetLogs
445
587
  last_log_batch_entry_id = log_batch.entry_id
@@ -456,14 +598,15 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
456
598
  # statically and dynamically built images.
457
599
  pass
458
600
  elif log_batch.pty_exec_id:
601
+ # This corresponds to the `modal run -i` use case where a breakpoint
602
+ # triggers and the task drops into an interactive PTY mode
459
603
  if pty_shell_finish_event:
460
604
  print("ERROR: concurrent PTY shells are not supported.")
461
605
  else:
462
- output_mgr.flush_lines()
463
- output_mgr.hide_status_spinner()
464
- output_mgr._visible_progress = False
606
+ pty_shell_stdout = output_mgr._stdout
465
607
  pty_shell_finish_event = asyncio.Event()
466
608
  pty_shell_task_id = log_batch.task_id
609
+ output_mgr.disable()
467
610
  asyncio.create_task(stream_pty_shell_input(client, log_batch.pty_exec_id, pty_shell_finish_event))
468
611
  else:
469
612
  for log in log_batch.items:
@@ -477,14 +620,7 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
477
620
  while True:
478
621
  try:
479
622
  await _get_logs()
480
- except asyncio.CancelledError:
481
- # TODO: this should come from the backend maybe
482
- app_logs_url = f"https://modal.com/logs/{app_id}"
483
- output_mgr.print_if_visible(
484
- f"[red]Timed out waiting for logs. [grey70]View logs at [underline]{app_logs_url}[/underline] for remaining output.[/grey70]"
485
- )
486
- raise
487
- except (GRPCError, StreamTerminatedError) as exc:
623
+ except (GRPCError, StreamTerminatedError, socket.gaierror, AttributeError) as exc:
488
624
  if isinstance(exc, GRPCError):
489
625
  if exc.status in RETRYABLE_GRPC_STATUS_CODES:
490
626
  # Try again if we had a temporary connection drop,
@@ -494,10 +630,18 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
494
630
  elif isinstance(exc, StreamTerminatedError):
495
631
  logger.debug("Stream closed. Retrying ...")
496
632
  continue
633
+ elif isinstance(exc, socket.gaierror):
634
+ logger.debug("Lost connection. Retrying ...")
635
+ continue
636
+ elif isinstance(exc, AttributeError):
637
+ if "_write_appdata" in str(exc):
638
+ # Happens after losing connection
639
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
640
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
641
+ # TODO: update to newer version (>=0.4.8) once stable
642
+ logger.debug("Lost connection. Retrying ...")
643
+ continue
497
644
  raise
498
- except Exception as exc:
499
- logger.exception(f"Failed to fetch logs: {exc}")
500
- await asyncio.sleep(1)
501
645
 
502
646
  if last_log_batch_entry_id is None:
503
647
  break
modal/_pty.py CHANGED
@@ -2,12 +2,12 @@
2
2
  import contextlib
3
3
  import os
4
4
  import sys
5
- from typing import Optional, Tuple
5
+ from typing import Optional
6
6
 
7
7
  from modal_proto import api_pb2
8
8
 
9
9
 
10
- def get_winsz(fd) -> Tuple[Optional[int], Optional[int]]:
10
+ def get_winsz(fd) -> tuple[Optional[int], Optional[int]]:
11
11
  try:
12
12
  import fcntl
13
13
  import struct