modal 1.1.5.dev43__py3-none-any.whl → 1.1.5.dev45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

@@ -32,14 +32,14 @@ from modal._partial_function import (
32
32
  _PartialFunctionFlags,
33
33
  )
34
34
  from modal._serialization import deserialize, deserialize_params
35
- from modal._utils.async_utils import TaskContext, synchronizer
35
+ from modal._utils.async_utils import TaskContext, aclosing, synchronizer
36
36
  from modal._utils.function_utils import (
37
37
  callable_has_non_self_params,
38
38
  )
39
39
  from modal.app import App, _App
40
40
  from modal.client import Client, _Client
41
41
  from modal.config import logger
42
- from modal.exception import ExecutionError, InputCancellation, InvalidError
42
+ from modal.exception import ExecutionError, InputCancellation
43
43
  from modal.running_app import RunningApp, running_app_from_layout
44
44
  from modal_proto import api_pb2
45
45
 
@@ -184,17 +184,13 @@ def call_function(
184
184
  batch_wait_ms: int,
185
185
  ):
186
186
  async def run_input_async(io_context: IOContext) -> None:
187
- started_at = time.time()
188
187
  reset_context = execution_context._set_current_context_ids(
189
188
  io_context.input_ids, io_context.function_call_ids, io_context.attempt_tokens
190
189
  )
190
+ started_at = time.time()
191
191
  async with container_io_manager.handle_input_exception.aio(io_context, started_at):
192
- res = io_context.call_finalized_function()
193
192
  # TODO(erikbern): any exception below shouldn't be considered a user exception
194
193
  if io_context.finalized_function.is_generator:
195
- if not inspect.isasyncgen(res):
196
- raise InvalidError(f"Async generator function returned value of type {type(res)}")
197
-
198
194
  # Send up to this many outputs at a time.
199
195
  current_function_call_id = execution_context.current_function_call_id()
200
196
  assert current_function_call_id is not None # Set above.
@@ -204,33 +200,24 @@ def call_function(
204
200
  async with container_io_manager.generator_output_sender(
205
201
  current_function_call_id,
206
202
  current_attempt_token,
207
- io_context.finalized_function.data_format,
203
+ io_context._generator_output_format(),
208
204
  generator_queue,
209
205
  ):
210
206
  item_count = 0
211
- async for value in res:
212
- await container_io_manager._queue_put.aio(generator_queue, value)
213
- item_count += 1
207
+ async with aclosing(io_context.call_generator_async()) as gen:
208
+ async for value in gen:
209
+ await container_io_manager._queue_put.aio(generator_queue, value)
210
+ item_count += 1
214
211
 
215
- message = api_pb2.GeneratorDone(items_total=item_count)
216
- await container_io_manager.push_outputs.aio(
217
- io_context,
218
- started_at,
219
- message,
220
- api_pb2.DATA_FORMAT_GENERATOR_DONE,
212
+ await container_io_manager._send_outputs.aio(
213
+ started_at, io_context.output_items_generator_done(started_at, item_count)
221
214
  )
222
215
  else:
223
- if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
224
- raise InvalidError(
225
- f"Async (non-generator) function returned value of type {type(res)}"
226
- " You might need to use @app.function(..., is_generator=True)."
227
- )
228
- value = await res
216
+ value = await io_context.call_function_async()
229
217
  await container_io_manager.push_outputs.aio(
230
218
  io_context,
231
219
  started_at,
232
220
  value,
233
- io_context.finalized_function.data_format,
234
221
  )
235
222
  reset_context()
236
223
 
@@ -240,13 +227,9 @@ def call_function(
240
227
  io_context.input_ids, io_context.function_call_ids, io_context.attempt_tokens
241
228
  )
242
229
  with container_io_manager.handle_input_exception(io_context, started_at):
243
- res = io_context.call_finalized_function()
244
-
245
230
  # TODO(erikbern): any exception below shouldn't be considered a user exception
246
231
  if io_context.finalized_function.is_generator:
247
- if not inspect.isgenerator(res):
248
- raise InvalidError(f"Generator function returned value of type {type(res)}")
249
-
232
+ gen = io_context.call_generator_sync()
250
233
  # Send up to this many outputs at a time.
251
234
  current_function_call_id = execution_context.current_function_call_id()
252
235
  assert current_function_call_id is not None # Set above.
@@ -256,25 +239,20 @@ def call_function(
256
239
  with container_io_manager.generator_output_sender(
257
240
  current_function_call_id,
258
241
  current_attempt_token,
259
- io_context.finalized_function.data_format,
242
+ io_context._generator_output_format(),
260
243
  generator_queue,
261
244
  ):
262
245
  item_count = 0
263
- for value in res:
246
+ for value in gen:
264
247
  container_io_manager._queue_put(generator_queue, value)
265
248
  item_count += 1
266
249
 
267
- message = api_pb2.GeneratorDone(items_total=item_count)
268
- container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
269
- else:
270
- if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
271
- raise InvalidError(
272
- f"Sync (non-generator) function return value of type {type(res)}."
273
- " You might need to use @app.function(..., is_generator=True)."
274
- )
275
- container_io_manager.push_outputs(
276
- io_context, started_at, res, io_context.finalized_function.data_format
250
+ container_io_manager._send_outputs(
251
+ started_at, io_context.output_items_generator_done(started_at, item_count)
277
252
  )
253
+ else:
254
+ values = io_context.call_function_sync()
255
+ container_io_manager.push_outputs(io_context, started_at, values)
278
256
  reset_context()
279
257
 
280
258
  if container_io_manager.input_concurrency_enabled:
modal/_functions.py CHANGED
@@ -150,8 +150,7 @@ class _Invocation:
150
150
  args,
151
151
  kwargs,
152
152
  stub,
153
- max_object_size_bytes=function._max_object_size_bytes,
154
- method_name=function._use_method_name,
153
+ function=function,
155
154
  function_call_invocation_type=function_call_invocation_type,
156
155
  )
157
156
 
@@ -439,8 +438,7 @@ class _InputPlaneInvocation:
439
438
  args,
440
439
  kwargs,
441
440
  control_plane_stub,
442
- max_object_size_bytes=function._max_object_size_bytes,
443
- method_name=function._use_method_name,
441
+ function=function,
444
442
  )
445
443
 
446
444
  request = api_pb2.AttemptStartRequest(
@@ -698,6 +696,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
698
696
  experimental_options: Optional[dict[str, str]] = None,
699
697
  _experimental_proxy_ip: Optional[str] = None,
700
698
  _experimental_custom_scaling_factor: Optional[float] = None,
699
+ restrict_output: bool = False,
701
700
  ) -> "_Function":
702
701
  """mdmd:hidden
703
702
 
@@ -834,17 +833,23 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
834
833
  is_web_endpoint=is_web_endpoint,
835
834
  ignore_first_argument=True,
836
835
  )
836
+ if is_web_endpoint:
837
+ method_input_formats = [api_pb2.DATA_FORMAT_ASGI]
838
+ method_output_formats = [api_pb2.DATA_FORMAT_ASGI]
839
+ else:
840
+ method_input_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
841
+ if restrict_output:
842
+ method_output_formats = [api_pb2.DATA_FORMAT_CBOR]
843
+ else:
844
+ method_output_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
845
+
837
846
  method_definition = api_pb2.MethodDefinition(
838
847
  webhook_config=partial_function.params.webhook_config,
839
848
  function_type=function_type,
840
849
  function_name=function_name,
841
850
  function_schema=method_schema,
842
- supported_input_formats=[api_pb2.DATA_FORMAT_ASGI]
843
- if is_web_endpoint
844
- else [api_pb2.DATA_FORMAT_PICKLE],
845
- supported_output_formats=[api_pb2.DATA_FORMAT_ASGI]
846
- if is_web_endpoint
847
- else [api_pb2.DATA_FORMAT_PICKLE],
851
+ supported_input_formats=method_input_formats,
852
+ supported_output_formats=method_output_formats,
848
853
  )
849
854
  method_definitions[method_name] = method_definition
850
855
 
@@ -869,16 +874,18 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
869
874
  return deps
870
875
 
871
876
  if info.is_service_class():
872
- # classes don't have data formats themselves - methods do
877
+ # classes don't have data formats themselves - input/output formats are set per method above
873
878
  supported_input_formats = []
874
879
  supported_output_formats = []
875
880
  elif webhook_config is not None:
876
881
  supported_input_formats = [api_pb2.DATA_FORMAT_ASGI]
877
882
  supported_output_formats = [api_pb2.DATA_FORMAT_ASGI]
878
883
  else:
879
- # TODO: add CBOR support
880
- supported_input_formats = [api_pb2.DATA_FORMAT_PICKLE]
881
- supported_output_formats = [api_pb2.DATA_FORMAT_PICKLE]
884
+ supported_input_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
885
+ if restrict_output:
886
+ supported_output_formats = [api_pb2.DATA_FORMAT_CBOR]
887
+ else:
888
+ supported_output_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
882
889
 
883
890
  async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
884
891
  assert resolver.client and resolver.client.stub