fxn 0.0.42__tar.gz → 0.0.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {fxn-0.0.42 → fxn-0.0.43}/PKG-INFO +1 -1
  2. {fxn-0.0.42 → fxn-0.0.43}/fxn/__init__.py +2 -1
  3. {fxn-0.0.42 → fxn-0.0.43}/fxn/cli/__init__.py +8 -6
  4. fxn-0.0.43/fxn/cli/compile.py +141 -0
  5. {fxn-0.0.42 → fxn-0.0.43}/fxn/cli/predictions.py +16 -13
  6. fxn-0.0.43/fxn/cli/predictors.py +49 -0
  7. fxn-0.0.43/fxn/client.py +119 -0
  8. {fxn-0.0.42/fxn/compile → fxn-0.0.43/fxn}/compile.py +4 -8
  9. fxn-0.0.43/fxn/logging.py +137 -0
  10. {fxn-0.0.42/fxn/compile → fxn-0.0.43/fxn}/sandbox.py +62 -33
  11. {fxn-0.0.42 → fxn-0.0.43}/fxn/services/prediction.py +27 -28
  12. {fxn-0.0.42 → fxn-0.0.43}/fxn/types/predictor.py +1 -1
  13. {fxn-0.0.42 → fxn-0.0.43}/fxn/version.py +1 -1
  14. {fxn-0.0.42 → fxn-0.0.43}/fxn.egg-info/PKG-INFO +1 -1
  15. {fxn-0.0.42 → fxn-0.0.43}/fxn.egg-info/SOURCES.txt +4 -4
  16. fxn-0.0.42/fxn/cli/predictors.py +0 -18
  17. fxn-0.0.42/fxn/client.py +0 -58
  18. fxn-0.0.42/fxn/compile/__init__.py +0 -7
  19. fxn-0.0.42/fxn/compile/signature.py +0 -183
  20. {fxn-0.0.42 → fxn-0.0.43}/LICENSE +0 -0
  21. {fxn-0.0.42 → fxn-0.0.43}/README.md +0 -0
  22. {fxn-0.0.42 → fxn-0.0.43}/fxn/beta/__init__.py +0 -0
  23. {fxn-0.0.42 → fxn-0.0.43}/fxn/beta/client.py +0 -0
  24. {fxn-0.0.42 → fxn-0.0.43}/fxn/beta/prediction.py +0 -0
  25. {fxn-0.0.42 → fxn-0.0.43}/fxn/beta/remote.py +0 -0
  26. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/__init__.py +0 -0
  27. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/configuration.py +0 -0
  28. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/fxnc.py +0 -0
  29. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/map.py +0 -0
  30. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/prediction.py +0 -0
  31. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/predictor.py +0 -0
  32. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/stream.py +0 -0
  33. {fxn-0.0.42 → fxn-0.0.43}/fxn/c/value.py +0 -0
  34. {fxn-0.0.42 → fxn-0.0.43}/fxn/cli/auth.py +0 -0
  35. {fxn-0.0.42 → fxn-0.0.43}/fxn/cli/misc.py +0 -0
  36. {fxn-0.0.42 → fxn-0.0.43}/fxn/function.py +0 -0
  37. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/__init__.py +0 -0
  38. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/linux/arm64/libFunction.so +0 -0
  39. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/linux/x86_64/libFunction.so +0 -0
  40. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/macos/arm64/Function.dylib +0 -0
  41. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/macos/x86_64/Function.dylib +0 -0
  42. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/windows/arm64/Function.dll +0 -0
  43. {fxn-0.0.42 → fxn-0.0.43}/fxn/lib/windows/x86_64/Function.dll +0 -0
  44. {fxn-0.0.42 → fxn-0.0.43}/fxn/services/__init__.py +0 -0
  45. {fxn-0.0.42 → fxn-0.0.43}/fxn/services/predictor.py +0 -0
  46. {fxn-0.0.42 → fxn-0.0.43}/fxn/services/user.py +0 -0
  47. {fxn-0.0.42 → fxn-0.0.43}/fxn/types/__init__.py +0 -0
  48. {fxn-0.0.42 → fxn-0.0.43}/fxn/types/dtype.py +0 -0
  49. {fxn-0.0.42 → fxn-0.0.43}/fxn/types/prediction.py +0 -0
  50. {fxn-0.0.42 → fxn-0.0.43}/fxn/types/user.py +0 -0
  51. {fxn-0.0.42 → fxn-0.0.43}/fxn.egg-info/dependency_links.txt +0 -0
  52. {fxn-0.0.42 → fxn-0.0.43}/fxn.egg-info/entry_points.txt +0 -0
  53. {fxn-0.0.42 → fxn-0.0.43}/fxn.egg-info/requires.txt +0 -0
  54. {fxn-0.0.42 → fxn-0.0.43}/fxn.egg-info/top_level.txt +0 -0
  55. {fxn-0.0.42 → fxn-0.0.43}/pyproject.toml +0 -0
  56. {fxn-0.0.42 → fxn-0.0.43}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: fxn
3
- Version: 0.0.42
3
+ Version: 0.0.43
4
4
  Summary: Run prediction functions locally in Python. Register at https://fxn.ai.
5
5
  Author-email: "NatML Inc." <hi@fxn.ai>
6
6
  License: Apache License
@@ -4,7 +4,8 @@
4
4
  #
5
5
 
6
6
  from .client import FunctionAPIError
7
- from .compile import *
7
+ from .compile import compile
8
8
  from .function import Function
9
+ from .sandbox import Sandbox
9
10
  from .types import *
10
11
  from .version import *
@@ -6,10 +6,10 @@
6
6
  from typer import Typer
7
7
 
8
8
  from .auth import app as auth_app
9
- #from .compile import compile_predictor
9
+ from .compile import compile_predictor
10
10
  from .misc import cli_options
11
11
  from .predictions import create_prediction
12
- from .predictors import retrieve_predictor
12
+ from .predictors import archive_predictor, delete_predictor, retrieve_predictor
13
13
  from ..version import __version__
14
14
 
15
15
  # Define CLI
@@ -33,11 +33,13 @@ app.command(
33
33
  help="Make a prediction.",
34
34
  context_settings={ "allow_extra_args": True, "ignore_unknown_options": True }
35
35
  )(create_prediction)
36
- # app.command(
37
- # name="compile",
38
- # help="Create a predictor by compiling a Python function."
39
- # )(compile_predictor)
36
+ app.command(
37
+ name="compile",
38
+ help="Create a predictor by compiling a Python function."
39
+ )(compile_predictor)
40
40
  app.command(name="retrieve", help="Retrieve a predictor.")(retrieve_predictor)
41
+ app.command(name="archive", help="Archive a predictor.")(archive_predictor)
42
+ app.command(name="delete", help="Delete a predictor.")(delete_predictor)
41
43
 
42
44
  # Run
43
45
  if __name__ == "__main__":
@@ -0,0 +1,141 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from asyncio import run as run_async
7
+ from importlib.util import module_from_spec, spec_from_file_location
8
+ from inspect import getmembers, getmodulename, isfunction
9
+ from pathlib import Path
10
+ from pydantic import BaseModel
11
+ from re import sub
12
+ from rich import print as print_rich
13
+ from rich.progress import SpinnerColumn, TextColumn
14
+ import sys
15
+ from typer import Argument, Option
16
+ from typing import Callable, Literal
17
+ from urllib.parse import urlparse, urlunparse
18
+
19
+ from ..compile import PredictorSpec
20
+ from ..function import Function
21
+ from ..sandbox import EntrypointCommand
22
+ from ..logging import CustomProgress, CustomProgressTask
23
+ from .auth import get_access_key
24
+
25
+ def compile_predictor (
26
+ path: str=Argument(..., help="Predictor path.")
27
+ ):
28
+ run_async(_compile_predictor_async(path))
29
+
30
+ async def _compile_predictor_async (path: str):
31
+ fxn = Function(get_access_key())
32
+ path: Path = Path(path).resolve()
33
+ with CustomProgress(
34
+ SpinnerColumn(spinner_name="dots", finished_text="[bold green]✔[/bold green]"),
35
+ TextColumn("[progress.description]{task.description}"),
36
+ ):
37
+ # Load
38
+ with CustomProgressTask(loading_text="Loading predictor...") as task:
39
+ func = _load_predictor_func(path)
40
+ entrypoint = EntrypointCommand(from_path=str(path), to_path="./", name=func.__name__)
41
+ spec: PredictorSpec = func.__predictor_spec
42
+ task.finish(f"Loaded prediction function: [bold cyan]{spec.tag}[/bold cyan]")
43
+ # Populate
44
+ sandbox = spec.sandbox
45
+ sandbox.commands.append(entrypoint)
46
+ with CustomProgressTask(loading_text="Uploading sandbox...", done_text="Uploaded sandbox"):
47
+ sandbox.populate(fxn=fxn)
48
+ # Compile
49
+ with CustomProgressTask(loading_text="Running codegen...", done_text="Completed codegen"):
50
+ with CustomProgressTask(loading_text="Creating predictor..."):
51
+ predictor = fxn.client.request(
52
+ method="POST",
53
+ path="/predictors",
54
+ body=spec.model_dump(mode="json"),
55
+ response_type=_Predictor
56
+ )
57
+ with ProgressLogQueue() as task_queue:
58
+ async for event in fxn.client.stream(
59
+ method="POST",
60
+ path=f"/predictors/{predictor.tag}/compile",
61
+ body={ },
62
+ response_type=_LogEvent | _ErrorEvent
63
+ ):
64
+ if isinstance(event, _LogEvent):
65
+ task_queue.push_log(event)
66
+ elif isinstance(event, _ErrorEvent):
67
+ task_queue.push_error(event)
68
+ raise RuntimeError(event.data.error)
69
+ predictor_url = _compute_predictor_url(fxn.client.api_url, spec.tag)
70
+ print_rich(f"\n[bold spring_green3]🎉 Predictor is now being compiled.[/bold spring_green3] Check it out at {predictor_url}")
71
+
72
+ def _load_predictor_func (path: str) -> Callable[...,object]:
73
+ if "" not in sys.path:
74
+ sys.path.insert(0, "")
75
+ path: Path = Path(path).resolve()
76
+ sys.path.insert(0, str(path.parent))
77
+ name = getmodulename(path)
78
+ spec = spec_from_file_location(name, path)
79
+ module = module_from_spec(spec)
80
+ sys.modules[name] = module
81
+ spec.loader.exec_module(module)
82
+ main_func = next(func for _, func in getmembers(module, isfunction) if hasattr(func, "__predictor_spec"))
83
+ return main_func
84
+
85
+ def _compute_predictor_url (api_url: str, tag: str) -> str:
86
+ parsed_url = urlparse(api_url)
87
+ hostname_parts = parsed_url.hostname.split(".")
88
+ if hostname_parts[0] == "api":
89
+ hostname_parts.pop(0)
90
+ hostname = ".".join(hostname_parts)
91
+ netloc = hostname if not parsed_url.port else f"{hostname}:{parsed_url.port}"
92
+ predictor_url = urlunparse(parsed_url._replace(netloc=netloc, path=f"{tag}"))
93
+ return predictor_url
94
+
95
+ class _Predictor (BaseModel):
96
+ tag: str
97
+
98
+ class _LogData (BaseModel):
99
+ message: str
100
+ level: int = 0
101
+
102
+ class _LogEvent (BaseModel):
103
+ event: Literal["log"]
104
+ data: _LogData
105
+
106
+ class _ErrorData (BaseModel):
107
+ error: str
108
+
109
+ class _ErrorEvent (BaseModel):
110
+ event: Literal["error"]
111
+ data: _ErrorData
112
+
113
+ class ProgressLogQueue:
114
+
115
+ def __init__ (self):
116
+ self.queue: list[tuple[int, CustomProgressTask]] = []
117
+
118
+ def push_log (self, event: _LogEvent):
119
+ while self.queue:
120
+ current_level, current_task = self.queue[-1]
121
+ if event.data.level > current_level:
122
+ break
123
+ current_task.__exit__(None, None, None)
124
+ self.queue.pop()
125
+ message = sub(r"`([^`]+)`", r"[hot_pink italic]\1[/hot_pink italic]", event.data.message)
126
+ task = CustomProgressTask(loading_text=message)
127
+ task.__enter__()
128
+ self.queue.append((event.data.level, task))
129
+
130
+ def push_error (self, error: _ErrorEvent):
131
+ while self.queue:
132
+ _, current_task = self.queue.pop()
133
+ current_task.__exit__(RuntimeError, None, None)
134
+
135
+ def __enter__ (self):
136
+ return self
137
+
138
+ def __exit__ (self, exc_type, exc_value, traceback):
139
+ while self.queue:
140
+ _, current_task = self.queue.pop()
141
+ current_task.__exit__(None, None, None)
@@ -9,11 +9,11 @@ from numpy import array_repr, ndarray
9
9
  from pathlib import Path, PurePath
10
10
  from PIL import Image
11
11
  from rich import print_json
12
- from rich.progress import Progress, SpinnerColumn, TextColumn
13
12
  from tempfile import mkstemp
14
13
  from typer import Argument, Context, Option
15
14
 
16
15
  from ..function import Function
16
+ from ..logging import CustomProgress, CustomProgressTask
17
17
  from ..types import Prediction
18
18
  from .auth import get_access_key
19
19
 
@@ -26,18 +26,21 @@ def create_prediction (
26
26
 
27
27
  async def _predict_async (tag: str, quiet: bool, context: Context):
28
28
  # Preload
29
- fxn = Function(get_access_key())
30
- fxn.predictions.create(tag, inputs={ }, verbose=not quiet)
31
- # Predict
32
- with Progress(
33
- SpinnerColumn(spinner_name="dots"),
34
- TextColumn("[progress.description]{task.description}"),
35
- transient=True
36
- ) as progress:
37
- progress.add_task(description="Running Function...", total=None)
38
- inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) }
39
- prediction = fxn.predictions.create(tag, inputs=inputs)
40
- _log_prediction(prediction)
29
+ with CustomProgress(transient=True, disable=quiet):
30
+ fxn = Function(get_access_key())
31
+ with CustomProgressTask(
32
+ loading_text="Preloading predictor...",
33
+ done_text="Preloaded predictor"
34
+ ):
35
+ fxn.predictions.create(tag, inputs={ })
36
+ with CustomProgressTask(loading_text="Making prediction..."):
37
+ inputs = { }
38
+ for i in range(0, len(context.args), 2):
39
+ name = context.args[i].replace("-", "")
40
+ value = _parse_value(context.args[i+1])
41
+ inputs[name] = value
42
+ prediction = fxn.predictions.create(tag, inputs=inputs)
43
+ _log_prediction(prediction)
41
44
 
42
45
  def _parse_value (value: str):
43
46
  """
@@ -0,0 +1,49 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from rich import print_json
7
+ from typer import Argument
8
+
9
+ from ..function import Function
10
+ from ..logging import CustomProgress, CustomProgressTask
11
+ from .auth import get_access_key
12
+
13
+ def retrieve_predictor (
14
+ tag: str=Argument(..., help="Predictor tag.")
15
+ ):
16
+ with CustomProgress(transient=True):
17
+ with CustomProgressTask(loading_text="Retrieving predictor..."):
18
+ fxn = Function(get_access_key())
19
+ predictor = fxn.predictors.retrieve(tag)
20
+ predictor = predictor.model_dump() if predictor else None
21
+ print_json(data=predictor)
22
+
23
+ def archive_predictor (
24
+ tag: str=Argument(..., help="Predictor tag.")
25
+ ):
26
+ with CustomProgress():
27
+ with CustomProgressTask(
28
+ loading_text="Archiving predictor...",
29
+ done_text=f"Archived predictor: [bold dark_orange]{tag}[/bold dark_orange]"
30
+ ):
31
+ fxn = Function(get_access_key())
32
+ fxn.client.request(
33
+ method="POST",
34
+ path=f"/predictors/{tag}/archive"
35
+ )
36
+
37
+ def delete_predictor (
38
+ tag: str=Argument(..., help="Predictor tag.")
39
+ ):
40
+ with CustomProgress():
41
+ with CustomProgressTask(
42
+ loading_text="Deleting predictor...",
43
+ done_text=f"Deleted predictor: [bold red]{tag}[/bold red]"
44
+ ):
45
+ fxn = Function(get_access_key())
46
+ fxn.client.request(
47
+ method="DELETE",
48
+ path=f"/predictors/{tag}"
49
+ )
@@ -0,0 +1,119 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from json import loads, JSONDecodeError
7
+ from pydantic import BaseModel, TypeAdapter
8
+ from requests import request
9
+ from typing import AsyncGenerator, Literal, Type, TypeVar
10
+
11
+ T = TypeVar("T", bound=BaseModel)
12
+
13
+ class FunctionClient:
14
+
15
+ def __init__(self, access_key: str, api_url: str | None) -> None:
16
+ self.access_key = access_key
17
+ self.api_url = api_url or "https://api.fxn.ai/v1"
18
+
19
+ def request (
20
+ self,
21
+ *,
22
+ method: Literal["GET", "POST", "PATCH", "DELETE"],
23
+ path: str,
24
+ body: dict[str, object]=None,
25
+ response_type: Type[T]=None
26
+ ) -> T:
27
+ """
28
+ Make a request to a REST endpoint.
29
+
30
+ Parameters:
31
+ method (str): Request method.
32
+ path (str): Endpoint path.
33
+ body (dict): Request JSON body.
34
+ response_type (Type): Response type.
35
+ """
36
+ response = request(
37
+ method=method,
38
+ url=f"{self.api_url}{path}",
39
+ json=body,
40
+ headers={ "Authorization": f"Bearer {self.access_key}" }
41
+ )
42
+ data = response.text
43
+ try:
44
+ data = response.json()
45
+ except JSONDecodeError:
46
+ pass
47
+ if response.ok:
48
+ return response_type(**data) if response_type is not None else None
49
+ else:
50
+ error = _ErrorResponse(**data).errors[0].message if isinstance(data, dict) else data
51
+ raise FunctionAPIError(error, response.status_code)
52
+
53
+ async def stream (
54
+ self,
55
+ *,
56
+ method: Literal["GET", "POST", "PATCH", "DELETE"],
57
+ path: str,
58
+ body: dict[str, object]=None,
59
+ response_type: Type[T]=None
60
+ ) -> AsyncGenerator[T, None]:
61
+ """
62
+ Make a request to a REST endpoint and consume the response as a server-sent events stream.
63
+
64
+ Parameters:
65
+ method (str): Request method.
66
+ path (str): Endpoint path.
67
+ body (dict): Request JSON body.
68
+ response_type (Type): Response type.
69
+ """
70
+ response = request(
71
+ method=method,
72
+ url=f"{self.api_url}{path}",
73
+ json=body,
74
+ headers={
75
+ "Accept": "text/event-stream",
76
+ "Authorization": f"Bearer {self.access_key}"
77
+ },
78
+ stream=True
79
+ )
80
+ event = None
81
+ data: str = ""
82
+ for line in response.iter_lines(decode_unicode=True):
83
+ if line is None:
84
+ break
85
+ line: str = line.strip()
86
+ if line:
87
+ if line.startswith("event:"):
88
+ event = line[len("event:"):].strip()
89
+ elif line.startswith("data:"):
90
+ line_data = line[len("data:"):].strip()
91
+ data = f"{data}\n{line_data}"
92
+ continue
93
+ if event is not None:
94
+ yield _parse_sse_event(event, data, response_type)
95
+ event = None
96
+ data = ""
97
+ if event or data:
98
+ yield _parse_sse_event(event, data, response_type)
99
+
100
+ class FunctionAPIError (Exception):
101
+
102
+ def __init__(self, message: str, status_code: int):
103
+ super().__init__(message)
104
+ self.message = message
105
+ self.status_code = status_code
106
+
107
+ def __str__(self):
108
+ return f"FunctionAPIError: {self.message} (Status Code: {self.status_code})"
109
+
110
+ class _APIError (BaseModel):
111
+ message: str
112
+
113
+ class _ErrorResponse (BaseModel):
114
+ errors: list[_APIError]
115
+
116
+ def _parse_sse_event (event: str, data: str, type: Type[T]=None) -> T:
117
+ result = { "event": event, "data": loads(data) }
118
+ result = TypeAdapter(type).validate_python(result) if type is not None else result
119
+ return result
@@ -5,12 +5,12 @@
5
5
 
6
6
  from collections.abc import Callable
7
7
  from functools import wraps
8
+ from inspect import isasyncgenfunction, iscoroutinefunction
8
9
  from pathlib import Path
9
10
  from pydantic import BaseModel, Field
10
11
 
11
- from ..types import AccessMode, Signature
12
12
  from .sandbox import Sandbox
13
- from .signature import get_function_type, infer_function_signature, FunctionType
13
+ from .types import AccessMode
14
14
 
15
15
  class PredictorSpec (BaseModel):
16
16
  """
@@ -20,7 +20,6 @@ class PredictorSpec (BaseModel):
20
20
  description: str = Field(description="Predictor description. MUST be less than 100 characters long.", min_length=4, max_length=100)
21
21
  sandbox: Sandbox = Field(description="Sandbox to compile the function.")
22
22
  access: AccessMode = Field(description="Predictor access.")
23
- signature: Signature = Field(description="Predictor signature.")
24
23
  card: str | None = Field(default=None, description="Predictor card (markdown).")
25
24
  media: str | None = Field(default=None, description="Predictor media URL.")
26
25
  license: str | None = Field(default=None, description="Predictor license URL. This is required for public predictors.")
@@ -51,11 +50,9 @@ def compile (
51
50
  # Check type
52
51
  if not callable(func):
53
52
  raise TypeError("Cannot compile non-function objects")
54
- func_type = get_function_type(func)
55
- if func_type not in { FunctionType.Function, FunctionType.Generator }:
56
- raise TypeError(f"Function '{func.__name__}' must be a regular function or generator")
53
+ if isasyncgenfunction(func) or iscoroutinefunction(func):
54
+ raise TypeError(f"Function '{func.__name__}' must be a regular function or generator")
57
55
  # Gather metadata
58
- signature = infer_function_signature(func) # throws
59
56
  if isinstance(card, Path):
60
57
  with open(card_content, "r") as f:
61
58
  card_content = f.read()
@@ -66,7 +63,6 @@ def compile (
66
63
  description=description,
67
64
  sandbox=sandbox if sandbox is not None else Sandbox(),
68
65
  access=access,
69
- signature=signature,
70
66
  card=card_content,
71
67
  media=None, # INCOMPLETE
72
68
  license=license
@@ -0,0 +1,137 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from contextvars import ContextVar
7
+ from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
8
+ from typing import Literal
9
+
10
+ current_progress = ContextVar("current_progress", default=None)
11
+ progress_task_stack = ContextVar("progress_task_stack", default=[])
12
+
13
+ class CustomProgress(Progress):
14
+
15
+ def __init__ (
16
+ self,
17
+ *columns: ProgressColumn,
18
+ console=None,
19
+ auto_refresh=True,
20
+ refresh_per_second = 10,
21
+ speed_estimate_period=30,
22
+ transient=False,
23
+ redirect_stdout=True,
24
+ redirect_stderr=True,
25
+ get_time=None,
26
+ disable=False,
27
+ expand=False
28
+ ):
29
+ default_columns = list(columns) if len(columns) > 0 else [
30
+ SpinnerColumn(spinner_name="dots", finished_text="[bold green]✔[/bold green]"),
31
+ TextColumn("[progress.description]{task.description}"),
32
+ ]
33
+ super().__init__(
34
+ *default_columns,
35
+ console=console,
36
+ auto_refresh=auto_refresh,
37
+ refresh_per_second=refresh_per_second,
38
+ speed_estimate_period=speed_estimate_period,
39
+ transient=transient,
40
+ redirect_stdout=redirect_stdout,
41
+ redirect_stderr=redirect_stderr,
42
+ get_time=get_time,
43
+ disable=disable,
44
+ expand=expand
45
+ )
46
+ self.default_columns = default_columns
47
+
48
+ def __enter__ (self):
49
+ self._token = current_progress.set(self)
50
+ self._stack_token = progress_task_stack.set([])
51
+ return super().__enter__()
52
+
53
+ def __exit__ (self, exc_type, exc_val, exc_tb):
54
+ current_progress.reset(self._token)
55
+ progress_task_stack.reset(self._stack_token)
56
+ return super().__exit__(exc_type, exc_val, exc_tb)
57
+
58
+ def get_renderables (self):
59
+ for task in self.tasks:
60
+ task_columns = task.fields.get("columns") or list()
61
+ self.columns = self.default_columns + task_columns
62
+ yield self.make_tasks_table([task])
63
+
64
+ class CustomProgressTask:
65
+
66
+ def __init__ (
67
+ self,
68
+ *,
69
+ loading_text: str,
70
+ done_text: str=None,
71
+ columns: list[ProgressColumn]=None
72
+ ):
73
+ self.loading_text = loading_text
74
+ self.done_text = done_text if done_text is not None else loading_text
75
+ self.task_id = None
76
+ self.columns = columns
77
+
78
+ def __enter__ (self):
79
+ progress = current_progress.get()
80
+ indent_level = len(progress_task_stack.get())
81
+ indent = self.__get_indent(indent_level)
82
+ if progress is not None:
83
+ self.task_id = progress.add_task(
84
+ f"{indent}{self.loading_text}",
85
+ total=1,
86
+ columns=self.columns
87
+ )
88
+ current_stack = progress_task_stack.get()
89
+ progress_task_stack.set(current_stack + [self.task_id])
90
+ return self
91
+
92
+ def __exit__ (self, exc_type, exc_val, exc_tb):
93
+ progress = current_progress.get()
94
+ if progress is not None and self.task_id is not None:
95
+ indent_level = len(progress_task_stack.get()) - 1
96
+ indent = self.__get_indent(indent_level)
97
+ if exc_type is None:
98
+ total = progress._tasks[self.task_id].total
99
+ progress.update(
100
+ self.task_id,
101
+ description=f"{indent}{self.done_text}",
102
+ completed=total
103
+ )
104
+ else:
105
+ progress.update(
106
+ self.task_id,
107
+ description=f"{indent}[bright_red]✘ {self.loading_text}[/bright_red]",
108
+ )
109
+ current_stack = progress_task_stack.get()
110
+ if current_stack:
111
+ progress_task_stack.set(current_stack[:-1])
112
+ self.task_id = None
113
+ return False
114
+
115
+ def update (self, **kwargs):
116
+ progress = current_progress.get()
117
+ if progress is None or self.task_id is None:
118
+ return
119
+ if "description" in kwargs:
120
+ stack = progress_task_stack.get()
121
+ try:
122
+ index = stack.index(self.task_id)
123
+ except ValueError:
124
+ index = len(stack) - 1
125
+ indent = self.__get_indent(index)
126
+ description = kwargs["description"]
127
+ kwargs["description"] = f"{indent}{description}"
128
+ progress.update(self.task_id, **kwargs)
129
+
130
+ def finish (self, message: str):
131
+ self.done_text = message
132
+
133
+ def __get_indent (self, level: int) -> str:
134
+ if level == 0:
135
+ return ""
136
+ indicator = "└── "
137
+ return " " * len(indicator) * (level - 1) + indicator