fal 0.12.2__py3-none-any.whl → 0.12.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (46) hide show
  1. fal/__init__.py +11 -2
  2. fal/api.py +130 -50
  3. fal/app.py +81 -134
  4. fal/apps.py +24 -6
  5. fal/auth/__init__.py +14 -2
  6. fal/auth/auth0.py +34 -25
  7. fal/cli.py +9 -4
  8. fal/env.py +0 -4
  9. fal/flags.py +1 -0
  10. fal/logging/__init__.py +0 -2
  11. fal/logging/trace.py +8 -1
  12. fal/sdk.py +33 -6
  13. fal/toolkit/__init__.py +16 -0
  14. fal/workflows.py +481 -0
  15. {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/METADATA +4 -7
  16. fal-0.12.4.dist-info/RECORD +88 -0
  17. openapi_fal_rest/__init__.py +1 -0
  18. openapi_fal_rest/api/workflows/__init__.py +0 -0
  19. openapi_fal_rest/api/workflows/create_or_update_workflow_workflows_post.py +172 -0
  20. openapi_fal_rest/api/workflows/delete_workflow_workflows_user_id_workflow_name_delete.py +175 -0
  21. openapi_fal_rest/api/workflows/execute_workflow_workflows_user_id_workflow_name_post.py +268 -0
  22. openapi_fal_rest/api/workflows/get_workflow_workflows_user_id_workflow_name_get.py +181 -0
  23. openapi_fal_rest/api/workflows/get_workflows_workflows_get.py +189 -0
  24. openapi_fal_rest/models/__init__.py +34 -0
  25. openapi_fal_rest/models/app_metadata_response_app_metadata.py +1 -0
  26. openapi_fal_rest/models/customer_details.py +15 -14
  27. openapi_fal_rest/models/execute_workflow_workflows_user_id_workflow_name_post_json_body_type_0.py +44 -0
  28. openapi_fal_rest/models/execute_workflow_workflows_user_id_workflow_name_post_response_200_type_0.py +44 -0
  29. openapi_fal_rest/models/page_workflow_item.py +107 -0
  30. openapi_fal_rest/models/typed_workflow.py +85 -0
  31. openapi_fal_rest/models/workflow_contents.py +98 -0
  32. openapi_fal_rest/models/workflow_contents_nodes.py +59 -0
  33. openapi_fal_rest/models/workflow_contents_output.py +44 -0
  34. openapi_fal_rest/models/workflow_detail.py +149 -0
  35. openapi_fal_rest/models/workflow_detail_contents_type_0.py +44 -0
  36. openapi_fal_rest/models/workflow_item.py +80 -0
  37. openapi_fal_rest/models/workflow_node.py +74 -0
  38. openapi_fal_rest/models/workflow_node_type.py +9 -0
  39. openapi_fal_rest/models/workflow_schema.py +73 -0
  40. openapi_fal_rest/models/workflow_schema_input.py +44 -0
  41. openapi_fal_rest/models/workflow_schema_output.py +44 -0
  42. openapi_fal_rest/types.py +1 -0
  43. fal/logging/datadog.py +0 -78
  44. fal-0.12.2.dist-info/RECORD +0 -67
  45. {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/WHEEL +0 -0
  46. {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/entry_points.txt +0 -0
fal/auth/__init__.py CHANGED
@@ -62,7 +62,7 @@ def _fetch_access_token() -> str:
62
62
 
63
63
  if access_token is not None:
64
64
  try:
65
- auth0.validate_access_token(access_token)
65
+ auth0.verify_access_token_expiration(access_token)
66
66
  return access_token
67
67
  except:
68
68
  # access_token expired, will refresh
@@ -85,10 +85,12 @@ def _fetch_access_token() -> str:
85
85
  class UserAccess:
86
86
  _access_token: str | None = field(repr=False, default=None)
87
87
  _user_info: dict | None = field(repr=False, default=None)
88
+ _exc: Exception | None = field(repr=False, default=None)
88
89
 
89
90
  def invalidate(self) -> None:
90
91
  self._access_token = None
91
92
  self._user_info = None
93
+ self._exc = None
92
94
 
93
95
  @property
94
96
  def info(self) -> dict:
@@ -99,8 +101,18 @@ class UserAccess:
99
101
 
100
102
  @property
101
103
  def access_token(self) -> str:
104
+ if self._exc is not None:
105
+ # We access this several times, so we want to raise the
106
+ # original exception instead of the newer exceptions we
107
+ # would get from the effects of the original exception.
108
+ raise self._exc
109
+
102
110
  if self._access_token is None:
103
- self._access_token = _fetch_access_token()
111
+ try:
112
+ self._access_token = _fetch_access_token()
113
+ except Exception as e:
114
+ self._exc = e
115
+ raise
104
116
 
105
117
  return self._access_token
106
118
 
fal/auth/auth0.py CHANGED
@@ -1,14 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import functools
3
4
  import time
4
5
  import warnings
5
6
 
6
7
  import click
7
- import requests
8
- from auth0.authentication.token_verifier import (
9
- AsymmetricSignatureVerifier,
10
- TokenVerifier,
11
- )
8
+ import httpx
12
9
 
13
10
  from fal.console import console
14
11
  from fal.console.icons import CHECK_ICON
@@ -54,7 +51,7 @@ def login() -> dict:
54
51
  "client_id": AUTH0_CLIENT_ID,
55
52
  "scope": AUTH0_SCOPE,
56
53
  }
57
- device_code_response = requests.post(
54
+ device_code_response = httpx.post(
58
55
  f"https://{AUTH0_DOMAIN}/oauth/device/code", data=device_code_payload
59
56
  )
60
57
 
@@ -81,7 +78,7 @@ def login() -> dict:
81
78
 
82
79
  with console.status("Waiting for confirmation...") as status:
83
80
  while True:
84
- token_response = requests.post(
81
+ token_response = httpx.post(
85
82
  f"https://{AUTH0_DOMAIN}/oauth/token", data=token_payload
86
83
  )
87
84
 
@@ -109,14 +106,12 @@ def refresh(token: str) -> dict:
109
106
  "refresh_token": token,
110
107
  }
111
108
 
112
- token_response = requests.post(
109
+ token_response = httpx.post(
113
110
  f"https://{AUTH0_DOMAIN}/oauth/token", data=token_payload
114
111
  )
115
112
 
116
113
  token_data = token_response.json()
117
114
  if token_response.status_code == 200:
118
- # DEBUG: print("Authenticated!")
119
-
120
115
  validate_id_token(token_data["id_token"])
121
116
 
122
117
  return token_data
@@ -130,7 +125,7 @@ def revoke(token: str):
130
125
  "token": token,
131
126
  }
132
127
 
133
- token_response = requests.post(
128
+ token_response = httpx.post(
134
129
  f"https://{AUTH0_DOMAIN}/oauth/revoke", data=token_payload
135
130
  )
136
131
 
@@ -142,7 +137,7 @@ def revoke(token: str):
142
137
 
143
138
 
144
139
  def get_user_info(bearer_token: str) -> dict:
145
- userinfo_response = requests.post(
140
+ userinfo_response = httpx.post(
146
141
  f"https://{AUTH0_DOMAIN}/userinfo",
147
142
  headers={"Authorization": bearer_token},
148
143
  )
@@ -153,30 +148,44 @@ def get_user_info(bearer_token: str) -> dict:
153
148
  return userinfo_response.json()
154
149
 
155
150
 
151
+ @functools.lru_cache
152
+ def build_jwk_client():
153
+ from jwt import PyJWKClient
154
+
155
+ return PyJWKClient(AUTH0_JWKS_URL, cache_keys=True)
156
+
157
+
156
158
  def validate_id_token(token: str):
157
159
  """
158
- Verify the token and its precedence.
159
- `id_token`s are intended for the client (this sdk) only.
160
- Never send one to another service.
161
-
162
- :param id_token:
160
+ id_token is intended for the client (this sdk) only. Never send one to another service.
163
161
  """
164
- sv = AsymmetricSignatureVerifier(AUTH0_JWKS_URL)
165
- tv = TokenVerifier(
166
- signature_verifier=sv,
162
+ from jwt import decode
163
+
164
+ jwk_client = build_jwk_client()
165
+
166
+ decode(
167
+ token,
168
+ key=jwk_client.get_signing_key_from_jwt(token).key,
169
+ algorithms=AUTH0_ALGORITHMS,
167
170
  issuer=AUTH0_ISSUER,
168
171
  audience=AUTH0_CLIENT_ID,
172
+ leeway=60, # 1 minute, to account for clock skew
173
+ options={
174
+ "verify_signature": True,
175
+ "verify_exp": True,
176
+ "verify_iat": True,
177
+ "verify_aud": True,
178
+ "verify_iss": True,
179
+ },
169
180
  )
170
- tv.verify(token)
171
-
172
181
 
173
- def validate_access_token(token: str):
174
- from datetime import timedelta
175
182
 
183
+ def verify_access_token_expiration(token: str):
176
184
  from jwt import decode
177
185
 
186
+ leeway = 60 * 30 * 60 # 30 minutes
178
187
  decode(
179
188
  token,
180
- leeway=timedelta(minutes=-30), # Mark as expired some time before it expires
189
+ leeway=-leeway, # negative to consider expired before actual expiration
181
190
  options={"verify_exp": True, "verify_signature": False},
182
191
  )
fal/cli.py CHANGED
@@ -29,7 +29,7 @@ PORT_ENVVAR = "FAL_PORT"
29
29
  DEBUG_ENABLED = False
30
30
 
31
31
 
32
- log = get_logger(__name__)
32
+ logger = get_logger(__name__)
33
33
 
34
34
 
35
35
  class ExecutionInfo:
@@ -63,13 +63,13 @@ class MainGroup(click.Group):
63
63
  qualified_name, attributes={"invocation_id": invocation_id}
64
64
  ):
65
65
  try:
66
- log.debug(
66
+ logger.debug(
67
67
  f"Executing command: {qualified_name}",
68
68
  command=qualified_name,
69
69
  )
70
70
  return super().invoke(ctx)
71
71
  except Exception as exception:
72
- log.error(exception)
72
+ logger.error(exception)
73
73
  if execution_info.debug:
74
74
  # Here we supress detailed errors on click lines because
75
75
  # they're mostly decorator calls, irrelevant to the dev's error tracing
@@ -123,7 +123,10 @@ class AliasCommand(click.Command):
123
123
  return self._wrapped.__getattribute__(__name)
124
124
 
125
125
 
126
- @click.group(cls=MainGroup)
126
+ @click.group(
127
+ cls=MainGroup,
128
+ context_settings={"allow_interspersed_args": True},
129
+ )
127
130
  @click.option(
128
131
  "--debug", is_flag=True, help="Enable detailed errors and verbose logging."
129
132
  )
@@ -468,6 +471,7 @@ def alias_list_runners(
468
471
  table.add_column("Runner ID")
469
472
  table.add_column("In Flight Requests")
470
473
  table.add_column("Expires in")
474
+ table.add_column("Uptime")
471
475
 
472
476
  for runner in runners:
473
477
  table.add_row(
@@ -478,6 +482,7 @@ def alias_list_runners(
478
482
  if not runner.expiration_countdown
479
483
  else f"{runner.expiration_countdown}s"
480
484
  ),
485
+ f"{runner.uptime} ({runner.uptime.total_seconds()}s)",
481
486
  )
482
487
 
483
488
  console.print(table)
fal/env.py CHANGED
@@ -1,7 +1,3 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  CLI_ENV = "prod"
4
-
5
- DATADOG_API_KEY = "pub4cd6a1c4763c93ad5af2740b2d931145"
6
- DATADOG_APP_KEY = "4981bae640864a409dcfaddd69c2a157523b1585"
7
-
fal/flags.py CHANGED
@@ -30,3 +30,4 @@ FAL_RUN_HOST = (
30
30
  )
31
31
 
32
32
  FORCE_SETUP = bool_envvar("FAL_FORCE_SETUP")
33
+ DONT_OPEN_LINKS = bool_envvar("FAL_DONT_OPEN_LINKS")
fal/logging/__init__.py CHANGED
@@ -5,7 +5,6 @@ from typing import Any
5
5
  import structlog
6
6
  from structlog.typing import EventDict, WrappedLogger
7
7
 
8
- from .datadog import submit_to_datadog
9
8
  from .style import LEVEL_STYLES
10
9
  from .user import add_user_id
11
10
 
@@ -45,7 +44,6 @@ structlog.configure(
45
44
  structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S"),
46
45
  structlog.processors.StackInfoRenderer(),
47
46
  add_user_id,
48
- submit_to_datadog,
49
47
  _console_log_output,
50
48
  ],
51
49
  wrapper_class=structlog.stdlib.BoundLogger,
fal/logging/trace.py CHANGED
@@ -7,6 +7,10 @@ from grpc_interceptor import ClientCallDetails, ClientInterceptor
7
7
  from opentelemetry import trace
8
8
  from opentelemetry.sdk.trace import TracerProvider
9
9
 
10
+ from fal.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
10
14
  provider = TracerProvider()
11
15
  # The line below can be used in dev to inspect opentelemetry result
12
16
  # It must be imported from opentelemetry.sdk.trace.export
@@ -41,6 +45,7 @@ class TraceContextInterceptor(ClientInterceptor):
41
45
  call_details: ClientCallDetails,
42
46
  ):
43
47
  current_span = get_current_span_context()
48
+
44
49
  if current_span is not None:
45
50
  new_details = call_details._replace(
46
51
  metadata=(
@@ -50,5 +55,7 @@ class TraceContextInterceptor(ClientInterceptor):
50
55
  ("x-fal-invocation-id", current_span.invocation_id),
51
56
  )
52
57
  )
53
- return method(request_or_iterator, new_details)
58
+ call_details = new_details
59
+
60
+ logger.debug("Calling %s", call_details)
54
61
  return method(request_or_iterator, call_details)
fal/sdk.py CHANGED
@@ -29,7 +29,7 @@ FAL_SERVERLESS_DEFAULT_KEEP_ALIVE = 10
29
29
  FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING = 1
30
30
  FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY = 0
31
31
 
32
- log = get_logger(__name__)
32
+ logger = get_logger(__name__)
33
33
 
34
34
  patch_dill()
35
35
 
@@ -39,8 +39,29 @@ class ServerCredentials:
39
39
  raise NotImplementedError
40
40
 
41
41
  @property
42
- def extra_options(self) -> list[tuple[str, str]]:
43
- return GRPC_OPTIONS
42
+ def base_options(self) -> dict[str, str | int]:
43
+ import json
44
+
45
+ grpc_ops: dict[str, str | int] = dict(GRPC_OPTIONS)
46
+ grpc_ops["grpc.enable_retries"] = 1
47
+ grpc_ops["grpc.service_config"] = json.dumps(
48
+ {
49
+ "methodConfig": [
50
+ {
51
+ "name": [{}],
52
+ "retryPolicy": {
53
+ "maxAttempts": 5,
54
+ "initialBackoff": "0.1s",
55
+ "maxBackoff": "5s",
56
+ "backoffMultiplier": 2,
57
+ "retryableStatusCodes": ["UNAVAILABLE"],
58
+ },
59
+ }
60
+ ]
61
+ }
62
+ )
63
+
64
+ return grpc_ops
44
65
 
45
66
 
46
67
  class LocalCredentials(ServerCredentials):
@@ -140,7 +161,7 @@ def get_default_credentials() -> Credentials:
140
161
 
141
162
  key_creds = key_credentials()
142
163
  if key_creds:
143
- log.debug("Using key credentials")
164
+ logger.debug("Using key credentials")
144
165
  return FalServerlessKeyCredentials(key_creds[0], key_creds[1])
145
166
  else:
146
167
  return AuthenticatedCredentials()
@@ -183,6 +204,7 @@ class RunnerInfo:
183
204
  runner_id: str
184
205
  in_flight_requests: int
185
206
  expiration_countdown: int
207
+ uptime: timedelta
186
208
 
187
209
 
188
210
  @dataclass
@@ -270,6 +292,7 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
270
292
  runner_id=message.runner_id,
271
293
  in_flight_requests=message.in_flight_requests,
272
294
  expiration_countdown=message.expiration_countdown,
295
+ uptime=timedelta(seconds=message.uptime),
273
296
  )
274
297
 
275
298
 
@@ -346,10 +369,14 @@ class FalServerlessConnection:
346
369
  if self._stub:
347
370
  return self._stub
348
371
 
349
- options = self.credentials.server_credentials.extra_options
372
+ options = self.credentials.server_credentials.base_options
350
373
  channel_creds = self.credentials.to_grpc()
351
374
  channel = self._stack.enter_context(
352
- grpc.secure_channel(self.hostname, channel_creds, options)
375
+ grpc.secure_channel(
376
+ target=self.hostname,
377
+ credentials=channel_creds,
378
+ options=list(options.items()),
379
+ )
353
380
  )
354
381
  channel = grpc.intercept_channel(channel, TraceContextInterceptor())
355
382
  self._stub = isolate_proto.IsolateControllerStub(channel)
fal/toolkit/__init__.py CHANGED
@@ -12,3 +12,19 @@ from fal.toolkit.utils import (
12
12
  download_file,
13
13
  download_model_weights,
14
14
  )
15
+
16
+ __all__ = [
17
+ "CompressedFile",
18
+ "File",
19
+ "Image",
20
+ "ImageSizeInput",
21
+ "get_image_size",
22
+ "mainify",
23
+ "optimize",
24
+ "FAL_MODEL_WEIGHTS_DIR",
25
+ "FAL_PERSISTENT_DIR",
26
+ "FAL_REPOSITORY_DIR",
27
+ "clone_repository",
28
+ "download_file",
29
+ "download_model_weights",
30
+ ]