modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,6 @@ from modal_proto.modal_api_grpc import ModalClientModal
27
27
 
28
28
  from ..exception import ExecutionError
29
29
  from .async_utils import TaskContext, retry
30
- from .grpc_utils import retry_transient_errors
31
30
  from .hash_utils import UploadHashes, get_upload_hashes
32
31
  from .http_utils import ClientSessionRegistry
33
32
  from .logger import logger
@@ -85,7 +84,7 @@ async def _upload_to_s3_url(
85
84
  ) as resp:
86
85
  # S3 signal to slow down request rate.
87
86
  if resp.status == 503:
88
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
87
+ logger.debug("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
89
88
  await asyncio.sleep(1)
90
89
 
91
90
  if resp.status != 200:
@@ -188,16 +187,10 @@ def get_content_length(data: BinaryIO) -> int:
188
187
  return content_length - pos
189
188
 
190
189
 
191
- async def _measure_endpoint_latency(item: str) -> int:
192
- latency_ms = 0
193
- t0 = time.monotonic_ns()
194
- async with ClientSessionRegistry.get_session().head(item) as _:
195
- latency_ms = (time.monotonic_ns() - t0) // 1_000_000
196
- return latency_ms
197
-
198
-
199
- async def _blob_upload_with_fallback(items, blob_ids: list[str], callback) -> tuple[str, bool, int]:
200
- r2_latency_ms = 0
190
+ async def _blob_upload_with_fallback(
191
+ items, blob_ids: list[str], callback, content_length: int
192
+ ) -> tuple[str, bool, int]:
193
+ r2_throughput_bytes_s = 0
201
194
  r2_failed = False
202
195
  for idx, (item, blob_id) in enumerate(zip(items, blob_ids)):
203
196
  # We want to default to R2 95% of the time and S3 5% of the time.
@@ -206,14 +199,13 @@ async def _blob_upload_with_fallback(items, blob_ids: list[str], callback) -> tu
206
199
  continue
207
200
  try:
208
201
  if blob_id.endswith(":r2"):
209
- # measure the time it takes to contact the bucket endpoint
210
- r2_latency_ms, _ = await asyncio.gather(
211
- _measure_endpoint_latency(item),
212
- callback(item),
213
- )
202
+ t0 = time.monotonic_ns()
203
+ await callback(item)
204
+ dt_ns = time.monotonic_ns() - t0
205
+ r2_throughput_bytes_s = (content_length * 1_000_000_000) // max(dt_ns, 1)
214
206
  else:
215
207
  await callback(item)
216
- return blob_id, r2_failed, r2_latency_ms
208
+ return blob_id, r2_failed, r2_throughput_bytes_s
217
209
  except Exception as _:
218
210
  if blob_id.endswith(":r2"):
219
211
  r2_failed = True
@@ -236,7 +228,7 @@ async def _blob_upload(
236
228
  content_sha256_base64=upload_hashes.sha256_base64,
237
229
  content_length=content_length,
238
230
  )
239
- resp = await retry_transient_errors(stub.BlobCreate, req)
231
+ resp = await stub.BlobCreate(req)
240
232
 
241
233
  if resp.WhichOneof("upload_types_oneof") == "multiparts":
242
234
 
@@ -251,10 +243,11 @@ async def _blob_upload(
251
243
  progress_report_cb=progress_report_cb,
252
244
  )
253
245
 
254
- blob_id, r2_failed, r2_latency_ms = await _blob_upload_with_fallback(
246
+ blob_id, r2_failed, r2_throughput_bytes_s = await _blob_upload_with_fallback(
255
247
  resp.multiparts.items,
256
248
  resp.blob_ids,
257
249
  upload_multipart_upload,
250
+ content_length=content_length,
258
251
  )
259
252
  else:
260
253
  from .bytes_io_segment_payload import BytesIOSegmentPayload
@@ -271,16 +264,17 @@ async def _blob_upload(
271
264
  content_md5_b64=upload_hashes.md5_base64,
272
265
  )
273
266
 
274
- blob_id, r2_failed, r2_latency_ms = await _blob_upload_with_fallback(
267
+ blob_id, r2_failed, r2_throughput_bytes_s = await _blob_upload_with_fallback(
275
268
  resp.upload_urls.items,
276
269
  resp.blob_ids,
277
270
  upload_to_s3_url,
271
+ content_length=content_length,
278
272
  )
279
273
 
280
274
  if progress_report_cb:
281
275
  progress_report_cb(complete=True)
282
276
 
283
- return blob_id, r2_failed, r2_latency_ms
277
+ return blob_id, r2_failed, r2_throughput_bytes_s
284
278
 
285
279
 
286
280
  async def blob_upload_with_r2_failure_info(payload: bytes, stub: ModalClientModal) -> tuple[str, bool, int]:
@@ -288,16 +282,16 @@ async def blob_upload_with_r2_failure_info(payload: bytes, stub: ModalClientModa
288
282
  logger.debug(f"Uploading large blob of size {size_mib:.2f} MiB")
289
283
  t0 = time.time()
290
284
  if isinstance(payload, str):
291
- logger.warning("Blob uploading string, not bytes - auto-encoding as utf8")
285
+ logger.debug("Blob uploading string, not bytes - auto-encoding as utf8")
292
286
  payload = payload.encode("utf8")
293
287
  upload_hashes = get_upload_hashes(payload)
294
- blob_id, r2_failed, r2_latency_ms = await _blob_upload(upload_hashes, payload, stub)
288
+ blob_id, r2_failed, r2_throughput_bytes_s = await _blob_upload(upload_hashes, payload, stub)
295
289
  dur_s = max(time.time() - t0, 0.001) # avoid division by zero
296
290
  throughput_mib_s = (size_mib) / dur_s
297
291
  logger.debug(
298
292
  f"Uploaded large blob of size {size_mib:.2f} MiB ({throughput_mib_s:.2f} MiB/s, total {dur_s:.2f}s). {blob_id}"
299
293
  )
300
- return blob_id, r2_failed, r2_latency_ms
294
+ return blob_id, r2_failed, r2_throughput_bytes_s
301
295
 
302
296
 
303
297
  async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
@@ -305,6 +299,10 @@ async def blob_upload(payload: bytes, stub: ModalClientModal) -> str:
305
299
  return blob_id
306
300
 
307
301
 
302
+ async def format_blob_data(data: bytes, api_stub: ModalClientModal) -> dict[str, Any]:
303
+ return {"data_blob_id": await blob_upload(data, api_stub)} if len(data) > MAX_OBJECT_SIZE_BYTES else {"data": data}
304
+
305
+
308
306
  async def blob_upload_file(
309
307
  file_obj: BinaryIO,
310
308
  stub: ModalClientModal,
@@ -322,7 +320,7 @@ async def _download_from_url(download_url: str) -> bytes:
322
320
  async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
323
321
  # S3 signal to slow down request rate.
324
322
  if s3_resp.status == 503:
325
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
323
+ logger.debug("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
326
324
  await asyncio.sleep(1)
327
325
 
328
326
  if s3_resp.status != 200:
@@ -336,7 +334,7 @@ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
336
334
  logger.debug(f"Downloading large blob {blob_id}")
337
335
  t0 = time.time()
338
336
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
339
- resp = await retry_transient_errors(stub.BlobGet, req)
337
+ resp = await stub.BlobGet(req)
340
338
  data = await _download_from_url(resp.download_url)
341
339
  size_mib = len(data) / 1024 / 1024
342
340
  dur_s = max(time.time() - t0, 0.001) # avoid division by zero
@@ -349,12 +347,12 @@ async def blob_download(blob_id: str, stub: ModalClientModal) -> bytes:
349
347
 
350
348
  async def blob_iter(blob_id: str, stub: ModalClientModal) -> AsyncIterator[bytes]:
351
349
  req = api_pb2.BlobGetRequest(blob_id=blob_id)
352
- resp = await retry_transient_errors(stub.BlobGet, req)
350
+ resp = await stub.BlobGet(req)
353
351
  download_url = resp.download_url
354
352
  async with ClientSessionRegistry.get_session().get(download_url) as s3_resp:
355
353
  # S3 signal to slow down request rate.
356
354
  if s3_resp.status == 503:
357
- logger.warning("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
355
+ logger.debug("Received SlowDown signal from S3, sleeping for 1 second before retrying.")
358
356
  await asyncio.sleep(1)
359
357
 
360
358
  if s3_resp.status != 200:
@@ -449,14 +447,24 @@ def get_file_upload_spec_from_fileobj(fp: BinaryIO, mount_filename: PurePosixPat
449
447
  _FileUploadSource2 = Callable[[], ContextManager[BinaryIO]]
450
448
 
451
449
 
450
+ @dataclasses.dataclass
451
+ class FileUploadBlock:
452
+ # The start (byte offset, inclusive) of the block within the file
453
+ start: int
454
+ # The end (byte offset, exclusive) of the block, after having removed any trailing zeroes
455
+ end: int
456
+ # Raw (unencoded 32 byte) SHA256 sum of the block, not including trailing zeroes
457
+ contents_sha256: bytes
458
+
459
+
452
460
  @dataclasses.dataclass
453
461
  class FileUploadSpec2:
454
462
  source: _FileUploadSource2
455
463
  source_description: Union[str, Path]
456
464
 
457
465
  path: str
458
- # Raw (unencoded 32 byte) SHA256 sum per 8MiB file block
459
- blocks_sha256: list[bytes]
466
+ # 8MiB file blocks
467
+ blocks: list[FileUploadBlock]
460
468
  mode: int # file permission bits (last 12 bits of st_mode)
461
469
  size: int
462
470
 
@@ -527,53 +535,102 @@ class FileUploadSpec2:
527
535
  source_fp.seek(0, os.SEEK_END)
528
536
  size = source_fp.tell()
529
537
 
530
- blocks_sha256 = await hash_blocks_sha256(source, size, hash_semaphore)
538
+ blocks = await _gather_blocks(source, size, hash_semaphore)
531
539
 
532
540
  return FileUploadSpec2(
533
541
  source=source,
534
542
  source_description=source_description,
535
543
  path=mount_filename.as_posix(),
536
- blocks_sha256=blocks_sha256,
544
+ blocks=blocks,
537
545
  mode=mode & 0o7777,
538
546
  size=size,
539
547
  )
540
548
 
541
549
 
542
- async def hash_blocks_sha256(
550
+ async def _gather_blocks(
543
551
  source: _FileUploadSource2,
544
552
  size: int,
545
553
  hash_semaphore: asyncio.Semaphore,
546
- ) -> list[bytes]:
554
+ ) -> list[FileUploadBlock]:
547
555
  def ceildiv(a: int, b: int) -> int:
548
556
  return -(a // -b)
549
557
 
550
558
  num_blocks = ceildiv(size, BLOCK_SIZE)
551
559
 
552
- def blocking_hash_block_sha256(block_idx: int) -> bytes:
553
- sha256_hash = hashlib.sha256()
554
- block_start = block_idx * BLOCK_SIZE
560
+ async def gather_block(block_idx: int) -> FileUploadBlock:
561
+ async with hash_semaphore:
562
+ return await asyncio.to_thread(_gather_block, source, block_idx)
563
+
564
+ tasks = (gather_block(idx) for idx in range(num_blocks))
565
+ return await asyncio.gather(*tasks)
555
566
 
556
- with source() as block_fp:
557
- block_fp.seek(block_start)
558
567
 
559
- num_bytes_read = 0
560
- while num_bytes_read < BLOCK_SIZE:
561
- chunk = block_fp.read(BLOCK_SIZE - num_bytes_read)
568
+ def _gather_block(source: _FileUploadSource2, block_idx: int) -> FileUploadBlock:
569
+ start = block_idx * BLOCK_SIZE
570
+ end = _find_end_of_block(source, start, start + BLOCK_SIZE)
571
+ contents_sha256 = _hash_range_sha256(source, start, end)
572
+ return FileUploadBlock(start=start, end=end, contents_sha256=contents_sha256)
562
573
 
563
- if not chunk:
564
- break
565
574
 
566
- num_bytes_read += len(chunk)
567
- sha256_hash.update(chunk)
575
+ def _hash_range_sha256(source: _FileUploadSource2, start, end):
576
+ sha256_hash = hashlib.sha256()
577
+ range_size = end - start
568
578
 
569
- return sha256_hash.digest()
579
+ with source() as fp:
580
+ fp.seek(start)
581
+
582
+ num_bytes_read = 0
583
+ while num_bytes_read < range_size:
584
+ chunk = fp.read(range_size - num_bytes_read)
585
+
586
+ if not chunk:
587
+ break
588
+
589
+ num_bytes_read += len(chunk)
590
+ sha256_hash.update(chunk)
591
+
592
+ return sha256_hash.digest()
593
+
594
+
595
+ def _find_end_of_block(source: _FileUploadSource2, start: int, end: int) -> Optional[int]:
596
+ """Finds the appropriate end of a block, which is the index of the byte just past the last non-zero byte in the
597
+ block.
598
+
599
+ >>> _find_end_of_block(lambda: BytesIO(b"abc123\0\0\0"), 0, 1024)
600
+ 6
601
+ >>> _find_end_of_block(lambda: BytesIO(b"abc123\0\0\0"), 3, 1024)
602
+ 6
603
+ >>> _find_end_of_block(lambda: BytesIO(b"abc123\0\0\0"), 0, 3)
604
+ 4
605
+ >>> _find_end_of_block(lambda: BytesIO(b"abc123\0\0\0a"), 0, 9)
606
+ 6
607
+ >>> _find_end_of_block(lambda: BytesIO(b"\0\0\0"), 0, 3)
608
+ 0
609
+ >>> _find_end_of_block(lambda: BytesIO(b"\0\0\0\0\0\0"), 3, 6)
610
+ 3
611
+ >>> _find_end_of_block(lambda: BytesIO(b""), 0, 1024)
612
+ 0
613
+ """
614
+ size = end - start
615
+ new_end = start
570
616
 
571
- async def hash_block_sha256(block_idx: int) -> bytes:
572
- async with hash_semaphore:
573
- return await asyncio.to_thread(blocking_hash_block_sha256, block_idx)
617
+ with source() as block_fp:
618
+ block_fp.seek(start)
574
619
 
575
- tasks = (hash_block_sha256(idx) for idx in range(num_blocks))
576
- return await asyncio.gather(*tasks)
620
+ num_bytes_read = 0
621
+ while num_bytes_read < size:
622
+ chunk = block_fp.read(size - num_bytes_read)
623
+
624
+ if not chunk:
625
+ break
626
+
627
+ stripped_chunk = chunk.rstrip(b"\0")
628
+ if stripped_chunk:
629
+ new_end = start + num_bytes_read + len(stripped_chunk)
630
+
631
+ num_bytes_read += len(chunk)
632
+
633
+ return new_end
577
634
 
578
635
 
579
636
  def use_md5(url: str) -> bool:
@@ -1,8 +1,8 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
- import enum
4
3
  import inspect
5
4
  import os
5
+ import typing
6
6
  from collections.abc import AsyncGenerator
7
7
  from enum import Enum
8
8
  from pathlib import Path, PurePosixPath
@@ -18,7 +18,9 @@ from modal_proto.modal_api_grpc import ModalClientModal
18
18
  from .._serialization import (
19
19
  deserialize,
20
20
  deserialize_data_format,
21
+ get_preferred_payload_format,
21
22
  serialize,
23
+ serialize_data_format as _serialize_data_format,
22
24
  signature_to_parameter_specs,
23
25
  )
24
26
  from .._traceback import append_modal_tb
@@ -39,6 +41,9 @@ from .blob_utils import (
39
41
  )
40
42
  from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES
41
43
 
44
+ if typing.TYPE_CHECKING:
45
+ import modal._functions
46
+
42
47
 
43
48
  class FunctionInfoType(Enum):
44
49
  PACKAGE = "package"
@@ -70,6 +75,10 @@ def is_global_object(object_qual_name: str):
70
75
  return "<locals>" not in object_qual_name.split(".")
71
76
 
72
77
 
78
+ def is_flash_object(experimental_options: Optional[dict[str, Any]]) -> bool:
79
+ return experimental_options.get("flash", False) if experimental_options else False
80
+
81
+
73
82
  def is_method_fn(object_qual_name: str):
74
83
  # methods have names like Cls.foo.
75
84
  if "<locals>" in object_qual_name:
@@ -386,9 +395,16 @@ def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:
386
395
 
387
396
 
388
397
  async def _stream_function_call_data(
389
- client, stub, function_call_id: str, variant: Literal["data_in", "data_out"]
398
+ client,
399
+ stub,
400
+ function_call_id: Optional[str],
401
+ variant: Literal["data_in", "data_out"],
402
+ attempt_token: Optional[str] = None,
390
403
  ) -> AsyncGenerator[Any, None]:
391
404
  """Read from the `data_in` or `data_out` stream of a function call."""
405
+ if not function_call_id and not attempt_token:
406
+ raise ValueError("function_call_id or attempt_token is required to read from a data stream")
407
+
392
408
  if stub is None:
393
409
  stub = client.stub
394
410
 
@@ -406,7 +422,12 @@ async def _stream_function_call_data(
406
422
  raise ValueError(f"Invalid variant {variant}")
407
423
 
408
424
  while True:
409
- req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
425
+ req = api_pb2.FunctionCallGetDataRequest(
426
+ function_call_id=function_call_id,
427
+ last_index=last_index,
428
+ )
429
+ if attempt_token:
430
+ req.attempt_token = attempt_token # oneof clears function_call_id.
410
431
  try:
411
432
  async for chunk in stub_fn.unary_stream(req):
412
433
  if chunk.index <= last_index:
@@ -475,7 +496,12 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,
475
496
  elif result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
476
497
  raise InternalFailure(result.exception)
477
498
  elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
478
- if data:
499
+ if data and data_format in (api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_UNSPECIFIED):
500
+ # *Unspecified data format here but data present usually means that the exception
501
+ # was created by the server representing an exception that occurred during container
502
+ # startup (crash looping) that eventually got escalated to input failures.
503
+ # TaskResult doesn't specify data format, so these results don't have that metadata
504
+ # the moment.
479
505
  try:
480
506
  exc = deserialize(data, client)
481
507
  except DeserializationError as deser_exc:
@@ -532,43 +558,52 @@ def should_upload(
532
558
  )
533
559
 
534
560
 
561
+ # This must be called against the client stub, not the input-plane stub.
535
562
  async def _create_input(
536
563
  args,
537
564
  kwargs,
538
565
  stub: ModalClientModal,
539
566
  *,
540
- max_object_size_bytes: int,
567
+ function: "modal._functions._Function",
541
568
  idx: Optional[int] = None,
542
- method_name: Optional[str] = None,
543
569
  function_call_invocation_type: Optional["api_pb2.FunctionCallInvocationType.ValueType"] = None,
544
570
  ) -> api_pb2.FunctionPutInputsItem:
545
571
  """Serialize function arguments and create a FunctionInput protobuf,
546
572
  uploading to blob storage if needed.
547
573
  """
574
+ method_name = function._use_method_name
575
+ max_object_size_bytes = function._max_object_size_bytes
576
+
548
577
  if idx is None:
549
578
  idx = 0
550
- if method_name is None:
551
- method_name = "" # proto compatible
552
579
 
553
- args_serialized = serialize((args, kwargs))
580
+ data_format = get_preferred_payload_format()
581
+ if not function._metadata:
582
+ raise ExecutionError("Attempted to call function that has not been hydrated with metadata")
583
+
584
+ supported_input_formats = function._metadata.supported_input_formats or [api_pb2.DATA_FORMAT_PICKLE]
585
+ if data_format not in supported_input_formats:
586
+ data_format = supported_input_formats[0]
587
+
588
+ args_serialized = _serialize_data_format((args, kwargs), data_format)
554
589
 
555
590
  if should_upload(len(args_serialized), max_object_size_bytes, function_call_invocation_type):
556
- args_blob_id, r2_failed, r2_latency_ms = await blob_upload_with_r2_failure_info(args_serialized, stub)
591
+ args_blob_id, r2_failed, r2_throughput_bytes_s = await blob_upload_with_r2_failure_info(args_serialized, stub)
557
592
  return api_pb2.FunctionPutInputsItem(
558
593
  input=api_pb2.FunctionInput(
559
594
  args_blob_id=args_blob_id,
560
- data_format=api_pb2.DATA_FORMAT_PICKLE,
595
+ data_format=data_format,
561
596
  method_name=method_name,
562
597
  ),
563
598
  idx=idx,
564
599
  r2_failed=r2_failed,
565
- r2_latency_ms=r2_latency_ms,
600
+ r2_throughput_bytes_s=r2_throughput_bytes_s,
566
601
  )
567
602
  else:
568
603
  return api_pb2.FunctionPutInputsItem(
569
604
  input=api_pb2.FunctionInput(
570
605
  args=args_serialized,
571
- data_format=api_pb2.DATA_FORMAT_PICKLE,
606
+ data_format=data_format,
572
607
  method_name=method_name,
573
608
  ),
574
609
  idx=idx,
@@ -610,14 +645,13 @@ class FunctionCreationStatus:
610
645
  if not self.response:
611
646
  self.status_row.finish(f"Unknown error when creating function {self.tag}")
612
647
 
613
- elif self.response.function.web_url:
648
+ elif web_url := self.response.handle_metadata.web_url:
614
649
  url_info = self.response.function.web_url_info
615
650
  requires_proxy_auth = self.response.function.webhook_config.requires_proxy_auth
616
651
  proxy_auth_suffix = " 🔑" if requires_proxy_auth else ""
617
652
  # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc.
618
653
  suffix = _get_suffix_from_web_url_info(url_info)
619
654
  # TODO: this is only printed when we're showing progress. Maybe move this somewhere else.
620
- web_url = self.response.handle_metadata.web_url
621
655
  for warning in self.response.server_warnings:
622
656
  self.status_row.warning(warning)
623
657
  self.status_row.finish(
@@ -660,30 +694,3 @@ class FunctionCreationStatus:
660
694
  f"Custom domain for {method_definition.function_name} => [magenta underline]"
661
695
  f"{custom_domain.url}[/magenta underline]"
662
696
  )
663
-
664
-
665
- class IncludeSourceMode(enum.Enum):
666
- INCLUDE_NOTHING = False # can only be set in source, can't be set in config
667
- INCLUDE_MAIN_PACKAGE = True # Default behavior
668
-
669
-
670
- def get_include_source_mode(function_or_app_specific) -> IncludeSourceMode:
671
- """Which "automount" behavior should a function use
672
-
673
- function_or_app_specific: explicit value given in the @function or @cls decorator, in an App constructor, or None
674
-
675
- If function_or_app_specific is specified, validate and return the IncludeSourceMode
676
- If function_or_app_specific is None, infer it from config
677
- """
678
- if function_or_app_specific is not None:
679
- if not isinstance(function_or_app_specific, bool):
680
- raise ValueError(
681
- f"Invalid `include_source` value: {function_or_app_specific}. Use one of:\n"
682
- f"True - include function's package source\n"
683
- f"False - include no Python source (module expected to be present in Image)\n"
684
- )
685
-
686
- # explicitly set in app/function
687
- return IncludeSourceMode(function_or_app_specific)
688
-
689
- return IncludeSourceMode.INCLUDE_MAIN_PACKAGE