fal 0.14.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/api.py CHANGED
@@ -23,22 +23,25 @@ from typing import (
23
23
  overload,
24
24
  )
25
25
 
26
- import dill
27
- import dill.detect
26
+ import cloudpickle
28
27
  import grpc
29
28
  import isolate
30
29
  import tblib
30
+ import uvicorn
31
31
  import yaml
32
32
  from fastapi import FastAPI
33
+ from fastapi import __version__ as fastapi_version
33
34
  from isolate.backends.common import active_python
34
35
  from isolate.backends.settings import DEFAULT_SETTINGS
35
36
  from isolate.connections import PythonIPC
36
37
  from packaging.requirements import Requirement
37
38
  from packaging.utils import canonicalize_name
39
+ from pydantic import __version__ as pydantic_version
38
40
  from typing_extensions import Concatenate, ParamSpec
39
41
 
40
42
  import fal.flags as flags
41
- from fal._serialization import add_serialization_listeners_for, patch_dill, patch_pickle
43
+ from fal._serialization import include_modules_from, patch_pickle
44
+ from fal.exceptions import FalServerlessException
42
45
  from fal.logging.isolate import IsolateLogPrinter
43
46
  from fal.sdk import (
44
47
  FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
@@ -52,24 +55,28 @@ from fal.sdk import (
52
55
  get_agent_credentials,
53
56
  get_default_credentials,
54
57
  )
55
- from fal.toolkit import mainify
56
58
 
57
59
  ArgsT = ParamSpec("ArgsT")
58
- ReturnT = TypeVar("ReturnT", covariant=True)
60
+ ReturnT = TypeVar("ReturnT", covariant=True) # noqa: PLC0105
59
61
 
60
62
  BasicConfig = Dict[str, Any]
61
63
  _UNSET = object()
62
64
 
63
- SERVE_REQUIREMENTS = ["fastapi==0.99.1", "uvicorn"]
65
+ SERVE_REQUIREMENTS = [
66
+ f"fastapi=={fastapi_version}",
67
+ f"pydantic=={pydantic_version}",
68
+ "uvicorn",
69
+ "starlette_exporter",
70
+ ]
64
71
 
65
72
 
66
73
  @dataclass
67
- class FalServerlessError(Exception):
74
+ class FalServerlessError(FalServerlessException):
68
75
  message: str
69
76
 
70
77
 
71
78
  @dataclass
72
- class InternalFalServerlessError(Exception):
79
+ class InternalFalServerlessError(FalServerlessException):
73
80
  message: str
74
81
 
75
82
 
@@ -106,8 +113,8 @@ class Host(Generic[ArgsT, ReturnT]):
106
113
  environment options."""
107
114
 
108
115
  options = Options()
109
- for key, value in config.items():
110
- key, value = cls.parse_key(key, value)
116
+ for item in config.items():
117
+ key, value = cls.parse_key(*item)
111
118
  if key in cls._SUPPORTED_KEYS:
112
119
  options.host[key] = value
113
120
  elif key in cls._GATEWAY_KEYS:
@@ -177,24 +184,12 @@ def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
177
184
  return wrapper
178
185
 
179
186
 
180
- @mainify
181
- class UserFunctionException(Exception):
187
+ class UserFunctionException(FalServerlessException):
182
188
  pass
183
189
 
184
190
 
185
- def match_class(obj, cls):
186
- # NOTE: Can't use isinstance because we are not using dill's byref setting when
187
- # loading/dumping objects in RPC, which means that our exceptions from remote
188
- # server are created by value and are actually a separate class that only looks
189
- # like original one.
190
- #
191
- # See https://github.com/fal-ai/fal/issues/142
192
- return type(obj).__name__ == cls.__name__
193
-
194
-
195
191
  def _prepare_partial_func(
196
192
  func: Callable[ArgsT, ReturnT],
197
- patch_func: Callable[[], None],
198
193
  *args: ArgsT.args,
199
194
  **kwargs: ArgsT.kwargs,
200
195
  ) -> Callable[ArgsT, ReturnT]:
@@ -204,6 +199,8 @@ def _prepare_partial_func(
204
199
  def wrapper(*remote_args: ArgsT.args, **remote_kwargs: ArgsT.kwargs) -> ReturnT:
205
200
  try:
206
201
  result = func(*remote_args, *args, **remote_kwargs, **kwargs)
202
+ except FalServerlessException:
203
+ raise
207
204
  except Exception as exc:
208
205
  tb = exc.__traceback__
209
206
  if tb is not None and tb.tb_next is not None:
@@ -214,37 +211,22 @@ def _prepare_partial_func(
214
211
  ) from exc.with_traceback(tb)
215
212
  finally:
216
213
  with suppress(Exception):
217
- patch_func()
214
+ patch_pickle()
218
215
  return result
219
216
 
220
217
  return wrapper
221
218
 
222
219
 
223
- def _prepare_local_partial_func(
224
- func: Callable[ArgsT, ReturnT],
225
- *args: ArgsT.args,
226
- **kwargs: ArgsT.kwargs,
227
- ) -> Callable[ArgsT, ReturnT]:
228
-
229
- return _prepare_partial_func(func, patch_pickle, *args, **kwargs)
230
-
231
-
232
- def _prepare_remote_partial_func(
233
- func: Callable[ArgsT, ReturnT],
234
- *args: ArgsT.args,
235
- **kwargs: ArgsT.kwargs,
236
- ) -> Callable[ArgsT, ReturnT]:
237
-
238
- return _prepare_partial_func(func, patch_dill, *args, **kwargs)
239
-
240
-
241
220
  @dataclass
242
221
  class LocalHost(Host):
243
222
  # The environment which provides the default set of
244
223
  # packages for isolate agent to run.
245
224
  _AGENT_ENVIRONMENT = isolate.prepare_environment(
246
225
  "virtualenv",
247
- requirements=[f"dill=={dill.__version__}", f"tblib=={tblib.__version__}"],
226
+ requirements=[
227
+ f"cloudpickle=={cloudpickle.__version__}",
228
+ f"tblib=={tblib.__version__}",
229
+ ],
248
230
  )
249
231
  _log_printer = IsolateLogPrinter(debug=flags.DEBUG)
250
232
 
@@ -255,7 +237,11 @@ class LocalHost(Host):
255
237
  args: tuple[Any, ...],
256
238
  kwargs: dict[str, Any],
257
239
  ) -> ReturnT:
258
- settings = replace(DEFAULT_SETTINGS, serialization_method="dill", log_hook=self._log_printer.print)
240
+ settings = replace(
241
+ DEFAULT_SETTINGS,
242
+ serialization_method="cloudpickle",
243
+ log_hook=self._log_printer.print,
244
+ )
259
245
  environment = isolate.prepare_environment(
260
246
  **options.environment,
261
247
  context=settings,
@@ -265,7 +251,7 @@ class LocalHost(Host):
265
251
  environment.create(),
266
252
  extra_inheritance_paths=[self._AGENT_ENVIRONMENT.create()],
267
253
  ) as connection:
268
- executable = _prepare_local_partial_func(func, *args, **kwargs)
254
+ executable = _prepare_partial_func(func, *args, **kwargs)
269
255
  return connection.run(executable)
270
256
 
271
257
 
@@ -311,6 +297,8 @@ def _handle_grpc_error():
311
297
  def find_missing_dependencies(
312
298
  func: Callable, env: dict
313
299
  ) -> Iterator[tuple[str, list[str]]]:
300
+ import dill
301
+
314
302
  if env["kind"] != "virtualenv":
315
303
  return
316
304
 
@@ -422,7 +410,7 @@ class FalServerlessHost(Host):
422
410
  min_concurrency=min_concurrency,
423
411
  )
424
412
 
425
- partial_func = _prepare_remote_partial_func(func)
413
+ partial_func = _prepare_partial_func(func)
426
414
 
427
415
  if metadata is None:
428
416
  metadata = {}
@@ -452,6 +440,8 @@ class FalServerlessHost(Host):
452
440
  if partial_result.result:
453
441
  return partial_result.result.application_id
454
442
 
443
+ return None
444
+
455
445
  @_handle_grpc_error()
456
446
  def run(
457
447
  self,
@@ -492,7 +482,7 @@ class FalServerlessHost(Host):
492
482
  return_value = _UNSET
493
483
  # Allow isolate provided arguments (such as setup function) to take
494
484
  # precedence over the ones provided by the user.
495
- partial_func = _prepare_remote_partial_func(func, *args, **kwargs)
485
+ partial_func = _prepare_partial_func(func, *args, **kwargs)
496
486
  for partial_result in self._connection.run(
497
487
  partial_func,
498
488
  environments,
@@ -555,7 +545,8 @@ _DEFAULT_HOST = FalServerlessHost()
555
545
  _SERVE_PORT = 8080
556
546
 
557
547
  # Overload @function to help users identify the correct signature.
558
- # NOTE: This is both in sync with host options and with environment configs from `isolate` package.
548
+ # NOTE: This is both in sync with host options and with environment configs from
549
+ # `isolate` package.
559
550
 
560
551
 
561
552
  ## virtualenv
@@ -766,7 +757,7 @@ def function( # type: ignore
766
757
  options = host.parse_options(kind=kind, **config)
767
758
 
768
759
  def wrapper(func: Callable[ArgsT, ReturnT]):
769
- add_serialization_listeners_for(func)
760
+ include_modules_from(func)
770
761
  proxy = IsolatedFunction(
771
762
  host=host,
772
763
  raw_func=func, # type: ignore
@@ -777,7 +768,6 @@ def function( # type: ignore
777
768
  return wrapper
778
769
 
779
770
 
780
- @mainify
781
771
  class FalFastAPI(FastAPI):
782
772
  """
783
773
  A subclass of FastAPI that adds some fal-specific functionality.
@@ -796,7 +786,8 @@ class FalFastAPI(FastAPI):
796
786
  """
797
787
  Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
798
788
 
799
- NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
789
+ NOTE: We rely on the fact that fastapi and Python dicts keep the order of
790
+ properties.
800
791
  """
801
792
 
802
793
  def mark_order(obj: dict[str, Any], key: str):
@@ -821,7 +812,6 @@ class FalFastAPI(FastAPI):
821
812
  return spec
822
813
 
823
814
 
824
- @mainify
825
815
  class RouteSignature(NamedTuple):
826
816
  path: str
827
817
  is_websocket: bool = False
@@ -832,7 +822,6 @@ class RouteSignature(NamedTuple):
832
822
  emit_timings: bool = False
833
823
 
834
824
 
835
- @mainify
836
825
  class BaseServable:
837
826
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
838
827
  raise NotImplementedError
@@ -851,6 +840,7 @@ class BaseServable:
851
840
  from fastapi import HTTPException, Request
852
841
  from fastapi.middleware.cors import CORSMiddleware
853
842
  from fastapi.responses import JSONResponse
843
+ from starlette_exporter import PrometheusMiddleware
854
844
 
855
845
  _app = FalFastAPI(lifespan=self.lifespan)
856
846
 
@@ -861,6 +851,13 @@ class BaseServable:
861
851
  allow_methods=("*"),
862
852
  allow_origins=("*"),
863
853
  )
854
+ _app.add_middleware(
855
+ PrometheusMiddleware,
856
+ prefix="http",
857
+ group_paths=True,
858
+ filter_unhandled_paths=True,
859
+ app_name="fal",
860
+ )
864
861
 
865
862
  self._add_extra_middlewares(_app)
866
863
 
@@ -906,13 +903,45 @@ class BaseServable:
906
903
  return self._build_app().openapi()
907
904
 
908
905
  def serve(self) -> None:
909
- import uvicorn
906
+ import asyncio
907
+
908
+ from starlette_exporter import handle_metrics
909
+ from uvicorn import Config
910
910
 
911
911
  app = self._build_app()
912
- uvicorn.run(app, host="0.0.0.0", port=8080)
912
+ server = Server(config=Config(app, host="0.0.0.0", port=8080))
913
+ metrics_app = FastAPI()
914
+ metrics_app.add_route("/metrics", handle_metrics)
915
+ metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))
916
+
917
+ async def _serve() -> None:
918
+ tasks = {
919
+ asyncio.create_task(server.serve()): server,
920
+ asyncio.create_task(metrics_server.serve()): metrics_server,
921
+ }
922
+
923
+ _, pending = await asyncio.wait(
924
+ tasks.keys(), return_when=asyncio.FIRST_COMPLETED,
925
+ )
926
+ if not pending:
927
+ return
928
+
929
+ # try graceful shutdown
930
+ for task in pending:
931
+ tasks[task].should_exit = True
932
+ _, pending = await asyncio.wait(pending, timeout=2)
933
+ if not pending:
934
+ return
935
+
936
+ for task in pending:
937
+ task.cancel()
938
+ await asyncio.wait(pending)
939
+
940
+ with suppress(asyncio.CancelledError):
941
+ asyncio.set_event_loop(asyncio.new_event_loop())
942
+ asyncio.run(_serve())
913
943
 
914
944
 
915
- @mainify
916
945
  class ServeWrapper(BaseServable):
917
946
  _func: Callable
918
947
 
@@ -982,18 +1011,20 @@ class IsolatedFunction(Generic[ArgsT, ReturnT]):
982
1011
  lines = []
983
1012
  for used_modules, references in pairs:
984
1013
  lines.append(
985
- f"\t- {used_modules!r} (accessed through {', '.join(map(repr, references))})"
1014
+ f"\t- {used_modules!r} "
1015
+ f"(accessed through {', '.join(map(repr, references))})"
986
1016
  )
987
1017
 
988
1018
  function_name = self.func.__name__
989
1019
  raise FalServerlessError(
990
- f"Couldn't deserialize your function on the remote server. \n\n[Hint] {function_name!r} "
991
- f"function uses the following modules which weren't present in the environment definition:\n"
1020
+ f"Couldn't deserialize your function on the remote server. \n\n"
1021
+ f"[Hint] {function_name!r} function uses the following modules "
1022
+ "which weren't present in the environment definition:\n"
992
1023
  + "\n".join(lines)
993
1024
  ) from None
994
1025
  except Exception as exc:
995
1026
  cause = exc.__cause__
996
- if self.reraise and match_class(exc, UserFunctionException) and cause:
1027
+ if self.reraise and isinstance(exc, UserFunctionException) and cause:
997
1028
  # re-raise original exception without our wrappers
998
1029
  raise cause
999
1030
  raise
@@ -1040,7 +1071,8 @@ class IsolatedFunction(Generic[ArgsT, ReturnT]):
1040
1071
  def func(self) -> Callable[ArgsT, ReturnT]:
1041
1072
  serve_mode = self.options.gateway.get("serve")
1042
1073
  if serve_mode:
1043
- # This type can be safely ignored because this case only happens when it is a ServedIsolatedFunction
1074
+ # This type can be safely ignored because this case only happens when it
1075
+ # is a ServedIsolatedFunction
1044
1076
  serve_func: Callable[[], None] = ServeWrapper(self.raw_func)
1045
1077
  return serve_func # type: ignore
1046
1078
  else:
@@ -1069,3 +1101,16 @@ class ServedIsolatedFunction(
1069
1101
  self, host: Host | None = None, *, serve: Literal[False], **config: Any
1070
1102
  ) -> IsolatedFunction[ArgsT, ReturnT]:
1071
1103
  ...
1104
+
1105
+
1106
+ class Server(uvicorn.Server):
1107
+ """Server is a uvicorn.Server that actually plays nicely with signals.
1108
+ By default, uvicorn's Server class overwrites the signal handler for SIGINT,
1109
+ swallowing the signal and preventing other tasks from cancelling.
1110
+ This class allows the task to be gracefully cancelled using asyncio's built-in task
1111
+ cancellation or with an event, like aiohttp.
1112
+ """
1113
+
1114
+ def install_signal_handlers(self) -> None:
1115
+ pass
1116
+
fal/app.py CHANGED
@@ -10,10 +10,10 @@ from typing import Any, Callable, ClassVar, TypeVar
10
10
  from fastapi import FastAPI
11
11
 
12
12
  import fal.api
13
- from fal._serialization import add_serialization_listeners_for
13
+ from fal._serialization import include_modules_from
14
14
  from fal.api import RouteSignature
15
15
  from fal.logging import get_logger
16
- from fal.toolkit import mainify
16
+ from fal.toolkit.file.providers import fal as fal_provider_module
17
17
 
18
18
  REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
19
19
 
@@ -29,7 +29,7 @@ async def _call_any_fn(fn, *args, **kwargs):
29
29
 
30
30
 
31
31
  def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
32
- add_serialization_listeners_for(cls)
32
+ include_modules_from(cls)
33
33
 
34
34
  def initialize_and_serve():
35
35
  app = cls()
@@ -64,7 +64,6 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
64
64
  return fn
65
65
 
66
66
 
67
- @mainify
68
67
  class App(fal.api.BaseServable):
69
68
  requirements: ClassVar[list[str]] = []
70
69
  machine_type: ClassVar[str] = "S"
@@ -126,12 +125,27 @@ class App(fal.api.BaseServable):
126
125
  )
127
126
  return response
128
127
 
128
+ @app.middleware("http")
129
+ async def set_global_object_preference(request, call_next):
130
+ response = await call_next(request)
131
+ try:
132
+ fal_provider_module.GLOBAL_LIFECYCLE_PREFERENCE = request.headers.get(
133
+ "X-Fal-Object-Lifecycle-Preference"
134
+ )
135
+ except Exception:
136
+ from fastapi.logger import logger
137
+
138
+ logger.exception(
139
+ "Failed set a global lifecycle preference %s",
140
+ self.__class__.__name__,
141
+ )
142
+ return response
143
+
129
144
  def provide_hints(self) -> list[str]:
130
145
  """Provide hints for routing the application."""
131
146
  raise NotImplementedError
132
147
 
133
148
 
134
- @mainify
135
149
  def endpoint(
136
150
  path: str, *, is_websocket: bool = False
137
151
  ) -> Callable[[EndpointT], EndpointT]:
@@ -241,14 +255,16 @@ def _fal_websocket_template(
241
255
  output = output.dict()
242
256
  else:
243
257
  raise TypeError(
244
- f"Expected a dict or pydantic model as output, got {type(output)}"
258
+ "Expected a dict or pydantic model as output, got "
259
+ f"{type(output)}"
245
260
  )
246
261
 
247
262
  messages = [
248
263
  msgpack.packb(output, use_bin_type=True),
249
264
  ]
250
265
  if route_signature.emit_timings:
251
- # We emit x-fal messages in JSON, no matter what the input/output format is.
266
+ # We emit x-fal messages in JSON, no matter what the
267
+ # input/output format is.
252
268
  timings = {
253
269
  "type": "x-fal-message",
254
270
  "action": "timings",
@@ -343,7 +359,6 @@ def _fal_websocket_template(
343
359
  _SENTINEL = object()
344
360
 
345
361
 
346
- @mainify
347
362
  def realtime(
348
363
  path: str,
349
364
  *,
@@ -359,7 +374,8 @@ def realtime(
359
374
 
360
375
  if hasattr(original_func, "route_signature"):
361
376
  raise ValueError(
362
- f"Can't set multiple routes for the same function: {original_func.__name__}"
377
+ "Can't set multiple routes for the same function: "
378
+ f"{original_func.__name__}"
363
379
  )
364
380
 
365
381
  if input_modal is _SENTINEL:
fal/auth/__init__.py CHANGED
@@ -9,10 +9,8 @@ from fal.auth import auth0, local
9
9
  from fal.console import console
10
10
  from fal.console.icons import CHECK_ICON
11
11
  from fal.exceptions.auth import UnauthenticatedException
12
- from fal.toolkit.mainify import mainify
13
12
 
14
13
 
15
- @mainify
16
14
  def key_credentials() -> tuple[str, str] | None:
17
15
  # Ignore key credentials when the user forces auth by user.
18
16
  if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
@@ -53,7 +51,8 @@ def _fetch_access_token() -> str:
53
51
  Load the refresh token, request a new access_token (refreshing the refresh token)
54
52
  and return the access_token.
55
53
  """
56
- # We need to lock both read and write access because we could be reading a soon invalid refresh_token
54
+ # We need to lock both read and write access because we could be reading a soon
55
+ # invalid refresh_token
57
56
  with local.lock_token():
58
57
  refresh_token, access_token = local.load_token()
59
58
 
fal/auth/auth0.py CHANGED
@@ -30,7 +30,8 @@ def _open_browser(url: str, code: str | None) -> None:
30
30
  maybe_open_browser_tab(url)
31
31
 
32
32
  console.print(
33
- "If browser didn't open automatically, on your computer or mobile device navigate to"
33
+ "If browser didn't open automatically, "
34
+ "on your computer or mobile device navigate to"
34
35
  )
35
36
  console.print(url)
36
37
 
@@ -155,7 +156,8 @@ def build_jwk_client():
155
156
 
156
157
  def validate_id_token(token: str):
157
158
  """
158
- id_token is intended for the client (this sdk) only. Never send one to another service.
159
+ id_token is intended for the client (this sdk) only.
160
+ Never send one to another service.
159
161
  """
160
162
  from jwt import decode
161
163
 
fal/auth/local.py CHANGED
@@ -62,7 +62,8 @@ def delete_token() -> None:
62
62
  @contextmanager
63
63
  def lock_token():
64
64
  """
65
- Lock the access to the token file to avoid race conditions when running multiple scripts at the same time.
65
+ Lock the access to the token file to avoid race conditions when running multiple
66
+ scripts at the same time.
66
67
  """
67
68
  lock_file = _check_dir_exist() / _LOCK_FILE
68
69
  with portalocker.utils.TemporaryFileLock(
fal/cli.py CHANGED
@@ -89,12 +89,13 @@ class MainGroup(RichGroup):
89
89
  except Exception as exception:
90
90
  logger.error(exception)
91
91
  if state.debug:
92
- # Here we supress detailed errors on click lines because
93
- # they're mostly decorator calls, irrelevant to the dev's error tracing
92
+ # Here we supress detailed errors on click lines because they're
93
+ # mostly decorator calls, irrelevant to the dev's error tracing
94
94
  console.print_exception(suppress=[click])
95
95
  console.print()
96
96
  console.print(
97
- f"The [markdown.code]invocation_id[/] for this operation is: [white]{state.invocation_id}[/]"
97
+ "The [markdown.code]invocation_id[/] for this operation is: "
98
+ f"[white]{state.invocation_id}[/]"
98
99
  )
99
100
  else:
100
101
  self._exception_handler.handle(exception)
@@ -207,7 +208,8 @@ def key_generate(client: sdk.FalServerlessClient, scope: str, alias: str | None)
207
208
  print(
208
209
  f"Generated key id and key secret, with the scope `{scope}`.\n"
209
210
  "This is the only time the secret will be visible.\n"
210
- "You will need to generate a new key pair if you lose access to this secret."
211
+ "You will need to generate a new key pair if you lose access to this "
212
+ "secret."
211
213
  )
212
214
  print(f"FAL_KEY='{result[1]}:{result[0]}'")
213
215
 
@@ -267,8 +269,8 @@ def load_function_from(
267
269
  raise api.FalServerlessError(f"Function '{function_name}' not found in module")
268
270
 
269
271
  # The module for the function is set to <run_path> when runpy is used, in which
270
- # case we want to manually include the packages it is defined in.
271
- _serialization.include_packages_from_path(file_path)
272
+ # case we want to manually include the package it is defined in.
273
+ _serialization.include_package_from_path(file_path)
272
274
 
273
275
  target = module[function_name]
274
276
  if isinstance(target, type) and issubclass(target, fal.App):
@@ -306,7 +308,8 @@ def register_application(
306
308
  gateway_options = isolated_function.options.gateway
307
309
  if "serve" not in gateway_options and "exposed_port" not in gateway_options:
308
310
  raise api.FalServerlessError(
309
- "One of `serve` or `exposed_port` options needs to be specified in the isolated annotation to register a function"
311
+ "One of `serve` or `exposed_port` options needs to be specified "
312
+ "in the isolated annotation to register a function"
310
313
  )
311
314
  elif (
312
315
  "exposed_port" in gateway_options
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from ._base import FalServerlessException # noqa: F401
3
4
  from .handlers import (
4
5
  BaseExceptionHandler,
5
- FalServerlessExceptionHandler,
6
6
  GrpcExceptionHandler,
7
7
  UserFunctionExceptionHandler,
8
8
  )
@@ -13,14 +13,14 @@ class ApplicationExceptionHandler:
13
13
 
14
14
  This exception handler is capable of handling, i.e. customize the output
15
15
  and add behavior, of any type of exception. Click handles all `ClickException`
16
- types by default, but prints the stack for other exception not wrapped in ClickException.
16
+ types by default, but prints the stack for other exception not wrapped in
17
+ ClickException.
17
18
 
18
19
  The handler also allows for central metrics and logging collection.
19
20
  """
20
21
 
21
22
  _handlers: list[BaseExceptionHandler] = [
22
23
  GrpcExceptionHandler(),
23
- FalServerlessExceptionHandler(),
24
24
  UserFunctionExceptionHandler(),
25
25
  ]
26
26
 
fal/exceptions/_base.py CHANGED
@@ -3,15 +3,4 @@ from __future__ import annotations
3
3
 
4
4
  class FalServerlessException(Exception):
5
5
  """Base exception type for fal Serverless related flows and APIs."""
6
-
7
- message: str
8
-
9
- hint: str | None
10
-
11
- def __init__(self, message: str, hint: str | None = None) -> None:
12
- self.message = message
13
- self.hint = hint
14
- super().__init__(message)
15
-
16
- def __str__(self) -> str:
17
- return self.message + (f"\nHint: {self.hint}" if self.hint else "")
6
+ pass
fal/exceptions/auth.py CHANGED
@@ -4,10 +4,8 @@ from ._base import FalServerlessException
4
4
 
5
5
 
6
6
  class UnauthenticatedException(FalServerlessException):
7
- """Exception that indicates that"""
8
-
9
7
  def __init__(self) -> None:
10
8
  super().__init__(
11
- message="You must be authenticated.",
12
- hint="Login via `fal auth login` or make sure to setup fal keys correctly.",
9
+ "You must be authenticated. "
10
+ "Login via `fal auth login` or make sure to setup fal keys correctly."
13
11
  )
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Generic, TypeVar
4
4
 
5
5
  from grpc import Call as RpcCall
6
- from rich.markdown import Markdown
7
6
 
8
7
  from fal.console import console
9
8
  from fal.console.icons import CROSS_ICON
@@ -11,9 +10,8 @@ from fal.console.icons import CROSS_ICON
11
10
  if TYPE_CHECKING:
12
11
  from fal.api import UserFunctionException
13
12
 
14
- from ._base import FalServerlessException
15
13
 
16
- ExceptionType = TypeVar("ExceptionType")
14
+ ExceptionType = TypeVar("ExceptionType", bound=BaseException)
17
15
 
18
16
 
19
17
  class BaseExceptionHandler(Generic[ExceptionType]):
@@ -23,20 +21,11 @@ class BaseExceptionHandler(Generic[ExceptionType]):
23
21
  return True
24
22
 
25
23
  def handle(self, exception: ExceptionType):
26
- console.print(str(exception))
27
-
28
-
29
- class FalServerlessExceptionHandler(BaseExceptionHandler[FalServerlessException]):
30
- """Handle fal Serverless exceptions"""
31
-
32
- def should_handle(self, exception: Exception) -> bool:
33
- return isinstance(exception, FalServerlessException)
34
-
35
- def handle(self, exception: FalServerlessException):
36
- console.print(f"{CROSS_ICON} {exception.message}")
37
- if exception.hint is not None:
38
- console.print(Markdown(f"**Hint:** {exception.hint}"))
39
- console.print()
24
+ msg = f"{CROSS_ICON} {str(exception)}"
25
+ cause = exception.__cause__
26
+ if cause is not None:
27
+ msg += f": {str(cause)}"
28
+ console.print(msg)
40
29
 
41
30
 
42
31
  class GrpcExceptionHandler(BaseExceptionHandler[RpcCall]):
@@ -51,9 +40,9 @@ class GrpcExceptionHandler(BaseExceptionHandler[RpcCall]):
51
40
 
52
41
  class UserFunctionExceptionHandler(BaseExceptionHandler["UserFunctionException"]):
53
42
  def should_handle(self, exception: Exception) -> bool:
54
- from fal.api import UserFunctionException, match_class
43
+ from fal.api import UserFunctionException
55
44
 
56
- return match_class(exception, UserFunctionException)
45
+ return isinstance(exception, UserFunctionException)
57
46
 
58
47
  def handle(self, exception: UserFunctionException):
59
48
  import rich