fxn 0.0.41__tar.gz → 0.0.42__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fxn-0.0.41 → fxn-0.0.42}/PKG-INFO +2 -2
- fxn-0.0.42/fxn/__init__.py +10 -0
- fxn-0.0.42/fxn/beta/__init__.py +6 -0
- fxn-0.0.42/fxn/beta/client.py +16 -0
- fxn-0.0.42/fxn/beta/prediction.py +16 -0
- fxn-0.0.42/fxn/beta/remote.py +207 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/__init__.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/configuration.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/fxnc.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/map.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/prediction.py +2 -2
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/predictor.py +2 -3
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/stream.py +2 -3
- {fxn-0.0.41 → fxn-0.0.42}/fxn/c/value.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/cli/__init__.py +8 -10
- {fxn-0.0.41 → fxn-0.0.42}/fxn/cli/auth.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/cli/misc.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/cli/predictions.py +1 -1
- fxn-0.0.42/fxn/cli/predictors.py +18 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/client.py +23 -11
- fxn-0.0.42/fxn/compile/__init__.py +7 -0
- fxn-0.0.42/fxn/compile/compile.py +80 -0
- fxn-0.0.42/fxn/compile/sandbox.py +177 -0
- fxn-0.0.42/fxn/compile/signature.py +183 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/function.py +6 -2
- fxn-0.0.42/fxn/lib/__init__.py +4 -0
- fxn-0.0.42/fxn/lib/linux/arm64/libFunction.so +0 -0
- fxn-0.0.42/fxn/lib/linux/x86_64/libFunction.so +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/lib/macos/arm64/Function.dylib +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/lib/macos/x86_64/Function.dylib +0 -0
- fxn-0.0.42/fxn/lib/windows/arm64/Function.dll +0 -0
- fxn-0.0.42/fxn/lib/windows/x86_64/Function.dll +0 -0
- fxn-0.0.42/fxn/services/__init__.py +8 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/services/prediction.py +5 -4
- {fxn-0.0.41 → fxn-0.0.42}/fxn/services/predictor.py +6 -3
- {fxn-0.0.41 → fxn-0.0.42}/fxn/services/user.py +6 -3
- fxn-0.0.42/fxn/types/__init__.py +9 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn/types/dtype.py +1 -1
- {fxn-0.0.41 → fxn-0.0.42}/fxn/types/prediction.py +12 -2
- {fxn-0.0.41 → fxn-0.0.42}/fxn/types/predictor.py +2 -13
- {fxn-0.0.41 → fxn-0.0.42}/fxn/types/user.py +1 -1
- fxn-0.0.42/fxn/version.py +6 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn.egg-info/PKG-INFO +2 -2
- {fxn-0.0.41 → fxn-0.0.42}/fxn.egg-info/SOURCES.txt +8 -1
- {fxn-0.0.41 → fxn-0.0.42}/pyproject.toml +1 -1
- fxn-0.0.41/fxn/__init__.py +0 -8
- fxn-0.0.41/fxn/cli/env.py +0 -40
- fxn-0.0.41/fxn/cli/predictors.py +0 -66
- fxn-0.0.41/fxn/lib/__init__.py +0 -4
- fxn-0.0.41/fxn/lib/linux/arm64/libFunction.so +0 -0
- fxn-0.0.41/fxn/lib/linux/x86_64/libFunction.so +0 -0
- fxn-0.0.41/fxn/lib/windows/arm64/Function.dll +0 -0
- fxn-0.0.41/fxn/lib/windows/x86_64/Function.dll +0 -0
- fxn-0.0.41/fxn/services/__init__.py +0 -8
- fxn-0.0.41/fxn/types/__init__.py +0 -9
- fxn-0.0.41/fxn/version.py +0 -6
- {fxn-0.0.41 → fxn-0.0.42}/LICENSE +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/README.md +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn.egg-info/dependency_links.txt +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn.egg-info/entry_points.txt +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn.egg-info/requires.txt +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/fxn.egg-info/top_level.txt +0 -0
- {fxn-0.0.41 → fxn-0.0.42}/setup.cfg +0 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
#
|
2
|
+
# Function
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
|
+
#
|
5
|
+
|
6
|
+
from ..client import FunctionClient
|
7
|
+
from .prediction import PredictionService
|
8
|
+
|
9
|
+
class BetaClient:
|
10
|
+
"""
|
11
|
+
Client for incubating features.
|
12
|
+
"""
|
13
|
+
predictions: PredictionService
|
14
|
+
|
15
|
+
def __init__ (self, client: FunctionClient):
|
16
|
+
self.predictions = PredictionService(client)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
#
|
2
|
+
# Function
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
|
+
#
|
5
|
+
|
6
|
+
from ..client import FunctionClient
|
7
|
+
from .remote import RemotePredictionService
|
8
|
+
|
9
|
+
class PredictionService:
|
10
|
+
"""
|
11
|
+
Make predictions.
|
12
|
+
"""
|
13
|
+
remote: RemotePredictionService
|
14
|
+
|
15
|
+
def __init__ (self, client: FunctionClient):
|
16
|
+
self.remote = RemotePredictionService(client)
|
@@ -0,0 +1,207 @@
|
|
1
|
+
#
|
2
|
+
# Function
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
|
+
#
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
from base64 import b64encode
|
8
|
+
from dataclasses import asdict, is_dataclass
|
9
|
+
from enum import Enum
|
10
|
+
from io import BytesIO
|
11
|
+
from json import dumps, loads
|
12
|
+
from numpy import array, frombuffer, ndarray
|
13
|
+
from PIL import Image
|
14
|
+
from pydantic import BaseModel, Field
|
15
|
+
from requests import get, put
|
16
|
+
from typing import Any
|
17
|
+
from urllib.request import urlopen
|
18
|
+
|
19
|
+
from ..c import Configuration
|
20
|
+
from ..client import FunctionClient
|
21
|
+
from ..services import Value
|
22
|
+
from ..types import Dtype, Prediction
|
23
|
+
|
24
|
+
class RemoteAcceleration (str, Enum):
|
25
|
+
"""
|
26
|
+
Remote acceleration.
|
27
|
+
"""
|
28
|
+
Auto = "auto"
|
29
|
+
CPU = "cpu"
|
30
|
+
A40 = "a40"
|
31
|
+
A100 = "a100"
|
32
|
+
|
33
|
+
class RemotePredictionService:
|
34
|
+
"""
|
35
|
+
Make remote predictions.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__ (self, client: FunctionClient):
|
39
|
+
self.client = client
|
40
|
+
|
41
|
+
def create (
|
42
|
+
self,
|
43
|
+
tag: str,
|
44
|
+
*,
|
45
|
+
inputs: dict[str, Value],
|
46
|
+
acceleration: RemoteAcceleration=RemoteAcceleration.Auto
|
47
|
+
) -> Prediction:
|
48
|
+
"""
|
49
|
+
Create a remote prediction.
|
50
|
+
|
51
|
+
Parameters:
|
52
|
+
tag (str): Predictor tag.
|
53
|
+
inputs (dict): Input values.
|
54
|
+
acceleration (RemoteAcceleration): Prediction acceleration.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Prediction: Created prediction.
|
58
|
+
"""
|
59
|
+
input_map = { name: self.__to_value(value, name=name).model_dump(mode="json") for name, value in inputs.items() }
|
60
|
+
prediction = self.client.request(
|
61
|
+
method="POST",
|
62
|
+
path="/predictions/remote",
|
63
|
+
body={
|
64
|
+
"tag": tag,
|
65
|
+
"inputs": input_map,
|
66
|
+
"acceleration": acceleration,
|
67
|
+
"clientId": Configuration.get_client_id()
|
68
|
+
},
|
69
|
+
response_type=RemotePrediction
|
70
|
+
)
|
71
|
+
results = list(map(self.__to_object, prediction.results)) if prediction.results is not None else None
|
72
|
+
prediction = Prediction(**{ **prediction.model_dump(), "results": results })
|
73
|
+
return prediction
|
74
|
+
|
75
|
+
def __to_value (
|
76
|
+
self,
|
77
|
+
object: Value,
|
78
|
+
*,
|
79
|
+
name: str,
|
80
|
+
max_data_url_size: int=4 * 1024 * 1024
|
81
|
+
) -> RemoteValue:
|
82
|
+
object = self.__try_ensure_serializable(object)
|
83
|
+
if object is None:
|
84
|
+
return RemoteValue(data=None, type=Dtype.null)
|
85
|
+
elif isinstance(object, float):
|
86
|
+
object = array(object, dtype=Dtype.float32)
|
87
|
+
return self.__to_value(object, name=name, max_data_url_size=max_data_url_size)
|
88
|
+
elif isinstance(object, bool):
|
89
|
+
object = array(object, dtype=Dtype.bool)
|
90
|
+
return self.__to_value(object, name=name, max_data_url_size=max_data_url_size)
|
91
|
+
elif isinstance(object, int):
|
92
|
+
object = array(object, dtype=Dtype.int32)
|
93
|
+
return self.__to_value(object, name=name, max_data_url_size=max_data_url_size)
|
94
|
+
elif isinstance(object, ndarray):
|
95
|
+
buffer = BytesIO(object.tobytes())
|
96
|
+
data = self.__upload(buffer, name=name, max_data_url_size=max_data_url_size)
|
97
|
+
return RemoteValue(data=data, type=object.dtype.name, shape=list(object.shape))
|
98
|
+
elif isinstance(object, str):
|
99
|
+
buffer = BytesIO(object.encode())
|
100
|
+
data = self.__upload(buffer, name=name, mime="text/plain", max_data_url_size=max_data_url_size)
|
101
|
+
return RemoteValue(data=data, type=Dtype.string)
|
102
|
+
elif isinstance(object, list):
|
103
|
+
buffer = BytesIO(dumps(object).encode())
|
104
|
+
data = self.__upload(buffer, name=name, mime="application/json", max_data_url_size=max_data_url_size)
|
105
|
+
return RemoteValue(data=data, type=Dtype.list)
|
106
|
+
elif isinstance(object, dict):
|
107
|
+
buffer = BytesIO(dumps(object).encode())
|
108
|
+
data = self.__upload(buffer, name=name, mime="application/json", max_data_url_size=max_data_url_size)
|
109
|
+
return RemoteValue(data=data, type=Dtype.dict)
|
110
|
+
elif isinstance(object, Image.Image):
|
111
|
+
buffer = BytesIO()
|
112
|
+
format = "PNG" if object.mode == "RGBA" else "JPEG"
|
113
|
+
mime = f"image/{format.lower()}"
|
114
|
+
object.save(buffer, format=format)
|
115
|
+
data = self.__upload(buffer, name=name, mime=mime, max_data_url_size=max_data_url_size)
|
116
|
+
return RemoteValue(data=data, type=Dtype.image)
|
117
|
+
elif isinstance(object, BytesIO):
|
118
|
+
data = self.__upload(object, name=name, max_data_url_size=max_data_url_size)
|
119
|
+
return RemoteValue(data=data, type=Dtype.binary)
|
120
|
+
else:
|
121
|
+
raise ValueError(f"Failed to serialize value '{object}' of type `{type(object)}` because it is not supported")
|
122
|
+
|
123
|
+
def __to_object (self, value: RemoteValue) -> Value:
|
124
|
+
if value.type == Dtype.null:
|
125
|
+
return None
|
126
|
+
buffer = self.__download(value.data)
|
127
|
+
if value.type in [
|
128
|
+
Dtype.int8, Dtype.int16, Dtype.int32, Dtype.int64,
|
129
|
+
Dtype.uint8, Dtype.uint16, Dtype.uint32, Dtype.uint64,
|
130
|
+
Dtype.float16, Dtype.float32, Dtype.float64, Dtype.bool
|
131
|
+
]:
|
132
|
+
assert value.shape is not None, "Array value must have a shape specified"
|
133
|
+
array = frombuffer(buffer.getbuffer(), dtype=value.type).reshape(value.shape)
|
134
|
+
return array if len(value.shape) > 0 else array.item()
|
135
|
+
elif value.type == Dtype.string:
|
136
|
+
return buffer.getvalue().decode("utf-8")
|
137
|
+
elif value.type in [Dtype.list, Dtype.dict]:
|
138
|
+
return loads(buffer.getvalue().decode("utf-8"))
|
139
|
+
elif value.type == Dtype.image:
|
140
|
+
return Image.open(buffer)
|
141
|
+
elif value.type == Dtype.binary:
|
142
|
+
return buffer
|
143
|
+
else:
|
144
|
+
raise ValueError(f"Failed to deserialize value with type `{value.type}` because it is not supported")
|
145
|
+
|
146
|
+
def __upload (
|
147
|
+
self,
|
148
|
+
data: BytesIO,
|
149
|
+
*,
|
150
|
+
name: str,
|
151
|
+
mime: str="application/octet-stream",
|
152
|
+
max_data_url_size: int=4 * 1024 * 1024
|
153
|
+
) -> str:
|
154
|
+
if data.getbuffer().nbytes <= max_data_url_size:
|
155
|
+
encoded_data = b64encode(data.getvalue()).decode("ascii")
|
156
|
+
return f"data:{mime};base64,{encoded_data}"
|
157
|
+
value = self.client.request(
|
158
|
+
method="POST",
|
159
|
+
path="/values",
|
160
|
+
body={ "name": name },
|
161
|
+
response_type=CreateValueResponse
|
162
|
+
)
|
163
|
+
put(
|
164
|
+
value.upload_url,
|
165
|
+
data=data,
|
166
|
+
headers={ "Content-Type": mime }
|
167
|
+
).raise_for_status()
|
168
|
+
return value.download_url
|
169
|
+
|
170
|
+
def __download (self, url: str) -> BytesIO:
|
171
|
+
if url.startswith("data:"):
|
172
|
+
with urlopen(url) as response:
|
173
|
+
return BytesIO(response.read())
|
174
|
+
response = get(url)
|
175
|
+
response.raise_for_status()
|
176
|
+
result = BytesIO(response.content)
|
177
|
+
return result
|
178
|
+
|
179
|
+
@classmethod
|
180
|
+
def __try_ensure_serializable (cls, object: Any) -> Any:
|
181
|
+
if object is None:
|
182
|
+
return object
|
183
|
+
if isinstance(object, list):
|
184
|
+
return [cls.__try_ensure_serializable(x) for x in object]
|
185
|
+
if is_dataclass(object) and not isinstance(object, type):
|
186
|
+
return asdict(object)
|
187
|
+
if isinstance(object, BaseModel):
|
188
|
+
return object.model_dump(mode="json", by_alias=True)
|
189
|
+
return object
|
190
|
+
|
191
|
+
class RemoteValue (BaseModel):
|
192
|
+
data: str | None
|
193
|
+
type: Dtype
|
194
|
+
shape: list[int] | None = None
|
195
|
+
|
196
|
+
class RemotePrediction (BaseModel):
|
197
|
+
id: str
|
198
|
+
tag: str
|
199
|
+
created: str
|
200
|
+
results: list[RemoteValue] | None
|
201
|
+
latency: float | None
|
202
|
+
error: str | None
|
203
|
+
logs: str | None
|
204
|
+
|
205
|
+
class CreateValueResponse (BaseModel):
|
206
|
+
upload_url: str = Field(validation_alias="uploadUrl")
|
207
|
+
download_url: str = Field(validation_alias="downloadUrl")
|
@@ -1,9 +1,9 @@
|
|
1
1
|
#
|
2
2
|
# Function
|
3
|
-
# Copyright ©
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
4
|
#
|
5
5
|
|
6
|
-
from ctypes import byref, c_double,
|
6
|
+
from ctypes import byref, c_double, c_int32, c_void_p, create_string_buffer
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import final
|
9
9
|
|
@@ -1,10 +1,9 @@
|
|
1
1
|
#
|
2
2
|
# Function
|
3
|
-
# Copyright ©
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
4
|
#
|
5
5
|
|
6
|
-
from ctypes import byref,
|
7
|
-
from pathlib import Path
|
6
|
+
from ctypes import byref, c_void_p
|
8
7
|
from typing import final
|
9
8
|
|
10
9
|
from .configuration import Configuration
|
@@ -1,10 +1,9 @@
|
|
1
1
|
#
|
2
2
|
# Function
|
3
|
-
# Copyright ©
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
4
|
#
|
5
5
|
|
6
|
-
from ctypes import byref,
|
7
|
-
from pathlib import Path
|
6
|
+
from ctypes import byref, c_void_p
|
8
7
|
from typing import final
|
9
8
|
|
10
9
|
from .fxnc import get_fxnc, status_to_error, FXNStatus
|
@@ -1,15 +1,15 @@
|
|
1
1
|
#
|
2
2
|
# Function
|
3
|
-
# Copyright ©
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
4
|
#
|
5
5
|
|
6
6
|
from typer import Typer
|
7
7
|
|
8
8
|
from .auth import app as auth_app
|
9
|
-
from .
|
9
|
+
#from .compile import compile_predictor
|
10
10
|
from .misc import cli_options
|
11
11
|
from .predictions import create_prediction
|
12
|
-
from .predictors import
|
12
|
+
from .predictors import retrieve_predictor
|
13
13
|
from ..version import __version__
|
14
14
|
|
15
15
|
# Define CLI
|
@@ -26,20 +26,18 @@ app.callback()(cli_options)
|
|
26
26
|
|
27
27
|
# Add subcommands
|
28
28
|
app.add_typer(auth_app, name="auth", help="Login, logout, and check your authentication status.")
|
29
|
-
#app.add_typer(env_app, name="env", help="Manage predictor environment variables.")
|
30
29
|
|
31
30
|
# Add top-level commands
|
32
|
-
#app.command(name="create", help="Create a predictor.")(create_predictor)
|
33
|
-
#app.command(name="delete", help="Delete a predictor.")(delete_predictor)
|
34
31
|
app.command(
|
35
32
|
name="predict",
|
36
33
|
help="Make a prediction.",
|
37
34
|
context_settings={ "allow_extra_args": True, "ignore_unknown_options": True }
|
38
35
|
)(create_prediction)
|
39
|
-
#app.command(
|
40
|
-
#
|
41
|
-
#
|
42
|
-
#
|
36
|
+
# app.command(
|
37
|
+
# name="compile",
|
38
|
+
# help="Create a predictor by compiling a Python function."
|
39
|
+
# )(compile_predictor)
|
40
|
+
app.command(name="retrieve", help="Retrieve a predictor.")(retrieve_predictor)
|
43
41
|
|
44
42
|
# Run
|
45
43
|
if __name__ == "__main__":
|
@@ -0,0 +1,18 @@
|
|
1
|
+
#
|
2
|
+
# Function
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
|
+
#
|
5
|
+
|
6
|
+
from rich import print_json
|
7
|
+
from typer import Argument
|
8
|
+
|
9
|
+
from ..function import Function
|
10
|
+
from .auth import get_access_key
|
11
|
+
|
12
|
+
def retrieve_predictor (
|
13
|
+
tag: str=Argument(..., help="Predictor tag.")
|
14
|
+
):
|
15
|
+
fxn = Function(get_access_key())
|
16
|
+
predictor = fxn.predictors.retrieve(tag)
|
17
|
+
predictor = predictor.model_dump() if predictor else None
|
18
|
+
print_json(data=predictor)
|
@@ -1,10 +1,14 @@
|
|
1
1
|
#
|
2
2
|
# Function
|
3
|
-
# Copyright ©
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
4
|
#
|
5
5
|
|
6
|
+
from json import loads, JSONDecodeError
|
7
|
+
from pydantic import BaseModel
|
6
8
|
from requests import request
|
7
|
-
from typing import Any, Literal
|
9
|
+
from typing import Any, Literal, Type, TypeVar
|
10
|
+
|
11
|
+
T = TypeVar("T", bound=BaseModel)
|
8
12
|
|
9
13
|
class FunctionClient:
|
10
14
|
|
@@ -17,23 +21,25 @@ class FunctionClient:
|
|
17
21
|
*,
|
18
22
|
method: Literal["GET", "POST", "DELETE"],
|
19
23
|
path: str,
|
20
|
-
body: dict[str, Any]=None
|
21
|
-
|
24
|
+
body: dict[str, Any]=None,
|
25
|
+
response_type: Type[T]=None
|
26
|
+
) -> T:
|
22
27
|
response = request(
|
23
28
|
method=method,
|
24
29
|
url=f"{self.api_url}{path}",
|
25
30
|
json=body,
|
26
31
|
headers={ "Authorization": f"Bearer {self.access_key}" }
|
27
32
|
)
|
28
|
-
data =
|
33
|
+
data = response.text
|
29
34
|
try:
|
30
35
|
data = response.json()
|
31
|
-
except
|
32
|
-
|
33
|
-
if
|
34
|
-
|
36
|
+
except JSONDecodeError:
|
37
|
+
pass
|
38
|
+
if response.ok:
|
39
|
+
return response_type(**data) if response_type is not None else None
|
40
|
+
else:
|
41
|
+
error = _ErrorResponse(**data).errors[0].message if isinstance(data, dict) else data
|
35
42
|
raise FunctionAPIError(error, response.status_code)
|
36
|
-
return data
|
37
43
|
|
38
44
|
class FunctionAPIError (Exception):
|
39
45
|
|
@@ -43,4 +49,10 @@ class FunctionAPIError (Exception):
|
|
43
49
|
self.status_code = status_code
|
44
50
|
|
45
51
|
def __str__(self):
|
46
|
-
return f"FunctionAPIError: {self.message} (Status Code: {self.status_code})"
|
52
|
+
return f"FunctionAPIError: {self.message} (Status Code: {self.status_code})"
|
53
|
+
|
54
|
+
class _APIError (BaseModel):
|
55
|
+
message: str
|
56
|
+
|
57
|
+
class _ErrorResponse (BaseModel):
|
58
|
+
errors: list[_APIError]
|
@@ -0,0 +1,80 @@
|
|
1
|
+
#
|
2
|
+
# Function
|
3
|
+
# Copyright © 2025 NatML Inc. All Rights Reserved.
|
4
|
+
#
|
5
|
+
|
6
|
+
from collections.abc import Callable
|
7
|
+
from functools import wraps
|
8
|
+
from pathlib import Path
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
|
11
|
+
from ..types import AccessMode, Signature
|
12
|
+
from .sandbox import Sandbox
|
13
|
+
from .signature import get_function_type, infer_function_signature, FunctionType
|
14
|
+
|
15
|
+
class PredictorSpec (BaseModel):
|
16
|
+
"""
|
17
|
+
Descriptor of a predictor to be compiled.
|
18
|
+
"""
|
19
|
+
tag: str = Field(description="Predictor tag.")
|
20
|
+
description: str = Field(description="Predictor description. MUST be less than 100 characters long.", min_length=4, max_length=100)
|
21
|
+
sandbox: Sandbox = Field(description="Sandbox to compile the function.")
|
22
|
+
access: AccessMode = Field(description="Predictor access.")
|
23
|
+
signature: Signature = Field(description="Predictor signature.")
|
24
|
+
card: str | None = Field(default=None, description="Predictor card (markdown).")
|
25
|
+
media: str | None = Field(default=None, description="Predictor media URL.")
|
26
|
+
license: str | None = Field(default=None, description="Predictor license URL. This is required for public predictors.")
|
27
|
+
|
28
|
+
def compile (
|
29
|
+
tag: str,
|
30
|
+
*,
|
31
|
+
description: str,
|
32
|
+
sandbox: Sandbox=None,
|
33
|
+
access: AccessMode=AccessMode.Private,
|
34
|
+
card: str | Path=None,
|
35
|
+
media: Path=None,
|
36
|
+
license: str=None,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Create a predictor by compiling a stateless function.
|
40
|
+
|
41
|
+
Parameters:
|
42
|
+
tag (str): Predictor tag.
|
43
|
+
description (str): Predictor description. MUST be less than 100 characters long.
|
44
|
+
sandbox (Sandbox): Sandbox to compile the function.
|
45
|
+
access (AccessMode): Predictor access.
|
46
|
+
card (str | Path): Predictor card markdown string or path to card.
|
47
|
+
media (Path): Predictor thumbnail image (jpeg or png) path.
|
48
|
+
license (str): Predictor license URL. This is required for public predictors.
|
49
|
+
"""
|
50
|
+
def decorator (func: Callable):
|
51
|
+
# Check type
|
52
|
+
if not callable(func):
|
53
|
+
raise TypeError("Cannot compile non-function objects")
|
54
|
+
func_type = get_function_type(func)
|
55
|
+
if func_type not in { FunctionType.Function, FunctionType.Generator }:
|
56
|
+
raise TypeError(f"Function '{func.__name__}' must be a regular function or generator")
|
57
|
+
# Gather metadata
|
58
|
+
signature = infer_function_signature(func) # throws
|
59
|
+
if isinstance(card, Path):
|
60
|
+
with open(card_content, "r") as f:
|
61
|
+
card_content = f.read()
|
62
|
+
else:
|
63
|
+
card_content = card
|
64
|
+
spec = PredictorSpec(
|
65
|
+
tag=tag,
|
66
|
+
description=description,
|
67
|
+
sandbox=sandbox if sandbox is not None else Sandbox(),
|
68
|
+
access=access,
|
69
|
+
signature=signature,
|
70
|
+
card=card_content,
|
71
|
+
media=None, # INCOMPLETE
|
72
|
+
license=license
|
73
|
+
)
|
74
|
+
# Wrap
|
75
|
+
@wraps(func)
|
76
|
+
def wrapper (*args, **kwargs):
|
77
|
+
return func(*args, **kwargs)
|
78
|
+
wrapper.__predictor_spec = spec
|
79
|
+
return wrapper
|
80
|
+
return decorator
|