bigdl-core-npu 2.6.0b20241110__cp310-cp310-win_amd64.whl → 2.6.0b20241113__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bigdl_core_npu-2.6.0b20241110.dist-info → bigdl_core_npu-2.6.0b20241113.dist-info}/METADATA +1 -1
- {bigdl_core_npu-2.6.0b20241110.dist-info → bigdl_core_npu-2.6.0b20241113.dist-info}/RECORD +95 -85
- intel_npu_acceleration_library/_version.py +1 -1
- intel_npu_acceleration_library/backend/bindings.py +10 -0
- intel_npu_acceleration_library/backend/factory.py +2 -26
- intel_npu_acceleration_library/backend/tensor.py +69 -0
- intel_npu_acceleration_library/device.py +2 -2
- intel_npu_acceleration_library/dtypes.py +34 -1
- intel_npu_acceleration_library/external/openvino/_offline_transformations/__init__.py +1 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/jaxpr_decoder.py +283 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/utils.py +129 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/fx_decoder.py +8 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/gptq.py +1 -1
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/patch_model.py +28 -8
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/op_support.py +1 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/ts_decoder.py +3 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/helpers/packing.py +4 -4
- intel_npu_acceleration_library/external/openvino/preprocess/__init__.py +2 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/requirements.txt +1 -0
- intel_npu_acceleration_library/external/openvino/properties/__init__.py +1 -0
- intel_npu_acceleration_library/external/openvino/runtime/op/__init__.py +1 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset1/ops.py +2 -1
- intel_npu_acceleration_library/external/openvino/runtime/opset13/ops.py +5 -6
- intel_npu_acceleration_library/external/openvino/runtime/opset15/__init__.py +2 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset15/ops.py +62 -1
- intel_npu_acceleration_library/external/openvino/runtime/opset6/ops.py +60 -43
- intel_npu_acceleration_library/external/openvino/runtime/opset8/ops.py +4 -0
- intel_npu_acceleration_library/external/openvino/runtime/properties/__init__.py +1 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/decorators.py +67 -1
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/inputs_filling.py +9 -9
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert_impl.py +16 -2
- intel_npu_acceleration_library/external/openvino/tools/ovc/main.py +5 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/jax_frontend_utils.py +19 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pipeline.py +68 -16
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +70 -60
- intel_npu_acceleration_library/external/openvino/tools/ovc/utils.py +90 -3
- intel_npu_acceleration_library/external/openvino/utils.py +17 -0
- intel_npu_acceleration_library/lib/Release/intel_npu_acceleration_library.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_auto_batch_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_auto_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_c.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_hetero_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_cpu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_gpu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_npu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_ir_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_jax_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_onnx_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_paddle_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_pytorch_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_tensorflow_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_tensorflow_lite_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbb12.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbb12_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbbind_2_5.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbbind_2_5_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy_debug.dll +0 -0
- intel_npu_acceleration_library/nn/module.py +17 -17
- {bigdl_core_npu-2.6.0b20241110.dist-info → bigdl_core_npu-2.6.0b20241113.dist-info}/WHEEL +0 -0
- {bigdl_core_npu-2.6.0b20241110.dist-info → bigdl_core_npu-2.6.0b20241113.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# flake8: noqa
|
5
|
+
# mypy: ignore-errors
|
6
|
+
|
7
|
+
import jax.core
|
8
|
+
from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
|
9
|
+
from openvino.runtime import PartialShape, Type as OVType, OVAny
|
10
|
+
from openvino.frontend.jax.utils import jax_array_to_ov_const, get_ov_type_for_value, \
|
11
|
+
ivalue_to_constant
|
12
|
+
|
13
|
+
import jax
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from typing import List
|
17
|
+
import logging
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
logger.setLevel(logging.WARNING)
|
20
|
+
|
21
|
+
class JaxprPythonDecoder (Decoder):
|
22
|
+
'''
|
23
|
+
The jaxpr decoder uses Jaxpr to get graph information from a jax module.
|
24
|
+
It takes use of the following parts.
|
25
|
+
|
26
|
+
- `ClosedJaxpr`: the jaxpr object that contains the jaxpr and literals.
|
27
|
+
- `Jaxpr`: the jaxpr object that contains the invars, outvars, and eqns.
|
28
|
+
- `JaxEqns`: A list of jaxpr equations, which contains the information of the operation.
|
29
|
+
- `Primitive`: the operation that is used in the equation.
|
30
|
+
- `invars`: the input variables of the equation.
|
31
|
+
- `aval`: the abstract value.
|
32
|
+
- `outvars`: the output variables of the equation.
|
33
|
+
- `aval`: the abstract value.
|
34
|
+
- `params`: the named params of this equation.
|
35
|
+
- `invars`: the inputs of the model (traced graph).
|
36
|
+
- `aval`: the abstract value.
|
37
|
+
- `outvars`: the outputs of the model (traced graph).
|
38
|
+
- `aval`: the abstract value.
|
39
|
+
- `constvars`: the constant variables used in this model.
|
40
|
+
- `aval`: the abstract value.
|
41
|
+
- `Literal`: the literal object that contains the value of the constants.
|
42
|
+
'''
|
43
|
+
|
44
|
+
def __init__(self, jaxpr, name=None, literals=None):
|
45
|
+
'''
|
46
|
+
Inputs:
|
47
|
+
- jaxpr: for users, `ClosedJaxpr` is expected here. See https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L197
|
48
|
+
- name: the name for the model.
|
49
|
+
- literals: the literals (constants) that are used in the model.
|
50
|
+
'''
|
51
|
+
Decoder.__init__(self)
|
52
|
+
|
53
|
+
if isinstance(jaxpr, (jax.core.JaxprEqn, jax.core.Jaxpr)):
|
54
|
+
self.jaxpr = jaxpr
|
55
|
+
elif isinstance(jaxpr, jax.core.ClosedJaxpr):
|
56
|
+
# Take the `Jaxpr` from `ClosedJaxpr`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
|
57
|
+
self.jaxpr = jaxpr.jaxpr
|
58
|
+
# Literal should be a `Jax.core.Var`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
|
59
|
+
self.literals = jaxpr.literals
|
60
|
+
else:
|
61
|
+
raise ValueError(f"Unexpected type of jaxpr: {type(jaxpr)}")
|
62
|
+
self.name = name
|
63
|
+
if self.name is None:
|
64
|
+
self.name = "jax_module"
|
65
|
+
if literals is not None:
|
66
|
+
self.literals = literals
|
67
|
+
|
68
|
+
self.params = {}
|
69
|
+
if hasattr(self.jaxpr, 'params') and isinstance(self.jaxpr.params, dict):
|
70
|
+
for k in self.jaxpr.params.keys():
|
71
|
+
self.params[k] = self.convert_param_to_constant_node(self.jaxpr, k)
|
72
|
+
|
73
|
+
# TODO: this implementation may lead to memory increasing. Any better solution?
|
74
|
+
self.m_decoders = []
|
75
|
+
|
76
|
+
def inputs(self) -> List[int]:
|
77
|
+
if isinstance(self.jaxpr, jax.core.JaxprEqn):
|
78
|
+
idx = 0
|
79
|
+
res = []
|
80
|
+
for inp in self.jaxpr.invars:
|
81
|
+
if isinstance(inp, jax.core.Literal):
|
82
|
+
res.append(self.literals[idx].output(0))
|
83
|
+
idx += 1
|
84
|
+
else:
|
85
|
+
res.append(id(inp))
|
86
|
+
return res
|
87
|
+
else:
|
88
|
+
return [id(v) for v in self.jaxpr.invars]
|
89
|
+
|
90
|
+
def input(self, idx: int) -> int:
|
91
|
+
return id(self.jaxpr.invars[idx])
|
92
|
+
|
93
|
+
def get_input_shape(self, index):
|
94
|
+
return PartialShape(self.jaxpr.invars[index].aval.shape)
|
95
|
+
|
96
|
+
def get_input_signature_name(self, index) -> str:
|
97
|
+
return "jaxpr_invar_" + str(index)
|
98
|
+
|
99
|
+
def get_input_type(self, index) -> OVType:
|
100
|
+
return get_ov_type_for_value(self.jaxpr.invars[index])
|
101
|
+
|
102
|
+
def get_named_param(self, name):
|
103
|
+
'''
|
104
|
+
Get the object id of the named parameter by the name.
|
105
|
+
'''
|
106
|
+
return self.params[name].output(0)
|
107
|
+
|
108
|
+
def get_named_param_as_constant(self, name):
|
109
|
+
'''
|
110
|
+
The named parameter in JAX is a python object but we want to use its value in cpp.
|
111
|
+
Therefore this API is used to get the named parameter as a constant, which can be used
|
112
|
+
to extract the value of it in cpp-level.
|
113
|
+
'''
|
114
|
+
return self.params[name].as_constant()
|
115
|
+
|
116
|
+
def get_param_names(self):
|
117
|
+
'''
|
118
|
+
In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
|
119
|
+
For example, the `jax.lax.cat` operation has a named parameter `dim`,
|
120
|
+
which is used to indicate the dimension to concatenate the tensors.
|
121
|
+
|
122
|
+
Here we return the names of all the named params that appear in the model for the current `JaxEqn`.
|
123
|
+
'''
|
124
|
+
return list(self.params.keys())
|
125
|
+
|
126
|
+
def get_output_type(self, index) -> OVType:
|
127
|
+
return get_ov_type_for_value(self.jaxpr.outvars[index])
|
128
|
+
|
129
|
+
def get_output_name(self, index) -> str:
|
130
|
+
return "jaxpr_outvar_" + str(index)
|
131
|
+
|
132
|
+
def get_output_shape(self, index):
|
133
|
+
return PartialShape(self.jaxpr.outvars[index].aval.shape)
|
134
|
+
|
135
|
+
def visit_subgraph(self, node_visitor) -> None:
|
136
|
+
if isinstance(self.jaxpr, jax.core.JaxprEqn):
|
137
|
+
return
|
138
|
+
for _, decoder in self.params.items():
|
139
|
+
self.m_decoders.append(decoder)
|
140
|
+
node_visitor(decoder)
|
141
|
+
for idx, node in enumerate(self.jaxpr.constvars):
|
142
|
+
decoder = self.convert_literal_to_constant_node(
|
143
|
+
literal=self.literals[idx],
|
144
|
+
name=self.name + "/" + f"const({id(node)})",
|
145
|
+
output_id=id(node)
|
146
|
+
)
|
147
|
+
self.m_decoders.append(decoder)
|
148
|
+
node_visitor(decoder)
|
149
|
+
# Visit every `JaxEqn` in the jaxpr, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L285
|
150
|
+
for node in self.jaxpr.eqns:
|
151
|
+
literal_decoders = []
|
152
|
+
for inp in node.invars:
|
153
|
+
if isinstance(inp, jax.core.Literal):
|
154
|
+
literal_decoder = self.convert_literal_to_constant_node(inp)
|
155
|
+
literal_decoders.append(literal_decoder)
|
156
|
+
node_visitor(literal_decoder)
|
157
|
+
decoder = JaxprPythonDecoder(node, name=self.name + "/" + node.primitive.name, literals=literal_decoders)
|
158
|
+
self.m_decoders.append(decoder)
|
159
|
+
node_visitor(decoder)
|
160
|
+
|
161
|
+
def get_op_type(self) -> str:
|
162
|
+
if isinstance(self.jaxpr, jax.core.JaxprEqn):
|
163
|
+
return self.jaxpr.primitive.name
|
164
|
+
else:
|
165
|
+
return "root"
|
166
|
+
|
167
|
+
def outputs(self) -> List[int]:
|
168
|
+
return [id(v) for v in self.jaxpr.outvars]
|
169
|
+
|
170
|
+
def output(self, idx: int) -> int:
|
171
|
+
return id(self.jaxpr.outvars[idx])
|
172
|
+
|
173
|
+
def num_inputs(self) -> int:
|
174
|
+
return len(self.jaxpr.invars)
|
175
|
+
|
176
|
+
def num_outputs(self) -> int:
|
177
|
+
return len(self.jaxpr.outvars)
|
178
|
+
|
179
|
+
def as_constant(self):
|
180
|
+
if self.get_op_type() == 'constant':
|
181
|
+
value = self.literals
|
182
|
+
# TODO: dig out how to share the memory.
|
183
|
+
# Currently, using shared_memory will raise `ValueError: array is not writeable``
|
184
|
+
ov_const = jax_array_to_ov_const(value, shared_memory=False)
|
185
|
+
return ov_const.outputs()
|
186
|
+
else:
|
187
|
+
raise ValueError("This is not a constant node so it cannot be converted to a constant.")
|
188
|
+
|
189
|
+
@staticmethod
|
190
|
+
def convert_param_to_constant_node(jaxpr, param):
|
191
|
+
assert hasattr(jaxpr, 'params'), "The jaxpr does not have params."
|
192
|
+
constant = ivalue_to_constant(jaxpr.params[param], shared_memory=False)
|
193
|
+
return _JaxprPythonConstantDecoder(constant=constant)
|
194
|
+
|
195
|
+
@staticmethod
|
196
|
+
def convert_literal_to_constant_node(literal, name=None, output_id=None):
|
197
|
+
if isinstance(literal, jax.core.Literal):
|
198
|
+
constant = ivalue_to_constant(literal.val, shared_memory=False)
|
199
|
+
elif isinstance(literal, (jax.Array, np.ndarray)):
|
200
|
+
constant = ivalue_to_constant(literal, shared_memory=False)
|
201
|
+
else:
|
202
|
+
raise TypeError( f"The input should be a literal or jax array, but got {type(literal)}.")
|
203
|
+
return _JaxprPythonConstantDecoder(constant=constant, name=name, output_id=output_id)
|
204
|
+
|
205
|
+
class _JaxprPythonConstantDecoder (Decoder):
|
206
|
+
def __init__(self, name=None, constant=None, output_id=None):
|
207
|
+
'''
|
208
|
+
A decoder specially for constants and named parameters.
|
209
|
+
|
210
|
+
Inputs:
|
211
|
+
- name: the name for the model.
|
212
|
+
- literals: the literals (constants) that are used in the model.
|
213
|
+
- output_id: the id specified for this decoder's output. If none, use `id(self.constant)`.
|
214
|
+
'''
|
215
|
+
Decoder.__init__(self)
|
216
|
+
|
217
|
+
self.name = name
|
218
|
+
self.constant = constant
|
219
|
+
self.output_id = id(self.constant) if output_id is None else output_id
|
220
|
+
|
221
|
+
def inputs(self) -> List[int]:
|
222
|
+
return []
|
223
|
+
|
224
|
+
def input(self, idx: int) -> int:
|
225
|
+
raise ValueError("This is a constant node so it does not have input.")
|
226
|
+
|
227
|
+
def get_input_shape(self, index):
|
228
|
+
raise ValueError("This is a constant node so it does not have input shape.")
|
229
|
+
|
230
|
+
def get_input_signature_name(self, index) -> str:
|
231
|
+
raise ValueError("This is a constant node so it does not have input signature name.")
|
232
|
+
|
233
|
+
def get_input_type(self, index) -> OVType:
|
234
|
+
raise ValueError("This is a constant node so it does not have input type.")
|
235
|
+
|
236
|
+
def get_named_param(self, name):
|
237
|
+
raise ValueError("This is a constant node so it does not have named param.")
|
238
|
+
|
239
|
+
def get_named_param_as_constant(self, name):
|
240
|
+
raise ValueError("This is a constant node so it does not have named param.")
|
241
|
+
|
242
|
+
def get_param_names(self):
|
243
|
+
'''
|
244
|
+
In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
|
245
|
+
For example, the `jax.lax.cat` operation has a named parameter `dim`,
|
246
|
+
which is used to indicate the dimension to concatenate the tensors.
|
247
|
+
|
248
|
+
However, `_JaxprPythonConstantDecoder` is already a named param or a constant.
|
249
|
+
So it will never have a named param.
|
250
|
+
'''
|
251
|
+
return []
|
252
|
+
|
253
|
+
def get_output_type(self, index) -> OVType:
|
254
|
+
assert len(self.constant) == 1
|
255
|
+
return OVAny(self.constant[0].element_type)
|
256
|
+
|
257
|
+
def get_output_name(self, index) -> str:
|
258
|
+
return "jaxpr_outvar_" + str(index)
|
259
|
+
|
260
|
+
def get_output_shape(self, index):
|
261
|
+
assert len(self.constant) == 1
|
262
|
+
return PartialShape(self.constant[0].shape)
|
263
|
+
|
264
|
+
def visit_subgraph(self, node_visitor) -> None:
|
265
|
+
return
|
266
|
+
|
267
|
+
def get_op_type(self) -> str:
|
268
|
+
return "constant"
|
269
|
+
|
270
|
+
def outputs(self) -> List[int]:
|
271
|
+
return [self.output_id]
|
272
|
+
|
273
|
+
def output(self, idx: int) -> int:
|
274
|
+
return self.output_id
|
275
|
+
|
276
|
+
def num_inputs(self) -> int:
|
277
|
+
return 0
|
278
|
+
|
279
|
+
def num_outputs(self) -> int:
|
280
|
+
return 1
|
281
|
+
|
282
|
+
def as_constant(self):
|
283
|
+
return self.constant
|
intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp310-win_amd64.pyd
ADDED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp311-win_amd64.pyd
ADDED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp312-win_amd64.pyd
ADDED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp38-win_amd64.pyd
ADDED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/jax/py_jax_frontend.cp39-win_amd64.pyd
ADDED
Binary file
|
@@ -0,0 +1,129 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# flake8: noqa
|
5
|
+
# mypy: ignore-errors
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import numpy as np
|
9
|
+
import jax.numpy as jnp
|
10
|
+
|
11
|
+
from openvino.runtime import op, Type as OVType, Shape, OVAny
|
12
|
+
|
13
|
+
numpy_to_ov_type_map = {
|
14
|
+
np.float32: OVType.f32,
|
15
|
+
bool: OVType.boolean,
|
16
|
+
jax.dtypes.bfloat16: OVType.bf16, # TODO: check this
|
17
|
+
np.float16: OVType.f16,
|
18
|
+
np.float32: OVType.f32,
|
19
|
+
np.float64: OVType.f64,
|
20
|
+
np.uint8: OVType.u8,
|
21
|
+
np.int8: OVType.i8,
|
22
|
+
np.int16: OVType.i16,
|
23
|
+
np.int32: OVType.i32,
|
24
|
+
np.int64: OVType.i64,
|
25
|
+
}
|
26
|
+
|
27
|
+
jax_to_ov_type_map = {
|
28
|
+
jnp.float32: OVType.f32,
|
29
|
+
jnp.bfloat16: OVType.bf16, # TODO: check this
|
30
|
+
jnp.float16: OVType.f16,
|
31
|
+
jnp.float64: OVType.f64,
|
32
|
+
jnp.uint8: OVType.u8,
|
33
|
+
jnp.int8: OVType.i8,
|
34
|
+
jnp.int16: OVType.i16,
|
35
|
+
jnp.int32: OVType.i32,
|
36
|
+
jnp.int64: OVType.i64,
|
37
|
+
}
|
38
|
+
|
39
|
+
try:
|
40
|
+
jax_to_ov_type_map[jnp.bool] = OVType.boolean
|
41
|
+
except:
|
42
|
+
pass
|
43
|
+
|
44
|
+
basic_to_ov_type_map = {
|
45
|
+
int: OVType.i64,
|
46
|
+
float: OVType.f32,
|
47
|
+
bool: OVType.boolean,
|
48
|
+
}
|
49
|
+
|
50
|
+
ov_type_to_int_map = {
|
51
|
+
OVType.u8: 0,
|
52
|
+
OVType.i8: 1,
|
53
|
+
OVType.i16: 2,
|
54
|
+
OVType.i32: 3,
|
55
|
+
OVType.i64: 4,
|
56
|
+
OVType.f16: 5,
|
57
|
+
OVType.f32: 6,
|
58
|
+
OVType.f64: 7,
|
59
|
+
OVType.boolean: 11,
|
60
|
+
OVType.bf16: 15,
|
61
|
+
}
|
62
|
+
|
63
|
+
def get_type_from_py_type(value):
|
64
|
+
if isinstance(value, float):
|
65
|
+
return OVType.f32
|
66
|
+
if isinstance(value, bool):
|
67
|
+
return OVType.boolean
|
68
|
+
if isinstance(value, int):
|
69
|
+
return OVType.i64
|
70
|
+
return OVType.dynamic
|
71
|
+
|
72
|
+
def get_ov_type_for_value(value):
|
73
|
+
if isinstance(value, (jax.core.Var, jax.core.Literal)):
|
74
|
+
if value.aval.dtype in jax_to_ov_type_map:
|
75
|
+
return OVAny(jax_to_ov_type_map[value.aval.dtype])
|
76
|
+
for k, v in numpy_to_ov_type_map.items():
|
77
|
+
if value.aval.dtype == k:
|
78
|
+
return OVAny(v)
|
79
|
+
for k, v in basic_to_ov_type_map.items():
|
80
|
+
if isinstance(value.aval.dtype, k):
|
81
|
+
return OVAny(v)
|
82
|
+
elif isinstance(value, (int, float, bool)):
|
83
|
+
return OVAny(jax_to_ov_type_map[type(value)])
|
84
|
+
else:
|
85
|
+
raise NotImplementedError(f"dtype for {value} of type {type(value)} has not been supported yet.")
|
86
|
+
|
87
|
+
def get_ov_type_from_jax_type(dtype):
|
88
|
+
if dtype in jax_to_ov_type_map:
|
89
|
+
return OVAny(jax_to_ov_type_map[dtype])
|
90
|
+
for k, v in numpy_to_ov_type_map.items():
|
91
|
+
if dtype == k:
|
92
|
+
return OVAny(v)
|
93
|
+
for k, v in basic_to_ov_type_map.items():
|
94
|
+
if isinstance(dtype, k):
|
95
|
+
return OVAny(v)
|
96
|
+
return None
|
97
|
+
|
98
|
+
def jax_array_to_ov_const(arr: np.ndarray, shared_memory=True):
|
99
|
+
# TODO: deal with bfloat16 dtype here.
|
100
|
+
if isinstance(arr, np.ndarray):
|
101
|
+
return op.Constant(arr, shared_memory=shared_memory)
|
102
|
+
elif isinstance(arr, jax.Array):
|
103
|
+
return op.Constant(np.array(jax.device_get(arr)), shared_memory=shared_memory)
|
104
|
+
else:
|
105
|
+
raise ValueError(f"Constant is expected to be a numpy array or jax array but got {type(arr)}")
|
106
|
+
|
107
|
+
def ivalue_to_constant(ivalue, shared_memory=True):
|
108
|
+
'''
|
109
|
+
Convert a python object to an openvino constant.
|
110
|
+
'''
|
111
|
+
ov_type = get_type_from_py_type(ivalue)
|
112
|
+
if ov_type.is_static():
|
113
|
+
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
|
114
|
+
|
115
|
+
if isinstance(ivalue, (list, tuple)):
|
116
|
+
assert len(ivalue) > 0, "Can't deduce type for empty list"
|
117
|
+
ov_type = get_type_from_py_type(ivalue[0])
|
118
|
+
assert ov_type.is_static(), "Can't deduce type for list"
|
119
|
+
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
|
120
|
+
|
121
|
+
if isinstance(ivalue, (jax.Array, np.ndarray)):
|
122
|
+
return jax_array_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
|
123
|
+
|
124
|
+
ov_dtype_value = get_ov_type_from_jax_type(ivalue)
|
125
|
+
if ov_dtype_value is not None:
|
126
|
+
return op.Constant(OVType.i64, Shape([]), [ov_type_to_int_map[ov_dtype_value]]).outputs()
|
127
|
+
|
128
|
+
print(f"[WARNING][JAX FE] Cannot get constant from value {ivalue}")
|
129
|
+
return None
|
intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp310-win_amd64.pyd
CHANGED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp311-win_amd64.pyd
CHANGED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp312-win_amd64.pyd
CHANGED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp38-win_amd64.pyd
CHANGED
Binary file
|
intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp39-win_amd64.pyd
CHANGED
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -30,6 +30,7 @@ class TorchFXPythonDecoder (Decoder):
|
|
30
30
|
self.input_shapes = input_shapes
|
31
31
|
|
32
32
|
self._input_signature = []
|
33
|
+
self._example_input = None
|
33
34
|
|
34
35
|
if issubclass(type(pt_module), torch.fx.graph_module.GraphModule):
|
35
36
|
|
@@ -316,6 +317,13 @@ class TorchFXPythonDecoder (Decoder):
|
|
316
317
|
def num_of_outputs(self):
|
317
318
|
return len(self.outputs())
|
318
319
|
|
320
|
+
def output_list_size(self):
|
321
|
+
max_out_id = -1
|
322
|
+
for user in self.pt_module.users:
|
323
|
+
if "<built-in function getitem>" == str(user.target) and max_out_id < user.args[1]:
|
324
|
+
max_out_id = user.args[1]
|
325
|
+
return max_out_id + 1
|
326
|
+
|
319
327
|
def output(self, index):
|
320
328
|
return self.outputs()[index]
|
321
329
|
|
@@ -32,7 +32,7 @@ def patched_forward(self, *args, **kwargs):
|
|
32
32
|
x = args[0]
|
33
33
|
dtype = x.dtype
|
34
34
|
outshape = x.shape[:-1] + (self.width,)
|
35
|
-
x = x.view(-1, x.shape[-1])
|
35
|
+
x = x.contiguous().view(-1, x.shape[-1])
|
36
36
|
groups = self.qzeros.shape[0]
|
37
37
|
height = self.qweight.shape[0]
|
38
38
|
|
@@ -30,6 +30,7 @@ def patch_model(model, module_extensions, orig_forward_name):
|
|
30
30
|
|
31
31
|
if extension:
|
32
32
|
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
|
33
|
+
|
33
34
|
class Trampoline(torch.autograd.Function):
|
34
35
|
target_extension = extension
|
35
36
|
original_module = m
|
@@ -83,16 +84,35 @@ def unpatch_model(model, orig_forward_name):
|
|
83
84
|
|
84
85
|
|
85
86
|
def __make_16bit_traceable(model: torch.nn.Module):
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
87
|
+
"""
|
88
|
+
Prepare a 16-bit PyTorch model for tracing with OpenVINO.
|
89
|
+
- Replace known list of modules with ModuleExtension.
|
90
|
+
- Convert other modules with weights to FP32.
|
91
|
+
"""
|
92
|
+
extensions = {
|
93
|
+
torch.nn.Linear: ModuleExtension(
|
94
|
+
torch.nn.Linear, "ov_ext::linear",
|
95
|
+
evaluate=lambda module, *args, **kwargs: torch.full(
|
96
|
+
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
|
97
|
+
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias)),
|
98
|
+
torch.nn.Embedding: ModuleExtension(
|
99
|
+
torch.nn.Embedding, "ov_ext::embedding",
|
100
|
+
evaluate=lambda module, *args, **kwargs: torch.full(
|
101
|
+
list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
|
102
|
+
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight, args[0], module.padding_idx, module.scale_grad_by_freq, module.sparse)),
|
93
103
|
}
|
104
|
+
try:
|
105
|
+
from transformers.pytorch_utils import Conv1D
|
106
|
+
extensions[Conv1D] = ModuleExtension(
|
107
|
+
Conv1D, "ov_ext::conv1d",
|
108
|
+
evaluate=lambda module, *args, **kwargs: torch.full(
|
109
|
+
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
|
110
|
+
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias))
|
111
|
+
except:
|
112
|
+
pass
|
94
113
|
patch_model(model, extensions,
|
95
114
|
"_openvino_module_extension_patch_orig_forward")
|
96
115
|
for _, module in model.named_modules():
|
97
|
-
if module.__class__ not in extensions and
|
116
|
+
if module.__class__ not in extensions and (any([p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False)])
|
117
|
+
or any([b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False)])):
|
98
118
|
module.float()
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -241,6 +241,7 @@ class OperatorSupport(OperatorSupport):
|
|
241
241
|
"torch.ops.aten.transpose.int": None,
|
242
242
|
"torch.ops.aten.tril.default": None,
|
243
243
|
"torch.ops.aten.tril_.default": None,
|
244
|
+
"torch.ops.aten.triu.default": None,
|
244
245
|
"torch.ops.aten.unbind.int": None,
|
245
246
|
"torch.ops.aten.unfold.default": None,
|
246
247
|
"torch.ops.aten.unsqueeze.default": None,
|
@@ -96,6 +96,7 @@ class TorchScriptPythonDecoder (Decoder):
|
|
96
96
|
if isinstance(pt_module, torch.nn.Module):
|
97
97
|
pt_module.eval()
|
98
98
|
input_signature = None
|
99
|
+
input_parameters = None
|
99
100
|
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
100
101
|
# input params is dictionary contains input names and their signature values (type hints and default values if any)
|
101
102
|
input_params = inspect.signature(pt_module.forward if hasattr(
|
@@ -150,8 +151,10 @@ class TorchScriptPythonDecoder (Decoder):
|
|
150
151
|
scripted, preserved_attrs=preserved_attrs)
|
151
152
|
else:
|
152
153
|
f_model = scripted
|
154
|
+
self._example_input = input_parameters["example_inputs"] if input_parameters else None
|
153
155
|
else:
|
154
156
|
f_model = pt_module
|
157
|
+
self._example_input = example_inputs
|
155
158
|
|
156
159
|
self._input_signature = input_signature
|
157
160
|
return f_model
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -20,10 +20,10 @@ def pack_data(array: np.ndarray, type: Type) -> np.ndarray:
|
|
20
20
|
|
21
21
|
:param array: numpy array with values to pack.
|
22
22
|
:type array: numpy array
|
23
|
-
:param type: Type to interpret the array values. Type must be u1, u4, i4 or
|
23
|
+
:param type: Type to interpret the array values. Type must be u1, u4, i4, nf4 or f4e2m1.
|
24
24
|
:type type: openvino.runtime.Type
|
25
25
|
"""
|
26
|
-
assert type in [Type.u1, Type.u4, Type.i4, Type.nf4], "Packing algorithm for the" "data types stored in 1, 2 or 4 bits"
|
26
|
+
assert type in [Type.u1, Type.u4, Type.i4, Type.nf4, Type.f4e2m1], "Packing algorithm for the" "data types stored in 1, 2 or 4 bits"
|
27
27
|
|
28
28
|
minimum_regular_dtype = np.int8 if type == Type.i4 else np.uint8
|
29
29
|
casted_to_regular_type = array.astype(dtype=minimum_regular_dtype, casting="unsafe")
|
@@ -57,12 +57,12 @@ def unpack_data(array: np.ndarray, type: Type, shape: Union[list, Shape]) -> np.
|
|
57
57
|
|
58
58
|
:param array: numpy array to unpack.
|
59
59
|
:type array: numpy array
|
60
|
-
:param type: Type to extract from array values. Type must be u1, u4, i4 or
|
60
|
+
:param type: Type to extract from array values. Type must be u1, u4, i4, nf4 or f4e2m1.
|
61
61
|
:type type: openvino.runtime.Type
|
62
62
|
:param shape: the new shape for the unpacked array.
|
63
63
|
:type shape: Union[list, openvino.runtime.Shape]
|
64
64
|
"""
|
65
|
-
assert type in [Type.u1, Type.u4, Type.i4, Type.nf4], "Unpacking algorithm for the" "data types stored in 1, 2 or 4 bits"
|
65
|
+
assert type in [Type.u1, Type.u4, Type.i4, Type.nf4, Type.f4e2m1], "Unpacking algorithm for the" "data types stored in 1, 2 or 4 bits"
|
66
66
|
unpacked = np.unpackbits(array.view(np.uint8))
|
67
67
|
shape = list(shape)
|
68
68
|
if type.bitwidth == 1:
|
@@ -24,3 +24,5 @@ from openvino._pyopenvino.preprocess import PreProcessSteps
|
|
24
24
|
from openvino._pyopenvino.preprocess import PostProcessSteps
|
25
25
|
from openvino._pyopenvino.preprocess import ColorFormat
|
26
26
|
from openvino._pyopenvino.preprocess import ResizeAlgorithm
|
27
|
+
from openvino._pyopenvino.preprocess import PaddingMode
|
28
|
+
|
@@ -15,4 +15,5 @@ from openvino._pyopenvino.op import Parameter
|
|
15
15
|
from openvino._pyopenvino.op import if_op
|
16
16
|
from openvino._pyopenvino.op import loop
|
17
17
|
from openvino._pyopenvino.op import tensor_iterator
|
18
|
+
from openvino._pyopenvino.op import read_value
|
18
19
|
from openvino._pyopenvino.op import Result
|
@@ -31,7 +31,7 @@ from openvino.runtime.utils.types import (
|
|
31
31
|
get_element_type_str,
|
32
32
|
make_constant_node,
|
33
33
|
)
|
34
|
-
|
34
|
+
from openvino.utils import deprecated
|
35
35
|
|
36
36
|
_get_node_factory_opset1 = partial(_get_node_factory, "opset1")
|
37
37
|
|
@@ -1532,6 +1532,7 @@ def lstm_cell(
|
|
1532
1532
|
return _get_node_factory_opset1().create("LSTMCell", node_inputs, attributes)
|
1533
1533
|
|
1534
1534
|
|
1535
|
+
@deprecated(version="2025.0", message="Use lstm_sequence from opset 5")
|
1535
1536
|
@nameable_op
|
1536
1537
|
def lstm_sequence(
|
1537
1538
|
X: NodeInput,
|