fal 1.49.1__py3-none-any.whl → 1.57.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fal/cli/runners.py CHANGED
@@ -1,20 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import argparse
4
+ import fcntl
3
5
  import json
6
+ import os
7
+ import signal
8
+ import struct
9
+ import sys
10
+ import termios
11
+ import tty
4
12
  from collections import deque
5
13
  from dataclasses import dataclass
6
14
  from datetime import datetime, timedelta, timezone
7
15
  from http import HTTPStatus
8
- from typing import Iterator, List
16
+ from queue import Empty, Queue
17
+ from threading import Thread
18
+ from typing import TYPE_CHECKING, Iterator, List
9
19
 
20
+ if TYPE_CHECKING:
21
+ from openapi_fal_rest.client import Client
22
+
23
+ import grpc
10
24
  import httpx
11
25
  from httpx_sse import connect_sse
12
26
  from rich.console import Console
13
27
  from structlog.typing import EventDict
14
28
 
15
29
  from fal.api.client import SyncServerlessClient
16
- from fal.rest_client import REST_CLIENT
17
- from fal.sdk import FalServerlessClient, RunnerInfo, RunnerState
30
+ from fal.sdk import RunnerInfo, RunnerState
18
31
 
19
32
  from .parser import FalClientParser, SinceAction, get_output_parser
20
33
 
@@ -95,12 +108,111 @@ def runners_requests_table(runners: list[RunnerInfo]):
95
108
  return table
96
109
 
97
110
 
111
+ def _get_tty_size():
112
+ """Get current terminal dimensions."""
113
+ try:
114
+ h, w = struct.unpack("HH", fcntl.ioctl(0, termios.TIOCGWINSZ, b"\0" * 4))[:2]
115
+ return h, w
116
+ except (OSError, ValueError):
117
+ return 24, 80 # Fallback to standard size
118
+
119
+
120
+ def _shell(args):
121
+ """Execute interactive shell in runner."""
122
+ import isolate_proto
123
+
124
+ client = SyncServerlessClient(host=args.host, team=args.team)
125
+ stub = client._create_host()._connection.stub
126
+ runner_id = args.id
127
+
128
+ # Setup terminal for raw mode
129
+ fd = sys.stdin.fileno()
130
+ old_settings = termios.tcgetattr(fd)
131
+ tty.setraw(fd)
132
+
133
+ # Message queue for stdin data and resize events
134
+ messages = Queue() # type: ignore
135
+ stop_flag = False
136
+
137
+ def handle_resize(*_):
138
+ messages.put(("resize", None))
139
+
140
+ signal.signal(signal.SIGWINCH, handle_resize)
141
+
142
+ def read_stdin():
143
+ """Read stdin in a background thread."""
144
+ nonlocal stop_flag
145
+ while not stop_flag:
146
+ try:
147
+ data = os.read(fd, 4096)
148
+ if not data:
149
+ break
150
+ messages.put(("data", data))
151
+ except OSError:
152
+ break
153
+
154
+ reader = Thread(target=read_stdin, daemon=True)
155
+ reader.start()
156
+
157
+ def stream_inputs():
158
+ """Generate input stream for gRPC."""
159
+ # Send initial message with runner_id
160
+ yield isolate_proto.ShellRunnerInput(runner_id=runner_id)
161
+
162
+ # Send terminal size
163
+ msg = isolate_proto.ShellRunnerInput()
164
+ h, w = _get_tty_size()
165
+ msg.tty_size.height = h
166
+ msg.tty_size.width = w
167
+ yield msg
168
+
169
+ # Stream stdin data and resize events
170
+ while True:
171
+ try:
172
+ msg_type, data = messages.get(timeout=0.1)
173
+ except Empty:
174
+ continue
175
+
176
+ if msg_type == "data":
177
+ yield isolate_proto.ShellRunnerInput(data=data)
178
+ elif msg_type == "resize":
179
+ msg = isolate_proto.ShellRunnerInput()
180
+ h, w = _get_tty_size()
181
+ msg.tty_size.height = h
182
+ msg.tty_size.width = w
183
+ yield msg
184
+
185
+ exit_code = 1
186
+ try:
187
+ for output in stub.ShellRunner(stream_inputs()):
188
+ if output.HasField("exit_code"):
189
+ exit_code = output.exit_code
190
+ break
191
+ if output.data:
192
+ sys.stdout.buffer.write(output.data)
193
+ sys.stdout.buffer.flush()
194
+ if output.close:
195
+ break
196
+ exit_code = exit_code or 0
197
+ except grpc.RpcError as exc:
198
+ args.console.print(f"\n[red]Connection error:[/] {exc.details()}")
199
+ except Exception as exc:
200
+ args.console.print(f"\n[red]Error:[/] {exc}")
201
+ finally:
202
+ stop_flag = True
203
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
204
+
205
+ return exit_code
206
+
207
+
208
+ def _stop(args):
209
+ client = SyncServerlessClient(host=args.host, team=args.team)
210
+ client.runners.stop(args.id)
211
+
212
+
98
213
  def _kill(args):
99
214
  client = SyncServerlessClient(host=args.host, team=args.team)
100
- with FalServerlessClient(
101
- client._grpc_host, client._credentials
102
- ).connect() as connection:
103
- connection.kill_runner(args.id)
215
+ client.runners.kill(args.id)
104
216
 
105
217
 
106
218
  def _list_json(args, runners: list[RunnerInfo]):
@@ -131,13 +243,24 @@ def _list(args):
131
243
  if args.state:
132
244
  states = set(args.state)
133
245
  if "all" not in states:
134
- runners = [r for r in runners if r.state.value in states]
246
+ runners = [
247
+ r
248
+ for r in runners
249
+ if r.state.value.lower() in states
250
+ or (
251
+ "terminated" in states and r.state.value.lower() == "dead"
252
+ ) # TODO for backwards compatibility. remove later
253
+ ]
135
254
 
136
255
  pending_runners = [
137
256
  runner for runner in runners if runner.state == RunnerState.PENDING
138
257
  ]
139
258
  setup_runners = [runner for runner in runners if runner.state == RunnerState.SETUP]
140
- dead_runners = [runner for runner in runners if runner.state == RunnerState.DEAD]
259
+ terminated_runners = [
260
+ runner
261
+ for runner in runners
262
+ if runner.state == RunnerState.DEAD or runner.state == RunnerState.TERMINATED
263
+ ]
141
264
  if args.output == "pretty":
142
265
  args.console.print(
143
266
  "Runners: "
@@ -145,7 +268,7 @@ def _list(args):
145
268
  len(runners)
146
269
  - len(pending_runners)
147
270
  - len(setup_runners)
148
- - len(dead_runners)
271
+ - len(terminated_runners)
149
272
  )
150
273
  )
151
274
  args.console.print(f"Runners Pending: {len(pending_runners)}")
@@ -161,6 +284,21 @@ def _list(args):
161
284
  raise AssertionError(f"Invalid output format: {args.output}")
162
285
 
163
286
 
287
+ def _add_stop_parser(subparsers, parents):
288
+ stop_help = "Stop a runner gracefully."
289
+ parser = subparsers.add_parser(
290
+ "stop",
291
+ description=stop_help,
292
+ help=stop_help,
293
+ parents=parents,
294
+ )
295
+ parser.add_argument(
296
+ "id",
297
+ help="Runner ID.",
298
+ )
299
+ parser.set_defaults(func=_stop)
300
+
301
+
164
302
  def _add_kill_parser(subparsers, parents):
165
303
  kill_help = "Kill a runner."
166
304
  parser = subparsers.add_parser(
@@ -190,14 +328,14 @@ def _add_list_parser(subparsers, parents):
190
328
  action=SinceAction,
191
329
  limit="1 day",
192
330
  help=(
193
- "Show dead runners since the given time. "
331
+ "Show terminated runners since the given time. "
194
332
  "Accepts 'now', relative like '30m', '1h', '1d', "
195
333
  "or an ISO timestamp. Max 24 hours."
196
334
  ),
197
335
  )
198
336
  parser.add_argument(
199
337
  "--state",
200
- choices=["all", "running", "pending", "setup", "dead"],
338
+ choices=["all", "running", "pending", "setup", "terminated"],
201
339
  nargs="+",
202
340
  default=None,
203
341
  help=("Filter by runner state(s). Choose one or more, or 'all'(default)."),
@@ -256,10 +394,10 @@ class RestRunnerInfo:
256
394
  ended_at: datetime | None
257
395
 
258
396
 
259
- def _get_runner_info(runner_id: str) -> RestRunnerInfo:
260
- headers = REST_CLIENT.get_headers()
397
+ def _get_runner_info(rest_client: Client, runner_id: str) -> RestRunnerInfo:
398
+ headers = rest_client.get_headers()
261
399
  with httpx.Client(
262
- base_url=REST_CLIENT.base_url, headers=headers, timeout=30
400
+ base_url=rest_client.base_url, headers=headers, timeout=30
263
401
  ) as client:
264
402
  resp = client.get(f"/runners/{runner_id}")
265
403
  if resp.status_code == HTTPStatus.NOT_FOUND:
@@ -293,16 +431,19 @@ def _get_runner_info(runner_id: str) -> RestRunnerInfo:
293
431
 
294
432
 
295
433
  def _stream_logs(
296
- base_params: dict[str, str], since: datetime | None, until: datetime | None
434
+ rest_client: Client,
435
+ base_params: dict[str, str],
436
+ since: datetime | None,
437
+ until: datetime | None,
297
438
  ) -> Iterator[dict]:
298
- headers = REST_CLIENT.get_headers()
439
+ headers = rest_client.get_headers()
299
440
  params: dict[str, str] = base_params.copy()
300
441
  if since is not None:
301
442
  params["since"] = _to_iso_naive(since)
302
443
  if until is not None:
303
444
  params["until"] = _to_iso_naive(until)
304
445
  with httpx.Client(
305
- base_url=REST_CLIENT.base_url,
446
+ base_url=rest_client.base_url,
306
447
  headers=headers,
307
448
  timeout=None,
308
449
  follow_redirects=True,
@@ -329,11 +470,14 @@ DEFAULT_PAGE_SIZE = 1000
329
470
 
330
471
 
331
472
  def _iter_logs(
332
- base_params: dict[str, str], start: datetime | None, end: datetime | None
473
+ rest_client: Client,
474
+ base_params: dict[str, str],
475
+ start: datetime | None,
476
+ end: datetime | None,
333
477
  ) -> Iterator[dict]:
334
- headers = REST_CLIENT.get_headers()
478
+ headers = rest_client.get_headers()
335
479
  with httpx.Client(
336
- base_url=REST_CLIENT.base_url,
480
+ base_url=rest_client.base_url,
337
481
  headers=headers,
338
482
  timeout=300,
339
483
  follow_redirects=True,
@@ -356,6 +500,7 @@ def _iter_logs(
356
500
 
357
501
 
358
502
  def _get_logs(
503
+ rest_client: Client,
359
504
  params: dict[str, str],
360
505
  since: datetime | None,
361
506
  until: datetime | None,
@@ -364,12 +509,12 @@ def _get_logs(
364
509
  oldest: bool = False,
365
510
  ) -> Iterator[dict]:
366
511
  if lines_count is None:
367
- yield from _iter_logs(params, since, until)
512
+ yield from _iter_logs(rest_client, params, since, until)
368
513
  return
369
514
 
370
515
  if oldest:
371
516
  produced = 0
372
- for log in _iter_logs(params, since, until):
517
+ for log in _iter_logs(rest_client, params, since, until):
373
518
  if produced >= lines_count:
374
519
  break
375
520
  produced += 1
@@ -378,7 +523,7 @@ def _get_logs(
378
523
 
379
524
  # newest tail: collect into a fixed-size deque, then yield
380
525
  tail: deque[dict] = deque(maxlen=lines_count)
381
- for log in _iter_logs(params, since, until):
526
+ for log in _iter_logs(rest_client, params, since, until):
382
527
  tail.append(log)
383
528
  for log in tail:
384
529
  yield log
@@ -421,7 +566,9 @@ def _logs(args):
421
566
  if args.search is not None:
422
567
  params["search"] = args.search
423
568
 
424
- runner_info = _get_runner_info(args.id)
569
+ client = SyncServerlessClient(host=args.host, team=args.team)
570
+ rest_client = client._create_rest_client()
571
+ runner_info = _get_runner_info(rest_client, args.id)
425
572
  follow: bool = args.follow
426
573
  since = args.since
427
574
  if follow:
@@ -467,9 +614,11 @@ def _logs(args):
467
614
  args.parser.error("Invalid -n|--lines value. Use an integer or +integer.")
468
615
 
469
616
  if follow:
470
- logs_gen = _stream_logs(params, since, until)
617
+ logs_gen = _stream_logs(rest_client, params, since, until)
471
618
  else:
472
- logs_gen = _get_logs(params, since, until, lines_count, oldest=lines_oldest)
619
+ logs_gen = _get_logs(
620
+ rest_client, params, since, until, lines_count, oldest=lines_oldest
621
+ )
473
622
 
474
623
  printer = LogPrinter(args.console)
475
624
 
@@ -546,6 +695,17 @@ def _add_logs_parser(subparsers, parents):
546
695
  parser.set_defaults(func=_logs)
547
696
 
548
697
 
698
+ def _add_shell_parser(subparsers, parents):
699
+ """Add hidden shell command parser."""
700
+ parser = subparsers.add_parser(
701
+ "shell",
702
+ help=argparse.SUPPRESS,
703
+ parents=parents,
704
+ )
705
+ parser.add_argument("id", help="Runner ID.")
706
+ parser.set_defaults(func=_shell)
707
+
708
+
549
709
  def add_parser(main_subparsers, parents):
550
710
  runners_help = "Manage fal runners."
551
711
  parser = main_subparsers.add_parser(
@@ -563,6 +723,8 @@ def add_parser(main_subparsers, parents):
563
723
  parser_class=FalClientParser,
564
724
  )
565
725
 
726
+ _add_stop_parser(subparsers, parents)
566
727
  _add_kill_parser(subparsers, parents)
567
728
  _add_list_parser(subparsers, parents)
568
729
  _add_logs_parser(subparsers, parents)
730
+ _add_shell_parser(subparsers, parents)
fal/cli/secrets.py CHANGED
@@ -1,12 +1,12 @@
1
- from ._utils import get_client
1
+ from fal.api.client import SyncServerlessClient
2
+
2
3
  from .parser import DictAction, FalClientParser
3
4
 
4
5
 
5
6
  def _set(args):
6
- client = get_client(args.host, args.team)
7
- with client.connect() as connection:
8
- for name, value in args.secrets.items():
9
- connection.set_secret(name, value)
7
+ client = SyncServerlessClient(host=args.host, team=args.team)
8
+ for name, value in args.secrets.items():
9
+ client.secrets.set(name, value)
10
10
 
11
11
 
12
12
  def _add_set_parser(subparsers, parents):
@@ -33,32 +33,31 @@ def _add_set_parser(subparsers, parents):
33
33
  def _list(args):
34
34
  import json
35
35
 
36
- client = get_client(args.host, args.team)
37
- with client.connect() as connection:
38
- secrets = list(connection.list_secrets())
36
+ client = SyncServerlessClient(host=args.host, team=args.team)
37
+ secrets = client.secrets.list()
39
38
 
40
- if args.output == "json":
41
- json_secrets = [
42
- {
43
- "name": secret.name,
44
- "created_at": str(secret.created_at),
45
- }
46
- for secret in secrets
47
- ]
48
- args.console.print(json.dumps({"secrets": json_secrets}))
49
- elif args.output == "pretty":
50
- from rich.table import Table
39
+ if args.output == "json":
40
+ json_secrets = [
41
+ {
42
+ "name": secret.name,
43
+ "created_at": str(secret.created_at),
44
+ }
45
+ for secret in secrets
46
+ ]
47
+ args.console.print(json.dumps({"secrets": json_secrets}))
48
+ elif args.output == "pretty":
49
+ from rich.table import Table
51
50
 
52
- table = Table()
53
- table.add_column("Secret Name")
54
- table.add_column("Created At")
51
+ table = Table()
52
+ table.add_column("Secret Name")
53
+ table.add_column("Created At")
55
54
 
56
- for secret in secrets:
57
- table.add_row(secret.name, str(secret.created_at))
55
+ for secret in secrets:
56
+ table.add_row(secret.name, str(secret.created_at))
58
57
 
59
- args.console.print(table)
60
- else:
61
- raise AssertionError(f"Invalid output format: {args.output}")
58
+ args.console.print(table)
59
+ else:
60
+ raise AssertionError(f"Invalid output format: {args.output}")
62
61
 
63
62
 
64
63
  def _add_list_parser(subparsers, parents):
@@ -75,9 +74,8 @@ def _add_list_parser(subparsers, parents):
75
74
 
76
75
 
77
76
  def _unset(args):
78
- client = get_client(args.host, args.team)
79
- with client.connect() as connection:
80
- connection.delete_secret(args.secret)
77
+ client = SyncServerlessClient(host=args.host, team=args.team)
78
+ client.secrets.unset(args.secret)
81
79
 
82
80
 
83
81
  def _add_unset_parser(subparsers, parents):
fal/files.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  import posixpath
5
5
  from concurrent.futures import ThreadPoolExecutor
6
6
  from functools import cached_property
7
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, Optional
8
8
 
9
9
  from fsspec import AbstractFileSystem
10
10
 
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
14
14
  USER_AGENT = "fal-sdk/1.14.0 (python)"
15
15
  MULTIPART_THRESHOLD = 10 * 1024 * 1024 # 10MB
16
16
  MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024 # 10MB
17
- MULTIPART_WORKERS = 2 # only 2 because our REST is currently struggling with more
17
+ MULTIPART_WORKERS = 10
18
18
 
19
19
 
20
20
  def _compute_md5(lpath, chunk_size=8192):
@@ -28,20 +28,43 @@ def _compute_md5(lpath, chunk_size=8192):
28
28
 
29
29
 
30
30
  class FalFileSystem(AbstractFileSystem):
31
+ def __init__(
32
+ self,
33
+ *,
34
+ host: Optional[str] = None,
35
+ team: Optional[str] = None,
36
+ profile: Optional[str] = None,
37
+ **kwargs,
38
+ ):
39
+ self.host = host
40
+ self.team = team
41
+ self.profile = profile
42
+ super().__init__(**kwargs)
43
+
31
44
  @cached_property
32
45
  def _client(self) -> "httpx.Client":
33
- from httpx import Client
46
+ from httpx import Client, Timeout
47
+
48
+ from fal.api.client import SyncServerlessClient
34
49
 
35
- from fal.flags import REST_URL
36
- from fal.sdk import get_default_credentials
50
+ client = SyncServerlessClient(
51
+ host=self.host,
52
+ team=self.team,
53
+ profile=self.profile,
54
+ )
37
55
 
38
- creds = get_default_credentials()
39
56
  return Client(
40
- base_url=REST_URL,
57
+ base_url=client._rest_url,
41
58
  headers={
42
- **creds.to_headers(),
59
+ **client._credentials.to_headers(),
43
60
  "User-Agent": USER_AGENT,
44
61
  },
62
+ timeout=Timeout(
63
+ connect=30,
64
+ read=4 * 60, # multipart complete can take time
65
+ write=5 * 60, # we could be uploading slowly
66
+ pool=30,
67
+ ),
45
68
  )
46
69
 
47
70
  def _request(self, method, path, **kwargs):
@@ -224,6 +247,7 @@ class FalFileSystem(AbstractFileSystem):
224
247
  "POST",
225
248
  f"/files/file/url/{abs_rpath}",
226
249
  json={"url": url},
250
+ timeout=10 * 60, # 10 minutes in seconds
227
251
  )
228
252
  self.dircache.clear()
229
253
 
fal/logging/__init__.py CHANGED
@@ -7,11 +7,6 @@ from structlog.typing import EventDict, WrappedLogger
7
7
 
8
8
  from .style import LEVEL_STYLES
9
9
 
10
- # Unfortunately structlog console processor does not support
11
- # more general theming as a public API. Consider a PR on the
12
- # structlog repo to add better support for it.
13
- structlog.dev._ColorfulStyles.bright = ""
14
-
15
10
 
16
11
  class DebugConsoleLogProcessor:
17
12
  """