fxn 0.0.52__tar.gz → 0.0.53__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {fxn-0.0.52 → fxn-0.0.53}/PKG-INFO +1 -1
  2. fxn-0.0.53/fxn/beta/__init__.py +13 -0
  3. fxn-0.0.53/fxn/beta/cli/__init__.py +6 -0
  4. fxn-0.0.53/fxn/beta/cli/llm.py +22 -0
  5. {fxn-0.0.52 → fxn-0.0.53}/fxn/beta/client.py +1 -2
  6. fxn-0.0.53/fxn/beta/llm/__init__.py +5 -0
  7. fxn-0.0.53/fxn/beta/llm/server.py +5 -0
  8. fxn-0.0.53/fxn/beta/metadata.py +171 -0
  9. fxn-0.0.53/fxn/beta/services/__init__.py +7 -0
  10. {fxn-0.0.52/fxn/beta → fxn-0.0.53/fxn/beta/services}/prediction.py +1 -1
  11. {fxn-0.0.52/fxn/beta → fxn-0.0.53/fxn/beta/services}/remote.py +4 -4
  12. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/__init__.py +3 -1
  13. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/compile.py +2 -0
  14. {fxn-0.0.52 → fxn-0.0.53}/fxn/compile.py +3 -3
  15. {fxn-0.0.52 → fxn-0.0.53}/fxn/services/prediction.py +0 -12
  16. {fxn-0.0.52 → fxn-0.0.53}/fxn/version.py +1 -1
  17. {fxn-0.0.52 → fxn-0.0.53}/fxn.egg-info/PKG-INFO +1 -1
  18. {fxn-0.0.52 → fxn-0.0.53}/fxn.egg-info/SOURCES.txt +7 -2
  19. fxn-0.0.52/fxn/beta/__init__.py +0 -11
  20. fxn-0.0.52/fxn/beta/metadata.py +0 -89
  21. {fxn-0.0.52 → fxn-0.0.53}/LICENSE +0 -0
  22. {fxn-0.0.52 → fxn-0.0.53}/README.md +0 -0
  23. {fxn-0.0.52 → fxn-0.0.53}/fxn/__init__.py +0 -0
  24. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/__init__.py +0 -0
  25. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/configuration.py +0 -0
  26. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/fxnc.py +0 -0
  27. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/map.py +0 -0
  28. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/prediction.py +0 -0
  29. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/predictor.py +0 -0
  30. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/stream.py +0 -0
  31. {fxn-0.0.52 → fxn-0.0.53}/fxn/c/value.py +0 -0
  32. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/auth.py +0 -0
  33. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/misc.py +0 -0
  34. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/predictions.py +0 -0
  35. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/predictors.py +0 -0
  36. {fxn-0.0.52 → fxn-0.0.53}/fxn/cli/sources.py +0 -0
  37. {fxn-0.0.52 → fxn-0.0.53}/fxn/client.py +0 -0
  38. {fxn-0.0.52 → fxn-0.0.53}/fxn/function.py +0 -0
  39. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/__init__.py +0 -0
  40. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/linux/arm64/libFunction.so +0 -0
  41. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/linux/x86_64/libFunction.so +0 -0
  42. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/macos/arm64/Function.dylib +0 -0
  43. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/macos/x86_64/Function.dylib +0 -0
  44. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/windows/arm64/Function.dll +0 -0
  45. {fxn-0.0.52 → fxn-0.0.53}/fxn/lib/windows/x86_64/Function.dll +0 -0
  46. {fxn-0.0.52 → fxn-0.0.53}/fxn/logging.py +0 -0
  47. {fxn-0.0.52 → fxn-0.0.53}/fxn/sandbox.py +0 -0
  48. {fxn-0.0.52 → fxn-0.0.53}/fxn/services/__init__.py +0 -0
  49. {fxn-0.0.52 → fxn-0.0.53}/fxn/services/predictor.py +0 -0
  50. {fxn-0.0.52 → fxn-0.0.53}/fxn/services/user.py +0 -0
  51. {fxn-0.0.52 → fxn-0.0.53}/fxn/types/__init__.py +0 -0
  52. {fxn-0.0.52 → fxn-0.0.53}/fxn/types/dtype.py +0 -0
  53. {fxn-0.0.52 → fxn-0.0.53}/fxn/types/prediction.py +0 -0
  54. {fxn-0.0.52 → fxn-0.0.53}/fxn/types/predictor.py +0 -0
  55. {fxn-0.0.52 → fxn-0.0.53}/fxn/types/user.py +0 -0
  56. {fxn-0.0.52 → fxn-0.0.53}/fxn.egg-info/dependency_links.txt +0 -0
  57. {fxn-0.0.52 → fxn-0.0.53}/fxn.egg-info/entry_points.txt +0 -0
  58. {fxn-0.0.52 → fxn-0.0.53}/fxn.egg-info/requires.txt +0 -0
  59. {fxn-0.0.52 → fxn-0.0.53}/fxn.egg-info/top_level.txt +0 -0
  60. {fxn-0.0.52 → fxn-0.0.53}/pyproject.toml +0 -0
  61. {fxn-0.0.52 → fxn-0.0.53}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fxn
3
- Version: 0.0.52
3
+ Version: 0.0.53
4
4
  Summary: Run prediction functions locally in Python. Register at https://fxn.ai.
5
5
  Author-email: "NatML Inc." <hi@fxn.ai>
6
6
  License: Apache License
@@ -0,0 +1,13 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from .metadata import (
7
+ CoreMLInferenceMetadata, LiteRTInferenceMetadata, LlamaCppInferenceMetadata,
8
+ OnnxInferenceMetadata, OnnxRuntimeInferenceSessionMetadata, OpenVINOInferenceMetadata,
9
+ QnnInferenceMetadata, QnnInferenceBackend, QnnInferenceQuantization,
10
+ # Deprecated
11
+ ONNXInferenceMetadata, ONNXRuntimeInferenceSessionMetadata
12
+ )
13
+ from .services import RemoteAcceleration
@@ -0,0 +1,6 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from .llm import app as llm_app
@@ -0,0 +1,22 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from pathlib import Path
7
+ from typer import Argument, Option, Typer
8
+ from typing_extensions import Annotated
9
+
10
+ app = Typer(no_args_is_help=True)
11
+
12
+ @app.command(name="chat", help="Start a chat session.")
13
+ def chat (
14
+ model: Annotated[str, Argument(help="Model to chat with.")]
15
+ ):
16
+ pass
17
+
18
+ @app.command(name="serve", help="Start an LLM server.")
19
+ def serve (
20
+ port: Annotated[int, Option(help="Port to start the server on.")] = 11435
21
+ ):
22
+ pass
@@ -10,8 +10,7 @@ from typing import get_origin, Callable, Generator, Iterator, TypeVar
10
10
  from ..client import FunctionClient
11
11
  from ..services import PredictionService as EdgePredictionService
12
12
  from ..types import Acceleration
13
- from .prediction import PredictionService
14
- from .remote import RemoteAcceleration
13
+ from .services import PredictionService, RemoteAcceleration
15
14
 
16
15
  F = TypeVar("F", bound=Callable[..., object])
17
16
 
@@ -0,0 +1,5 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
@@ -0,0 +1,5 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
@@ -0,0 +1,171 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from os import PathLike
7
+ from pathlib import Path
8
+ from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
9
+ from typing import Annotated, Literal
10
+
11
+ def _validate_torch_module (module: "torch.nn.Module") -> "torch.nn.Module": # type: ignore
12
+ try:
13
+ from torch.nn import Module # type: ignore
14
+ if not isinstance(module, Module):
15
+ raise ValueError(f"Expected torch.nn.Module, got {type(module)}")
16
+ return module
17
+ except ImportError:
18
+ raise ImportError("PyTorch is required to create this metadata but is not installed.")
19
+
20
+ def _validate_ort_inference_session (session: "onnxruntime.InferenceSession") -> "onnxruntime.InferenceSession": # type: ignore
21
+ try:
22
+ from onnxruntime import InferenceSession # type: ignore
23
+ if not isinstance(session, InferenceSession):
24
+ raise ValueError(f"Expected onnxruntime.InferenceSession, got {type(session)}")
25
+ return session
26
+ except ImportError:
27
+ raise ImportError("ONNXRuntime is required to create this metadata but is not installed.")
28
+
29
+ class CoreMLInferenceMetadata (BaseModel):
30
+ """
31
+ Metadata required to lower a PyTorch model for inference on iOS, macOS, and visionOS with CoreML.
32
+
33
+ Members:
34
+ model (torch.nn.Module): PyTorch module to apply metadata to.
35
+ model_args (tuple[Tensor,...]): Positional inputs to the model.
36
+ """
37
+ kind: Literal["meta.inference.coreml"] = "meta.inference.coreml"
38
+ model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(
39
+ description="PyTorch module to apply metadata to.",
40
+ exclude=True
41
+ )
42
+ model_args: list[object] = Field(
43
+ description="Positional inputs to the model.",
44
+ exclude=True
45
+ )
46
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
47
+
48
+ class OnnxInferenceMetadata (BaseModel):
49
+ """
50
+ Metadata required to lower a PyTorch model for inference.
51
+
52
+ Members:
53
+ model (torch.nn.Module): PyTorch module to apply metadata to.
54
+ model_args (tuple[Tensor,...]): Positional inputs to the model.
55
+ """
56
+ kind: Literal["meta.inference.onnx"] = "meta.inference.onnx"
57
+ model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(
58
+ description="PyTorch module to apply metadata to.",
59
+ exclude=True
60
+ )
61
+ model_args: list[object] = Field(
62
+ description="Positional inputs to the model.",
63
+ exclude=True
64
+ )
65
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
66
+
67
+ class OnnxRuntimeInferenceSessionMetadata (BaseModel):
68
+ """
69
+ Metadata required to lower an ONNXRuntime `InferenceSession` for inference.
70
+
71
+ Members:
72
+ session (onnxruntime.InferenceSession): ONNXRuntime inference session to apply metadata to.
73
+ model_path (str | Path): ONNX model path. The model must exist at this path in the compiler sandbox.
74
+ """
75
+ kind: Literal["meta.inference.onnxruntime"] = "meta.inference.onnxruntime"
76
+ session: Annotated[object, BeforeValidator(_validate_ort_inference_session)] = Field(
77
+ description="ONNXRuntime inference session to apply metadata to.",
78
+ exclude=True
79
+ )
80
+ model_path: str | Path = Field(
81
+ description="ONNX model path. The model must exist at this path in the compiler sandbox.",
82
+ exclude=True
83
+ )
84
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
85
+
86
+ class LiteRTInferenceMetadata (BaseModel):
87
+ """
88
+ Metadata required to lower PyTorch model for inference with LiteRT (fka TensorFlow Lite).
89
+
90
+ Members:
91
+ model (torch.nn.Module): PyTorch module to apply metadata to.
92
+ model_args (tuple[Tensor,...]): Positional inputs to the model.
93
+ """
94
+ kind: Literal["meta.inference.litert"] = "meta.inference.litert"
95
+ model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(
96
+ description="PyTorch module to apply metadata to.",
97
+ exclude=True
98
+ )
99
+ model_args: list[object] = Field(
100
+ description="Positional inputs to the model.",
101
+ exclude=True
102
+ )
103
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
104
+
105
+ class OpenVINOInferenceMetadata (BaseModel):
106
+ """
107
+ Metadata required to lower PyTorch model for interence with Intel OpenVINO.
108
+
109
+ Members:
110
+ model (torch.nn.Module): PyTorch module to apply metadata to.
111
+ model_args (tuple[Tensor,...]): Positional inputs to the model.
112
+ """
113
+ kind: Literal["meta.inference.openvino"] = "meta.inference.openvino"
114
+ model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(
115
+ description="PyTorch module to apply metadata to.",
116
+ exclude=True
117
+ )
118
+ model_args: list[object] = Field(
119
+ description="Positional inputs to the model.",
120
+ exclude=True
121
+ )
122
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
123
+
124
+ QnnInferenceBackend = Literal["cpu", "gpu"] # `htp` coming soon
125
+ QnnInferenceQuantization = Literal["w8a8", "w8a16", "w4a8", "w4a16"]
126
+
127
+ class QnnInferenceMetadata (BaseModel):
128
+ """
129
+ Metadata required to lower a PyTorch model for inference on Qualcomm accelerators with QNN SDK.
130
+
131
+ Members:
132
+ model (torch.nn.Module): PyTorch module to apply metadata to.
133
+ model_args (tuple[Tensor,...]): Positional inputs to the model.
134
+ backend (QnnInferenceBackend): QNN inference backend. Defaults to `cpu`.
135
+ quantization (QnnInferenceQuantization): QNN model quantization mode. This MUST only be specified when backend is `htp`.
136
+ """
137
+ kind: Literal["meta.inference.qnn"] = "meta.inference.qnn"
138
+ model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(
139
+ description="PyTorch module to apply metadata to.",
140
+ exclude=True
141
+ )
142
+ model_args: list[object] = Field(
143
+ description="Positional inputs to the model.",
144
+ exclude=True
145
+ )
146
+ backend: QnnInferenceBackend = Field(
147
+ default="cpu",
148
+ description="QNN backend to execute the model.",
149
+ exclude=True
150
+ )
151
+ quantization: QnnInferenceQuantization | None = Field(
152
+ default=None,
153
+ description="QNN model quantization mode. This MUST only be specified when backend is `htp`.",
154
+ exclude=True
155
+ )
156
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
157
+
158
+ class LlamaCppInferenceMetadata (BaseModel): # INCOMPLETE
159
+ """
160
+ Metadata required to lower a GGUF model for LLM inference.
161
+ """
162
+ kind: Literal["meta.inference.gguf"] = "meta.inference.gguf"
163
+ model_path: Path = Field(
164
+ description="GGUF model path. The model must exist at this path in the compiler sandbox.",
165
+ exclude=True
166
+ )
167
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
168
+
169
+ # DEPRECATED
170
+ ONNXInferenceMetadata = OnnxInferenceMetadata
171
+ ONNXRuntimeInferenceSessionMetadata = OnnxRuntimeInferenceSessionMetadata
@@ -0,0 +1,7 @@
1
+ #
2
+ # Function
3
+ # Copyright © 2025 NatML Inc. All Rights Reserved.
4
+ #
5
+
6
+ from .prediction import PredictionService
7
+ from .remote import RemoteAcceleration
@@ -3,7 +3,7 @@
3
3
  # Copyright © 2025 NatML Inc. All Rights Reserved.
4
4
  #
5
5
 
6
- from ..client import FunctionClient
6
+ from ...client import FunctionClient
7
7
  from .remote import RemotePredictionService
8
8
 
9
9
  class PredictionService:
@@ -15,10 +15,10 @@ from requests import get, put
15
15
  from typing import Literal
16
16
  from urllib.request import urlopen
17
17
 
18
- from ..c import Configuration
19
- from ..client import FunctionClient
20
- from ..services import Value
21
- from ..types import Dtype, Prediction
18
+ from ...c import Configuration
19
+ from ...client import FunctionClient
20
+ from ...services.prediction import Value
21
+ from ...types import Dtype, Prediction
22
22
 
23
23
  RemoteAcceleration = Literal["auto", "cpu", "a40", "a100"]
24
24
 
@@ -14,6 +14,7 @@ from .misc import cli_options
14
14
  from .predictions import create_prediction
15
15
  from .predictors import archive_predictor, delete_predictor, retrieve_predictor
16
16
  from .sources import retrieve_source
17
+ from ..beta.cli import llm_app
17
18
 
18
19
  # Define CLI
19
20
  typer.main.console_stderr = TracebackMarkupConsole()
@@ -30,6 +31,7 @@ app.callback()(cli_options)
30
31
 
31
32
  # Add subcommands
32
33
  app.add_typer(auth_app, name="auth", help="Login, logout, and check your authentication status.")
34
+ app.add_typer(llm_app, name="llm", hidden=True, help="Work with large language models (LLMs).")
33
35
 
34
36
  # Add top-level commands
35
37
  app.command(
@@ -44,7 +46,7 @@ app.command(
44
46
  app.command(name="retrieve", help="Retrieve a predictor.")(retrieve_predictor)
45
47
  app.command(name="archive", help="Archive a predictor.")(archive_predictor)
46
48
  app.command(name="delete", help="Delete a predictor.")(delete_predictor)
47
- app.command(name="source", help="Retrieve the native source code for a given prediction.")(retrieve_source)
49
+ app.command(name="source", help="Retrieve the generated native code for a given predictor.")(retrieve_source)
48
50
 
49
51
  # Run
50
52
  if __name__ == "__main__":
@@ -86,6 +86,8 @@ def _load_predictor_func (path: str) -> Callable[...,object]:
86
86
  if "" not in sys.path:
87
87
  sys.path.insert(0, "")
88
88
  path: Path = Path(path).resolve()
89
+ if not path.exists():
90
+ raise ValueError(f"Cannot compile predictor because no Python module exists at the given path.")
89
91
  sys.path.insert(0, str(path.parent))
90
92
  name = getmodulename(path)
91
93
  spec = spec_from_file_location(name, path)
@@ -13,7 +13,7 @@ from typing import Any, Callable, Literal, ParamSpec, TypeVar, cast
13
13
 
14
14
  from .beta import (
15
15
  CoreMLInferenceMetadata, LiteRTInferenceMetadata, LlamaCppInferenceMetadata,
16
- ONNXInferenceMetadata, ONNXRuntimeInferenceSessionMetadata, OpenVINOInferenceMetadata,
16
+ OnnxInferenceMetadata, OnnxRuntimeInferenceSessionMetadata, OpenVINOInferenceMetadata,
17
17
  QnnInferenceMetadata
18
18
  )
19
19
  from .sandbox import Sandbox
@@ -33,8 +33,8 @@ CompileMetadata = (
33
33
  CoreMLInferenceMetadata |
34
34
  LiteRTInferenceMetadata |
35
35
  LlamaCppInferenceMetadata |
36
- ONNXInferenceMetadata |
37
- ONNXRuntimeInferenceSessionMetadata |
36
+ OnnxInferenceMetadata |
37
+ OnnxRuntimeInferenceSessionMetadata |
38
38
  OpenVINOInferenceMetadata |
39
39
  QnnInferenceMetadata
40
40
  )
@@ -43,18 +43,6 @@ class PredictionService:
43
43
  self.__cache_dir = self.__class__.__get_home_dir() / ".fxn" / "cache"
44
44
  self.__cache_dir.mkdir(parents=True, exist_ok=True)
45
45
 
46
- def ready (self, tag: str, **kwargs) -> bool:
47
- """
48
- Check whether a predictor has been preloaded and is ready to make predictions.
49
-
50
- Parameters:
51
- tag (str): Predictor tag.
52
-
53
- Returns:
54
- bool: Whether the predictor is ready to make predictions.
55
- """
56
- return tag in self.__cache
57
-
58
46
  def create (
59
47
  self,
60
48
  tag: str,
@@ -3,4 +3,4 @@
3
3
  # Copyright © 2025 NatML Inc. All Rights Reserved.
4
4
  #
5
5
 
6
- __version__ = "0.0.52"
6
+ __version__ = "0.0.53"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fxn
3
- Version: 0.0.52
3
+ Version: 0.0.53
4
4
  Summary: Run prediction functions locally in Python. Register at https://fxn.ai.
5
5
  Author-email: "NatML Inc." <hi@fxn.ai>
6
6
  License: Apache License
@@ -17,8 +17,13 @@ fxn.egg-info/top_level.txt
17
17
  fxn/beta/__init__.py
18
18
  fxn/beta/client.py
19
19
  fxn/beta/metadata.py
20
- fxn/beta/prediction.py
21
- fxn/beta/remote.py
20
+ fxn/beta/cli/__init__.py
21
+ fxn/beta/cli/llm.py
22
+ fxn/beta/llm/__init__.py
23
+ fxn/beta/llm/server.py
24
+ fxn/beta/services/__init__.py
25
+ fxn/beta/services/prediction.py
26
+ fxn/beta/services/remote.py
22
27
  fxn/c/__init__.py
23
28
  fxn/c/configuration.py
24
29
  fxn/c/fxnc.py
@@ -1,11 +0,0 @@
1
- #
2
- # Function
3
- # Copyright © 2025 NatML Inc. All Rights Reserved.
4
- #
5
-
6
- from .metadata import (
7
- CoreMLInferenceMetadata, LiteRTInferenceMetadata, LlamaCppInferenceMetadata,
8
- ONNXInferenceMetadata, ONNXRuntimeInferenceSessionMetadata, OpenVINOInferenceMetadata,
9
- QnnInferenceMetadata
10
- )
11
- from .remote import RemoteAcceleration
@@ -1,89 +0,0 @@
1
- #
2
- # Function
3
- # Copyright © 2025 NatML Inc. All Rights Reserved.
4
- #
5
-
6
- from pathlib import Path
7
- from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
8
- from typing import Annotated, Literal
9
-
10
- def _validate_torch_module (module: "torch.nn.Module") -> "torch.nn.Module": # type: ignore
11
- try:
12
- from torch.nn import Module # type: ignore
13
- if not isinstance(module, Module):
14
- raise ValueError(f"Expected torch.nn.Module, got {type(module)}")
15
- return module
16
- except ImportError:
17
- raise ImportError("PyTorch is required to create this metadata but is not installed.")
18
-
19
- def _validate_ort_inference_session (session: "onnxruntime.InferenceSession") -> "onnxruntime.InferenceSession": # type: ignore
20
- try:
21
- from onnxruntime import InferenceSession # type: ignore
22
- if not isinstance(session, InferenceSession):
23
- raise ValueError(f"Expected onnxruntime.InferenceSession, got {type(session)}")
24
- return session
25
- except ImportError:
26
- raise ImportError("ONNXRuntime is required to create this metadata but is not installed.")
27
-
28
- class CoreMLInferenceMetadata (BaseModel):
29
- """
30
- Metadata required to lower a PyTorch model for inference on iOS, macOS, and visionOS with CoreML.
31
- """
32
- kind: Literal["meta.inference.coreml"] = "meta.inference.coreml"
33
- model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(description="PyTorch module to apply metadata to.", exclude=True)
34
- model_args: list[object] = Field(description="Positional inputs to the model.", exclude=True)
35
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
36
-
37
- class ONNXInferenceMetadata (BaseModel):
38
- """
39
- Metadata required to lower a PyTorch model for inference.
40
- """
41
- kind: Literal["meta.inference.onnx"] = "meta.inference.onnx"
42
- model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(description="PyTorch module to apply metadata to.", exclude=True)
43
- model_args: list[object] = Field(description="Positional inputs to the model.", exclude=True)
44
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
45
-
46
- class ONNXRuntimeInferenceSessionMetadata (BaseModel):
47
- """
48
- Metadata required to lower an ONNXRuntime `InferenceSession` for inference.
49
- """
50
- kind: Literal["meta.inference.onnxruntime"] = "meta.inference.onnxruntime"
51
- session: Annotated[object, BeforeValidator(_validate_ort_inference_session)] = Field(description="ONNXRuntime inference session to apply metadata to.", exclude=True)
52
- model_path: Path = Field(description="ONNX model path. The model must exist at this path in the compiler sandbox.", exclude=True)
53
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
54
-
55
- class LiteRTInferenceMetadata (BaseModel):
56
- """
57
- Metadata required to lower PyTorch model for inference with LiteRT (fka TensorFlow Lite).
58
- """
59
- kind: Literal["meta.inference.litert"] = "meta.inference.litert"
60
- model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(description="PyTorch module to apply metadata to.", exclude=True)
61
- model_args: list[object] = Field(description="Positional inputs to the model.", exclude=True)
62
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
63
-
64
- class OpenVINOInferenceMetadata (BaseModel):
65
- """
66
- Metadata required to lower PyTorch model for interence with Intel OpenVINO.
67
- """
68
- kind: Literal["meta.inference.openvino"] = "meta.inference.openvino"
69
- model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(description="PyTorch module to apply metadata to.", exclude=True)
70
- model_args: list[object] = Field(description="Positional inputs to the model.", exclude=True)
71
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
72
-
73
- class QnnInferenceMetadata (BaseModel):
74
- """
75
- Metadata required to lower a PyTorch model for inference on Qualcomm accelerators with QNN SDK.
76
- """
77
- kind: Literal["meta.inference.qnn"] = "meta.inference.qnn"
78
- model: Annotated[object, BeforeValidator(_validate_torch_module)] = Field(description="PyTorch module to apply metadata to.", exclude=True)
79
- model_args: list[object] = Field(description="Positional inputs to the model.", exclude=True)
80
- backend: Literal["cpu", "gpu"] = Field(default="cpu", description="QNN backend to execute the model.", exclude=True) # CHECK # Add `htp`
81
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
82
-
83
- class LlamaCppInferenceMetadata (BaseModel): # INCOMPLETE
84
- """
85
- Metadata required to lower a GGUF model for LLM inference.
86
- """
87
- kind: Literal["meta.inference.gguf"] = "meta.inference.gguf"
88
- model_path: Path = Field(description="GGUF model path. The model must exist at this path in the compiler sandbox.", exclude=True)
89
- model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes