bigdl-core-npu 2.6.0b20250114__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bigdl-core-npu/__init__.py +0 -0
- bigdl-core-npu/include/common.h +96 -0
- bigdl-core-npu/include/npu_llm.h +74 -0
- bigdl-core-npu/npu_llm.dll +0 -0
- bigdl-core-npu/npu_llm.lib +0 -0
- bigdl_core_npu-2.6.0b20250114.dist-info/METADATA +44 -0
- bigdl_core_npu-2.6.0b20250114.dist-info/RECORD +234 -0
- bigdl_core_npu-2.6.0b20250114.dist-info/WHEEL +5 -0
- bigdl_core_npu-2.6.0b20250114.dist-info/top_level.txt +2 -0
- intel_npu_acceleration_library/__init__.py +24 -0
- intel_npu_acceleration_library/_version.py +6 -0
- intel_npu_acceleration_library/backend/__init__.py +37 -0
- intel_npu_acceleration_library/backend/base.py +250 -0
- intel_npu_acceleration_library/backend/bindings.py +383 -0
- intel_npu_acceleration_library/backend/compression.py +24 -0
- intel_npu_acceleration_library/backend/convolution.py +58 -0
- intel_npu_acceleration_library/backend/factory.py +1161 -0
- intel_npu_acceleration_library/backend/linear.py +60 -0
- intel_npu_acceleration_library/backend/matmul.py +59 -0
- intel_npu_acceleration_library/backend/mlp.py +58 -0
- intel_npu_acceleration_library/backend/ops.py +142 -0
- intel_npu_acceleration_library/backend/qlinear.py +75 -0
- intel_npu_acceleration_library/backend/qmatmul.py +66 -0
- intel_npu_acceleration_library/backend/runtime.py +215 -0
- intel_npu_acceleration_library/backend/sdpa.py +107 -0
- intel_npu_acceleration_library/backend/tensor.py +1120 -0
- intel_npu_acceleration_library/backend/utils.py +70 -0
- intel_npu_acceleration_library/compiler.py +194 -0
- intel_npu_acceleration_library/device.py +230 -0
- intel_npu_acceleration_library/dtypes.py +155 -0
- intel_npu_acceleration_library/external/openvino/__init__.py +72 -0
- intel_npu_acceleration_library/external/openvino/_offline_transformations/__init__.py +21 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/experimental/__init__.py +14 -0
- intel_npu_acceleration_library/external/openvino/frontend/__init__.py +34 -0
- intel_npu_acceleration_library/external/openvino/frontend/frontend.py +44 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/jaxpr_decoder.py +293 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/passes.py +65 -0
- intel_npu_acceleration_library/external/openvino/frontend/jax/utils.py +182 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/__init__.py +19 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/fx_decoder.py +370 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/gptq.py +180 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/module_extension.py +39 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/patch_model.py +118 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend.py +131 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend_utils.py +85 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/compile.py +141 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/decompositions.py +116 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/execute.py +189 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/op_support.py +290 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/partition.py +126 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/ts_decoder.py +568 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/utils.py +258 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/__init__.py +16 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/graph_iterator.py +116 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/node_decoder.py +219 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/utils.py +481 -0
- intel_npu_acceleration_library/external/openvino/helpers/__init__.py +6 -0
- intel_npu_acceleration_library/external/openvino/helpers/packing.py +87 -0
- intel_npu_acceleration_library/external/openvino/preprocess/README.md +60 -0
- intel_npu_acceleration_library/external/openvino/preprocess/__init__.py +28 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/preprocess_converter.py +47 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/requirements.txt +5 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/torchvision_preprocessing.py +347 -0
- intel_npu_acceleration_library/external/openvino/properties/__init__.py +22 -0
- intel_npu_acceleration_library/external/openvino/properties/_properties.py +55 -0
- intel_npu_acceleration_library/external/openvino/properties/device/__init__.py +14 -0
- intel_npu_acceleration_library/external/openvino/properties/hint/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_auto/__init__.py +12 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_cpu/__init__.py +8 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_gpu/__init__.py +12 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_gpu/hint/__init__.py +11 -0
- intel_npu_acceleration_library/external/openvino/properties/log/__init__.py +11 -0
- intel_npu_acceleration_library/external/openvino/properties/streams/__init__.py +11 -0
- intel_npu_acceleration_library/external/openvino/runtime/__init__.py +85 -0
- intel_npu_acceleration_library/external/openvino/runtime/exceptions.py +17 -0
- intel_npu_acceleration_library/external/openvino/runtime/ie_api.py +631 -0
- intel_npu_acceleration_library/external/openvino/runtime/op/__init__.py +19 -0
- intel_npu_acceleration_library/external/openvino/runtime/op/util/__init__.py +22 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset1/__init__.py +112 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset1/ops.py +3068 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset10/__init__.py +179 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset10/ops.py +173 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset11/__init__.py +179 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset11/ops.py +107 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset12/__init__.py +180 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset12/ops.py +120 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset13/__init__.py +188 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset13/ops.py +398 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset14/__init__.py +190 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset14/ops.py +171 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset15/__init__.py +17 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset15/ops.py +276 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset2/__init__.py +118 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset2/ops.py +216 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset3/__init__.py +134 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset3/ops.py +638 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset4/__init__.py +145 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset4/ops.py +464 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset5/__init__.py +152 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset5/ops.py +372 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset6/__init__.py +154 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset6/ops.py +215 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset7/__init__.py +158 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset7/ops.py +169 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset8/__init__.py +169 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset8/ops.py +787 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset9/__init__.py +175 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset9/ops.py +341 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset_utils.py +22 -0
- intel_npu_acceleration_library/external/openvino/runtime/passes/__init__.py +19 -0
- intel_npu_acceleration_library/external/openvino/runtime/passes/graph_rewrite.py +33 -0
- intel_npu_acceleration_library/external/openvino/runtime/passes/manager.py +26 -0
- intel_npu_acceleration_library/external/openvino/runtime/properties/__init__.py +40 -0
- intel_npu_acceleration_library/external/openvino/runtime/properties/hint/__init__.py +25 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/__init__.py +7 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/broadcasting.py +44 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/__init__.py +8 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/data_dispatcher.py +447 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/wrappers.py +148 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/decorators.py +156 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/input_validation.py +133 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/node_factory.py +127 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/reduction.py +25 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/types.py +175 -0
- intel_npu_acceleration_library/external/openvino/tools/__init__.py +4 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/__init__.py +3 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/benchmark.py +186 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/main.py +695 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/parameters.py +199 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/__init__.py +3 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/constants.py +26 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/inputs_filling.py +482 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/logging.py +8 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/statistics_report.py +296 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/utils.py +836 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/__init__.py +20 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/__main__.py +10 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/cli_parser.py +633 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert.py +102 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert_data_type.py +82 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert_impl.py +550 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/environment_setup_utils.py +50 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/error.py +49 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/get_ov_update_message.py +16 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/help.py +45 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/logger.py +91 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/main.py +40 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/__init__.py +2 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/analysis.py +46 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/check_config.py +57 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/extractor.py +447 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/jax_frontend_utils.py +19 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/layout_utils.py +73 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/moc_emit_ir.py +32 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/offline_transformations.py +107 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/paddle_frontend_utils.py +83 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pipeline.py +298 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/preprocessing.py +220 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +214 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/shape_utils.py +109 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/type_utils.py +82 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/ovc.py +13 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_params.py +6 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_stub.py +28 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_utils.py +118 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/utils.py +196 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/version.py +80 -0
- intel_npu_acceleration_library/external/openvino/torch/__init__.py +5 -0
- intel_npu_acceleration_library/external/openvino/utils.py +115 -0
- intel_npu_acceleration_library/functional/__init__.py +8 -0
- intel_npu_acceleration_library/functional/scaled_dot_product_attention.py +47 -0
- intel_npu_acceleration_library/lib/Release/cache.json +113732 -0
- intel_npu_acceleration_library/lib/Release/intel_npu_acceleration_library.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_auto_batch_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_auto_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_c.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_hetero_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_cpu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_gpu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_npu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_ir_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_onnx_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_paddle_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_pytorch_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_tensorflow_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_tensorflow_lite_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbb12.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbb12_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbbind_2_5.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbbind_2_5_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy_debug.dll +0 -0
- intel_npu_acceleration_library/modelling.py +150 -0
- intel_npu_acceleration_library/nn/__init__.py +20 -0
- intel_npu_acceleration_library/nn/autograd.py +68 -0
- intel_npu_acceleration_library/nn/conv.py +257 -0
- intel_npu_acceleration_library/nn/functional.py +1207 -0
- intel_npu_acceleration_library/nn/linear.py +162 -0
- intel_npu_acceleration_library/nn/llm.py +417 -0
- intel_npu_acceleration_library/nn/module.py +393 -0
- intel_npu_acceleration_library/optimizations.py +157 -0
- intel_npu_acceleration_library/quantization.py +174 -0
@@ -0,0 +1,72 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
|
6
|
+
|
7
|
+
# Required for Windows OS platforms
|
8
|
+
# Note: always top-level
|
9
|
+
try:
|
10
|
+
from openvino.utils import _add_openvino_libs_to_search_path
|
11
|
+
_add_openvino_libs_to_search_path()
|
12
|
+
except ImportError:
|
13
|
+
pass
|
14
|
+
|
15
|
+
# #
|
16
|
+
# # OpenVINO API
|
17
|
+
# # This __init__.py forces checking of runtime modules to propagate errors.
|
18
|
+
# # It is not compared with init files from openvino-dev package.
|
19
|
+
# #
|
20
|
+
# Import all public modules
|
21
|
+
from openvino import runtime as runtime
|
22
|
+
from openvino import frontend as frontend
|
23
|
+
from openvino import helpers as helpers
|
24
|
+
from openvino import experimental as experimental
|
25
|
+
from openvino import preprocess as preprocess
|
26
|
+
from openvino import utils as utils
|
27
|
+
from openvino import properties as properties
|
28
|
+
|
29
|
+
# Import most important classes and functions from openvino.runtime
|
30
|
+
from openvino.runtime import Model
|
31
|
+
from openvino.runtime import Core
|
32
|
+
from openvino.runtime import CompiledModel
|
33
|
+
from openvino.runtime import InferRequest
|
34
|
+
from openvino.runtime import AsyncInferQueue
|
35
|
+
|
36
|
+
from openvino.runtime import Symbol
|
37
|
+
from openvino.runtime import Dimension
|
38
|
+
from openvino.runtime import Strides
|
39
|
+
from openvino.runtime import PartialShape
|
40
|
+
from openvino.runtime import Shape
|
41
|
+
from openvino.runtime import Layout
|
42
|
+
from openvino.runtime import Type
|
43
|
+
from openvino.runtime import Tensor
|
44
|
+
from openvino.runtime import OVAny
|
45
|
+
|
46
|
+
from openvino.runtime import compile_model
|
47
|
+
from openvino.runtime import get_batch
|
48
|
+
from openvino.runtime import set_batch
|
49
|
+
from openvino.runtime import serialize
|
50
|
+
from openvino.runtime import shutdown
|
51
|
+
from openvino.runtime import tensor_from_file
|
52
|
+
from openvino.runtime import save_model
|
53
|
+
from openvino.runtime import layout_helpers
|
54
|
+
|
55
|
+
from openvino._pyopenvino import RemoteContext
|
56
|
+
from openvino._pyopenvino import RemoteTensor
|
57
|
+
from openvino._pyopenvino import Op
|
58
|
+
|
59
|
+
# libva related:
|
60
|
+
from openvino._pyopenvino import VAContext
|
61
|
+
from openvino._pyopenvino import VASurfaceTensor
|
62
|
+
|
63
|
+
# Set version for openvino package
|
64
|
+
from openvino.runtime import get_version
|
65
|
+
__version__ = get_version()
|
66
|
+
|
67
|
+
# Tools
|
68
|
+
try:
|
69
|
+
# Model Conversion API - ovc should reside in the main namespace
|
70
|
+
from openvino.tools.ovc import convert_model
|
71
|
+
except ImportError:
|
72
|
+
pass
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
# flake8: noqa
|
6
|
+
|
7
|
+
from openvino._pyopenvino import get_version
|
8
|
+
|
9
|
+
__version__ = get_version()
|
10
|
+
|
11
|
+
from openvino._pyopenvino._offline_transformations import apply_fused_names_cleanup
|
12
|
+
from openvino._pyopenvino._offline_transformations import apply_moc_transformations
|
13
|
+
from openvino._pyopenvino._offline_transformations import apply_moc_legacy_transformations
|
14
|
+
from openvino._pyopenvino._offline_transformations import apply_low_latency_transformation
|
15
|
+
from openvino._pyopenvino._offline_transformations import apply_pruning_transformation
|
16
|
+
from openvino._pyopenvino._offline_transformations import apply_make_stateful_transformation
|
17
|
+
from openvino._pyopenvino._offline_transformations import compress_model_transformation
|
18
|
+
from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
|
19
|
+
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
|
20
|
+
from openvino._pyopenvino._offline_transformations import paged_attention_transformation
|
21
|
+
from openvino._pyopenvino._offline_transformations import stateful_to_stateless_transformation
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
"""
|
5
|
+
Package: openvino
|
6
|
+
This module provides access to experimental functionality that is subject to change without prior notice.
|
7
|
+
"""
|
8
|
+
|
9
|
+
# flake8: noqa
|
10
|
+
|
11
|
+
from openvino._pyopenvino.experimental import evaluate_as_partial_shape
|
12
|
+
from openvino._pyopenvino.experimental import evaluate_both_bounds
|
13
|
+
from openvino._pyopenvino.experimental import set_element_type
|
14
|
+
from openvino._pyopenvino.experimental import set_tensor_type
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
"""
|
5
|
+
Package: openvino
|
6
|
+
Low level wrappers for the FrontEnd C++ API.
|
7
|
+
"""
|
8
|
+
|
9
|
+
# flake8: noqa
|
10
|
+
|
11
|
+
from openvino._pyopenvino import get_version
|
12
|
+
|
13
|
+
__version__ = get_version()
|
14
|
+
|
15
|
+
# main classes
|
16
|
+
from openvino.frontend.frontend import FrontEndManager
|
17
|
+
from openvino.frontend.frontend import FrontEnd
|
18
|
+
from openvino._pyopenvino import InputModel
|
19
|
+
from openvino._pyopenvino import NodeContext
|
20
|
+
from openvino._pyopenvino import Place
|
21
|
+
|
22
|
+
# extensions
|
23
|
+
from openvino._pyopenvino import DecoderTransformationExtension
|
24
|
+
from openvino._pyopenvino import ConversionExtension
|
25
|
+
from openvino._pyopenvino import OpExtension
|
26
|
+
from openvino._pyopenvino import ProgressReporterExtension
|
27
|
+
from openvino._pyopenvino import TelemetryExtension
|
28
|
+
|
29
|
+
# exceptions
|
30
|
+
from openvino._pyopenvino import NotImplementedFailure
|
31
|
+
from openvino._pyopenvino import InitializationFailure
|
32
|
+
from openvino._pyopenvino import OpConversionFailure
|
33
|
+
from openvino._pyopenvino import OpValidationFailure
|
34
|
+
from openvino._pyopenvino import GeneralFailure
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
from openvino._pyopenvino import FrontEnd as FrontEndBase
|
8
|
+
from openvino._pyopenvino import FrontEndManager as FrontEndManagerBase
|
9
|
+
from openvino._pyopenvino import InputModel
|
10
|
+
from openvino.runtime import Model
|
11
|
+
|
12
|
+
|
13
|
+
class FrontEnd(FrontEndBase):
|
14
|
+
def __init__(self, fe: FrontEndBase) -> None:
|
15
|
+
super().__init__(fe)
|
16
|
+
|
17
|
+
def convert(self, model: Union[Model, InputModel]) -> Model:
|
18
|
+
converted_model = super().convert(model)
|
19
|
+
if isinstance(model, InputModel):
|
20
|
+
return Model(converted_model)
|
21
|
+
return converted_model
|
22
|
+
|
23
|
+
def convert_partially(self, model: InputModel) -> Model:
|
24
|
+
return Model(super().convert_partially(model))
|
25
|
+
|
26
|
+
def decode(self, model: InputModel) -> Model:
|
27
|
+
return Model(super().decode(model))
|
28
|
+
|
29
|
+
def normalize(self, model: Model) -> None:
|
30
|
+
super().normalize(model)
|
31
|
+
|
32
|
+
|
33
|
+
class FrontEndManager(FrontEndManagerBase):
|
34
|
+
def load_by_framework(self, framework: str) -> Union[FrontEnd, None]:
|
35
|
+
fe = super().load_by_framework(framework)
|
36
|
+
if fe is not None:
|
37
|
+
return FrontEnd(fe)
|
38
|
+
return fe
|
39
|
+
|
40
|
+
def load_by_model(self, model: str) -> Union[FrontEnd, None]:
|
41
|
+
fe = super().load_by_model(model)
|
42
|
+
if fe is not None:
|
43
|
+
return FrontEnd(fe)
|
44
|
+
return fe
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
"""
|
5
|
+
Package: openvino
|
6
|
+
Low level wrappers for the FrontEnd C++ API.
|
7
|
+
"""
|
8
|
+
|
9
|
+
# flake8: noqa
|
10
|
+
|
11
|
+
try:
|
12
|
+
from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
|
13
|
+
except ImportError as err:
|
14
|
+
raise ImportError("OpenVINO JAX frontend is not available, please make sure the frontend is built."
|
15
|
+
"{}".format(err))
|
@@ -0,0 +1,293 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# flake8: noqa
|
5
|
+
# mypy: ignore-errors
|
6
|
+
|
7
|
+
import jax.core
|
8
|
+
from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
|
9
|
+
from openvino.runtime import PartialShape, Type as OVType, OVAny
|
10
|
+
from openvino.frontend.jax.utils import jax_array_to_ov_const, get_ov_type_for_value, \
|
11
|
+
ivalue_to_constant, param_to_constants
|
12
|
+
|
13
|
+
import jax
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from typing import List
|
17
|
+
import logging
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
logger.setLevel(logging.WARNING)
|
20
|
+
|
21
|
+
class JaxprPythonDecoder (Decoder):
|
22
|
+
'''
|
23
|
+
The jaxpr decoder uses Jaxpr to get graph information from a jax module.
|
24
|
+
It takes use of the following parts.
|
25
|
+
|
26
|
+
- `ClosedJaxpr`: the jaxpr object that contains the jaxpr and literals.
|
27
|
+
- `Jaxpr`: the jaxpr object that contains the invars, outvars, and eqns.
|
28
|
+
- `JaxEqns`: A list of jaxpr equations, which contains the information of the operation.
|
29
|
+
- `Primitive`: the operation that is used in the equation.
|
30
|
+
- `invars`: the input variables of the equation.
|
31
|
+
- `aval`: the abstract value.
|
32
|
+
- `outvars`: the output variables of the equation.
|
33
|
+
- `aval`: the abstract value.
|
34
|
+
- `params`: the named params of this equation.
|
35
|
+
- `invars`: the inputs of the model (traced graph).
|
36
|
+
- `aval`: the abstract value.
|
37
|
+
- `outvars`: the outputs of the model (traced graph).
|
38
|
+
- `aval`: the abstract value.
|
39
|
+
- `constvars`: the constant variables used in this model.
|
40
|
+
- `aval`: the abstract value.
|
41
|
+
- `Literal`: the literal object that contains the value of the constants.
|
42
|
+
'''
|
43
|
+
|
44
|
+
def __init__(self, jaxpr, name=None, literals=None):
|
45
|
+
'''
|
46
|
+
Inputs:
|
47
|
+
- jaxpr: for users, `ClosedJaxpr` is expected here. See https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L197
|
48
|
+
- name: the name for the model.
|
49
|
+
- literals: the literals (constants) that are used in the model.
|
50
|
+
'''
|
51
|
+
Decoder.__init__(self)
|
52
|
+
|
53
|
+
if isinstance(jaxpr, (jax.core.JaxprEqn, jax.core.Jaxpr)):
|
54
|
+
self.jaxpr = jaxpr
|
55
|
+
elif isinstance(jaxpr, jax.core.ClosedJaxpr):
|
56
|
+
# Take the `Jaxpr` from `ClosedJaxpr`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
|
57
|
+
self.jaxpr = jaxpr.jaxpr
|
58
|
+
# Literal should be a `Jax.core.Var`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
|
59
|
+
self.literals = jaxpr.literals
|
60
|
+
else:
|
61
|
+
raise ValueError(f"Unexpected type of jaxpr: {type(jaxpr)}")
|
62
|
+
self.name = name
|
63
|
+
if self.name is None:
|
64
|
+
self.name = "jax_module"
|
65
|
+
if literals is not None:
|
66
|
+
self.literals = literals
|
67
|
+
|
68
|
+
self.params = {}
|
69
|
+
if hasattr(self.jaxpr, 'params') and isinstance(self.jaxpr.params, dict):
|
70
|
+
for k in self.jaxpr.params.keys():
|
71
|
+
converted = self.convert_param_to_constant_node(self.jaxpr, k)
|
72
|
+
if converted is not None:
|
73
|
+
self.params.update(converted)
|
74
|
+
|
75
|
+
# TODO: this implementation may lead to memory increasing. Any better solution?
|
76
|
+
self.m_decoders = []
|
77
|
+
|
78
|
+
def inputs(self) -> List[int]:
|
79
|
+
if isinstance(self.jaxpr, jax.core.JaxprEqn):
|
80
|
+
idx = 0
|
81
|
+
res = []
|
82
|
+
for inp in self.jaxpr.invars:
|
83
|
+
if isinstance(inp, jax.core.Literal):
|
84
|
+
res.append(self.literals[idx].output(0))
|
85
|
+
idx += 1
|
86
|
+
else:
|
87
|
+
res.append(id(inp))
|
88
|
+
return res
|
89
|
+
else:
|
90
|
+
return [id(v) for v in self.jaxpr.invars]
|
91
|
+
|
92
|
+
def input(self, idx: int) -> int:
|
93
|
+
return id(self.jaxpr.invars[idx])
|
94
|
+
|
95
|
+
def get_input_shape(self, index):
|
96
|
+
return PartialShape(self.jaxpr.invars[index].aval.shape)
|
97
|
+
|
98
|
+
def get_input_signature_name(self, index) -> str:
|
99
|
+
return "jaxpr_invar_" + str(index)
|
100
|
+
|
101
|
+
def get_input_type(self, index) -> OVType:
|
102
|
+
return get_ov_type_for_value(self.jaxpr.invars[index])
|
103
|
+
|
104
|
+
def get_named_param(self, name):
|
105
|
+
'''
|
106
|
+
Get the object id of the named parameter by the name.
|
107
|
+
'''
|
108
|
+
return self.params[name].output(0)
|
109
|
+
|
110
|
+
def get_named_param_as_constant(self, name):
|
111
|
+
'''
|
112
|
+
The named parameter in JAX is a python object but we want to use its value in cpp.
|
113
|
+
Therefore this API is used to get the named parameter as a constant, which can be used
|
114
|
+
to extract the value of it in cpp-level.
|
115
|
+
'''
|
116
|
+
return self.params[name].as_constant()
|
117
|
+
|
118
|
+
def get_param_names(self):
|
119
|
+
'''
|
120
|
+
In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
|
121
|
+
For example, the `jax.lax.cat` operation has a named parameter `dim`,
|
122
|
+
which is used to indicate the dimension to concatenate the tensors.
|
123
|
+
|
124
|
+
Here we return the names of all the named params that appear in the model for the current `JaxEqn`.
|
125
|
+
'''
|
126
|
+
return list(self.params.keys())
|
127
|
+
|
128
|
+
def get_output_type(self, index) -> OVType:
|
129
|
+
return get_ov_type_for_value(self.jaxpr.outvars[index])
|
130
|
+
|
131
|
+
def get_output_name(self, index) -> str:
|
132
|
+
return "jaxpr_outvar_" + str(index)
|
133
|
+
|
134
|
+
def get_output_shape(self, index):
|
135
|
+
return PartialShape(self.jaxpr.outvars[index].aval.shape)
|
136
|
+
|
137
|
+
def visit_subgraph(self, node_visitor) -> None:
|
138
|
+
if isinstance(self.jaxpr, jax.core.JaxprEqn):
|
139
|
+
return
|
140
|
+
for _, decoder in self.params.items():
|
141
|
+
self.m_decoders.append(decoder)
|
142
|
+
node_visitor(decoder)
|
143
|
+
for idx, node in enumerate(self.jaxpr.constvars):
|
144
|
+
decoder = self.convert_literal_to_constant_node(
|
145
|
+
literal=self.literals[idx],
|
146
|
+
name=self.name + "/" + f"const({id(node)})",
|
147
|
+
output_id=id(node)
|
148
|
+
)
|
149
|
+
self.m_decoders.append(decoder)
|
150
|
+
node_visitor(decoder)
|
151
|
+
# Visit every `JaxEqn` in the jaxpr, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L285
|
152
|
+
for node in self.jaxpr.eqns:
|
153
|
+
literal_decoders = []
|
154
|
+
for inp in node.invars:
|
155
|
+
if isinstance(inp, jax.core.Literal):
|
156
|
+
literal_decoder = self.convert_literal_to_constant_node(inp)
|
157
|
+
literal_decoders.append(literal_decoder)
|
158
|
+
node_visitor(literal_decoder)
|
159
|
+
decoder = JaxprPythonDecoder(node, name=self.name + "/" + node.primitive.name, literals=literal_decoders)
|
160
|
+
self.m_decoders.append(decoder)
|
161
|
+
node_visitor(decoder)
|
162
|
+
|
163
|
+
def get_op_type(self) -> str:
|
164
|
+
if isinstance(self.jaxpr, jax.core.JaxprEqn):
|
165
|
+
return self.jaxpr.primitive.name
|
166
|
+
else:
|
167
|
+
return "root"
|
168
|
+
|
169
|
+
def outputs(self) -> List[int]:
|
170
|
+
return [id(v) for v in self.jaxpr.outvars]
|
171
|
+
|
172
|
+
def output(self, idx: int) -> int:
|
173
|
+
return id(self.jaxpr.outvars[idx])
|
174
|
+
|
175
|
+
def num_inputs(self) -> int:
|
176
|
+
return len(self.jaxpr.invars)
|
177
|
+
|
178
|
+
def num_outputs(self) -> int:
|
179
|
+
return len(self.jaxpr.outvars)
|
180
|
+
|
181
|
+
def as_constant(self):
|
182
|
+
if self.get_op_type() == 'constant':
|
183
|
+
value = self.literals
|
184
|
+
# TODO: dig out how to share the memory.
|
185
|
+
# Currently, using shared_memory will raise `ValueError: array is not writeable``
|
186
|
+
ov_const = jax_array_to_ov_const(value, shared_memory=False)
|
187
|
+
return ov_const.outputs()
|
188
|
+
else:
|
189
|
+
raise ValueError("This is not a constant node so it cannot be converted to a constant.")
|
190
|
+
|
191
|
+
@staticmethod
|
192
|
+
def convert_param_to_constant_node(jaxpr, param) -> dict:
|
193
|
+
assert hasattr(jaxpr, 'params'), "The jaxpr does not have params."
|
194
|
+
if hasattr(jaxpr, 'primitive'):
|
195
|
+
param_map = param_to_constants(jaxpr.primitive.name, param, jaxpr, shared_memory=False)
|
196
|
+
res = {}
|
197
|
+
for name, constant in param_map.items():
|
198
|
+
if constant is not None:
|
199
|
+
res[name] = _JaxprPythonConstantDecoder(constant=constant)
|
200
|
+
else:
|
201
|
+
constant = ivalue_to_constant(jaxpr.params[param], shared_memory=False)
|
202
|
+
res = {param: _JaxprPythonConstantDecoder(constant=constant)} if constant is not None else {}
|
203
|
+
return res
|
204
|
+
|
205
|
+
@staticmethod
|
206
|
+
def convert_literal_to_constant_node(literal, name=None, output_id=None):
|
207
|
+
if isinstance(literal, jax.core.Literal):
|
208
|
+
constant = ivalue_to_constant(literal.val, shared_memory=False)
|
209
|
+
elif isinstance(literal, (jax.Array, np.ndarray)):
|
210
|
+
constant = ivalue_to_constant(literal, shared_memory=False)
|
211
|
+
else:
|
212
|
+
raise TypeError( f"The input should be a literal or jax array, but got {type(literal)}.")
|
213
|
+
return _JaxprPythonConstantDecoder(constant=constant, name=name, output_id=output_id)
|
214
|
+
|
215
|
+
class _JaxprPythonConstantDecoder (Decoder):
|
216
|
+
def __init__(self, name=None, constant=None, output_id=None):
|
217
|
+
'''
|
218
|
+
A decoder specially for constants and named parameters.
|
219
|
+
|
220
|
+
Inputs:
|
221
|
+
- name: the name for the model.
|
222
|
+
- literals: the literals (constants) that are used in the model.
|
223
|
+
- output_id: the id specified for this decoder's output. If none, use `id(self.constant)`.
|
224
|
+
'''
|
225
|
+
Decoder.__init__(self)
|
226
|
+
|
227
|
+
self.name = name
|
228
|
+
self.constant = constant
|
229
|
+
self.output_id = id(self.constant) if output_id is None else output_id
|
230
|
+
|
231
|
+
def inputs(self) -> List[int]:
|
232
|
+
return []
|
233
|
+
|
234
|
+
def input(self, idx: int) -> int:
|
235
|
+
raise ValueError("This is a constant node so it does not have input.")
|
236
|
+
|
237
|
+
def get_input_shape(self, index):
|
238
|
+
raise ValueError("This is a constant node so it does not have input shape.")
|
239
|
+
|
240
|
+
def get_input_signature_name(self, index) -> str:
|
241
|
+
raise ValueError("This is a constant node so it does not have input signature name.")
|
242
|
+
|
243
|
+
def get_input_type(self, index) -> OVType:
|
244
|
+
raise ValueError("This is a constant node so it does not have input type.")
|
245
|
+
|
246
|
+
def get_named_param(self, name):
|
247
|
+
raise ValueError("This is a constant node so it does not have named param.")
|
248
|
+
|
249
|
+
def get_named_param_as_constant(self, name):
|
250
|
+
raise ValueError("This is a constant node so it does not have named param.")
|
251
|
+
|
252
|
+
def get_param_names(self):
|
253
|
+
'''
|
254
|
+
In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
|
255
|
+
For example, the `jax.lax.cat` operation has a named parameter `dim`,
|
256
|
+
which is used to indicate the dimension to concatenate the tensors.
|
257
|
+
|
258
|
+
However, `_JaxprPythonConstantDecoder` is already a named param or a constant.
|
259
|
+
So it will never have a named param.
|
260
|
+
'''
|
261
|
+
return []
|
262
|
+
|
263
|
+
def get_output_type(self, index) -> OVType:
|
264
|
+
assert len(self.constant) == 1
|
265
|
+
return OVAny(self.constant[0].element_type)
|
266
|
+
|
267
|
+
def get_output_name(self, index) -> str:
|
268
|
+
return "jaxpr_outvar_" + str(index)
|
269
|
+
|
270
|
+
def get_output_shape(self, index):
|
271
|
+
assert len(self.constant) == 1
|
272
|
+
return PartialShape(self.constant[0].shape)
|
273
|
+
|
274
|
+
def visit_subgraph(self, node_visitor) -> None:
|
275
|
+
return
|
276
|
+
|
277
|
+
def get_op_type(self) -> str:
|
278
|
+
return "constant"
|
279
|
+
|
280
|
+
def outputs(self) -> List[int]:
|
281
|
+
return [self.output_id]
|
282
|
+
|
283
|
+
def output(self, idx: int) -> int:
|
284
|
+
return self.output_id
|
285
|
+
|
286
|
+
def num_inputs(self) -> int:
|
287
|
+
return 0
|
288
|
+
|
289
|
+
def num_outputs(self) -> int:
|
290
|
+
return 1
|
291
|
+
|
292
|
+
def as_constant(self):
|
293
|
+
return self.constant
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# flake8: noqa
|
5
|
+
# mypy: ignore-errors
|
6
|
+
|
7
|
+
from enum import Enum
|
8
|
+
from jax.lax import ConvDimensionNumbers
|
9
|
+
|
10
|
+
def enum_values_pass(value):
|
11
|
+
if isinstance(value, Enum):
|
12
|
+
return value.value
|
13
|
+
return value
|
14
|
+
|
15
|
+
|
16
|
+
def conv_dimension_numbers_pass(value):
|
17
|
+
if isinstance(value, ConvDimensionNumbers):
|
18
|
+
return [
|
19
|
+
list(value.lhs_spec),
|
20
|
+
list(value.rhs_spec),
|
21
|
+
list(value.out_spec)
|
22
|
+
]
|
23
|
+
return value
|
24
|
+
|
25
|
+
|
26
|
+
def filter_element(value):
|
27
|
+
passes = [enum_values_pass]
|
28
|
+
for pass_ in passes:
|
29
|
+
value = pass_(value)
|
30
|
+
return value
|
31
|
+
|
32
|
+
|
33
|
+
def filter_ivalue(value):
|
34
|
+
passes = [conv_dimension_numbers_pass]
|
35
|
+
for pass_ in passes:
|
36
|
+
value = pass_(value)
|
37
|
+
return value
|
38
|
+
|
39
|
+
|
40
|
+
def dot_general_param_pass(param_name: str, jax_eqn):
|
41
|
+
param = jax_eqn.params[param_name]
|
42
|
+
res = {}
|
43
|
+
if param_name == 'dimension_numbers':
|
44
|
+
contract_dimensions = param[0]
|
45
|
+
assert len(contract_dimensions) == 2
|
46
|
+
res['contract_dimensions'] = [list(contract_dimensions[0]), list(contract_dimensions[1])]
|
47
|
+
|
48
|
+
batch_dimensions = param[1]
|
49
|
+
assert len(batch_dimensions) == 2
|
50
|
+
lhs_length = len(batch_dimensions[0])
|
51
|
+
rhs_length = len(batch_dimensions[1])
|
52
|
+
assert lhs_length == rhs_length
|
53
|
+
if lhs_length > 0:
|
54
|
+
res['batch_dimensions'] = [list(batch_dimensions[0]), list(batch_dimensions[1])]
|
55
|
+
return res
|
56
|
+
|
57
|
+
# mapping from primitive to pass
|
58
|
+
param_passes = {
|
59
|
+
'dot_general': dot_general_param_pass,
|
60
|
+
}
|
61
|
+
|
62
|
+
def filter_param(primitive: str, param_name: str, jax_eqn):
|
63
|
+
if primitive in param_passes:
|
64
|
+
return param_passes[primitive](param_name, jax_eqn)
|
65
|
+
return {param_name: jax_eqn.params[param_name]}
|