fnnx 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fnnx/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1"
fnnx/device.py ADDED
@@ -0,0 +1,14 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class DeviceConfig:
6
+ accelerator: str
7
+ device_config: dict | None
8
+
9
+
10
+ @dataclass
11
+ class DeviceMap:
12
+ accelerator: str
13
+ node_device_map: dict[str, dict]
14
+ variant_device_config: dict | str | None = None
fnnx/dtypes.py ADDED
@@ -0,0 +1,191 @@
1
+ from copy import deepcopy
2
+ import jsonschema
3
+ from copy import copy, deepcopy
4
+
5
+
6
+ class FlatList:
7
+ def __init__(self, data: list):
8
+ if not isinstance(data, list):
9
+ raise ValueError("FlatList only accepts lists")
10
+ self.data = copy(data)
11
+
12
+ def __getitem__(self, index):
13
+ return self.data[index]
14
+
15
+ def __setitem__(self, index, value):
16
+ self.data[index] = value
17
+
18
+ def __repr__(self):
19
+ return f"FlatList({self.data})"
20
+
21
+ def append(self, value):
22
+ self.data.append(value)
23
+
24
+
25
+ class DtypesManager:
26
+
27
+ def __init__(self, external_dtypes: dict, builtins: dict):
28
+ self.dtypes = deepcopy(external_dtypes)
29
+ self.dtypes.update(deepcopy(builtins))
30
+ for dtype in self.dtypes:
31
+ if "[" in dtype:
32
+ raise ValueError(f"Invalid dtype name: {dtype}")
33
+ for reserved_type in [
34
+ "string",
35
+ "integer",
36
+ "float",
37
+ "Array",
38
+ "NDContainer",
39
+ "FlatList",
40
+ ]:
41
+ if reserved_type in self.dtypes:
42
+ raise ValueError(f"Invalid dtype name: {reserved_type}")
43
+
44
+ def get_dtype(self, dtype_name: str):
45
+ if dtype_name not in self.dtypes:
46
+ raise ValueError(f"Unknown dtype: {dtype_name}")
47
+ return self.dtypes[dtype_name]
48
+
49
+ def _validate_dtype(self, dtype_name: str, data: dict):
50
+ schema = self.get_dtype(dtype_name)
51
+ jsonschema.validate(data, schema)
52
+
53
+ def validate_dtype(self, dtype_name: str, data):
54
+ if isinstance(data, list):
55
+ for d in data:
56
+ self.validate_dtype(dtype_name, d)
57
+ elif isinstance(data, FlatList):
58
+ self.validate_dtype(dtype_name, data.data)
59
+ elif isinstance(data, dict):
60
+ self._validate_dtype(dtype_name, data)
61
+ elif isinstance(data, str):
62
+ if dtype_name != "string":
63
+ raise TypeError(
64
+ f"Invalid data type, expected `string`, got `{dtype_name}`"
65
+ )
66
+ elif isinstance(data, int):
67
+ if dtype_name != "integer":
68
+ raise TypeError(
69
+ f"Invalid data type, expected `integer`, got `{dtype_name}`"
70
+ )
71
+ elif isinstance(data, float):
72
+ if dtype_name != "float":
73
+ raise TypeError(
74
+ f"Invalid data type, expected `float`, got `{dtype_name}`"
75
+ )
76
+ else:
77
+ raise TypeError(f"Invalid data type: {type(data)}")
78
+
79
+
80
+ class NDContainer:
81
+ def __init__(self, data, dtype, dtypes_manager: DtypesManager):
82
+
83
+ if dtype.startswith("Array["):
84
+ raise ValueError("NDContainer does not support Array dtype")
85
+ elif dtype.startswith("NDContainer["):
86
+ dtype = dtype[12:-1]
87
+
88
+ self.data = deepcopy(data if isinstance(data, list) else [data])
89
+
90
+ if dtypes_manager:
91
+ dtypes_manager.validate_dtype(dtype, self.data)
92
+ self.dtypes_manager = dtypes_manager
93
+ self._dtype = dtype
94
+ if "FlatList[" in self._dtype:
95
+ self.data = self._inner_to_flatlist(self.data)
96
+ self.shape = tuple(self._compute_shape(self.data))
97
+
98
+ def _inner_to_flatlist(self, nested_list, root=True):
99
+ if root:
100
+ # check base case
101
+ if all(not isinstance(item, (list, FlatList)) for item in nested_list):
102
+ return [FlatList(nested_list)]
103
+
104
+ for i, item in enumerate(nested_list):
105
+ if isinstance(item, list):
106
+ if all(not isinstance(sub_item, (list, FlatList)) for sub_item in item):
107
+ nested_list[i] = FlatList(item)
108
+ else:
109
+ self._inner_to_flatlist(item, root=False)
110
+ return nested_list
111
+
112
+ def _compute_shape(self, data):
113
+ if not isinstance(data, list) or not data:
114
+ return []
115
+ sub_shape = self._compute_shape(data[0])
116
+ return [len(data)] + sub_shape
117
+
118
+ def __getitem__(self, index):
119
+ if isinstance(index, tuple):
120
+ result = self.data
121
+ for idx in index:
122
+ result = result[idx]
123
+ return result
124
+ return self.data[index]
125
+
126
+ def reshape(self, *new_shape):
127
+ if isinstance(new_shape[0], tuple) or isinstance(new_shape[0], list):
128
+ new_shape = new_shape[0]
129
+ # Check if the total number of elements matches
130
+ if self._product(new_shape) != self._product(self.shape):
131
+ raise ValueError(
132
+ "Cannot reshape array of size {} into shape {}".format(
133
+ self._product(self.shape), new_shape
134
+ )
135
+ )
136
+ flat_list = self.flatten(self.data)
137
+ if self._product(new_shape) != self._product(self.shape):
138
+ raise ValueError(
139
+ "Cannot reshape array of size {} into shape {}".format(
140
+ self._product(self.shape), new_shape
141
+ )
142
+ )
143
+ reshaped = self._reshape_helper(flat_list, list(new_shape))
144
+ return NDContainer(
145
+ data=reshaped, dtype=self.dtype, dtypes_manager=self.dtypes_manager
146
+ )
147
+
148
+ def flatten(self, data: list | None = None):
149
+ return NDContainer(
150
+ data=self._flatten(data or self.data),
151
+ dtype=self.dtype,
152
+ dtypes_manager=self.dtypes_manager,
153
+ )
154
+
155
+ def _flatten(self, data) -> list:
156
+ result = []
157
+ for item in data:
158
+ if isinstance(item, list):
159
+ result.extend(self._flatten(item))
160
+ else:
161
+ result.append(item)
162
+ return result
163
+
164
+ def _reshape_helper(self, flat_list, shape):
165
+ if len(shape) == 1:
166
+ return flat_list[: shape[0]]
167
+ step = self._product(shape[1:])
168
+ return [
169
+ self._reshape_helper(flat_list[i * step : (i + 1) * step], shape[1:])
170
+ for i in range(shape[0])
171
+ ]
172
+
173
+ def _product(self, shape):
174
+ product = 1
175
+ for dim in shape:
176
+ product *= dim
177
+ return product
178
+
179
+ @property
180
+ def dtype(self):
181
+ return self._dtype
182
+
183
+ @dtype.setter
184
+ def dtype(self, value):
185
+ raise AttributeError("Cannot modify immutable attribute dtype")
186
+
187
+ def __repr__(self) -> str:
188
+ return f"NDContainer(shape={self.shape}, dtype={self._dtype}, data={self.data})"
189
+
190
+
191
+ BUILTINS = {}
fnnx/handlers/_base.py ADDED
@@ -0,0 +1,28 @@
1
+ from abc import ABC, abstractmethod
2
+ from fnnx.device import DeviceMap
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class BaseHandlerConfig:
8
+ pass
9
+
10
+
11
+ class BaseHandler(ABC):
12
+ @abstractmethod
13
+ def __init__(
14
+ self,
15
+ model_path: str,
16
+ device_map: DeviceMap,
17
+ handler_config: BaseHandlerConfig | None = None,
18
+ **kwargs,
19
+ ):
20
+ pass
21
+
22
+ @abstractmethod
23
+ def compute(self, inputs: dict, dynamic_attributes: dict) -> dict:
24
+ pass
25
+
26
+ @abstractmethod
27
+ async def compute_async(self, inputs: dict, dynamic_attributes: dict) -> dict:
28
+ pass
@@ -0,0 +1,12 @@
1
+ import os
2
+ import tarfile
3
+ import tempfile
4
+
5
+
6
+ def unpack_model(model_path: str) -> tuple[str, bool]:
7
+ if os.path.isdir(model_path):
8
+ return model_path, False
9
+ with tarfile.open(model_path, "r") as tar:
10
+ tmp_dir = tempfile.mkdtemp(prefix="fnnx_") #
11
+ tar.extractall(tmp_dir, filter="data")
12
+ return tmp_dir, True
fnnx/handlers/local.py ADDED
@@ -0,0 +1,167 @@
1
+ try:
2
+ import numpy as np
3
+ except ImportError:
4
+ np = None
5
+ from os.path import join as pjoin
6
+ from shutil import rmtree
7
+ import atexit
8
+ import json
9
+ from dataclasses import dataclass
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from typing import Type
12
+ from fnnx.device import DeviceMap
13
+ from fnnx.dtypes import DtypesManager, BUILTINS, NDContainer
14
+ from fnnx.variants.pipeline import Pipeline
15
+ from fnnx.handlers._base import BaseHandler, BaseHandlerConfig
16
+ from fnnx.handlers._common import unpack_model
17
+ from fnnx.registry import Registry
18
+ from fnnx.variants.pyfunc import PyFuncVariant
19
+ from fnnx.validators.model_schema import (
20
+ validate_manifest,
21
+ validate_op_instances,
22
+ validate_variant,
23
+ )
24
+ from fnnx.ops._base import BaseOp
25
+ from fnnx.variants._base import BaseVariant
26
+
27
+
28
+ @dataclass
29
+ class LocalHandlerConfig(BaseHandlerConfig):
30
+ n_workers: int = 1
31
+ n_workers_node: int = 1
32
+ auto_cleanup: bool = True
33
+ extra_ops: dict[str, Type[BaseOp]] | None = None
34
+
35
+
36
+ class LocalHandler(BaseHandler):
37
+
38
+ def __init__(
39
+ self,
40
+ model_path: str,
41
+ device_map: DeviceMap,
42
+ handler_config: LocalHandlerConfig | None = None,
43
+ **kwargs,
44
+ ):
45
+ if handler_config is None:
46
+ handler_config = LocalHandlerConfig()
47
+
48
+ if not isinstance(device_map, DeviceMap):
49
+ raise ValueError("device_map must be an instance of DeviceMap")
50
+
51
+ model_path, cleanup = unpack_model(model_path)
52
+
53
+ self.model_path = model_path
54
+ self.cleanup = handler_config.auto_cleanup and cleanup
55
+
56
+ # should this be done on exit or on delete?
57
+ if self.cleanup:
58
+ # passing model_path and not self.model_path to avoid reference on self
59
+ atexit.register(lambda: _cleanup(model_path))
60
+
61
+ with open(pjoin(self.model_path, "manifest.json"), "r") as f:
62
+ self.manifest = json.load(f)
63
+ validate_manifest(self.manifest)
64
+
65
+ self.input_specs = {spec["name"]: spec for spec in self.manifest["inputs"]}
66
+ self.output_specs = {
67
+ spec["name"]: spec for spec in self.manifest["outputs"]
68
+ }
69
+
70
+ with open(pjoin(self.model_path, "ops.json"), "r") as f:
71
+ self.ops = json.load(f)
72
+ validate_op_instances(self.ops)
73
+
74
+ with open(pjoin(self.model_path, "variant_config.json"), "r") as f:
75
+ self.variant_config = json.load(f)
76
+ validate_variant(self.manifest["variant"], self.variant_config)
77
+
78
+ with open(pjoin(self.model_path, "dtypes.json"), "r") as f:
79
+ external_dtypes = json.load(f)
80
+ self.dtypes_manager = DtypesManager(external_dtypes, BUILTINS)
81
+
82
+ variant = self.manifest.get("variant")
83
+
84
+ registry = Registry()
85
+ registry.register_default_ops()
86
+ if handler_config.extra_ops:
87
+ for op_name, op in handler_config.extra_ops.items():
88
+ registry.register_op(op, op_name)
89
+
90
+ if variant == "pipeline":
91
+ vcls = Pipeline
92
+ elif variant == "pyfunc":
93
+ vcls = PyFuncVariant
94
+ else:
95
+ raise ValueError(f"Unknown variant: {variant}")
96
+
97
+ self.executor = ThreadPoolExecutor(max_workers=handler_config.n_workers)
98
+ self.op_executor = ThreadPoolExecutor(max_workers=handler_config.n_workers_node)
99
+ self.vrt: BaseVariant = vcls(
100
+ self.model_path,
101
+ self.ops,
102
+ self.variant_config,
103
+ registry=registry,
104
+ device_map=device_map,
105
+ dtypes_manager=self.dtypes_manager,
106
+ executor=self.executor,
107
+ op_executor=self.op_executor,
108
+ ).warmup()
109
+
110
+ def _prepare_inputs(self, inputs):
111
+ prepared_inputs = {}
112
+ for name, input in inputs.items():
113
+ spec = self.input_specs[name]
114
+ if spec["content_type"] == "NDJSON":
115
+ if "NDContainer[" in spec["dtype"]:
116
+ if not isinstance(input, NDContainer):
117
+ prepared_inputs[name] = NDContainer(
118
+ input,
119
+ dtype=spec["dtype"],
120
+ dtypes_manager=self.dtypes_manager,
121
+ )
122
+ else:
123
+ prepared_inputs[name] = input
124
+ elif "Array[" in spec["dtype"]:
125
+ if np is None:
126
+ raise RuntimeError(
127
+ "You must have numpy installed to use Array dtype"
128
+ )
129
+ dtype = spec["dtype"][6:-1]
130
+ if dtype == "string":
131
+ prepared_inputs[name] = np.asarray(input).astype(np.str_)
132
+ else:
133
+ prepared_inputs[name] = np.asarray(input).astype(dtype)
134
+ else:
135
+ raise ValueError(f"Unknown dtype {spec['dtype']}")
136
+ else:
137
+ raise ValueError(f"Unknown input type {spec['content_type']}")
138
+ return prepared_inputs
139
+
140
+ def _prepare_outputs(self, outputs: dict) -> dict:
141
+ return {k: outputs[k] for k in self.output_specs.keys()}
142
+
143
+ def compute(self, inputs: dict, dynamic_attributes: dict) -> dict:
144
+ res = self.vrt.compute(
145
+ self._prepare_inputs(inputs),
146
+ dynamic_attributes=dynamic_attributes,
147
+ )
148
+ return self._prepare_outputs(res)
149
+
150
+ async def compute_async(self, inputs: dict, dynamic_attributes: dict) -> dict:
151
+ res = await self.vrt.compute_async(
152
+ self._prepare_inputs(inputs),
153
+ dynamic_attributes=dynamic_attributes,
154
+ )
155
+ return self._prepare_outputs(res)
156
+
157
+ def __del__(self):
158
+ try:
159
+ self.executor.shutdown()
160
+ self.op_executor.shutdown()
161
+ except Exception:
162
+ pass
163
+
164
+
165
+ def _cleanup(model_path):
166
+ # print("Cleaning up temporary model files at", model_path)
167
+ rmtree(model_path)
fnnx/node_instance.py ADDED
@@ -0,0 +1,15 @@
1
+ from dataclasses import dataclass
2
+ from fnnx.ops._base import BaseOp
3
+ from typing import TypedDict
4
+
5
+
6
+ class IO(TypedDict):
7
+ dtype: str
8
+ shape: list[int | str]
9
+
10
+
11
+ @dataclass
12
+ class OpInstance:
13
+ operator: BaseOp
14
+ input_specs: IO
15
+ output_specs: IO
fnnx/ops/_base.py ADDED
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+ from fnnx.device import DeviceConfig
6
+ from fnnx.dtypes import DtypesManager
7
+ from concurrent.futures._base import Executor
8
+
9
+
10
+ class BaseOp(ABC):
11
+ supported_dynamic_attributes: list[str] = []
12
+ required_dynamic_attributes: list[str] = []
13
+
14
+ def __init__(
15
+ self,
16
+ artifact_path: str,
17
+ *args,
18
+ attributes: dict,
19
+ dynamic_attribute_map: dict,
20
+ device_config: DeviceConfig,
21
+ input_specs,
22
+ output_specs,
23
+ dtypes_manager: DtypesManager,
24
+ executor: Executor,
25
+ **kwargs,
26
+ ):
27
+ self.dynamic_attribute_map = dynamic_attribute_map
28
+ self._warmed_up = False
29
+ self.artifact_path = artifact_path
30
+ self._device_config: DeviceConfig = device_config
31
+ self.attributes = attributes
32
+ self.input_specs = input_specs
33
+ self.output_specs = output_specs
34
+ self.dtypes_manager = dtypes_manager
35
+ self.executor = executor
36
+
37
+ @abstractmethod
38
+ def warmup(self, *args, **kwargs) -> BaseOp:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def compute(self, inputs: list, dynamic_attributes: dict, **kwargs):
43
+ pass
44
+
45
+ @abstractmethod
46
+ async def compute_async(self, inputs: list, dynamic_attributes: dict, **kwargs):
47
+ pass
48
+
49
+ def extract_dynamic_attribute(self, dynamic_attributes: dict):
50
+
51
+ extracted = {}
52
+ for key, value in self.dynamic_attribute_map.items():
53
+ source_name = value.get("name")
54
+ default_value = value.get("default_value")
55
+ source_value = dynamic_attributes.get(source_name, None)
56
+ target_value = source_value or default_value
57
+ extracted[key] = target_value
58
+ return extracted
59
+
60
+ def verify_required_dynamic_attributes(self, dynamic_attributes_map: dict):
61
+ for key in self.required_dynamic_attributes:
62
+ if key not in dynamic_attributes_map:
63
+ raise ValueError(f"Missing required dynamic attribute: {key}")
64
+
65
+
66
+ @dataclass
67
+ class OpOutput:
68
+ value: list[Any]
69
+ metadata: dict | None = None
fnnx/ops/onnx.py ADDED
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+ from fnnx.ops._base import BaseOp, OpOutput
3
+ from os.path import join as pjoin
4
+
5
+ try:
6
+ import onnxruntime as ort
7
+ except ImportError:
8
+ ort = None
9
+
10
+ try:
11
+ from onnxruntime_extensions import get_library_path as _get_extensions_library_path # type: ignore
12
+ except ImportError:
13
+ _get_extensions_library_path = None
14
+
15
+ from fnnx.utils import to_thread
16
+ from fnnx.device import DeviceConfig
17
+ from fnnx.dtypes import DtypesManager
18
+ from concurrent.futures._base import Executor
19
+
20
+ CPU_EXECUTION_PROVIDER = "CPUExecutionProvider"
21
+ CUDA_EXECUTION_PROVIDER = "CUDAExecutionProvider"
22
+
23
+
24
+ class OnnxOp_V1(BaseOp):
25
+
26
+ def warmup(
27
+ self,
28
+ ) -> OnnxOp_V1:
29
+ self.model_path = pjoin(self.artifact_path, "model.onnx")
30
+ if not ort:
31
+ raise ImportError("onnxruntime is not installed")
32
+ if self._device_config.accelerator == "cuda":
33
+ execution_providers = [CUDA_EXECUTION_PROVIDER, CPU_EXECUTION_PROVIDER]
34
+ else:
35
+ execution_providers = [CPU_EXECUTION_PROVIDER]
36
+ session_options = ort.SessionOptions()
37
+ if self.attributes.get("use_onnxruntime_extensions", False):
38
+ if not _get_extensions_library_path:
39
+ raise ImportError("onnxruntime_extensions is not installed")
40
+ libpath = _get_extensions_library_path()
41
+ session_options.register_custom_ops_library(libpath)
42
+ self._sess = ort.InferenceSession(
43
+ self.model_path, providers=execution_providers, sess_options=session_options
44
+ )
45
+ self._ort_inputs = [i.name for i in self._sess.get_inputs()]
46
+ self._ort_outputs = [o.name for o in self._sess.get_outputs()]
47
+ self._warmed_up = True
48
+ return self
49
+
50
+ def compute(self, inputs: list, dynamic_attributes: dict, **kwargs):
51
+ if not self._warmed_up:
52
+ raise RuntimeError("Op is not warmed up")
53
+ outputs = self._sess.run(self._ort_outputs, dict(zip(self._ort_inputs, inputs)))
54
+ return OpOutput(value=list(outputs), metadata={})
55
+
56
+ async def compute_async(self, inputs: list, dynamic_attributes: dict, **kwargs):
57
+ return await to_thread(self.executor, self.compute, inputs, dynamic_attributes)
fnnx/registry.py ADDED
@@ -0,0 +1,23 @@
1
+ from typing import Type
2
+ from fnnx.ops._base import BaseOp
3
+ from fnnx.ops.onnx import OnnxOp_V1
4
+ import warnings
5
+
6
+
7
+ class Registry:
8
+
9
+ def __init__(self):
10
+ self.ops: dict[str, Type[BaseOp]] = {}
11
+
12
+ def register_op(self, op: Type[BaseOp], name: str):
13
+ self.ops[name] = op
14
+
15
+ def get_op(self, name: str) -> Type[BaseOp]:
16
+ return self.ops[name]
17
+
18
+ def register_default_ops(self):
19
+ if len(self.ops.keys()) > 0:
20
+ warnings.warn(
21
+ "Attempting to register default ops into a non-empty registry."
22
+ )
23
+ self.register_op(OnnxOp_V1, "ONNX_v1")
fnnx/runtime.py ADDED
@@ -0,0 +1,35 @@
1
+ from fnnx.handlers._base import BaseHandler
2
+ from fnnx.handlers.local import LocalHandler, LocalHandlerConfig
3
+ from fnnx.device import DeviceMap
4
+ from typing import Any, Type
5
+
6
+
7
+ class Runtime:
8
+ def __init__(
9
+ self,
10
+ bundle_path: str,
11
+ handler: Type[BaseHandler] | None = None,
12
+ handler_config: Any = None,
13
+ device_map: str | DeviceMap | None = None,
14
+ cleanup: bool = True,
15
+ ):
16
+ self.cleanup = cleanup
17
+ if handler is None:
18
+ handler = LocalHandler
19
+ handler_config = (
20
+ handler_config
21
+ if isinstance(handler_config, LocalHandlerConfig)
22
+ else LocalHandlerConfig()
23
+ )
24
+
25
+ if device_map is None:
26
+ device_map = DeviceMap(accelerator="cpu", node_device_map={})
27
+ elif isinstance(device_map, str):
28
+ device_map = DeviceMap(accelerator=device_map, node_device_map={})
29
+ self.handler: BaseHandler = handler(bundle_path, device_map, handler_config)
30
+
31
+ def compute(self, inputs: dict, dynamic_attributes: dict):
32
+ return self.handler.compute(inputs, dynamic_attributes)
33
+
34
+ async def compute_async(self, inputs: dict, dynamic_attributes: dict):
35
+ return await self.handler.compute_async(inputs, dynamic_attributes)