fxn 0.0.42__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fxn/__init__.py CHANGED
@@ -4,7 +4,8 @@
4
4
  #
5
5
 
6
6
  from .client import FunctionAPIError
7
- from .compile import *
7
+ from .compile import compile
8
8
  from .function import Function
9
+ from .sandbox import Sandbox
9
10
  from .types import *
10
11
  from .version import *
fxn/cli/__init__.py CHANGED
@@ -6,10 +6,10 @@
6
6
  from typer import Typer
7
7
 
8
8
  from .auth import app as auth_app
9
- #from .compile import compile_predictor
9
+ from .compile import compile_predictor
10
10
  from .misc import cli_options
11
11
  from .predictions import create_prediction
12
- from .predictors import retrieve_predictor
12
+ from .predictors import archive_predictor, delete_predictor, retrieve_predictor
13
13
  from ..version import __version__
14
14
 
15
15
  # Define CLI
@@ -33,11 +33,13 @@ app.command(
33
33
  help="Make a prediction.",
34
34
  context_settings={ "allow_extra_args": True, "ignore_unknown_options": True }
35
35
  )(create_prediction)
36
- # app.command(
37
- # name="compile",
38
- # help="Create a predictor by compiling a Python function."
39
- # )(compile_predictor)
36
+ app.command(
37
+ name="compile",
38
+ help="Create a predictor by compiling a Python function."
39
+ )(compile_predictor)
40
40
  app.command(name="retrieve", help="Retrieve a predictor.")(retrieve_predictor)
41
+ app.command(name="archive", help="Archive a predictor.")(archive_predictor)
42
+ app.command(name="delete", help="Delete a predictor.")(delete_predictor)
41
43
 
42
44
  # Run
43
45
  if __name__ == "__main__":
fxn/cli/compile.py ADDED
@@ -0,0 +1,141 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from asyncio import run as run_async
7
+ from importlib.util import module_from_spec, spec_from_file_location
8
+ from inspect import getmembers, getmodulename, isfunction
9
+ from pathlib import Path
10
+ from pydantic import BaseModel
11
+ from re import sub
12
+ from rich import print as print_rich
13
+ from rich.progress import SpinnerColumn, TextColumn
14
+ import sys
15
+ from typer import Argument, Option
16
+ from typing import Callable, Literal
17
+ from urllib.parse import urlparse, urlunparse
18
+
19
+ from ..compile import PredictorSpec
20
+ from ..function import Function
21
+ from ..sandbox import EntrypointCommand
22
+ from ..logging import CustomProgress, CustomProgressTask
23
+ from .auth import get_access_key
24
+
25
+ def compile_predictor (
26
+ path: str=Argument(..., help="Predictor path.")
27
+ ):
28
+ run_async(_compile_predictor_async(path))
29
+
30
+ async def _compile_predictor_async (path: str):
31
+ fxn = Function(get_access_key())
32
+ path: Path = Path(path).resolve()
33
+ with CustomProgress(
34
+ SpinnerColumn(spinner_name="dots", finished_text="[bold green]✔[/bold green]"),
35
+ TextColumn("[progress.description]{task.description}"),
36
+ ):
37
+ # Load
38
+ with CustomProgressTask(loading_text="Loading predictor...") as task:
39
+ func = _load_predictor_func(path)
40
+ entrypoint = EntrypointCommand(from_path=str(path), to_path="./", name=func.__name__)
41
+ spec: PredictorSpec = func.__predictor_spec
42
+ task.finish(f"Loaded prediction function: [bold cyan]{spec.tag}[/bold cyan]")
43
+ # Populate
44
+ sandbox = spec.sandbox
45
+ sandbox.commands.append(entrypoint)
46
+ with CustomProgressTask(loading_text="Uploading sandbox...", done_text="Uploaded sandbox"):
47
+ sandbox.populate(fxn=fxn)
48
+ # Compile
49
+ with CustomProgressTask(loading_text="Running codegen...", done_text="Completed codegen"):
50
+ with CustomProgressTask(loading_text="Creating predictor..."):
51
+ predictor = fxn.client.request(
52
+ method="POST",
53
+ path="/predictors",
54
+ body=spec.model_dump(mode="json"),
55
+ response_type=_Predictor
56
+ )
57
+ with ProgressLogQueue() as task_queue:
58
+ async for event in fxn.client.stream(
59
+ method="POST",
60
+ path=f"/predictors/{predictor.tag}/compile",
61
+ body={ },
62
+ response_type=_LogEvent | _ErrorEvent
63
+ ):
64
+ if isinstance(event, _LogEvent):
65
+ task_queue.push_log(event)
66
+ elif isinstance(event, _ErrorEvent):
67
+ task_queue.push_error(event)
68
+ raise RuntimeError(event.data.error)
69
+ predictor_url = _compute_predictor_url(fxn.client.api_url, spec.tag)
70
+ print_rich(f"\n[bold spring_green3]🎉 Predictor is now being compiled.[/bold spring_green3] Check it out at {predictor_url}")
71
+
72
+ def _load_predictor_func (path: str) -> Callable[...,object]:
73
+ if "" not in sys.path:
74
+ sys.path.insert(0, "")
75
+ path: Path = Path(path).resolve()
76
+ sys.path.insert(0, str(path.parent))
77
+ name = getmodulename(path)
78
+ spec = spec_from_file_location(name, path)
79
+ module = module_from_spec(spec)
80
+ sys.modules[name] = module
81
+ spec.loader.exec_module(module)
82
+ main_func = next(func for _, func in getmembers(module, isfunction) if hasattr(func, "__predictor_spec"))
83
+ return main_func
84
+
85
+ def _compute_predictor_url (api_url: str, tag: str) -> str:
86
+ parsed_url = urlparse(api_url)
87
+ hostname_parts = parsed_url.hostname.split(".")
88
+ if hostname_parts[0] == "api":
89
+ hostname_parts.pop(0)
90
+ hostname = ".".join(hostname_parts)
91
+ netloc = hostname if not parsed_url.port else f"{hostname}:{parsed_url.port}"
92
+ predictor_url = urlunparse(parsed_url._replace(netloc=netloc, path=f"{tag}"))
93
+ return predictor_url
94
+
95
+ class _Predictor (BaseModel):
96
+ tag: str
97
+
98
+ class _LogData (BaseModel):
99
+ message: str
100
+ level: int = 0
101
+
102
+ class _LogEvent (BaseModel):
103
+ event: Literal["log"]
104
+ data: _LogData
105
+
106
+ class _ErrorData (BaseModel):
107
+ error: str
108
+
109
+ class _ErrorEvent (BaseModel):
110
+ event: Literal["error"]
111
+ data: _ErrorData
112
+
113
+ class ProgressLogQueue:
114
+
115
+ def __init__ (self):
116
+ self.queue: list[tuple[int, CustomProgressTask]] = []
117
+
118
+ def push_log (self, event: _LogEvent):
119
+ while self.queue:
120
+ current_level, current_task = self.queue[-1]
121
+ if event.data.level > current_level:
122
+ break
123
+ current_task.__exit__(None, None, None)
124
+ self.queue.pop()
125
+ message = sub(r"`([^`]+)`", r"[hot_pink italic]\1[/hot_pink italic]", event.data.message)
126
+ task = CustomProgressTask(loading_text=message)
127
+ task.__enter__()
128
+ self.queue.append((event.data.level, task))
129
+
130
+ def push_error (self, error: _ErrorEvent):
131
+ while self.queue:
132
+ _, current_task = self.queue.pop()
133
+ current_task.__exit__(RuntimeError, None, None)
134
+
135
+ def __enter__ (self):
136
+ return self
137
+
138
+ def __exit__ (self, exc_type, exc_value, traceback):
139
+ while self.queue:
140
+ _, current_task = self.queue.pop()
141
+ current_task.__exit__(None, None, None)
fxn/cli/predictions.py CHANGED
@@ -9,11 +9,11 @@ from numpy import array_repr, ndarray
9
9
  from pathlib import Path, PurePath
10
10
  from PIL import Image
11
11
  from rich import print_json
12
- from rich.progress import Progress, SpinnerColumn, TextColumn
13
12
  from tempfile import mkstemp
14
13
  from typer import Argument, Context, Option
15
14
 
16
15
  from ..function import Function
16
+ from ..logging import CustomProgress, CustomProgressTask
17
17
  from ..types import Prediction
18
18
  from .auth import get_access_key
19
19
 
@@ -26,18 +26,21 @@ def create_prediction (
26
26
 
27
27
  async def _predict_async (tag: str, quiet: bool, context: Context):
28
28
  # Preload
29
- fxn = Function(get_access_key())
30
- fxn.predictions.create(tag, inputs={ }, verbose=not quiet)
31
- # Predict
32
- with Progress(
33
- SpinnerColumn(spinner_name="dots"),
34
- TextColumn("[progress.description]{task.description}"),
35
- transient=True
36
- ) as progress:
37
- progress.add_task(description="Running Function...", total=None)
38
- inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) }
39
- prediction = fxn.predictions.create(tag, inputs=inputs)
40
- _log_prediction(prediction)
29
+ with CustomProgress(transient=True, disable=quiet):
30
+ fxn = Function(get_access_key())
31
+ with CustomProgressTask(
32
+ loading_text="Preloading predictor...",
33
+ done_text="Preloaded predictor"
34
+ ):
35
+ fxn.predictions.create(tag, inputs={ })
36
+ with CustomProgressTask(loading_text="Making prediction..."):
37
+ inputs = { }
38
+ for i in range(0, len(context.args), 2):
39
+ name = context.args[i].replace("-", "")
40
+ value = _parse_value(context.args[i+1])
41
+ inputs[name] = value
42
+ prediction = fxn.predictions.create(tag, inputs=inputs)
43
+ _log_prediction(prediction)
41
44
 
42
45
  def _parse_value (value: str):
43
46
  """
fxn/cli/predictors.py CHANGED
@@ -7,12 +7,43 @@ from rich import print_json
7
7
  from typer import Argument
8
8
 
9
9
  from ..function import Function
10
+ from ..logging import CustomProgress, CustomProgressTask
10
11
  from .auth import get_access_key
11
12
 
12
13
  def retrieve_predictor (
13
14
  tag: str=Argument(..., help="Predictor tag.")
14
15
  ):
15
- fxn = Function(get_access_key())
16
- predictor = fxn.predictors.retrieve(tag)
17
- predictor = predictor.model_dump() if predictor else None
18
- print_json(data=predictor)
16
+ with CustomProgress(transient=True):
17
+ with CustomProgressTask(loading_text="Retrieving predictor..."):
18
+ fxn = Function(get_access_key())
19
+ predictor = fxn.predictors.retrieve(tag)
20
+ predictor = predictor.model_dump() if predictor else None
21
+ print_json(data=predictor)
22
+
23
+ def archive_predictor (
24
+ tag: str=Argument(..., help="Predictor tag.")
25
+ ):
26
+ with CustomProgress():
27
+ with CustomProgressTask(
28
+ loading_text="Archiving predictor...",
29
+ done_text=f"Archived predictor: [bold dark_orange]{tag}[/bold dark_orange]"
30
+ ):
31
+ fxn = Function(get_access_key())
32
+ fxn.client.request(
33
+ method="POST",
34
+ path=f"/predictors/{tag}/archive"
35
+ )
36
+
37
+ def delete_predictor (
38
+ tag: str=Argument(..., help="Predictor tag.")
39
+ ):
40
+ with CustomProgress():
41
+ with CustomProgressTask(
42
+ loading_text="Deleting predictor...",
43
+ done_text=f"Deleted predictor: [bold red]{tag}[/bold red]"
44
+ ):
45
+ fxn = Function(get_access_key())
46
+ fxn.client.request(
47
+ method="DELETE",
48
+ path=f"/predictors/{tag}"
49
+ )
fxn/client.py CHANGED
@@ -4,9 +4,9 @@
4
4
  #
5
5
 
6
6
  from json import loads, JSONDecodeError
7
- from pydantic import BaseModel
7
+ from pydantic import BaseModel, TypeAdapter
8
8
  from requests import request
9
- from typing import Any, Literal, Type, TypeVar
9
+ from typing import AsyncGenerator, Literal, Type, TypeVar
10
10
 
11
11
  T = TypeVar("T", bound=BaseModel)
12
12
 
@@ -19,11 +19,20 @@ class FunctionClient:
19
19
  def request (
20
20
  self,
21
21
  *,
22
- method: Literal["GET", "POST", "DELETE"],
22
+ method: Literal["GET", "POST", "PATCH", "DELETE"],
23
23
  path: str,
24
- body: dict[str, Any]=None,
24
+ body: dict[str, object]=None,
25
25
  response_type: Type[T]=None
26
26
  ) -> T:
27
+ """
28
+ Make a request to a REST endpoint.
29
+
30
+ Parameters:
31
+ method (str): Request method.
32
+ path (str): Endpoint path.
33
+ body (dict): Request JSON body.
34
+ response_type (Type): Response type.
35
+ """
27
36
  response = request(
28
37
  method=method,
29
38
  url=f"{self.api_url}{path}",
@@ -40,6 +49,53 @@ class FunctionClient:
40
49
  else:
41
50
  error = _ErrorResponse(**data).errors[0].message if isinstance(data, dict) else data
42
51
  raise FunctionAPIError(error, response.status_code)
52
+
53
+ async def stream (
54
+ self,
55
+ *,
56
+ method: Literal["GET", "POST", "PATCH", "DELETE"],
57
+ path: str,
58
+ body: dict[str, object]=None,
59
+ response_type: Type[T]=None
60
+ ) -> AsyncGenerator[T, None]:
61
+ """
62
+ Make a request to a REST endpoint and consume the response as a server-sent events stream.
63
+
64
+ Parameters:
65
+ method (str): Request method.
66
+ path (str): Endpoint path.
67
+ body (dict): Request JSON body.
68
+ response_type (Type): Response type.
69
+ """
70
+ response = request(
71
+ method=method,
72
+ url=f"{self.api_url}{path}",
73
+ json=body,
74
+ headers={
75
+ "Accept": "text/event-stream",
76
+ "Authorization": f"Bearer {self.access_key}"
77
+ },
78
+ stream=True
79
+ )
80
+ event = None
81
+ data: str = ""
82
+ for line in response.iter_lines(decode_unicode=True):
83
+ if line is None:
84
+ break
85
+ line: str = line.strip()
86
+ if line:
87
+ if line.startswith("event:"):
88
+ event = line[len("event:"):].strip()
89
+ elif line.startswith("data:"):
90
+ line_data = line[len("data:"):].strip()
91
+ data = f"{data}\n{line_data}"
92
+ continue
93
+ if event is not None:
94
+ yield _parse_sse_event(event, data, response_type)
95
+ event = None
96
+ data = ""
97
+ if event or data:
98
+ yield _parse_sse_event(event, data, response_type)
43
99
 
44
100
  class FunctionAPIError (Exception):
45
101
 
@@ -55,4 +111,9 @@ class _APIError (BaseModel):
55
111
  message: str
56
112
 
57
113
  class _ErrorResponse (BaseModel):
58
- errors: list[_APIError]
114
+ errors: list[_APIError]
115
+
116
+ def _parse_sse_event (event: str, data: str, type: Type[T]=None) -> T:
117
+ result = { "event": event, "data": loads(data) }
118
+ result = TypeAdapter(type).validate_python(result) if type is not None else result
119
+ return result
@@ -5,12 +5,12 @@
5
5
 
6
6
  from collections.abc import Callable
7
7
  from functools import wraps
8
+ from inspect import isasyncgenfunction, iscoroutinefunction
8
9
  from pathlib import Path
9
10
  from pydantic import BaseModel, Field
10
11
 
11
- from ..types import AccessMode, Signature
12
12
  from .sandbox import Sandbox
13
- from .signature import get_function_type, infer_function_signature, FunctionType
13
+ from .types import AccessMode
14
14
 
15
15
  class PredictorSpec (BaseModel):
16
16
  """
@@ -20,7 +20,6 @@ class PredictorSpec (BaseModel):
20
20
  description: str = Field(description="Predictor description. MUST be less than 100 characters long.", min_length=4, max_length=100)
21
21
  sandbox: Sandbox = Field(description="Sandbox to compile the function.")
22
22
  access: AccessMode = Field(description="Predictor access.")
23
- signature: Signature = Field(description="Predictor signature.")
24
23
  card: str | None = Field(default=None, description="Predictor card (markdown).")
25
24
  media: str | None = Field(default=None, description="Predictor media URL.")
26
25
  license: str | None = Field(default=None, description="Predictor license URL. This is required for public predictors.")
@@ -51,11 +50,9 @@ def compile (
51
50
  # Check type
52
51
  if not callable(func):
53
52
  raise TypeError("Cannot compile non-function objects")
54
- func_type = get_function_type(func)
55
- if func_type not in { FunctionType.Function, FunctionType.Generator }:
56
- raise TypeError(f"Function '{func.__name__}' must be a regular function or generator")
53
+ if isasyncgenfunction(func) or iscoroutinefunction(func):
54
+ raise TypeError(f"Function '{func.__name__}' must be a regular function or generator")
57
55
  # Gather metadata
58
- signature = infer_function_signature(func) # throws
59
56
  if isinstance(card, Path):
60
57
  with open(card_content, "r") as f:
61
58
  card_content = f.read()
@@ -66,7 +63,6 @@ def compile (
66
63
  description=description,
67
64
  sandbox=sandbox if sandbox is not None else Sandbox(),
68
65
  access=access,
69
- signature=signature,
70
66
  card=card_content,
71
67
  media=None, # INCOMPLETE
72
68
  license=license
fxn/logging.py ADDED
@@ -0,0 +1,137 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from contextvars import ContextVar
7
+ from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
8
+ from typing import Literal
9
+
10
+ current_progress = ContextVar("current_progress", default=None)
11
+ progress_task_stack = ContextVar("progress_task_stack", default=[])
12
+
13
+ class CustomProgress(Progress):
14
+
15
+ def __init__ (
16
+ self,
17
+ *columns: ProgressColumn,
18
+ console=None,
19
+ auto_refresh=True,
20
+ refresh_per_second = 10,
21
+ speed_estimate_period=30,
22
+ transient=False,
23
+ redirect_stdout=True,
24
+ redirect_stderr=True,
25
+ get_time=None,
26
+ disable=False,
27
+ expand=False
28
+ ):
29
+ default_columns = list(columns) if len(columns) > 0 else [
30
+ SpinnerColumn(spinner_name="dots", finished_text="[bold green]✔[/bold green]"),
31
+ TextColumn("[progress.description]{task.description}"),
32
+ ]
33
+ super().__init__(
34
+ *default_columns,
35
+ console=console,
36
+ auto_refresh=auto_refresh,
37
+ refresh_per_second=refresh_per_second,
38
+ speed_estimate_period=speed_estimate_period,
39
+ transient=transient,
40
+ redirect_stdout=redirect_stdout,
41
+ redirect_stderr=redirect_stderr,
42
+ get_time=get_time,
43
+ disable=disable,
44
+ expand=expand
45
+ )
46
+ self.default_columns = default_columns
47
+
48
+ def __enter__ (self):
49
+ self._token = current_progress.set(self)
50
+ self._stack_token = progress_task_stack.set([])
51
+ return super().__enter__()
52
+
53
+ def __exit__ (self, exc_type, exc_val, exc_tb):
54
+ current_progress.reset(self._token)
55
+ progress_task_stack.reset(self._stack_token)
56
+ return super().__exit__(exc_type, exc_val, exc_tb)
57
+
58
+ def get_renderables (self):
59
+ for task in self.tasks:
60
+ task_columns = task.fields.get("columns") or list()
61
+ self.columns = self.default_columns + task_columns
62
+ yield self.make_tasks_table([task])
63
+
64
+ class CustomProgressTask:
65
+
66
+ def __init__ (
67
+ self,
68
+ *,
69
+ loading_text: str,
70
+ done_text: str=None,
71
+ columns: list[ProgressColumn]=None
72
+ ):
73
+ self.loading_text = loading_text
74
+ self.done_text = done_text if done_text is not None else loading_text
75
+ self.task_id = None
76
+ self.columns = columns
77
+
78
+ def __enter__ (self):
79
+ progress = current_progress.get()
80
+ indent_level = len(progress_task_stack.get())
81
+ indent = self.__get_indent(indent_level)
82
+ if progress is not None:
83
+ self.task_id = progress.add_task(
84
+ f"{indent}{self.loading_text}",
85
+ total=1,
86
+ columns=self.columns
87
+ )
88
+ current_stack = progress_task_stack.get()
89
+ progress_task_stack.set(current_stack + [self.task_id])
90
+ return self
91
+
92
+ def __exit__ (self, exc_type, exc_val, exc_tb):
93
+ progress = current_progress.get()
94
+ if progress is not None and self.task_id is not None:
95
+ indent_level = len(progress_task_stack.get()) - 1
96
+ indent = self.__get_indent(indent_level)
97
+ if exc_type is None:
98
+ total = progress._tasks[self.task_id].total
99
+ progress.update(
100
+ self.task_id,
101
+ description=f"{indent}{self.done_text}",
102
+ completed=total
103
+ )
104
+ else:
105
+ progress.update(
106
+ self.task_id,
107
+ description=f"{indent}[bright_red]✘ {self.loading_text}[/bright_red]",
108
+ )
109
+ current_stack = progress_task_stack.get()
110
+ if current_stack:
111
+ progress_task_stack.set(current_stack[:-1])
112
+ self.task_id = None
113
+ return False
114
+
115
+ def update (self, **kwargs):
116
+ progress = current_progress.get()
117
+ if progress is None or self.task_id is None:
118
+ return
119
+ if "description" in kwargs:
120
+ stack = progress_task_stack.get()
121
+ try:
122
+ index = stack.index(self.task_id)
123
+ except ValueError:
124
+ index = len(stack) - 1
125
+ indent = self.__get_indent(index)
126
+ description = kwargs["description"]
127
+ kwargs["description"] = f"{indent}{description}"
128
+ progress.update(self.task_id, **kwargs)
129
+
130
+ def finish (self, message: str):
131
+ self.done_text = message
132
+
133
+ def __get_indent (self, level: int) -> str:
134
+ if level == 0:
135
+ return ""
136
+ indicator = "└── "
137
+ return " " * len(indicator) * (level - 1) + indicator
@@ -4,13 +4,16 @@
4
4
  #
5
5
 
6
6
  from __future__ import annotations
7
+ from abc import ABC, abstractmethod
7
8
  from hashlib import sha256
8
9
  from pathlib import Path
9
10
  from pydantic import BaseModel
10
11
  from requests import put
12
+ from rich.progress import BarColumn, TextColumn
11
13
  from typing import Literal
12
14
 
13
- from ..function import Function
15
+ from .function import Function
16
+ from .logging import CustomProgressTask
14
17
 
15
18
  class WorkdirCommand (BaseModel):
16
19
  kind: Literal["workdir"] = "workdir"
@@ -20,17 +23,35 @@ class EnvCommand (BaseModel):
20
23
  kind: Literal["env"] = "env"
21
24
  env: dict[str, str]
22
25
 
23
- class UploadFileCommand (BaseModel):
24
- kind: Literal["upload_file"] = "upload_file"
26
+ class UploadableCommand (BaseModel, ABC):
25
27
  from_path: str
26
28
  to_path: str
27
29
  manifest: dict[str, str] | None = None
28
30
 
29
- class UploadDirectoryCommand (BaseModel):
31
+ @abstractmethod
32
+ def get_files (self) -> list[Path]:
33
+ pass
34
+
35
+ class UploadFileCommand (UploadableCommand):
36
+ kind: Literal["upload_file"] = "upload_file"
37
+
38
+ def get_files (self) -> list[Path]:
39
+ return [Path(self.from_path).resolve()]
40
+
41
+ class UploadDirectoryCommand (UploadableCommand):
30
42
  kind: Literal["upload_dir"] = "upload_dir"
31
- from_path: str
32
- to_path: str
33
- manifest: dict[str, str] | None = None
43
+
44
+ def get_files (self) -> list[Path]:
45
+ from_path = Path(self.from_path)
46
+ assert from_path.is_absolute(), "Cannot upload directory because directory path must be absolute"
47
+ return [file for file in from_path.rglob("*") if file.is_file()]
48
+
49
+ class EntrypointCommand (UploadableCommand):
50
+ kind: Literal["entrypoint"] = "entrypoint"
51
+ name: str
52
+
53
+ def get_files (self) -> list[Path]:
54
+ return [Path(self.from_path).resolve()]
34
55
 
35
56
  class PipInstallCommand (BaseModel):
36
57
  kind: Literal["pip_install"] = "pip_install"
@@ -40,10 +61,6 @@ class AptInstallCommand (BaseModel):
40
61
  kind: Literal["apt_install"] = "apt_install"
41
62
  packages: list[str]
42
63
 
43
- class EntrypointCommand (BaseModel):
44
- kind: Literal["entrypoint"] = "entrypoint"
45
- path: str
46
-
47
64
  Command = (
48
65
  WorkdirCommand |
49
66
  EnvCommand |
@@ -68,16 +85,14 @@ class Sandbox (BaseModel):
68
85
  path (str | Path): Path to change to.
69
86
  """
70
87
  command = WorkdirCommand(path=str(path))
71
- self.commands.append(command)
72
- return self
88
+ return Sandbox(commands=self.commands + [command])
73
89
 
74
90
  def env (self, **env: str) -> Sandbox:
75
91
  """
76
92
  Set environment variables in the sandbox.
77
93
  """
78
94
  command = EnvCommand(env=env)
79
- self.commands.append(command)
80
- return self
95
+ return Sandbox(commands=self.commands + [command])
81
96
 
82
97
  def upload_file (
83
98
  self,
@@ -92,8 +107,7 @@ class Sandbox (BaseModel):
92
107
  to_path (str | Path): Remote path to upload file to.
93
108
  """
94
109
  command = UploadFileCommand(from_path=str(from_path), to_path=str(to_path))
95
- self.commands.append(command)
96
- return self
110
+ return Sandbox(commands=self.commands + [command])
97
111
 
98
112
  def upload_directory (
99
113
  self,
@@ -108,8 +122,7 @@ class Sandbox (BaseModel):
108
122
  to_path (str | Path): Remote path to upload directory to.
109
123
  """
110
124
  command = UploadDirectoryCommand(from_path=str(from_path), to_path=str(to_path))
111
- self.commands.append(command)
112
- return self
125
+ return Sandbox(commands=self.commands + [command])
113
126
 
114
127
  def pip_install (self, *packages: str) -> Sandbox:
115
128
  """
@@ -119,8 +132,7 @@ class Sandbox (BaseModel):
119
132
  packages (list): Packages to install.
120
133
  """
121
134
  command = PipInstallCommand(packages=packages)
122
- self.commands.append(command)
123
- return self
135
+ return Sandbox(commands=self.commands + [command])
124
136
 
125
137
  def apt_install (self, *packages: str) -> Sandbox:
126
138
  """
@@ -130,24 +142,41 @@ class Sandbox (BaseModel):
130
142
  packages (list): Packages to install.
131
143
  """
132
144
  command = AptInstallCommand(packages=packages)
133
- self.commands.append(command)
134
- return self
135
-
136
- def populate (self, fxn: Function=None) -> Sandbox:
145
+ return Sandbox(commands=self.commands + [command])
146
+
147
+ def populate (self, fxn: Function=None) -> Sandbox: # CHECK # In place
137
148
  """
138
149
  Populate all metadata.
139
150
  """
140
151
  fxn = fxn if fxn is not None else Function()
152
+ entrypoint = next(cmd for cmd in self.commands if isinstance(cmd, EntrypointCommand))
153
+ entry_path = Path(entrypoint.from_path).resolve()
141
154
  for command in self.commands:
142
- if isinstance(command, UploadFileCommand):
143
- from_path = Path(command.from_path)
144
- to_path = Path(command.to_path)
145
- command.manifest = { str(to_path / from_path.name): self.__upload_file(from_path, fxn=fxn) }
146
- elif isinstance(command, UploadDirectoryCommand):
155
+ if isinstance(command, UploadableCommand):
156
+ cwd = Path.cwd()
147
157
  from_path = Path(command.from_path)
148
158
  to_path = Path(command.to_path)
149
- files = [file for file in from_path.rglob("*") if file.is_file()]
150
- command.manifest = { str(to_path / file.relative_to(from_path)): self.__upload_file(file, fxn=fxn) for file in files }
159
+ if not from_path.is_absolute():
160
+ from_path = (entry_path / from_path).resolve()
161
+ command.from_path = str(from_path)
162
+ files = command.get_files()
163
+ name = from_path.relative_to(cwd) if from_path.is_relative_to(cwd) else from_path.resolve()
164
+ with CustomProgressTask(
165
+ loading_text=f"Uploading [light_slate_blue]{name}[/light_slate_blue]...",
166
+ done_text=f"Uploaded [light_slate_blue]{name}[/light_slate_blue]",
167
+ columns=[
168
+ BarColumn(),
169
+ TextColumn("{task.completed}/{task.total}")
170
+ ]
171
+ ) as task:
172
+ manifest = { }
173
+ for idx, file in enumerate(files):
174
+ rel_file_path = file.relative_to(from_path) if from_path.is_dir() else file.name
175
+ dst_path = to_path / rel_file_path
176
+ checksum = self.__upload_file(file, fxn=fxn)
177
+ manifest[str(dst_path)] = checksum
178
+ task.update(total=len(files), completed=idx+1)
179
+ command.manifest = manifest
151
180
  return self
152
181
 
153
182
  def __upload_file (self, path: Path, fxn: Function) -> str:
@@ -172,6 +201,6 @@ class Sandbox (BaseModel):
172
201
  for chunk in iter(lambda: f.read(4096), b""):
173
202
  hash.update(chunk)
174
203
  return hash.hexdigest()
175
-
204
+
176
205
  class _Resource (BaseModel):
177
206
  url: str
@@ -11,13 +11,14 @@ from pathlib import Path
11
11
  from PIL import Image
12
12
  from pydantic import BaseModel
13
13
  from requests import get
14
- from rich.progress import Progress, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
15
- from tempfile import gettempdir
14
+ from rich.progress import BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
15
+ from tempfile import gettempdir, NamedTemporaryFile
16
16
  from typing import Any, AsyncIterator
17
17
  from urllib.parse import urlparse
18
18
 
19
19
  from ..c import Configuration, Predictor, Prediction as CPrediction, Value as CValue, ValueFlags, ValueMap
20
20
  from ..client import FunctionClient
21
+ from ..logging import CustomProgressTask
21
22
  from ..types import Acceleration, Prediction, PredictionResource
22
23
 
23
24
  Value = ndarray | str | float | int | bool | list[Any] | dict[str, Any] | Image.Image | BytesIO | memoryview
@@ -50,8 +51,7 @@ class PredictionService:
50
51
  acceleration: Acceleration=Acceleration.Auto,
51
52
  device=None,
52
53
  client_id: str=None,
53
- configuration_id: str=None,
54
- verbose: bool=False
54
+ configuration_id: str=None
55
55
  ) -> Prediction:
56
56
  """
57
57
  Create a prediction.
@@ -62,7 +62,6 @@ class PredictionService:
62
62
  acceleration (Acceleration): Prediction acceleration.
63
63
  client_id (str): Function client identifier. Specify this to override the current client identifier.
64
64
  configuration_id (str): Configuration identifier. Specify this to override the current client configuration identifier.
65
- verbose (bool): Enable verbose logging.
66
65
 
67
66
  Returns:
68
67
  Prediction: Created prediction.
@@ -78,8 +77,7 @@ class PredictionService:
78
77
  acceleration=acceleration,
79
78
  device=device,
80
79
  client_id=client_id,
81
- configuration_id=configuration_id,
82
- verbose=verbose
80
+ configuration_id=configuration_id
83
81
  )
84
82
  with (
85
83
  self.__to_value_map(inputs) as input_map,
@@ -145,8 +143,7 @@ class PredictionService:
145
143
  acceleration: Acceleration=Acceleration.Auto,
146
144
  device=None,
147
145
  client_id: str=None,
148
- configuration_id: str=None,
149
- verbose: bool=False
146
+ configuration_id: str=None
150
147
  ) -> Predictor:
151
148
  if tag in self.__cache:
152
149
  return self.__cache[tag]
@@ -155,20 +152,13 @@ class PredictionService:
155
152
  client_id=client_id,
156
153
  configuration_id=configuration_id
157
154
  )
158
- with Configuration() as configuration, Progress(
159
- TextColumn("[bold blue]{task.fields[filename]}"),
160
- BarColumn(),
161
- DownloadColumn(),
162
- TransferSpeedColumn(),
163
- TimeRemainingColumn(),
164
- disable=not verbose
165
- ) as progress:
155
+ with Configuration() as configuration:
166
156
  configuration.tag = prediction.tag
167
157
  configuration.token = prediction.configuration
168
158
  configuration.acceleration = acceleration
169
159
  configuration.device = device
170
160
  for resource in prediction.resources:
171
- path = self.__download_resource(resource, progress=progress)
161
+ path = self.__download_resource(resource)
172
162
  configuration.add_resource(resource.type, path)
173
163
  predictor = Predictor(configuration)
174
164
  self.__cache[tag] = predictor
@@ -227,12 +217,7 @@ class PredictionService:
227
217
  )
228
218
  return prediction
229
219
 
230
- def __download_resource (
231
- self,
232
- resource: PredictionResource,
233
- *,
234
- progress: Progress
235
- ) -> Path:
220
+ def __download_resource (self, resource: PredictionResource) -> Path:
236
221
  path = self.__get_resource_path(resource)
237
222
  if path.exists():
238
223
  return path
@@ -241,12 +226,26 @@ class PredictionService:
241
226
  response.raise_for_status()
242
227
  size = int(response.headers.get("content-length", 0))
243
228
  stem = Path(urlparse(resource.url).path).name
244
- task = progress.add_task(f"Downloading", filename=stem, total=size)
245
- with open(path, "wb") as fp:
229
+ completed = 0
230
+ color = "dark_orange" if not resource.type == "dso" else "purple"
231
+ with (
232
+ CustomProgressTask(
233
+ loading_text=f"[{color}]{stem}[/{color}]",
234
+ columns=[
235
+ BarColumn(),
236
+ DownloadColumn(),
237
+ TransferSpeedColumn(),
238
+ TimeRemainingColumn()
239
+ ]
240
+ ) as task,
241
+ NamedTemporaryFile(mode="wb", delete=False) as tmp_file
242
+ ):
246
243
  for chunk in response.iter_content(chunk_size=8192):
247
244
  if chunk:
248
- fp.write(chunk)
249
- progress.update(task, advance=len(chunk))
245
+ tmp_file.write(chunk)
246
+ completed += len(chunk)
247
+ task.update(total=size, completed=completed)
248
+ Path(tmp_file.name).replace(path)
250
249
  return path
251
250
 
252
251
  def __get_resource_path (self, resource: PredictionResource) -> Path:
fxn/types/predictor.py CHANGED
@@ -21,7 +21,7 @@ class PredictorStatus (str, Enum):
21
21
  """
22
22
  Predictor status.
23
23
  """
24
- Provisioning = "PROVISIONING"
24
+ Compiling = "COMPILING"
25
25
  Active = "ACTIVE"
26
26
  Invalid = "INVALID"
27
27
  Archived = "ARCHIVED"
fxn/version.py CHANGED
@@ -3,4 +3,4 @@
3
3
  # Copyright © 2025 NatML Inc. All Rights Reserved.
4
4
  #
5
5
 
6
- __version__ = "0.0.42"
6
+ __version__ = "0.0.43"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: fxn
3
- Version: 0.0.42
3
+ Version: 0.0.43
4
4
  Summary: Run prediction functions locally in Python. Register at https://fxn.ai.
5
5
  Author-email: "NatML Inc." <hi@fxn.ai>
6
6
  License: Apache License
@@ -1,7 +1,10 @@
1
- fxn/__init__.py,sha256=co_h8b85Oz3gyrT8T30RbwFQZ8Zf0i57ynGC2H44VZs,207
2
- fxn/client.py,sha256=i56oIr4f4ODjKhnD2EugBWQULVubIIFHqqBsoix23xw,1666
1
+ fxn/__init__.py,sha256=62Jg67pac85KHx1pd0ipJqTZmdd8bOrzCHw7qrSThVw,242
2
+ fxn/client.py,sha256=Deje8eiS1VOHX85tQnV34viv2CPVx2ljwHSbyVB5Z1o,3790
3
+ fxn/compile.py,sha256=v5ZiXZoBmurdbPPIIPyh6uLA3dgDRQK2epKmP31VNjk,2800
3
4
  fxn/function.py,sha256=JBZnGaLedgoY-wV5Tg-Rxwpd3tK_ADkD0Qve45NBaA0,1342
4
- fxn/version.py,sha256=-M_EgT0O4ScSm10y5QibgzW8y3KbHfvkosDKMCfirOQ,95
5
+ fxn/logging.py,sha256=sDEoRiCoVOWwHaoTep5820dS5IAVt0YRL9UI58lyAx0,4728
6
+ fxn/sandbox.py,sha256=w2dnHMBaKOERxFMpeAP11X6_SPqcvnpd6SmX6b_FOYQ,7000
7
+ fxn/version.py,sha256=2DgGAdo8M0eRSYqVhukqPY23jLiNk1bQDMfLzUqoxmg,95
5
8
  fxn/beta/__init__.py,sha256=gKoDhuXtXCjdhUYUqmF0gDPMhJfg3UwFgbvMtRB5ipo,111
6
9
  fxn/beta/client.py,sha256=x1AVAz0PlX0Wnemi7KXk5XfC8R6HC7vOQaGUviPJLN8,363
7
10
  fxn/beta/prediction.py,sha256=9DTBahNF6m0TicLab2o9e8IKpiSV6K7cUSTYaFju0ZU,356
@@ -14,15 +17,12 @@ fxn/c/prediction.py,sha256=-d-5yreFAaRS-nDHzhfabRNtgYcmJGiY_N2dt09gk84,2689
14
17
  fxn/c/predictor.py,sha256=48poLj1AthzCgU9n6Wv9gL8o4gFucIlOnBO2wdor6r0,1925
15
18
  fxn/c/stream.py,sha256=Y1Xv1Bt3_qlnWg9rCn7NWESpouF1eKMzDiQjhZWbXTg,1105
16
19
  fxn/c/value.py,sha256=zkmuKb153UUEnenPO560gXFzxJ_ATvs8_HM-j3OLrJU,7362
17
- fxn/cli/__init__.py,sha256=OBwaKLyHBqUUqwJD6waGzRRosMwbisLxCPgki0Ye_lU,1126
20
+ fxn/cli/__init__.py,sha256=6HY2vwoR6W6bT3tXYyeVlp1Ap4F_8BdS0APomiwTG98,1303
18
21
  fxn/cli/auth.py,sha256=6iGbNbjxfCr8OZT3_neLThXdWeKRBZATwru8vU0XmRw,1688
22
+ fxn/cli/compile.py,sha256=g8u2J58fKL9pyDkMUQe7sX4FMg27npISe0k5tG6GeEY,5307
19
23
  fxn/cli/misc.py,sha256=LcJbCj_GAgtGraTRva2zHHOPpNwI6SOFntRksxwlqvM,843
20
- fxn/cli/predictions.py,sha256=HN7-2BLTgwSn_4LYJQ7Ez9TpTKAZfiEGB62eMF_USaA,3084
21
- fxn/cli/predictors.py,sha256=t4DYwGTw_3z0dNDSqLmGmWeksPExGVPHyhHsxmVZk48,447
22
- fxn/compile/__init__.py,sha256=BXIcV2Ghp8YkfYZyhHuZjjV3BCLSuaoiTgu9YeFb-w0,130
23
- fxn/compile/compile.py,sha256=7PzPEjmHu_mQMoLcRrFOe1TJ11iZMgxcGt4rESqk6fg,3040
24
- fxn/compile/sandbox.py,sha256=3ewg1C3lST0KETxkx1qmqMuqcj6YJHBMdCVq4ne-2l8,5425
25
- fxn/compile/signature.py,sha256=QiB546g5p_MfiGt8hRi5BZ-_cmGaZm8vuYs47m2W-XM,6436
24
+ fxn/cli/predictions.py,sha256=ma7wbsKD5CFCRTU_TtJ8N0nN1fgFX2BZPGG8qm8HlNI,3182
25
+ fxn/cli/predictors.py,sha256=bVQAuBue_Jxb79X85RTCzOerWRRT2Ny1oF5DNYAsx4M,1545
26
26
  fxn/lib/__init__.py,sha256=-w1ikmmki5NMpzJjERW-O4SwOfBNkimej_0jL8ujYRk,71
27
27
  fxn/lib/linux/arm64/libFunction.so,sha256=NU9PEuQNObqtWPr5vXrWeQuYhzBfmX_Z4guaregFjrI,207632
28
28
  fxn/lib/linux/x86_64/libFunction.so,sha256=qZNlczayaaHIP_tJ9eeZ1TVpV1Os-ztvSWOoBuY9yWE,236272
@@ -31,17 +31,17 @@ fxn/lib/macos/x86_64/Function.dylib,sha256=qIu4dhx0Xk5dQHgTnZTcm2IpoMYJwRPmKRi9J
31
31
  fxn/lib/windows/arm64/Function.dll,sha256=FyL-oipK9wSxXdbD9frc8QFbUKTPMCdtmCkCT8ooIIM,419328
32
32
  fxn/lib/windows/x86_64/Function.dll,sha256=iL6w1FwDgBkHlNhQmhE7XgfoeHsiYQgpVGzeGDdHGUw,454656
33
33
  fxn/services/__init__.py,sha256=Bif8IttwJ089mSRsd3MFdob7z2eF-MKigKu4ZQFZBCQ,190
34
- fxn/services/prediction.py,sha256=IKudi7n3bm-RAW26dg188ItDgGnJcW4rcN48LD37CEg,10185
34
+ fxn/services/prediction.py,sha256=zWk1Y35m1a0849xVammTaxrkOfSucB2qEhG-r9cUgTQ,10243
35
35
  fxn/services/predictor.py,sha256=Wl_7YKiD5mTpC5x2Zaq4BpatRjwRUX8Th9GIrwd38MA,791
36
36
  fxn/services/user.py,sha256=ADl5MFLsk4K0altgKHnI-i64E3g1wU3e56Noq_ciRuk,685
37
37
  fxn/types/__init__.py,sha256=MEg71rzbGgoWfgB4Yi5QvxbnovHTZRIzCUZLtWtWP1E,292
38
38
  fxn/types/dtype.py,sha256=b0V91aknED2Ql0_BOG_vEg__YJVJByxsgB0cR4rtotE,617
39
39
  fxn/types/prediction.py,sha256=YVnRcqm6IPEx0796OuT3dqn_jOPjhWblzuM2lkk-Vzo,2173
40
- fxn/types/predictor.py,sha256=51hhb1rCYFt_r86pbOIVeV_tXYE6BVhcNP27_xmQG1Q,4006
40
+ fxn/types/predictor.py,sha256=KRGZEuDt7WPMCyRcZvQq4y2FMocfVrLEUNJCJgfDY9Y,4000
41
41
  fxn/types/user.py,sha256=Z44TwEocyxSrfKyzcNfmAXUrpX_Ry8fJ7MffSxRn4oU,1071
42
- fxn-0.0.42.dist-info/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
43
- fxn-0.0.42.dist-info/METADATA,sha256=RM5c7vePEj9cP3OSysoHclvklaV5-byVneeQnhVSKuU,16122
44
- fxn-0.0.42.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
45
- fxn-0.0.42.dist-info/entry_points.txt,sha256=O_AwD5dYaeB-YT1F9hPAPuDYCkw_W0tdNGYbc5RVR2k,45
46
- fxn-0.0.42.dist-info/top_level.txt,sha256=1ULIEGrnMlhId8nYAkjmRn9g3KEFuHKboq193SEKQkA,4
47
- fxn-0.0.42.dist-info/RECORD,,
42
+ fxn-0.0.43.dist-info/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
43
+ fxn-0.0.43.dist-info/METADATA,sha256=u2o6KhRYg6lkAeKgHVBnT1-x9pp-0uUK5xtma4CI8rg,16122
44
+ fxn-0.0.43.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
45
+ fxn-0.0.43.dist-info/entry_points.txt,sha256=O_AwD5dYaeB-YT1F9hPAPuDYCkw_W0tdNGYbc5RVR2k,45
46
+ fxn-0.0.43.dist-info/top_level.txt,sha256=1ULIEGrnMlhId8nYAkjmRn9g3KEFuHKboq193SEKQkA,4
47
+ fxn-0.0.43.dist-info/RECORD,,
fxn/compile/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- #
2
- # Function
3
- # Copyright © 2025 NatML Inc. All Rights Reserved.
4
- #
5
-
6
- from .compile import compile
7
- from .sandbox import Sandbox
fxn/compile/signature.py DELETED
@@ -1,183 +0,0 @@
1
- #
2
- # Function
3
- # Copyright © 2025 NatML Inc. All Rights Reserved.
4
- #
5
-
6
- from collections.abc import Mapping, Sequence
7
- from enum import Enum
8
- from inspect import isasyncgenfunction, iscoroutinefunction, isgeneratorfunction, signature
9
- from io import BytesIO
10
- import numpy as np
11
- from PIL import Image
12
- from pydantic import BaseModel, TypeAdapter
13
- from typing import get_type_hints, get_origin, get_args, Any, Dict, List, Union
14
-
15
- from ..types import Dtype, EnumerationMember, Parameter, Signature
16
-
17
- class FunctionType (str, Enum):
18
- Coroutine = "ASYNC_FUNCTION"
19
- Function = "FUNCTION"
20
- Generator = "GENERATOR"
21
- AsyncGenerator = "ASYNC_GENERATOR"
22
-
23
- def get_function_type (func) -> FunctionType:
24
- if isasyncgenfunction(func):
25
- return FunctionType.AsyncGenerator
26
- elif iscoroutinefunction(func):
27
- return FunctionType.Coroutine
28
- elif isgeneratorfunction(func):
29
- return FunctionType.Generator
30
- else:
31
- return FunctionType.Function
32
-
33
- def infer_function_signature (func) -> Signature:
34
- inputs = _get_input_parameters(func)
35
- outputs = _get_output_parameters(func)
36
- signature = Signature(inputs=inputs, outputs=outputs)
37
- return signature
38
-
39
- def _get_input_parameters (func) -> list[Parameter]:
40
- sig = signature(func)
41
- type_hints = get_type_hints(func)
42
- parameters = []
43
- for name, param in sig.parameters.items():
44
- param_type = type_hints.get(name)
45
- if param_type is None:
46
- raise TypeError(f"Missing type annotation for parameter '{name}' in function '{func.__name__}'")
47
- dtype = _infer_dtype(param_type)
48
- enumeration = [EnumerationMember(
49
- name=member.name,
50
- value=member.value
51
- ) for member in param_type] if _is_enum_subclass(param_type) else None
52
- value_schema = _get_type_schema(param_type) if dtype in { Dtype.list, Dtype.dict } else None
53
- input_param = Parameter(
54
- name=name,
55
- type=dtype,
56
- description=None,
57
- optional=param.default != param.empty,
58
- range=None,
59
- enumeration=enumeration,
60
- value_schema=value_schema
61
- )
62
- parameters.append(input_param)
63
- return parameters
64
-
65
- def _get_output_parameters (func) -> list[Parameter]:
66
- # Check for return annotation
67
- sig = signature(func)
68
- if sig.return_annotation is sig.empty:
69
- raise TypeError(f"Missing return type annotation for function '{func.__name__}'")
70
- # Gather return types
71
- return_types = []
72
- if _is_tuple_type(sig.return_annotation):
73
- return_types = get_args(sig.return_annotation)
74
- if not return_types or Ellipsis in return_types:
75
- raise TypeError(f"Return type of function '{func.__name__}' must be fully typed with generic type arguments.")
76
- else:
77
- return_types = [sig.return_annotation]
78
- # Create parameters
79
- parameters = [_get_output_parameter(f"output{idx}", output_type) for idx, output_type in enumerate(return_types)]
80
- return parameters
81
-
82
- def _get_output_parameter (name: str, return_type) -> Parameter:
83
- dtype = _infer_dtype(return_type)
84
- enumeration = [EnumerationMember(
85
- name=member.name,
86
- value=member.value
87
- ) for member in return_type] if _is_enum_subclass(return_type) else None
88
- value_schema = _get_type_schema(return_type) if dtype in { Dtype.list, Dtype.dict } else None
89
- parameter = Parameter(
90
- name=name,
91
- type=dtype,
92
- description=None,
93
- optional=False,
94
- range=None,
95
- enumeration=enumeration,
96
- value_schema=value_schema
97
- )
98
- return parameter
99
-
100
- def _infer_dtype (param_type) -> Dtype:
101
- param_type = _strip_optional(param_type)
102
- origin = get_origin(param_type)
103
- args = get_args(param_type)
104
- if origin is None:
105
- if param_type is np.ndarray:
106
- return Dtype.float32
107
- elif param_type is Image.Image:
108
- return Dtype.image
109
- elif param_type in { bytes, bytearray, memoryview, BytesIO }:
110
- return Dtype.binary
111
- elif param_type is int:
112
- return Dtype.int32
113
- elif param_type is float:
114
- return Dtype.float32
115
- elif param_type is bool:
116
- return Dtype.bool
117
- elif param_type is str:
118
- return Dtype.string
119
- elif _is_enum_subclass(param_type):
120
- return Dtype.string
121
- elif param_type is list:
122
- return Dtype.list
123
- elif param_type is dict:
124
- return Dtype.dict
125
- elif _is_pydantic_model(param_type):
126
- return Dtype.dict
127
- else:
128
- raise TypeError(f"Unsupported parameter type: {param_type}")
129
- else:
130
- if origin in { list, List, Sequence }:
131
- return Dtype.list
132
- elif origin in { dict, Dict, Mapping }:
133
- return Dtype.dict
134
- elif origin is np.ndarray:
135
- if args:
136
- dtype_arg = args[0]
137
- dtype = _numpy_to_fxn_dtype(dtype_arg)
138
- if dtype is not None:
139
- return dtype
140
- return Dtype.float32
141
- else:
142
- raise TypeError(f"Unsupported parameter type: {param_type}")
143
-
144
- def _is_enum_subclass (cls) -> bool:
145
- return isinstance(cls, type) and issubclass(cls, Enum)
146
-
147
- def _is_pydantic_model (cls) -> bool:
148
- return isinstance(cls, type) and issubclass(cls, BaseModel)
149
-
150
- def _is_tuple_type (param_type) -> bool:
151
- origin = get_origin(param_type)
152
- return origin is tuple
153
-
154
- def _strip_optional (param_type):
155
- if get_origin(param_type) is Union:
156
- args = get_args(param_type)
157
- non_none_args = [arg for arg in args if arg is not type(None)]
158
- if len(non_none_args) == 1:
159
- return non_none_args[0]
160
- return param_type
161
-
162
- def _numpy_to_fxn_dtype (dtype) -> Dtype | None:
163
- dtype_mapping = {
164
- np.int8: Dtype.int8,
165
- np.int16: Dtype.int16,
166
- np.int32: Dtype.int32,
167
- np.int64: Dtype.int64,
168
- np.uint8: Dtype.uint8,
169
- np.uint16: Dtype.uint16,
170
- np.uint32: Dtype.uint32,
171
- np.uint64: Dtype.uint64,
172
- np.float16: Dtype.float16,
173
- np.float32: Dtype.float32,
174
- np.float64: Dtype.float64,
175
- np.bool_: Dtype.bool,
176
- }
177
- return dtype_mapping.get(dtype, None)
178
-
179
- def _get_type_schema (param_type) -> dict[str, Any] | None:
180
- try:
181
- return TypeAdapter(param_type).json_schema(mode="serialization")
182
- except Exception:
183
- return None
File without changes
File without changes