fal 1.3.3__py3-none-any.whl → 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/_fal_version.py +2 -2
- fal/api.py +46 -14
- fal/app.py +157 -17
- fal/apps.py +138 -3
- fal/auth/__init__.py +50 -2
- fal/cli/_utils.py +8 -2
- fal/cli/apps.py +1 -1
- fal/cli/deploy.py +34 -8
- fal/cli/main.py +2 -2
- fal/cli/run.py +1 -1
- fal/cli/runners.py +44 -0
- fal/config.py +23 -0
- fal/container.py +1 -1
- fal/sdk.py +34 -9
- fal/toolkit/file/file.py +92 -19
- fal/toolkit/file/providers/fal.py +571 -83
- fal/toolkit/file/providers/gcp.py +8 -1
- fal/toolkit/file/providers/r2.py +8 -1
- fal/toolkit/file/providers/s3.py +80 -0
- fal/toolkit/file/types.py +11 -4
- fal/toolkit/image/__init__.py +3 -3
- fal/toolkit/image/image.py +25 -2
- fal/toolkit/types.py +140 -0
- fal/toolkit/utils/download_utils.py +4 -0
- fal/toolkit/utils/retry.py +45 -0
- fal/workflows.py +10 -4
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/METADATA +14 -9
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/RECORD +31 -26
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/WHEEL +1 -1
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/entry_points.txt +0 -0
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/top_level.txt +0 -0
fal/_fal_version.py
CHANGED
fal/api.py
CHANGED
|
@@ -76,6 +76,8 @@ SERVE_REQUIREMENTS = [
|
|
|
76
76
|
f"pydantic=={pydantic_version}",
|
|
77
77
|
"uvicorn",
|
|
78
78
|
"starlette_exporter",
|
|
79
|
+
"structlog",
|
|
80
|
+
"tomli",
|
|
79
81
|
]
|
|
80
82
|
|
|
81
83
|
|
|
@@ -170,6 +172,7 @@ class Host(Generic[ArgsT, ReturnT]):
|
|
|
170
172
|
application_name: str | None = None,
|
|
171
173
|
application_auth_mode: Literal["public", "shared", "private"] | None = None,
|
|
172
174
|
metadata: dict[str, Any] | None = None,
|
|
175
|
+
scale: bool = True,
|
|
173
176
|
) -> str | None:
|
|
174
177
|
"""Register the given function on the host for API call execution."""
|
|
175
178
|
raise NotImplementedError
|
|
@@ -389,12 +392,15 @@ class FalServerlessHost(Host):
|
|
|
389
392
|
_SUPPORTED_KEYS = frozenset(
|
|
390
393
|
{
|
|
391
394
|
"machine_type",
|
|
395
|
+
"machine_types",
|
|
396
|
+
"num_gpus",
|
|
392
397
|
"keep_alive",
|
|
393
398
|
"max_concurrency",
|
|
394
399
|
"min_concurrency",
|
|
395
400
|
"max_multiplexing",
|
|
396
401
|
"setup_function",
|
|
397
402
|
"metadata",
|
|
403
|
+
"request_timeout",
|
|
398
404
|
"_base_image",
|
|
399
405
|
"_scheduler",
|
|
400
406
|
"_scheduler_options",
|
|
@@ -426,25 +432,27 @@ class FalServerlessHost(Host):
|
|
|
426
432
|
application_auth_mode: Literal["public", "shared", "private"] | None = None,
|
|
427
433
|
metadata: dict[str, Any] | None = None,
|
|
428
434
|
deployment_strategy: Literal["recreate", "rolling"] = "recreate",
|
|
435
|
+
scale: bool = True,
|
|
429
436
|
) -> str | None:
|
|
430
437
|
environment_options = options.environment.copy()
|
|
431
438
|
environment_options.setdefault("python_version", active_python())
|
|
432
439
|
environments = [self._connection.define_environment(**environment_options)]
|
|
433
440
|
|
|
434
|
-
machine_type = options.host.get(
|
|
441
|
+
machine_type: list[str] | str = options.host.get(
|
|
435
442
|
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
|
|
436
443
|
)
|
|
437
444
|
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
|
|
438
|
-
max_concurrency = options.host.get("max_concurrency")
|
|
439
|
-
min_concurrency = options.host.get("min_concurrency")
|
|
440
|
-
max_multiplexing = options.host.get("max_multiplexing")
|
|
441
445
|
base_image = options.host.get("_base_image", None)
|
|
442
446
|
scheduler = options.host.get("_scheduler", None)
|
|
443
447
|
scheduler_options = options.host.get("_scheduler_options", None)
|
|
448
|
+
max_concurrency = options.host.get("max_concurrency")
|
|
449
|
+
min_concurrency = options.host.get("min_concurrency")
|
|
450
|
+
max_multiplexing = options.host.get("max_multiplexing")
|
|
444
451
|
exposed_port = options.get_exposed_port()
|
|
445
|
-
|
|
452
|
+
request_timeout = options.host.get("request_timeout")
|
|
446
453
|
machine_requirements = MachineRequirements(
|
|
447
|
-
|
|
454
|
+
machine_types=machine_type, # type: ignore
|
|
455
|
+
num_gpus=options.host.get("num_gpus"),
|
|
448
456
|
keep_alive=keep_alive,
|
|
449
457
|
base_image=base_image,
|
|
450
458
|
exposed_port=exposed_port,
|
|
@@ -453,6 +461,7 @@ class FalServerlessHost(Host):
|
|
|
453
461
|
max_multiplexing=max_multiplexing,
|
|
454
462
|
max_concurrency=max_concurrency,
|
|
455
463
|
min_concurrency=min_concurrency,
|
|
464
|
+
request_timeout=request_timeout,
|
|
456
465
|
)
|
|
457
466
|
|
|
458
467
|
partial_func = _prepare_partial_func(func)
|
|
@@ -479,6 +488,7 @@ class FalServerlessHost(Host):
|
|
|
479
488
|
machine_requirements=machine_requirements,
|
|
480
489
|
metadata=metadata,
|
|
481
490
|
deployment_strategy=deployment_strategy,
|
|
491
|
+
scale=scale,
|
|
482
492
|
):
|
|
483
493
|
for log in partial_result.logs:
|
|
484
494
|
self._log_printer.print(log)
|
|
@@ -501,7 +511,7 @@ class FalServerlessHost(Host):
|
|
|
501
511
|
environment_options.setdefault("python_version", active_python())
|
|
502
512
|
environments = [self._connection.define_environment(**environment_options)]
|
|
503
513
|
|
|
504
|
-
machine_type = options.host.get(
|
|
514
|
+
machine_type: list[str] | str = options.host.get(
|
|
505
515
|
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
|
|
506
516
|
)
|
|
507
517
|
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
|
|
@@ -513,9 +523,11 @@ class FalServerlessHost(Host):
|
|
|
513
523
|
scheduler_options = options.host.get("_scheduler_options", None)
|
|
514
524
|
exposed_port = options.get_exposed_port()
|
|
515
525
|
setup_function = options.host.get("setup_function", None)
|
|
526
|
+
request_timeout = options.host.get("request_timeout")
|
|
516
527
|
|
|
517
528
|
machine_requirements = MachineRequirements(
|
|
518
|
-
|
|
529
|
+
machine_types=machine_type, # type: ignore
|
|
530
|
+
num_gpus=options.host.get("num_gpus"),
|
|
519
531
|
keep_alive=keep_alive,
|
|
520
532
|
base_image=base_image,
|
|
521
533
|
exposed_port=exposed_port,
|
|
@@ -524,6 +536,7 @@ class FalServerlessHost(Host):
|
|
|
524
536
|
max_multiplexing=max_multiplexing,
|
|
525
537
|
max_concurrency=max_concurrency,
|
|
526
538
|
min_concurrency=min_concurrency,
|
|
539
|
+
request_timeout=request_timeout,
|
|
527
540
|
)
|
|
528
541
|
|
|
529
542
|
return_value = _UNSET
|
|
@@ -684,10 +697,12 @@ def function(
|
|
|
684
697
|
max_concurrency: int | None = None,
|
|
685
698
|
# FalServerlessHost options
|
|
686
699
|
metadata: dict[str, Any] | None = None,
|
|
687
|
-
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
700
|
+
machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
701
|
+
num_gpus: int | None = None,
|
|
688
702
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
689
703
|
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
|
|
690
704
|
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
|
|
705
|
+
request_timeout: int | None = None,
|
|
691
706
|
setup_function: Callable[..., None] | None = None,
|
|
692
707
|
_base_image: str | None = None,
|
|
693
708
|
_scheduler: str | None = None,
|
|
@@ -709,10 +724,12 @@ def function(
|
|
|
709
724
|
max_concurrency: int | None = None,
|
|
710
725
|
# FalServerlessHost options
|
|
711
726
|
metadata: dict[str, Any] | None = None,
|
|
712
|
-
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
727
|
+
machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
728
|
+
num_gpus: int | None = None,
|
|
713
729
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
714
730
|
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
|
|
715
731
|
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
|
|
732
|
+
request_timeout: int | None = None,
|
|
716
733
|
setup_function: Callable[..., None] | None = None,
|
|
717
734
|
_base_image: str | None = None,
|
|
718
735
|
_scheduler: str | None = None,
|
|
@@ -784,10 +801,12 @@ def function(
|
|
|
784
801
|
max_concurrency: int | None = None,
|
|
785
802
|
# FalServerlessHost options
|
|
786
803
|
metadata: dict[str, Any] | None = None,
|
|
787
|
-
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
804
|
+
machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
805
|
+
num_gpus: int | None = None,
|
|
788
806
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
789
807
|
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
|
|
790
808
|
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
|
|
809
|
+
request_timeout: int | None = None,
|
|
791
810
|
setup_function: Callable[..., None] | None = None,
|
|
792
811
|
_base_image: str | None = None,
|
|
793
812
|
_scheduler: str | None = None,
|
|
@@ -814,10 +833,12 @@ def function(
|
|
|
814
833
|
max_concurrency: int | None = None,
|
|
815
834
|
# FalServerlessHost options
|
|
816
835
|
metadata: dict[str, Any] | None = None,
|
|
817
|
-
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
836
|
+
machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
837
|
+
num_gpus: int | None = None,
|
|
818
838
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
819
839
|
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
|
|
820
840
|
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
|
|
841
|
+
request_timeout: int | None = None,
|
|
821
842
|
setup_function: Callable[..., None] | None = None,
|
|
822
843
|
_base_image: str | None = None,
|
|
823
844
|
_scheduler: str | None = None,
|
|
@@ -838,10 +859,12 @@ def function(
|
|
|
838
859
|
max_concurrency: int | None = None,
|
|
839
860
|
# FalServerlessHost options
|
|
840
861
|
metadata: dict[str, Any] | None = None,
|
|
841
|
-
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
862
|
+
machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
863
|
+
num_gpus: int | None = None,
|
|
842
864
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
843
865
|
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
|
|
844
866
|
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
|
|
867
|
+
request_timeout: int | None = None,
|
|
845
868
|
setup_function: Callable[..., None] | None = None,
|
|
846
869
|
_base_image: str | None = None,
|
|
847
870
|
_scheduler: str | None = None,
|
|
@@ -862,10 +885,12 @@ def function(
|
|
|
862
885
|
max_concurrency: int | None = None,
|
|
863
886
|
# FalServerlessHost options
|
|
864
887
|
metadata: dict[str, Any] | None = None,
|
|
865
|
-
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
888
|
+
machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
|
|
889
|
+
num_gpus: int | None = None,
|
|
866
890
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
867
891
|
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
|
|
868
892
|
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
|
|
893
|
+
request_timeout: int | None = None,
|
|
869
894
|
setup_function: Callable[..., None] | None = None,
|
|
870
895
|
_base_image: str | None = None,
|
|
871
896
|
_scheduler: str | None = None,
|
|
@@ -950,6 +975,8 @@ class RouteSignature(NamedTuple):
|
|
|
950
975
|
|
|
951
976
|
|
|
952
977
|
class BaseServable:
|
|
978
|
+
version: ClassVar[str] = "unknown"
|
|
979
|
+
|
|
953
980
|
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
954
981
|
raise NotImplementedError
|
|
955
982
|
|
|
@@ -1078,9 +1105,14 @@ class BaseServable:
|
|
|
1078
1105
|
def serve(self) -> None:
|
|
1079
1106
|
import asyncio
|
|
1080
1107
|
|
|
1108
|
+
from prometheus_client import Gauge
|
|
1081
1109
|
from starlette_exporter import handle_metrics
|
|
1082
1110
|
from uvicorn import Config
|
|
1083
1111
|
|
|
1112
|
+
# NOTE: this uses the global prometheus registry
|
|
1113
|
+
app_info = Gauge("fal_app_info", "Fal application information", ["version"])
|
|
1114
|
+
app_info.labels(version=self.version).set(1)
|
|
1115
|
+
|
|
1084
1116
|
app = self._build_app()
|
|
1085
1117
|
server = Server(
|
|
1086
1118
|
config=Config(app, host="0.0.0.0", port=8080, timeout_keep_alive=300)
|
fal/app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import inspect
|
|
4
5
|
import json
|
|
5
6
|
import os
|
|
@@ -8,20 +9,25 @@ import re
|
|
|
8
9
|
import threading
|
|
9
10
|
import time
|
|
10
11
|
import typing
|
|
11
|
-
from contextlib import asynccontextmanager, contextmanager
|
|
12
|
+
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
|
13
|
+
from dataclasses import dataclass
|
|
12
14
|
from typing import Any, Callable, ClassVar, Literal, TypeVar
|
|
13
15
|
|
|
16
|
+
import fastapi
|
|
17
|
+
import grpc.aio as async_grpc
|
|
14
18
|
import httpx
|
|
15
|
-
from
|
|
19
|
+
from isolate.server import definitions
|
|
16
20
|
|
|
17
21
|
import fal.api
|
|
18
22
|
from fal._serialization import include_modules_from
|
|
19
23
|
from fal.api import RouteSignature
|
|
20
|
-
from fal.exceptions import RequestCancelledException
|
|
24
|
+
from fal.exceptions import FalServerlessException, RequestCancelledException
|
|
21
25
|
from fal.logging import get_logger
|
|
22
|
-
from fal.toolkit.file
|
|
26
|
+
from fal.toolkit.file import request_lifecycle_preference
|
|
27
|
+
from fal.toolkit.file.providers.fal import LIFECYCLE_PREFERENCE
|
|
23
28
|
|
|
24
29
|
REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
|
|
30
|
+
REQUEST_ID_KEY = "x-fal-request-id"
|
|
25
31
|
|
|
26
32
|
EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
|
|
27
33
|
logger = get_logger(__name__)
|
|
@@ -34,6 +40,56 @@ async def _call_any_fn(fn, *args, **kwargs):
|
|
|
34
40
|
return fn(*args, **kwargs)
|
|
35
41
|
|
|
36
42
|
|
|
43
|
+
async def open_isolate_channel(address: str) -> async_grpc.Channel:
|
|
44
|
+
_stack = AsyncExitStack()
|
|
45
|
+
channel = await _stack.enter_async_context(
|
|
46
|
+
async_grpc.insecure_channel(
|
|
47
|
+
address,
|
|
48
|
+
options=[
|
|
49
|
+
("grpc.max_send_message_length", -1),
|
|
50
|
+
("grpc.max_receive_message_length", -1),
|
|
51
|
+
("grpc.min_reconnect_backoff_ms", 0),
|
|
52
|
+
("grpc.max_reconnect_backoff_ms", 100),
|
|
53
|
+
("grpc.dns_min_time_between_resolutions_ms", 100),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
channel_status = channel.channel_ready()
|
|
59
|
+
try:
|
|
60
|
+
await asyncio.wait_for(channel_status, timeout=1)
|
|
61
|
+
except asyncio.TimeoutError:
|
|
62
|
+
await _stack.aclose()
|
|
63
|
+
raise Exception("Timed out trying to connect to local isolate")
|
|
64
|
+
|
|
65
|
+
return channel
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
async def _set_logger_labels(
|
|
69
|
+
logger_labels: dict[str, str], channel: async_grpc.Channel
|
|
70
|
+
):
|
|
71
|
+
try:
|
|
72
|
+
import sys
|
|
73
|
+
|
|
74
|
+
# Flush any prints that were buffered before setting the logger labels
|
|
75
|
+
sys.stderr.flush()
|
|
76
|
+
sys.stdout.flush()
|
|
77
|
+
|
|
78
|
+
isolate = definitions.IsolateStub(channel)
|
|
79
|
+
isolate_request = definitions.SetMetadataRequest(
|
|
80
|
+
# TODO: when submit is shipped, get task_id from an env var
|
|
81
|
+
task_id="RUN",
|
|
82
|
+
metadata=definitions.TaskMetadata(logger_labels=logger_labels),
|
|
83
|
+
)
|
|
84
|
+
res = isolate.SetMetadata(isolate_request)
|
|
85
|
+
code = await res.code()
|
|
86
|
+
assert str(code) == "StatusCode.OK", str(code)
|
|
87
|
+
except BaseException:
|
|
88
|
+
# NOTE hiding this for now to not print on every request
|
|
89
|
+
# logger.debug("Failed to set logger labels", exc_info=True)
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
|
|
37
93
|
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
38
94
|
include_modules_from(cls)
|
|
39
95
|
|
|
@@ -60,6 +116,7 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
60
116
|
kind,
|
|
61
117
|
requirements=cls.requirements,
|
|
62
118
|
machine_type=cls.machine_type,
|
|
119
|
+
num_gpus=cls.num_gpus,
|
|
63
120
|
**cls.host_kwargs,
|
|
64
121
|
**kwargs,
|
|
65
122
|
metadata=metadata,
|
|
@@ -74,6 +131,12 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
74
131
|
return fn
|
|
75
132
|
|
|
76
133
|
|
|
134
|
+
@dataclass
|
|
135
|
+
class AppClientError(FalServerlessException):
|
|
136
|
+
message: str
|
|
137
|
+
status_code: int
|
|
138
|
+
|
|
139
|
+
|
|
77
140
|
class EndpointClient:
|
|
78
141
|
def __init__(self, url, endpoint, signature, timeout: int | None = None):
|
|
79
142
|
self.url = url
|
|
@@ -86,12 +149,19 @@ class EndpointClient:
|
|
|
86
149
|
|
|
87
150
|
def __call__(self, data):
|
|
88
151
|
with httpx.Client() as client:
|
|
152
|
+
url = self.url + self.signature.path
|
|
89
153
|
resp = client.post(
|
|
90
154
|
self.url + self.signature.path,
|
|
91
155
|
json=data.dict() if hasattr(data, "dict") else dict(data),
|
|
92
156
|
timeout=self.timeout,
|
|
93
157
|
)
|
|
94
|
-
resp.
|
|
158
|
+
if not resp.is_success:
|
|
159
|
+
# allow logs to be printed before raising the exception
|
|
160
|
+
time.sleep(1)
|
|
161
|
+
raise AppClientError(
|
|
162
|
+
f"Failed to POST {url}: {resp.status_code} {resp.text}",
|
|
163
|
+
status_code=resp.status_code,
|
|
164
|
+
)
|
|
95
165
|
resp_dict = resp.json()
|
|
96
166
|
|
|
97
167
|
if not self.return_type:
|
|
@@ -144,12 +214,16 @@ class AppClient:
|
|
|
144
214
|
with httpx.Client() as client:
|
|
145
215
|
retries = 100
|
|
146
216
|
for _ in range(retries):
|
|
147
|
-
|
|
217
|
+
url = info.url + "/health"
|
|
218
|
+
resp = client.get(url, timeout=60)
|
|
148
219
|
|
|
149
220
|
if resp.is_success:
|
|
150
221
|
break
|
|
151
222
|
elif resp.status_code not in (500, 404):
|
|
152
|
-
|
|
223
|
+
raise AppClientError(
|
|
224
|
+
f"Failed to GET {url}: {resp.status_code} {resp.text}",
|
|
225
|
+
status_code=resp.status_code,
|
|
226
|
+
)
|
|
153
227
|
time.sleep(0.1)
|
|
154
228
|
|
|
155
229
|
client = cls(app_cls, info.url)
|
|
@@ -174,9 +248,18 @@ def _to_fal_app_name(name: str) -> str:
|
|
|
174
248
|
return "-".join(part.lower() for part in PART_FINDER_RE.findall(name))
|
|
175
249
|
|
|
176
250
|
|
|
251
|
+
def _print_python_packages() -> None:
|
|
252
|
+
from importlib.metadata import distributions
|
|
253
|
+
|
|
254
|
+
packages = [f"{dist.metadata['Name']}=={dist.version}" for dist in distributions()]
|
|
255
|
+
|
|
256
|
+
print("[debug] Python packages installed:", ", ".join(packages))
|
|
257
|
+
|
|
258
|
+
|
|
177
259
|
class App(fal.api.BaseServable):
|
|
178
260
|
requirements: ClassVar[list[str]] = []
|
|
179
261
|
machine_type: ClassVar[str] = "S"
|
|
262
|
+
num_gpus: ClassVar[int | None] = None
|
|
180
263
|
host_kwargs: ClassVar[dict[str, Any]] = {
|
|
181
264
|
"_scheduler": "nomad",
|
|
182
265
|
"_scheduler_options": {
|
|
@@ -187,11 +270,18 @@ class App(fal.api.BaseServable):
|
|
|
187
270
|
}
|
|
188
271
|
app_name: ClassVar[str]
|
|
189
272
|
app_auth: ClassVar[Literal["private", "public", "shared"]] = "private"
|
|
273
|
+
request_timeout: ClassVar[int | None] = None
|
|
274
|
+
|
|
275
|
+
isolate_channel: async_grpc.Channel | None = None
|
|
190
276
|
|
|
191
277
|
def __init_subclass__(cls, **kwargs):
|
|
192
278
|
app_name = kwargs.pop("name", None) or _to_fal_app_name(cls.__name__)
|
|
193
279
|
parent_settings = getattr(cls, "host_kwargs", {})
|
|
194
280
|
cls.host_kwargs = {**parent_settings, **kwargs}
|
|
281
|
+
|
|
282
|
+
if cls.request_timeout is not None:
|
|
283
|
+
cls.host_kwargs["request_timeout"] = cls.request_timeout
|
|
284
|
+
|
|
195
285
|
cls.app_name = getattr(cls, "app_name", app_name)
|
|
196
286
|
|
|
197
287
|
if cls.__init__ is not App.__init__:
|
|
@@ -222,7 +312,8 @@ class App(fal.api.BaseServable):
|
|
|
222
312
|
}
|
|
223
313
|
|
|
224
314
|
@asynccontextmanager
|
|
225
|
-
async def lifespan(self, app: FastAPI):
|
|
315
|
+
async def lifespan(self, app: fastapi.FastAPI):
|
|
316
|
+
_print_python_packages()
|
|
226
317
|
await _call_any_fn(self.setup)
|
|
227
318
|
try:
|
|
228
319
|
yield
|
|
@@ -230,7 +321,7 @@ class App(fal.api.BaseServable):
|
|
|
230
321
|
await _call_any_fn(self.teardown)
|
|
231
322
|
|
|
232
323
|
def health(self):
|
|
233
|
-
return {}
|
|
324
|
+
return {"version": self.version}
|
|
234
325
|
|
|
235
326
|
def setup(self):
|
|
236
327
|
"""Setup the application before serving."""
|
|
@@ -238,7 +329,7 @@ class App(fal.api.BaseServable):
|
|
|
238
329
|
def teardown(self):
|
|
239
330
|
"""Teardown the application after serving."""
|
|
240
331
|
|
|
241
|
-
def _add_extra_middlewares(self, app: FastAPI):
|
|
332
|
+
def _add_extra_middlewares(self, app: fastapi.FastAPI):
|
|
242
333
|
@app.middleware("http")
|
|
243
334
|
async def provide_hints_headers(request, call_next):
|
|
244
335
|
response = await call_next(request)
|
|
@@ -259,11 +350,12 @@ class App(fal.api.BaseServable):
|
|
|
259
350
|
|
|
260
351
|
@app.middleware("http")
|
|
261
352
|
async def set_global_object_preference(request, call_next):
|
|
262
|
-
response = await call_next(request)
|
|
263
353
|
try:
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
354
|
+
preference_dict = request_lifecycle_preference(request)
|
|
355
|
+
if preference_dict is not None:
|
|
356
|
+
# This will not work properly for apps with multiplexing enabled
|
|
357
|
+
# we may mix up the preferences between requests
|
|
358
|
+
LIFECYCLE_PREFERENCE.set(preference_dict)
|
|
267
359
|
except Exception:
|
|
268
360
|
from fastapi.logger import logger
|
|
269
361
|
|
|
@@ -271,7 +363,52 @@ class App(fal.api.BaseServable):
|
|
|
271
363
|
"Failed set a global lifecycle preference %s",
|
|
272
364
|
self.__class__.__name__,
|
|
273
365
|
)
|
|
274
|
-
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
return await call_next(request)
|
|
369
|
+
finally:
|
|
370
|
+
# We may miss the global preference if there are operations
|
|
371
|
+
# being done in the background that go beyond the request
|
|
372
|
+
LIFECYCLE_PREFERENCE.set(None)
|
|
373
|
+
|
|
374
|
+
@app.middleware("http")
|
|
375
|
+
async def set_request_id(request, call_next):
|
|
376
|
+
# NOTE: Setting request_id is not supported for websocket/realtime endpoints
|
|
377
|
+
|
|
378
|
+
if self.isolate_channel is None:
|
|
379
|
+
grpc_port = os.environ.get("NOMAD_ALLOC_PORT_grpc")
|
|
380
|
+
self.isolate_channel = await open_isolate_channel(
|
|
381
|
+
f"localhost:{grpc_port}"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
request_id = request.headers.get(REQUEST_ID_KEY)
|
|
385
|
+
if request_id is None:
|
|
386
|
+
# Cut it short
|
|
387
|
+
return await call_next(request)
|
|
388
|
+
|
|
389
|
+
await _set_logger_labels(
|
|
390
|
+
{"fal_request_id": request_id}, channel=self.isolate_channel
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
async def _unset_at_end():
|
|
394
|
+
await _set_logger_labels({}, channel=self.isolate_channel) # type: ignore
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
response: fastapi.responses.Response = await call_next(request)
|
|
398
|
+
except BaseException:
|
|
399
|
+
await _unset_at_end()
|
|
400
|
+
raise
|
|
401
|
+
else:
|
|
402
|
+
# We need to wait for the entire response to be sent before
|
|
403
|
+
# we can set the logger labels back to the default.
|
|
404
|
+
background_tasks = fastapi.BackgroundTasks()
|
|
405
|
+
background_tasks.add_task(_unset_at_end)
|
|
406
|
+
if response.background:
|
|
407
|
+
# We normally have no background tasks, but we should handle it
|
|
408
|
+
background_tasks.add_task(response.background)
|
|
409
|
+
response.background = background_tasks
|
|
410
|
+
|
|
411
|
+
return response
|
|
275
412
|
|
|
276
413
|
@app.exception_handler(RequestCancelledException)
|
|
277
414
|
async def value_error_exception_handler(
|
|
@@ -284,7 +421,7 @@ class App(fal.api.BaseServable):
|
|
|
284
421
|
# the connection without receiving a response
|
|
285
422
|
return JSONResponse({"detail": str(exc)}, 499)
|
|
286
423
|
|
|
287
|
-
def _add_extra_routes(self, app: FastAPI):
|
|
424
|
+
def _add_extra_routes(self, app: fastapi.FastAPI):
|
|
288
425
|
@app.get("/health")
|
|
289
426
|
def health():
|
|
290
427
|
return self.health()
|
|
@@ -395,7 +532,10 @@ def _fal_websocket_template(
|
|
|
395
532
|
batch.append(next_input)
|
|
396
533
|
|
|
397
534
|
t0 = loop.time()
|
|
398
|
-
|
|
535
|
+
if inspect.iscoroutinefunction(func):
|
|
536
|
+
output = await func(self, *batch)
|
|
537
|
+
else:
|
|
538
|
+
output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
|
|
399
539
|
total_time = loop.time() - t0
|
|
400
540
|
if not isinstance(output, dict):
|
|
401
541
|
# Handle pydantic output modal
|