fal 0.15.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/cli/main.py ADDED
@@ -0,0 +1,82 @@
1
+ import argparse
2
+
3
+ import rich
4
+
5
+ from fal import __version__
6
+ from fal.console import console
7
+ from fal.console.icons import CROSS_ICON
8
+
9
+ from . import apps, auth, deploy, keys, run, secrets
10
+ from .debug import debugtools, get_debug_parser
11
+ from .parser import FalParser, FalParserExit
12
+
13
+
14
+ def _get_main_parser() -> argparse.ArgumentParser:
15
+ parents = [get_debug_parser()]
16
+ parser = FalParser(
17
+ prog="fal",
18
+ parents=parents,
19
+ )
20
+
21
+ parser.add_argument(
22
+ "--version",
23
+ action="version",
24
+ version=__version__,
25
+ help="Show fal version.",
26
+ )
27
+
28
+ subparsers = parser.add_subparsers(
29
+ title="Commands",
30
+ metavar="command",
31
+ required=True,
32
+ )
33
+
34
+ for cmd in [auth, apps, deploy, run, keys, secrets]:
35
+ cmd.add_parser(subparsers, parents)
36
+
37
+ return parser
38
+
39
+
40
+ def parse_args(argv=None):
41
+ parser = _get_main_parser()
42
+ args = parser.parse_args(argv)
43
+ args.console = console
44
+ args.parser = parser
45
+ return args
46
+
47
+
48
+ def main(argv=None) -> int:
49
+ import grpc
50
+
51
+ from fal.api import UserFunctionException
52
+
53
+ ret = 1
54
+ try:
55
+ args = parse_args(argv)
56
+
57
+ with debugtools(args):
58
+ ret = args.func(args)
59
+ except UserFunctionException as _exc:
60
+ cause = _exc.__cause__
61
+ exc: BaseException = cause or _exc
62
+ tb = rich.traceback.Traceback.from_exception(
63
+ type(exc),
64
+ exc,
65
+ exc.__traceback__,
66
+ )
67
+ console.print(tb)
68
+ console.print("Unhandled user exception")
69
+ except KeyboardInterrupt:
70
+ console.print("Aborted.")
71
+ except grpc.RpcError as exc:
72
+ console.print(exc.details())
73
+ except FalParserExit as exc:
74
+ ret = exc.status
75
+ except Exception as exc:
76
+ msg = f"{CROSS_ICON} {str(exc)}"
77
+ cause = exc.__cause__
78
+ if cause is not None:
79
+ msg += f": {str(cause)}"
80
+ console.print(msg)
81
+
82
+ return ret
fal/cli/parser.py ADDED
@@ -0,0 +1,74 @@
1
+ import argparse
2
+ import sys
3
+
4
+ import rich_argparse
5
+
6
+
7
+ class FalParserExit(Exception):
8
+ def __init__(self, status=0):
9
+ self.status = status
10
+
11
+
12
+ class RefAction(argparse.Action):
13
+ def __init__(self, *args, **kwargs):
14
+ kwargs.setdefault("default", (None, None))
15
+ super().__init__(*args, **kwargs)
16
+
17
+ def __call__(self, parser, args, values, option_string=None): # noqa: ARG002
18
+ if isinstance(values, tuple):
19
+ file_path, obj_path = values
20
+ elif values.find("::") > 1:
21
+ file_path, obj_path = values.split("::", 1)
22
+ else:
23
+ file_path, obj_path = values, None
24
+
25
+ setattr(args, self.dest, (file_path, obj_path))
26
+
27
+
28
+ class DictAction(argparse.Action):
29
+ def __init__(self, *args, **kwargs):
30
+ kwargs.setdefault("metavar", "<name>=<value>")
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def __call__(self, parser, args, values, option_string=None): # noqa: ARG002
34
+ d = getattr(args, self.dest) or {}
35
+
36
+ if isinstance(values, list):
37
+ kvs = values
38
+ else:
39
+ kvs = [values]
40
+
41
+ for kv in kvs:
42
+ parts = kv.split("=", 1)
43
+ if len(parts) != 2:
44
+ raise argparse.ArgumentError(
45
+ self,
46
+ f'Could not parse argument "{values}" as k1=v1 k2=v2 ... format',
47
+ )
48
+ key, value = parts
49
+ d[key] = value
50
+
51
+ setattr(args, self.dest, d)
52
+
53
+
54
+ class FalParser(argparse.ArgumentParser):
55
+ def __init__(self, *args, **kwargs):
56
+ kwargs.setdefault("formatter_class", rich_argparse.RawTextRichHelpFormatter)
57
+ super().__init__(*args, **kwargs)
58
+
59
+ def exit(self, status=0, message=None):
60
+ if message:
61
+ self._print_message(message, sys.stderr)
62
+ raise FalParserExit(status)
63
+
64
+
65
+ class FalClientParser(FalParser):
66
+ def __init__(self, *args, **kwargs):
67
+ from fal.flags import GRPC_HOST
68
+
69
+ super().__init__(*args, **kwargs)
70
+ self.add_argument(
71
+ "--host",
72
+ default=GRPC_HOST,
73
+ help=argparse.SUPPRESS,
74
+ )
fal/cli/run.py ADDED
@@ -0,0 +1,33 @@
1
+ from .parser import FalClientParser, RefAction
2
+
3
+
4
+ def _run(args):
5
+ from fal.api import FalServerlessHost
6
+ from fal.utils import load_function_from
7
+
8
+ host = FalServerlessHost(args.host)
9
+ isolated_function, _ = load_function_from(host, *args.func_ref)
10
+ # let our exc handlers handle UserFunctionException
11
+ isolated_function.reraise = False
12
+ isolated_function()
13
+
14
+
15
+ def add_parser(main_subparsers, parents):
16
+ run_help = "Run fal function."
17
+ epilog = (
18
+ "Examples:\n"
19
+ " fal run path/to/myfile.py::myfunc"
20
+ )
21
+ parser = main_subparsers.add_parser(
22
+ "run",
23
+ description=run_help,
24
+ parents=[*parents, FalClientParser(add_help=False)],
25
+ help=run_help,
26
+ epilog=epilog,
27
+ )
28
+ parser.add_argument(
29
+ "func_ref",
30
+ action=RefAction,
31
+ help="Function reference.",
32
+ )
33
+ parser.set_defaults(func=_run)
fal/cli/secrets.py ADDED
@@ -0,0 +1,107 @@
1
+
2
+ from .parser import DictAction, FalClientParser
3
+
4
+
5
+ def _set(args):
6
+ from fal.sdk import FalServerlessClient
7
+ client = FalServerlessClient(args.host)
8
+ with client.connect() as connection:
9
+ for name, value in args.secrets.items():
10
+ connection.set_secret(name, value)
11
+
12
+
13
+ def _add_set_parser(subparsers, parents):
14
+ set_help = "Set a secret."
15
+ epilog = (
16
+ "Examples:\n"
17
+ " fal secrets set HF_TOKEN=hf_***"
18
+ )
19
+
20
+ parser = subparsers.add_parser(
21
+ "set",
22
+ description=set_help,
23
+ help=set_help,
24
+ parents=parents,
25
+ epilog=epilog,
26
+ )
27
+ parser.add_argument(
28
+ "secrets",
29
+ metavar="NAME=VALUE",
30
+ nargs="+",
31
+ action=DictAction,
32
+ help="Secret NAME=VALUE pairs.",
33
+ )
34
+ parser.set_defaults(func=_set)
35
+
36
+
37
+ def _list(args):
38
+ from rich.table import Table
39
+
40
+ from fal.sdk import FalServerlessClient
41
+
42
+ table = Table()
43
+ table.add_column("Secret Name")
44
+ table.add_column("Created At")
45
+
46
+ client = FalServerlessClient(args.host)
47
+ with client.connect() as connection:
48
+ for secret in connection.list_secrets():
49
+ table.add_row(secret.name, str(secret.created_at))
50
+
51
+ args.console.print(table)
52
+
53
+
54
+ def _add_list_parser(subparsers, parents):
55
+ list_help = "List secrets."
56
+ parser = subparsers.add_parser(
57
+ "list",
58
+ description=list_help,
59
+ help=list_help,
60
+ parents=parents,
61
+ )
62
+ parser.set_defaults(func=_list)
63
+
64
+
65
+ def _unset(args):
66
+ from fal.sdk import FalServerlessClient
67
+ client = FalServerlessClient(args.host)
68
+ with client.connect() as connection:
69
+ connection.delete_secret(args.secret)
70
+
71
+
72
+ def _add_unset_parser(subparsers, parents):
73
+ unset_help = "Unset a secret."
74
+ parser = subparsers.add_parser(
75
+ "unset",
76
+ description=unset_help,
77
+ help=unset_help,
78
+ parents=parents,
79
+ )
80
+ parser.add_argument(
81
+ "secret",
82
+ metavar="NAME",
83
+ help="Secret's name.",
84
+ )
85
+ parser.set_defaults(func=_unset)
86
+
87
+
88
+ def add_parser(main_subparsers, parents):
89
+ secrets_help = "Manage fal secrets."
90
+ parser = main_subparsers.add_parser(
91
+ "secrets",
92
+ aliases=["secret"],
93
+ parents=parents,
94
+ description=secrets_help,
95
+ help=secrets_help,
96
+ )
97
+
98
+ subparsers = parser.add_subparsers(
99
+ title="Commands",
100
+ metavar="command",
101
+ required=True,
102
+ parser_class=FalClientParser,
103
+ )
104
+
105
+ _add_set_parser(subparsers, parents)
106
+ _add_list_parser(subparsers, parents)
107
+ _add_unset_parser(subparsers, parents)
@@ -1,31 +1,3 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from ._base import FalServerlessException # noqa: F401
4
- from .handlers import (
5
- BaseExceptionHandler,
6
- GrpcExceptionHandler,
7
- UserFunctionExceptionHandler,
8
- )
9
-
10
-
11
- class ApplicationExceptionHandler:
12
- """Handle exceptions top-level exceptions.
13
-
14
- This exception handler is capable of handling, i.e. customize the output
15
- and add behavior, of any type of exception. Click handles all `ClickException`
16
- types by default, but prints the stack for other exception not wrapped in ClickException.
17
-
18
- The handler also allows for central metrics and logging collection.
19
- """
20
-
21
- _handlers: list[BaseExceptionHandler] = [
22
- GrpcExceptionHandler(),
23
- UserFunctionExceptionHandler(),
24
- ]
25
-
26
- def handle(self, exception):
27
- match_handler: BaseExceptionHandler = next(
28
- (h for h in self._handlers if h.should_handle(exception)),
29
- BaseExceptionHandler(),
30
- )
31
- match_handler.handle(exception)
fal/flags.py CHANGED
@@ -18,8 +18,6 @@ GRPC_HOST = os.getenv("FAL_HOST", "api.alpha.fal.ai")
18
18
  if not TEST_MODE:
19
19
  assert GRPC_HOST.startswith("api"), "FAL_HOST must start with 'api'"
20
20
 
21
- GATEWAY_HOST = GRPC_HOST.replace("api", "gateway", 1)
22
-
23
21
  REST_HOST = GRPC_HOST.replace("api", "rest", 1)
24
22
  REST_SCHEME = "http" if TEST_MODE or AUTH_DISABLED else "https"
25
23
  REST_URL = f"{REST_SCHEME}://{REST_HOST}"
@@ -29,5 +27,4 @@ FAL_RUN_HOST = (
29
27
  GRPC_HOST.replace("api.", "", 1).replace("alpha.", "", 1).replace(".ai", ".run", 1)
30
28
  )
31
29
 
32
- FORCE_SETUP = bool_envvar("FAL_FORCE_SETUP")
33
30
  DONT_OPEN_LINKS = bool_envvar("FAL_DONT_OPEN_LINKS")
fal/logging/isolate.py CHANGED
@@ -34,10 +34,10 @@ class IsolateLogPrinter:
34
34
  timestamp = log.timestamp
35
35
  else:
36
36
  # Default value for timestamp if user has old `isolate` version.
37
- # Even if the controller version is controller by us, which means that the timestamp
38
- # is being sent in the gRPC message.
39
- # The `isolate` version users interpret that message with is out of our control.
40
- # So we need to handle this case.
37
+ # Even if the controller version is controller by us, which means that
38
+ # the timestamp is being sent in the gRPC message.
39
+ # The `isolate` version users interpret that message with is out of our
40
+ # control. So we need to handle this case.
41
41
  timestamp = datetime.now(timezone.utc)
42
42
 
43
43
  event: EventDict = {
fal/sdk.py CHANGED
@@ -8,17 +8,17 @@ from enum import Enum
8
8
  from typing import Any, Callable, Generic, Iterator, Literal, TypeVar
9
9
 
10
10
  import grpc
11
+ import isolate_proto
11
12
  from isolate.connections.common import is_agent
12
13
  from isolate.logs import Log
13
14
  from isolate.server.interface import from_grpc, to_serialized_object, to_struct
15
+ from isolate_proto.configuration import GRPC_OPTIONS
14
16
 
15
- import isolate_proto
16
17
  from fal import flags
17
18
  from fal._serialization import patch_pickle
18
19
  from fal.auth import USER, key_credentials
19
20
  from fal.logging import get_logger
20
21
  from fal.logging.trace import TraceContextInterceptor
21
- from isolate_proto.configuration import GRPC_OPTIONS
22
22
 
23
23
  ResultT = TypeVar("ResultT")
24
24
  InputT = TypeVar("InputT")
@@ -28,6 +28,7 @@ _DEFAULT_SERIALIZATION_METHOD = "cloudpickle"
28
28
  FAL_SERVERLESS_DEFAULT_KEEP_ALIVE = 10
29
29
  FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING = 1
30
30
  FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY = 0
31
+ ALIAS_AUTH_MODES = ["public", "private", "shared"]
31
32
 
32
33
  logger = get_logger(__name__)
33
34
 
@@ -187,6 +188,16 @@ class HostedRunStatus:
187
188
  state: HostedRunState
188
189
 
189
190
 
191
+ @dataclass
192
+ class ApplicationInfo:
193
+ application_id: str
194
+ keep_alive: int
195
+ max_concurrency: int
196
+ max_multiplexing: int
197
+ active_runners: int
198
+ min_concurrency: int
199
+
200
+
190
201
  @dataclass
191
202
  class AliasInfo:
192
203
  alias: str
@@ -263,6 +274,20 @@ class KeyScope(enum.Enum):
263
274
  raise ValueError(f"Unknown KeyScope: {proto}")
264
275
 
265
276
 
277
+ @from_grpc.register(isolate_proto.ApplicationInfo)
278
+ def _from_grpc_application_info(
279
+ message: isolate_proto.ApplicationInfo
280
+ ) -> ApplicationInfo:
281
+ return ApplicationInfo(
282
+ application_id=message.application_id,
283
+ keep_alive=message.keep_alive,
284
+ max_concurrency=message.max_concurrency,
285
+ max_multiplexing=message.max_multiplexing,
286
+ active_runners=message.active_runners,
287
+ min_concurrency=message.min_concurrency,
288
+ )
289
+
290
+
266
291
  @from_grpc.register(isolate_proto.AliasInfo)
267
292
  def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
268
293
  if message.auth_mode is isolate_proto.ApplicationAuthMode.PUBLIC:
@@ -496,6 +521,18 @@ class FalServerlessConnection:
496
521
  )
497
522
  return from_grpc(res.alias_info)
498
523
 
524
+ def list_applications(self) -> list[ApplicationInfo]:
525
+ request = isolate_proto.ListApplicationsRequest()
526
+ res: isolate_proto.ListApplicationsResult = self.stub.ListApplications(request)
527
+ return [from_grpc(app) for app in res.applications]
528
+
529
+ def delete_application(
530
+ self,
531
+ application_id: str,
532
+ ) -> None:
533
+ request = isolate_proto.DeleteApplicationRequest(application_id=application_id)
534
+ self.stub.DeleteApplication(request)
535
+
499
536
  def run(
500
537
  self,
501
538
  function: Callable[..., ResultT],
fal/sync.py CHANGED
@@ -31,7 +31,8 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
31
31
  body = upload_file_model.BodyUploadLocalFile(
32
32
  rest_types.File(
33
33
  payload=file_to_upload,
34
- # We need to set a file_name, otherwise the server errors processing the file
34
+ # We need to set a file_name, otherwise the server errors
35
+ # processing the file
35
36
  file_name=os.path.basename(source_path),
36
37
  )
37
38
  )
@@ -45,7 +46,9 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
45
46
 
46
47
  if response.status_code != 200:
47
48
  raise Exception(
48
- f"Failed to upload file. Server returned status code {response.status_code} and message {response.parsed}"
49
+ "Failed to upload file. "
50
+ "Server returned status code "
51
+ f"{response.status_code} and message {response.parsed}"
49
52
  )
50
53
 
51
54
 
@@ -94,7 +97,8 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
94
97
  local_dir_abs = os.path.expanduser(local_dir)
95
98
  if not os.path.isabs(remote_dir) or not remote_dir.startswith("/data"):
96
99
  raise ValueError(
97
- "'remote_dir' must be an absolute path starting with `/data`, e.g. '/data/sync/my_dir'"
100
+ "'remote_dir' must be an absolute path starting with `/data`, "
101
+ "e.g. '/data/sync/my_dir'"
98
102
  )
99
103
 
100
104
  remote_dir = remote_dir.replace("/data/", "", 1)
fal/toolkit/file/file.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import shutil
4
4
  from pathlib import Path
5
5
  from tempfile import NamedTemporaryFile, mkdtemp
6
- from typing import Callable, Optional, Any
6
+ from typing import Any, Callable, Optional
7
7
  from urllib.parse import urlparse
8
8
  from zipfile import ZipFile
9
9
 
@@ -13,8 +13,8 @@ import pydantic
13
13
  if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
14
14
  IS_PYDANTIC_V2 = False
15
15
  else:
16
- from pydantic_core import core_schema, CoreSchema
17
16
  from pydantic import GetCoreSchemaHandler
17
+ from pydantic_core import CoreSchema, core_schema
18
18
  IS_PYDANTIC_V2 = True
19
19
 
20
20
  from pydantic import BaseModel, Field
@@ -57,22 +57,28 @@ class File(BaseModel):
57
57
  description="The URL where the file can be downloaded from.",
58
58
  )
59
59
  content_type: Optional[str] = Field(
60
- None, description="The mime type of the file.",
60
+ None,
61
+ description="The mime type of the file.",
61
62
  examples=["image/png"],
62
63
  )
63
64
  file_name: Optional[str] = Field(
64
- None, description="The name of the file. It will be auto-generated if not provided.",
65
+ None,
66
+ description="The name of the file. It will be auto-generated if not provided.",
65
67
  examples=["z9RV14K95DvU.png"],
66
68
  )
67
69
  file_size: Optional[int] = Field(
68
70
  None, description="The size of the file in bytes.", examples=[4404019]
69
71
  )
70
72
  file_data: Optional[bytes] = Field(
71
- None, description="File data", exclude=True, repr=False,
73
+ None,
74
+ description="File data",
75
+ exclude=True,
76
+ repr=False,
72
77
  )
73
78
 
74
79
  # Pydantic custom validator for input type conversion
75
80
  if IS_PYDANTIC_V2:
81
+
76
82
  @classmethod
77
83
  def __get_pydantic_core_schema__(
78
84
  cls, source_type: Any, handler: GetCoreSchemaHandler
@@ -81,10 +87,12 @@ class File(BaseModel):
81
87
  cls.__convert_from_str,
82
88
  handler(source_type),
83
89
  )
90
+
84
91
  else:
92
+
85
93
  @classmethod
86
94
  def __get_validators__(cls):
87
- yield cls.__convert_from_str
95
+ yield cls.__convert_from_str
88
96
 
89
97
  @classmethod
90
98
  def __convert_from_str(cls, value: Any):
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
3
4
  import json
4
5
  import os
5
6
  from base64 import b64encode
@@ -14,6 +15,14 @@ from fal.toolkit.file.types import FileData, FileRepository
14
15
  _FAL_CDN = "https://fal.media"
15
16
 
16
17
 
18
+ @dataclass
19
+ class ObjectLifecyclePreference:
20
+ expriation_duration_seconds: int
21
+
22
+
23
+ GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference(expriation_duration_seconds=2)
24
+
25
+
17
26
  @dataclass
18
27
  class FalFileRepository(FileRepository):
19
28
  def save(self, file: FileData) -> str:
@@ -70,19 +79,27 @@ class FalFileRepository(FileRepository):
70
79
 
71
80
  @dataclass
72
81
  class InMemoryRepository(FileRepository):
73
- def save(self, file: FileData) -> str:
82
+ def save(
83
+ self,
84
+ file: FileData,
85
+ ) -> str:
74
86
  return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
75
87
 
76
88
 
77
89
  @dataclass
78
90
  class FalCDNFileRepository(FileRepository):
79
- def save(self, file: FileData) -> str:
91
+ def save(
92
+ self,
93
+ file: FileData,
94
+ ) -> str:
80
95
  headers = {
81
96
  **self.auth_headers,
82
97
  "Accept": "application/json",
83
98
  "Content-Type": file.content_type,
99
+ "X-Fal-Object-Lifecycle-Preference": json.dumps(
100
+ dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)
101
+ ),
84
102
  }
85
-
86
103
  url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"
87
104
  request = Request(url, headers=headers, method="POST", data=file.data)
88
105
  try:
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import io
4
4
  from tempfile import NamedTemporaryFile
5
- from typing import TYPE_CHECKING, Literal, Union, Optional
5
+ from typing import TYPE_CHECKING, Literal, Optional, Union
6
6
 
7
7
  from pydantic import BaseModel, Field
8
8
 
fal/toolkit/optimize.py CHANGED
@@ -4,7 +4,6 @@ import os
4
4
  import traceback
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
-
8
7
  if TYPE_CHECKING:
9
8
  import torch
10
9
 
@@ -17,7 +17,10 @@ FAL_MODEL_WEIGHTS_DIR = FAL_PERSISTENT_DIR / ".fal" / "model_weights"
17
17
 
18
18
  # TODO: how can we randomize the user agent to avoid being blocked?
19
19
  TEMP_HEADERS = {
20
- "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.8; rv:21.0) Gecko/20100101 Firefox/21.0"
20
+ "User-Agent": (
21
+ "Mozilla/5.0 (Macintosh; "
22
+ "Intel Mac OS X 10.8; rv:21.0) Gecko/20100101 Firefox/21.0"
23
+ ),
21
24
  }
22
25
 
23
26
 
@@ -94,8 +97,8 @@ def _file_content_length_matches(url: str, file_path: Path) -> bool:
94
97
  file_path: The local path to the file being compared.
95
98
 
96
99
  Returns:
97
- bool:`True` if the local file's content length matches the remote file's content length,
98
- `False` otherwise.
100
+ bool: `True` if the local file's content length matches the remote file's
101
+ content length, `False` otherwise.
99
102
  """
100
103
  local_file_content_length = file_path.stat().st_size
101
104
  remote_file_content_length = _get_remote_file_properties(url)[1]
fal/utils.py ADDED
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ import fal._serialization
4
+ from fal import App, wrap_app
5
+
6
+ from .api import FalServerlessError, FalServerlessHost, IsolatedFunction
7
+
8
+
9
+ def load_function_from(
10
+ host: FalServerlessHost,
11
+ file_path: str,
12
+ function_name: str | None = None,
13
+ ) -> tuple[IsolatedFunction, str | None]:
14
+ import runpy
15
+
16
+ module = runpy.run_path(file_path)
17
+ if function_name is None:
18
+ fal_objects = {
19
+ obj.app_name: obj_name
20
+ for obj_name, obj in module.items()
21
+ if isinstance(obj, type)
22
+ and issubclass(obj, fal.App)
23
+ and hasattr(obj, "app_name")
24
+ }
25
+ if len(fal_objects) == 0:
26
+ raise FalServerlessError("No fal.App found in the module.")
27
+ elif len(fal_objects) > 1:
28
+ raise FalServerlessError(
29
+ "Multiple fal.Apps found in the module. "
30
+ "Please specify the name of the app."
31
+ )
32
+
33
+ [(app_name, function_name)] = fal_objects.items()
34
+ else:
35
+ app_name = None
36
+
37
+ module = runpy.run_path(file_path)
38
+ if function_name not in module:
39
+ raise FalServerlessError(f"Function '{function_name}' not found in module")
40
+
41
+ # The module for the function is set to <run_path> when runpy is used, in which
42
+ # case we want to manually include the package it is defined in.
43
+ fal._serialization.include_package_from_path(file_path)
44
+
45
+ target = module[function_name]
46
+ if isinstance(target, type) and issubclass(target, App):
47
+ app_name = target.app_name
48
+ target = wrap_app(target, host=host)
49
+
50
+ if not isinstance(target, IsolatedFunction):
51
+ raise FalServerlessError(
52
+ f"Function '{function_name}' is not a fal.function or a fal.App"
53
+ )
54
+ return target, app_name
55
+