modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/_ipython.py CHANGED
@@ -1,21 +1,11 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import sys
3
- import warnings
4
-
5
- ipy_outstream = None
6
- try:
7
- with warnings.catch_warnings():
8
- warnings.simplefilter("ignore")
9
- import ipykernel.iostream
10
-
11
- ipy_outstream = ipykernel.iostream.OutStream
12
- except ImportError:
13
- pass
14
3
 
15
4
 
16
5
  def is_notebook(stdout=None):
17
- if ipy_outstream is None:
6
+ ipykernel_iostream = sys.modules.get("ipykernel.iostream")
7
+ if ipykernel_iostream is None:
18
8
  return False
19
9
  if stdout is None:
20
10
  stdout = sys.stdout
21
- return isinstance(stdout, ipy_outstream)
11
+ return isinstance(stdout, ipykernel_iostream.OutStream)
modal/_location.py CHANGED
@@ -1,33 +1,40 @@
1
1
  # Copyright Modal Labs 2022
2
2
  from enum import Enum
3
3
 
4
- from modal_proto import api_pb2
4
+ import modal_proto.api_pb2
5
5
 
6
6
  from .exception import InvalidError
7
7
 
8
8
 
9
9
  class CloudProvider(Enum):
10
- AWS = api_pb2.CLOUD_PROVIDER_AWS
11
- GCP = api_pb2.CLOUD_PROVIDER_GCP
12
- AUTO = api_pb2.CLOUD_PROVIDER_AUTO
13
- OCI = api_pb2.CLOUD_PROVIDER_OCI
10
+ AWS = modal_proto.api_pb2.CLOUD_PROVIDER_AWS
11
+ GCP = modal_proto.api_pb2.CLOUD_PROVIDER_GCP
12
+ AUTO = modal_proto.api_pb2.CLOUD_PROVIDER_AUTO
13
+ OCI = modal_proto.api_pb2.CLOUD_PROVIDER_OCI
14
14
 
15
15
 
16
- def parse_cloud_provider(value: str) -> "api_pb2.CloudProvider.V":
16
+ def parse_cloud_provider(value: str) -> "modal_proto.api_pb2.CloudProvider.V":
17
17
  try:
18
18
  cloud_provider = CloudProvider[value.upper()]
19
19
  except KeyError:
20
+ # provider's int identifier may be directly specified
21
+ try:
22
+ return int(value) # type: ignore
23
+ except ValueError:
24
+ pass
25
+
20
26
  raise InvalidError(
21
- f"Invalid cloud provider: {value}. Value must be one of {[x.name.lower() for x in CloudProvider]} (case-insensitive)."
27
+ f"Invalid cloud provider: {value}. "
28
+ f"Value must be one of {[x.name.lower() for x in CloudProvider]} (case-insensitive)."
22
29
  )
23
30
 
24
31
  return cloud_provider.value
25
32
 
26
33
 
27
- def display_location(cloud_provider: "api_pb2.CloudProvider.V") -> str:
28
- if cloud_provider == api_pb2.CLOUD_PROVIDER_GCP:
34
+ def display_location(cloud_provider: "modal_proto.api_pb2.CloudProvider.V") -> str:
35
+ if cloud_provider == modal_proto.api_pb2.CLOUD_PROVIDER_GCP:
29
36
  return "GCP (us-east1)"
30
- elif cloud_provider == api_pb2.CLOUD_PROVIDER_AWS:
37
+ elif cloud_provider == modal_proto.api_pb2.CLOUD_PROVIDER_AWS:
31
38
  return "AWS (us-east-1)"
32
39
  else:
33
40
  return ""
modal/_output.py CHANGED
@@ -7,9 +7,11 @@ import functools
7
7
  import io
8
8
  import platform
9
9
  import re
10
+ import socket
10
11
  import sys
12
+ from collections.abc import Generator
11
13
  from datetime import timedelta
12
- from typing import Callable, Dict, Optional, Tuple
14
+ from typing import Callable, ClassVar
13
15
 
14
16
  from grpclib.exceptions import GRPCError, StreamTerminatedError
15
17
  from rich.console import Console, Group, RenderableType
@@ -32,7 +34,7 @@ from rich.text import Text
32
34
 
33
35
  from modal_proto import api_pb2
34
36
 
35
- from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors, unary_stream
37
+ from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
36
38
  from ._utils.shell_utils import stream_from_stdin
37
39
  from .client import _Client
38
40
  from .config import logger
@@ -60,25 +62,6 @@ class FunctionQueuingColumn(ProgressColumn):
60
62
  return Text(str(delta), style="progress.elapsed")
61
63
 
62
64
 
63
- def step_progress(text: str = "") -> Spinner:
64
- """Returns the element to be rendered when a step is in progress."""
65
- return Spinner(default_spinner, text, style="blue")
66
-
67
-
68
- def step_progress_update(spinner: Spinner, message: str):
69
- spinner.update(text=message)
70
-
71
-
72
- def step_completed(message: str, is_substep: bool = False) -> RenderableType:
73
- """Returns the element to be rendered when a step is completed."""
74
-
75
- STEP_COMPLETED = "[green]✓[/green]"
76
- SUBSTEP_COMPLETED = "🔨"
77
-
78
- symbol = SUBSTEP_COMPLETED if is_substep else STEP_COMPLETED
79
- return f"{symbol} {message}"
80
-
81
-
82
65
  def download_progress_bar() -> Progress:
83
66
  """
84
67
  Returns a progress bar suitable for showing file download progress.
@@ -139,24 +122,28 @@ class LineBufferedOutput(io.StringIO):
139
122
 
140
123
 
141
124
  class OutputManager:
142
- _visible_progress: bool
125
+ _instance: ClassVar[OutputManager | None] = None
126
+
143
127
  _console: Console
144
- _task_states: Dict[str, int]
145
- _task_progress_items: Dict[Tuple[str, int], TaskID]
146
- _current_render_group: Optional[Group]
147
- _function_progress: Optional[Progress]
148
- _function_queueing_progress: Optional[Progress]
149
- _snapshot_progress: Optional[Progress]
150
- _line_buffers: Dict[int, LineBufferedOutput]
128
+ _task_states: dict[str, int]
129
+ _task_progress_items: dict[tuple[str, int], TaskID]
130
+ _current_render_group: Group | None
131
+ _function_progress: Progress | None
132
+ _function_queueing_progress: Progress | None
133
+ _snapshot_progress: Progress | None
134
+ _line_buffers: dict[int, LineBufferedOutput]
151
135
  _status_spinner: Spinner
152
- _app_page_url: Optional[str]
136
+ _app_page_url: str | None
153
137
  _show_image_logs: bool
154
- _status_spinner_live: Optional[Live]
155
-
156
- def __init__(self, stdout: io.TextIOWrapper, show_progress: bool, status_spinner_text: str = "Running app..."):
157
- self.stdout = stdout or sys.stdout
158
-
159
- self._visible_progress = show_progress
138
+ _status_spinner_live: Live | None
139
+
140
+ def __init__(
141
+ self,
142
+ *,
143
+ stdout: io.TextIOWrapper | None = None,
144
+ status_spinner_text: str = "Running app...",
145
+ ):
146
+ self._stdout = stdout or sys.stdout
160
147
  self._console = Console(file=stdout, highlight=False)
161
148
  self._task_states = {}
162
149
  self._task_progress_items = {}
@@ -165,18 +152,47 @@ class OutputManager:
165
152
  self._function_queueing_progress = None
166
153
  self._snapshot_progress = None
167
154
  self._line_buffers = {}
168
- self._status_spinner = step_progress(status_spinner_text)
155
+ self._status_spinner = OutputManager.step_progress(status_spinner_text)
169
156
  self._app_page_url = None
170
157
  self._show_image_logs = False
158
+ self._status_spinner_live = None
171
159
 
172
- def print_if_visible(self, renderable) -> None:
173
- if self._visible_progress:
174
- self._console.print(renderable)
160
+ @classmethod
161
+ def disable(cls):
162
+ cls._instance.flush_lines()
163
+ if cls._instance._status_spinner_live:
164
+ cls._instance._status_spinner_live.stop()
165
+ cls._instance = None
175
166
 
176
- def ctx_if_visible(self, context_mgr):
177
- if self._visible_progress:
178
- return context_mgr
179
- return contextlib.nullcontext()
167
+ @classmethod
168
+ def get(cls) -> OutputManager | None:
169
+ return cls._instance
170
+
171
+ @classmethod
172
+ @contextlib.contextmanager
173
+ def enable_output(cls, show_progress: bool = True) -> Generator[None]:
174
+ if show_progress:
175
+ cls._instance = OutputManager()
176
+ try:
177
+ yield
178
+ finally:
179
+ cls._instance = None
180
+
181
+ @staticmethod
182
+ def step_progress(text: str = "") -> Spinner:
183
+ """Returns the element to be rendered when a step is in progress."""
184
+ return Spinner(default_spinner, text, style="blue")
185
+
186
+ @staticmethod
187
+ def step_completed(message: str) -> RenderableType:
188
+ return f"[green]✓[/green] {message}"
189
+
190
+ @staticmethod
191
+ def substep_completed(message: str) -> RenderableType:
192
+ return f"🔨 {message}"
193
+
194
+ def print(self, renderable) -> None:
195
+ self._console.print(renderable)
180
196
 
181
197
  def make_live(self, renderable: RenderableType) -> Live:
182
198
  """Creates a customized `rich.Live` instance with the given renderable. The renderable
@@ -237,7 +253,7 @@ class OutputManager:
237
253
  self._current_render_group.renderables.append(self._function_queueing_progress)
238
254
  return self._function_queueing_progress
239
255
 
240
- def function_progress_callback(self, tag: str, total: Optional[int]) -> Callable[[int, int], None]:
256
+ def function_progress_callback(self, tag: str, total: int | None) -> Callable[[int, int], None]:
241
257
  """Adds a task to the current function_progress instance, and returns a callback
242
258
  to update task progress with new completed and total counts."""
243
259
 
@@ -292,7 +308,7 @@ class OutputManager:
292
308
  message = f"[blue]{message}[/blue] [grey70]View app at [underline]{self._app_page_url}[/underline][/grey70]"
293
309
 
294
310
  # Set the new message
295
- step_progress_update(self._status_spinner, message)
311
+ self._status_spinner.update(text=message)
296
312
 
297
313
  def update_snapshot_progress(self, image_id: str, task_progress: api_pb2.TaskProgress) -> None:
298
314
  # TODO(erikbern): move this to sit on the resolver object, mostly
@@ -315,7 +331,7 @@ class OutputManager:
315
331
  pass
316
332
 
317
333
  def update_queueing_progress(
318
- self, *, function_id: str, completed: int, total: Optional[int], description: Optional[str]
334
+ self, *, function_id: str, completed: int, total: int | None, description: str | None
319
335
  ) -> None:
320
336
  """Handle queueing updates, ignoring completion updates for functions that have no queue progress bar."""
321
337
  task_key = (function_id, api_pb2.FUNCTION_QUEUED)
@@ -335,33 +351,11 @@ class OutputManager:
335
351
  self._task_progress_items[task_key] = progress_task_id
336
352
 
337
353
  async def put_log_content(self, log: api_pb2.TaskLogs):
338
- if self._visible_progress:
339
- stream = self._line_buffers.get(log.file_descriptor)
340
- if stream is None:
341
- stream = LineBufferedOutput(functools.partial(self._print_log, log.file_descriptor))
342
- self._line_buffers[log.file_descriptor] = stream
343
- stream.write(log.data)
344
- elif hasattr(self.stdout, "buffer"):
345
- # If we're not showing progress, there's no need to buffer lines,
346
- # because the progress spinner can't interfere with output.
347
-
348
- data = log.data.encode("utf-8")
349
- written = 0
350
- n_retries = 0
351
- while written < len(data):
352
- try:
353
- written += self.stdout.buffer.write(data[written:])
354
- self.stdout.flush()
355
- except BlockingIOError:
356
- if n_retries >= 5:
357
- raise
358
- n_retries += 1
359
- await asyncio.sleep(0.1)
360
- else:
361
- # `stdout` isn't always buffered (e.g. %%capture in Jupyter notebooks redirects it to
362
- # io.StringIO).
363
- self.stdout.write(log.data)
364
- self.stdout.flush()
354
+ stream = self._line_buffers.get(log.file_descriptor)
355
+ if stream is None:
356
+ stream = LineBufferedOutput(functools.partial(self._print_log, log.file_descriptor))
357
+ self._line_buffers[log.file_descriptor] = stream
358
+ stream.write(log.data)
365
359
 
366
360
  def flush_lines(self):
367
361
  for stream in self._line_buffers.values():
@@ -370,12 +364,123 @@ class OutputManager:
370
364
  @contextlib.contextmanager
371
365
  def show_status_spinner(self):
372
366
  self._status_spinner_live = self.make_live(self._status_spinner)
373
- with self.ctx_if_visible(self._status_spinner_live):
367
+ with self._status_spinner_live:
374
368
  yield
375
369
 
376
- def hide_status_spinner(self):
377
- if self._status_spinner_live:
378
- self._status_spinner_live.stop()
370
+
371
+ class ProgressHandler:
372
+ live: Live
373
+ _type: str
374
+ _spinner: Spinner
375
+ _overall_progress: Progress
376
+ _download_progress: Progress
377
+ _overall_progress_task_id: TaskID
378
+ _total_tasks: int
379
+ _completed_tasks: int
380
+
381
+ def __init__(self, type: str, console: Console):
382
+ self._type = type
383
+
384
+ if self._type == "download":
385
+ title = "Downloading file(s) to local..."
386
+ elif self._type == "upload":
387
+ title = "Uploading file(s) to volume..."
388
+ else:
389
+ raise NotImplementedError(f"Progress handler of type: `{type}` not yet implemented")
390
+
391
+ self._spinner = OutputManager.step_progress(title)
392
+
393
+ self._overall_progress = Progress(
394
+ TextColumn(f"[bold white]{title}", justify="right"),
395
+ TimeElapsedColumn(),
396
+ BarColumn(bar_width=None),
397
+ TextColumn("[bold white]{task.description}"),
398
+ transient=True,
399
+ console=console,
400
+ )
401
+ self._download_progress = Progress(
402
+ TextColumn("[bold white]{task.fields[path]}", justify="right"),
403
+ BarColumn(bar_width=None),
404
+ "[progress.percentage]{task.percentage:>3.1f}%",
405
+ "•",
406
+ DownloadColumn(),
407
+ "•",
408
+ TransferSpeedColumn(),
409
+ "•",
410
+ TimeRemainingColumn(),
411
+ transient=True,
412
+ console=console,
413
+ )
414
+
415
+ self.live = Live(
416
+ Group(self._spinner, self._overall_progress, self._download_progress), transient=True, refresh_per_second=4
417
+ )
418
+
419
+ self._overall_progress_task_id = self._overall_progress.add_task(".", start=True)
420
+ self._total_tasks = 0
421
+ self._completed_tasks = 0
422
+
423
+ def _add_sub_task(self, name: str, size: float) -> TaskID:
424
+ task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size)
425
+ self._total_tasks += 1
426
+ self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks)
427
+ return task_id
428
+
429
+ def _reset_sub_task(self, task_id: TaskID):
430
+ self._download_progress.reset(task_id)
431
+
432
+ def _complete_progress(self):
433
+ # TODO: we could probably implement some callback progression from the server
434
+ # to get progress reports for the post processing too
435
+ # so we don't have to just spin here
436
+ self._overall_progress.remove_task(self._overall_progress_task_id)
437
+ self._spinner.update(text="Post processing...")
438
+
439
+ def _complete_sub_task(self, task_id: TaskID):
440
+ self._completed_tasks += 1
441
+ self._download_progress.remove_task(task_id)
442
+ self._overall_progress.update(
443
+ self._overall_progress_task_id,
444
+ advance=1,
445
+ description=f"({self._completed_tasks} out of {self._total_tasks} files completed)",
446
+ )
447
+
448
+ def _advance_sub_task(self, task_id: TaskID, advance: float):
449
+ self._download_progress.update(task_id, advance=advance)
450
+
451
+ def progress(
452
+ self,
453
+ task_id: TaskID | None = None,
454
+ advance: float | None = None,
455
+ name: str | None = None,
456
+ size: float | None = None,
457
+ reset: bool | None = False,
458
+ complete: bool | None = False,
459
+ ) -> TaskID | None:
460
+ try:
461
+ if task_id is not None:
462
+ if reset:
463
+ return self._reset_sub_task(task_id)
464
+ elif complete:
465
+ return self._complete_sub_task(task_id)
466
+ elif advance is not None:
467
+ return self._advance_sub_task(task_id, advance)
468
+ elif name is not None and size is not None:
469
+ return self._add_sub_task(name, size)
470
+ elif complete:
471
+ return self._complete_progress()
472
+ except Exception as exc:
473
+ # Liberal exception handling to avoid crashing downloads and uploads.
474
+ logger.error(f"failed progress update: {exc}")
475
+ raise NotImplementedError(
476
+ "Unknown action to take with args: "
477
+ + f"name={name} "
478
+ + f"size={size} "
479
+ + f"task_id={task_id} "
480
+ + f"advance={advance} "
481
+ + f"reset={reset} "
482
+ + f"complete={complete} "
483
+ )
379
484
 
380
485
 
381
486
  async def stream_pty_shell_input(client: _Client, exec_id: str, finish_event: asyncio.Event):
@@ -396,10 +501,42 @@ async def stream_pty_shell_input(client: _Client, exec_id: str, finish_event: as
396
501
  await finish_event.wait()
397
502
 
398
503
 
399
- async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputManager):
504
+ async def put_pty_content(log: api_pb2.TaskLogs, stdout):
505
+ if hasattr(stdout, "buffer"):
506
+ # If we're not showing progress, there's no need to buffer lines,
507
+ # because the progress spinner can't interfere with output.
508
+
509
+ data = log.data.encode("utf-8")
510
+ written = 0
511
+ n_retries = 0
512
+ while written < len(data):
513
+ try:
514
+ written += stdout.buffer.write(data[written:])
515
+ stdout.flush()
516
+ except BlockingIOError:
517
+ if n_retries >= 5:
518
+ raise
519
+ n_retries += 1
520
+ await asyncio.sleep(0.1)
521
+ else:
522
+ # `stdout` isn't always buffered (e.g. %%capture in Jupyter notebooks redirects it to
523
+ # io.StringIO).
524
+ stdout.write(log.data)
525
+ stdout.flush()
526
+
527
+
528
+ async def get_app_logs_loop(
529
+ client: _Client,
530
+ output_mgr: OutputManager,
531
+ app_id: str | None = None,
532
+ task_id: str | None = None,
533
+ app_logs_url: str | None = None,
534
+ ):
400
535
  last_log_batch_entry_id = ""
401
- pty_shell_finish_event: Optional[asyncio.Event] = None
402
- pty_shell_task_id: Optional[str] = None
536
+
537
+ pty_shell_stdout = None
538
+ pty_shell_finish_event: asyncio.Event | None = None
539
+ pty_shell_task_id: str | None = None
403
540
 
404
541
  async def stop_pty_shell():
405
542
  nonlocal pty_shell_finish_event
@@ -428,18 +565,23 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
428
565
  else: # Ensure forward-compatible with new types.
429
566
  logger.debug(f"Received unrecognized progress type: {log.task_progress.progress_type}")
430
567
  elif log.data:
431
- await output_mgr.put_log_content(log)
568
+ if pty_shell_finish_event:
569
+ await put_pty_content(log, pty_shell_stdout)
570
+ else:
571
+ await output_mgr.put_log_content(log)
432
572
 
433
573
  async def _get_logs():
434
- nonlocal last_log_batch_entry_id, pty_shell_finish_event, pty_shell_task_id
574
+ nonlocal last_log_batch_entry_id
575
+ nonlocal pty_shell_stdout, pty_shell_finish_event, pty_shell_task_id
435
576
 
436
577
  request = api_pb2.AppGetLogsRequest(
437
- app_id=app_id,
578
+ app_id=app_id or "",
579
+ task_id=task_id or "",
438
580
  timeout=55,
439
581
  last_entry_id=last_log_batch_entry_id,
440
582
  )
441
583
  log_batch: api_pb2.TaskLogsBatch
442
- async for log_batch in unary_stream(client.stub.AppGetLogs, request):
584
+ async for log_batch in client.stub.AppGetLogs.unary_stream(request):
443
585
  if log_batch.entry_id:
444
586
  # log_batch entry_id is empty for fd="server" messages from AppGetLogs
445
587
  last_log_batch_entry_id = log_batch.entry_id
@@ -456,14 +598,15 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
456
598
  # statically and dynamically built images.
457
599
  pass
458
600
  elif log_batch.pty_exec_id:
601
+ # This corresponds to the `modal run -i` use case where a breakpoint
602
+ # triggers and the task drops into an interactive PTY mode
459
603
  if pty_shell_finish_event:
460
604
  print("ERROR: concurrent PTY shells are not supported.")
461
605
  else:
462
- output_mgr.flush_lines()
463
- output_mgr.hide_status_spinner()
464
- output_mgr._visible_progress = False
606
+ pty_shell_stdout = output_mgr._stdout
465
607
  pty_shell_finish_event = asyncio.Event()
466
608
  pty_shell_task_id = log_batch.task_id
609
+ output_mgr.disable()
467
610
  asyncio.create_task(stream_pty_shell_input(client, log_batch.pty_exec_id, pty_shell_finish_event))
468
611
  else:
469
612
  for log in log_batch.items:
@@ -477,14 +620,7 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
477
620
  while True:
478
621
  try:
479
622
  await _get_logs()
480
- except asyncio.CancelledError:
481
- # TODO: this should come from the backend maybe
482
- app_logs_url = f"https://modal.com/logs/{app_id}"
483
- output_mgr.print_if_visible(
484
- f"[red]Timed out waiting for logs. [grey70]View logs at [underline]{app_logs_url}[/underline] for remaining output.[/grey70]"
485
- )
486
- raise
487
- except (GRPCError, StreamTerminatedError) as exc:
623
+ except (GRPCError, StreamTerminatedError, socket.gaierror, AttributeError) as exc:
488
624
  if isinstance(exc, GRPCError):
489
625
  if exc.status in RETRYABLE_GRPC_STATUS_CODES:
490
626
  # Try again if we had a temporary connection drop,
@@ -494,10 +630,18 @@ async def get_app_logs_loop(app_id: str, client: _Client, output_mgr: OutputMana
494
630
  elif isinstance(exc, StreamTerminatedError):
495
631
  logger.debug("Stream closed. Retrying ...")
496
632
  continue
633
+ elif isinstance(exc, socket.gaierror):
634
+ logger.debug("Lost connection. Retrying ...")
635
+ continue
636
+ elif isinstance(exc, AttributeError):
637
+ if "_write_appdata" in str(exc):
638
+ # Happens after losing connection
639
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
640
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
641
+ # TODO: update to newer version (>=0.4.8) once stable
642
+ logger.debug("Lost connection. Retrying ...")
643
+ continue
497
644
  raise
498
- except Exception as exc:
499
- logger.exception(f"Failed to fetch logs: {exc}")
500
- await asyncio.sleep(1)
501
645
 
502
646
  if last_log_batch_entry_id is None:
503
647
  break
modal/_pty.py CHANGED
@@ -2,12 +2,12 @@
2
2
  import contextlib
3
3
  import os
4
4
  import sys
5
- from typing import Optional, Tuple
5
+ from typing import Optional
6
6
 
7
7
  from modal_proto import api_pb2
8
8
 
9
9
 
10
- def get_winsz(fd) -> Tuple[Optional[int], Optional[int]]:
10
+ def get_winsz(fd) -> tuple[Optional[int], Optional[int]]:
11
11
  try:
12
12
  import fcntl
13
13
  import struct