bigdl-core-npu 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bigdl_core_npu-2.5.0.dist-info/METADATA +35 -0
- bigdl_core_npu-2.5.0.dist-info/RECORD +223 -0
- bigdl_core_npu-2.5.0.dist-info/WHEEL +5 -0
- bigdl_core_npu-2.5.0.dist-info/top_level.txt +1 -0
- intel_npu_acceleration_library/__init__.py +24 -0
- intel_npu_acceleration_library/_version.py +6 -0
- intel_npu_acceleration_library/backend/__init__.py +37 -0
- intel_npu_acceleration_library/backend/base.py +215 -0
- intel_npu_acceleration_library/backend/bindings.py +279 -0
- intel_npu_acceleration_library/backend/compression.py +24 -0
- intel_npu_acceleration_library/backend/convolution.py +58 -0
- intel_npu_acceleration_library/backend/factory.py +944 -0
- intel_npu_acceleration_library/backend/linear.py +60 -0
- intel_npu_acceleration_library/backend/matmul.py +59 -0
- intel_npu_acceleration_library/backend/mlp.py +58 -0
- intel_npu_acceleration_library/backend/ops.py +141 -0
- intel_npu_acceleration_library/backend/qlinear.py +71 -0
- intel_npu_acceleration_library/backend/qmatmul.py +66 -0
- intel_npu_acceleration_library/backend/runtime.py +210 -0
- intel_npu_acceleration_library/backend/sdpa.py +107 -0
- intel_npu_acceleration_library/backend/tensor.py +1050 -0
- intel_npu_acceleration_library/backend/utils.py +70 -0
- intel_npu_acceleration_library/compiler.py +194 -0
- intel_npu_acceleration_library/device.py +230 -0
- intel_npu_acceleration_library/dtypes.py +122 -0
- intel_npu_acceleration_library/external/openvino/__init__.py +71 -0
- intel_npu_acceleration_library/external/openvino/_offline_transformations/__init__.py +20 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/_pyopenvino.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/__init__.py +34 -0
- intel_npu_acceleration_library/external/openvino/frontend/frontend.py +44 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/__init__.py +19 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/fx_decoder.py +352 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/gptq.py +139 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/module_extension.py +39 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/patch_model.py +98 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend.py +119 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend_utils.py +85 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/compile.py +141 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/decompositions.py +116 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/execute.py +189 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/op_support.py +289 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/partition.py +118 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/ts_decoder.py +536 -0
- intel_npu_acceleration_library/external/openvino/frontend/pytorch/utils.py +256 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/__init__.py +16 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/graph_iterator.py +116 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/node_decoder.py +219 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp310-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp311-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp312-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp38-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp39-win_amd64.pyd +0 -0
- intel_npu_acceleration_library/external/openvino/frontend/tensorflow/utils.py +460 -0
- intel_npu_acceleration_library/external/openvino/helpers/__init__.py +6 -0
- intel_npu_acceleration_library/external/openvino/helpers/packing.py +87 -0
- intel_npu_acceleration_library/external/openvino/preprocess/README.md +60 -0
- intel_npu_acceleration_library/external/openvino/preprocess/__init__.py +26 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/preprocess_converter.py +47 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/requirements.txt +4 -0
- intel_npu_acceleration_library/external/openvino/preprocess/torchvision/torchvision_preprocessing.py +347 -0
- intel_npu_acceleration_library/external/openvino/properties/__init__.py +21 -0
- intel_npu_acceleration_library/external/openvino/properties/_properties.py +55 -0
- intel_npu_acceleration_library/external/openvino/properties/device/__init__.py +14 -0
- intel_npu_acceleration_library/external/openvino/properties/hint/__init__.py +15 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_auto/__init__.py +12 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_cpu/__init__.py +8 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_gpu/__init__.py +12 -0
- intel_npu_acceleration_library/external/openvino/properties/intel_gpu/hint/__init__.py +11 -0
- intel_npu_acceleration_library/external/openvino/properties/log/__init__.py +11 -0
- intel_npu_acceleration_library/external/openvino/properties/streams/__init__.py +11 -0
- intel_npu_acceleration_library/external/openvino/runtime/__init__.py +85 -0
- intel_npu_acceleration_library/external/openvino/runtime/exceptions.py +17 -0
- intel_npu_acceleration_library/external/openvino/runtime/ie_api.py +631 -0
- intel_npu_acceleration_library/external/openvino/runtime/op/__init__.py +18 -0
- intel_npu_acceleration_library/external/openvino/runtime/op/util/__init__.py +22 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset1/__init__.py +112 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset1/ops.py +3067 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset10/__init__.py +179 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset10/ops.py +173 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset11/__init__.py +179 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset11/ops.py +107 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset12/__init__.py +180 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset12/ops.py +120 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset13/__init__.py +188 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset13/ops.py +399 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset14/__init__.py +190 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset14/ops.py +171 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset15/__init__.py +10 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset15/ops.py +85 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset2/__init__.py +118 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset2/ops.py +216 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset3/__init__.py +134 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset3/ops.py +638 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset4/__init__.py +145 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset4/ops.py +464 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset5/__init__.py +152 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset5/ops.py +372 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset6/__init__.py +154 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset6/ops.py +189 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset7/__init__.py +158 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset7/ops.py +169 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset8/__init__.py +169 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset8/ops.py +783 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset9/__init__.py +175 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset9/ops.py +341 -0
- intel_npu_acceleration_library/external/openvino/runtime/opset_utils.py +22 -0
- intel_npu_acceleration_library/external/openvino/runtime/passes/__init__.py +19 -0
- intel_npu_acceleration_library/external/openvino/runtime/passes/graph_rewrite.py +33 -0
- intel_npu_acceleration_library/external/openvino/runtime/passes/manager.py +26 -0
- intel_npu_acceleration_library/external/openvino/runtime/properties/__init__.py +38 -0
- intel_npu_acceleration_library/external/openvino/runtime/properties/hint/__init__.py +25 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/__init__.py +7 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/broadcasting.py +44 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/__init__.py +8 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/data_dispatcher.py +429 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/wrappers.py +148 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/decorators.py +70 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/input_validation.py +133 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/node_factory.py +127 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/reduction.py +25 -0
- intel_npu_acceleration_library/external/openvino/runtime/utils/types.py +175 -0
- intel_npu_acceleration_library/external/openvino/tools/__init__.py +4 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/__init__.py +3 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/benchmark.py +186 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/main.py +695 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/parameters.py +199 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/__init__.py +3 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/constants.py +26 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/inputs_filling.py +482 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/logging.py +8 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/statistics_report.py +296 -0
- intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/utils.py +836 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/__init__.py +20 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/__main__.py +10 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/cli_parser.py +633 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert.py +102 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert_data_type.py +82 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/convert_impl.py +536 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/environment_setup_utils.py +50 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/error.py +49 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/get_ov_update_message.py +16 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/help.py +45 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/logger.py +91 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/main.py +35 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/__init__.py +2 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/analysis.py +46 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/check_config.py +57 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/extractor.py +447 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/layout_utils.py +73 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/moc_emit_ir.py +32 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/offline_transformations.py +107 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/paddle_frontend_utils.py +83 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pipeline.py +246 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/preprocessing.py +220 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +205 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/shape_utils.py +109 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/type_utils.py +82 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/ovc.py +13 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_params.py +6 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_stub.py +28 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_utils.py +118 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/utils.py +109 -0
- intel_npu_acceleration_library/external/openvino/tools/ovc/version.py +80 -0
- intel_npu_acceleration_library/external/openvino/torch/__init__.py +5 -0
- intel_npu_acceleration_library/external/openvino/utils.py +98 -0
- intel_npu_acceleration_library/functional/__init__.py +8 -0
- intel_npu_acceleration_library/functional/scaled_dot_product_attention.py +47 -0
- intel_npu_acceleration_library/lib/Release/cache.json +113732 -0
- intel_npu_acceleration_library/lib/Release/intel_npu_acceleration_library.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_auto_batch_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_auto_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_c.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_hetero_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_cpu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_gpu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_intel_npu_plugin.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_ir_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_onnx_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_paddle_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_pytorch_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_tensorflow_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/openvino_tensorflow_lite_frontend.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbb12.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbb12_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbbind_2_5.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbbind_2_5_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_debug.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy.dll +0 -0
- intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy_debug.dll +0 -0
- intel_npu_acceleration_library/modelling.py +150 -0
- intel_npu_acceleration_library/nn/__init__.py +20 -0
- intel_npu_acceleration_library/nn/autograd.py +68 -0
- intel_npu_acceleration_library/nn/conv.py +257 -0
- intel_npu_acceleration_library/nn/functional.py +1207 -0
- intel_npu_acceleration_library/nn/linear.py +162 -0
- intel_npu_acceleration_library/nn/llm.py +417 -0
- intel_npu_acceleration_library/nn/module.py +393 -0
- intel_npu_acceleration_library/optimizations.py +157 -0
- intel_npu_acceleration_library/quantization.py +174 -0
@@ -0,0 +1,536 @@
|
|
1
|
+
# Copyright (C) 2018-2024 Intel Corporation
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
# flake8: noqa
|
5
|
+
# mypy: ignore-errors
|
6
|
+
|
7
|
+
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
|
8
|
+
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
9
|
+
from openvino.runtime import op, PartialShape, Type as OVType, OVAny
|
10
|
+
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor, graph_has_ops
|
11
|
+
from openvino.runtime import opset11 as ops
|
12
|
+
from openvino.frontend.pytorch import gptq
|
13
|
+
from openvino.frontend.pytorch import patch_model
|
14
|
+
from openvino.frontend.pytorch.module_extension import ModuleExtension
|
15
|
+
|
16
|
+
import typing
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
class TorchScriptPythonDecoder (Decoder):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
pt_module,
|
24
|
+
graph_element=None,
|
25
|
+
example_input=None,
|
26
|
+
alias_db=None,
|
27
|
+
shared_memory=True,
|
28
|
+
skip_freeze=False,
|
29
|
+
constant_cache=None,
|
30
|
+
module_extensions=None):
|
31
|
+
Decoder.__init__(self)
|
32
|
+
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
|
33
|
+
self.m_decoders = []
|
34
|
+
self._input_signature = None
|
35
|
+
self._shared_memory = shared_memory
|
36
|
+
self._input_is_list = False
|
37
|
+
self.constant_cache = constant_cache if constant_cache is not None else dict()
|
38
|
+
self.module_extensions = module_extensions
|
39
|
+
if graph_element is None:
|
40
|
+
try:
|
41
|
+
pt_module = self._get_scripted_model(
|
42
|
+
pt_module, example_input, skip_freeze)
|
43
|
+
except Exception as e:
|
44
|
+
if example_input is not None:
|
45
|
+
msg = "tracing"
|
46
|
+
help_msg = "Please check correctness of provided 'example_input'. "
|
47
|
+
"Sometimes models can be converted in scripted mode, please try running "
|
48
|
+
"conversion without 'example_input'."
|
49
|
+
else:
|
50
|
+
msg = "scripting"
|
51
|
+
help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument."
|
52
|
+
raise RuntimeError(
|
53
|
+
f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n{help_msg} "
|
54
|
+
"You can also provide TorchScript module that you obtained"
|
55
|
+
" yourself, please refer to PyTorch documentation: "
|
56
|
+
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.")
|
57
|
+
self.graph_element = pt_module.inlined_graph
|
58
|
+
self.alias_db = self.graph_element.alias_db()
|
59
|
+
else:
|
60
|
+
self.graph_element = graph_element
|
61
|
+
self.alias_db = alias_db
|
62
|
+
self.pt_module = pt_module
|
63
|
+
self.raw_inputs = list(self.graph_element.inputs())
|
64
|
+
self.raw_outputs = list(self.graph_element.outputs())
|
65
|
+
if self._input_signature is not None:
|
66
|
+
if "self" in self.raw_inputs[0].debugName():
|
67
|
+
self._input_signature.insert(0, "self")
|
68
|
+
if 0 < len(self._input_signature) < len(self.raw_inputs):
|
69
|
+
# last input is args input, we need to multiply that name by number of extra inputs
|
70
|
+
self._input_signature = self._input_signature[:-1]
|
71
|
+
n = len(self._input_signature)
|
72
|
+
for i in range(len(self.raw_inputs) - n):
|
73
|
+
self._input_signature.append(
|
74
|
+
self.raw_inputs[i + n].debugName())
|
75
|
+
|
76
|
+
if isinstance(self.graph_element, torch.Graph):
|
77
|
+
self._transform_tensor_list_constants_to_listconstruct(
|
78
|
+
self.graph_element)
|
79
|
+
self._transform_optional_constants(self.graph_element)
|
80
|
+
self.out_debug_name_overwrites = {}
|
81
|
+
|
82
|
+
@staticmethod
|
83
|
+
def _get_preserved_attributes(model) -> list:
|
84
|
+
preserved_attributes = []
|
85
|
+
for name, module in model.named_modules():
|
86
|
+
if hasattr(module, "weight"):
|
87
|
+
if module.weight is not None and getattr(module.weight, "dtype", None) in [torch.int8, torch.uint8, torch.float16, torch.bfloat16]:
|
88
|
+
preserved_attributes.append(name)
|
89
|
+
return preserved_attributes
|
90
|
+
|
91
|
+
def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False):
|
92
|
+
import torch
|
93
|
+
import inspect
|
94
|
+
|
95
|
+
freeze_by_default = False
|
96
|
+
if isinstance(pt_module, torch.nn.Module):
|
97
|
+
pt_module.eval()
|
98
|
+
input_signature = None
|
99
|
+
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
100
|
+
# input params is dictionary contains input names and their signature values (type hints and default values if any)
|
101
|
+
input_params = inspect.signature(pt_module.forward if hasattr(
|
102
|
+
pt_module, "forward") else pt_module.__call__).parameters
|
103
|
+
input_signature = list(input_params)
|
104
|
+
|
105
|
+
if example_inputs is None:
|
106
|
+
if self.module_extensions:
|
107
|
+
raise RuntimeError("ModuleExtension is not supported for scripting. Please provide valid example_input argument to run tracing.")
|
108
|
+
scripted = torch.jit.script(pt_module)
|
109
|
+
freeze_by_default = True
|
110
|
+
else:
|
111
|
+
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
|
112
|
+
example_inputs, input_params, pt_module)
|
113
|
+
|
114
|
+
# name of attribute in a patched module where the original forward method is kept
|
115
|
+
orig_forward_name = '_openvino_module_extension_patch_orig_forward'
|
116
|
+
if self.module_extensions:
|
117
|
+
patch_model.patch_model(pt_module, self.module_extensions, orig_forward_name)
|
118
|
+
|
119
|
+
gptq_patched = False
|
120
|
+
if gptq.detect_gptq_model(pt_module):
|
121
|
+
try:
|
122
|
+
gptq.patch_model(pt_module)
|
123
|
+
gptq_patched = True
|
124
|
+
except Exception as error:
|
125
|
+
print(
|
126
|
+
'[ WARNING ] Failed patching of AutoGPTQ model. Error message:\n', error)
|
127
|
+
print(
|
128
|
+
'[ WARNING ] Tracing of the model will likely be unsuccessful or incorrect')
|
129
|
+
gptq.unpatch_model(pt_module)
|
130
|
+
gptq_patched = False
|
131
|
+
|
132
|
+
try:
|
133
|
+
scripted = torch.jit.trace(
|
134
|
+
pt_module, **input_parameters, strict=False)
|
135
|
+
finally:
|
136
|
+
if gptq_patched:
|
137
|
+
gptq.unpatch_model(pt_module)
|
138
|
+
if self.module_extensions:
|
139
|
+
patch_model.unpatch_model(pt_module, orig_forward_name)
|
140
|
+
|
141
|
+
if not freeze_by_default and graph_has_ops(scripted.inlined_graph, ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"]):
|
142
|
+
# freeze models with unsupported ops
|
143
|
+
freeze_by_default = True
|
144
|
+
if freeze_by_default and graph_has_ops(scripted.inlined_graph, ["quantized", "aten::as_strided"]):
|
145
|
+
# do not freeze quantized models and can't freeze for aten::as_strided it will result in incorrect inference
|
146
|
+
freeze_by_default = False
|
147
|
+
if freeze_by_default and not skip_freeze:
|
148
|
+
preserved_attrs = self._get_preserved_attributes(scripted)
|
149
|
+
f_model = torch.jit.freeze(
|
150
|
+
scripted, preserved_attrs=preserved_attrs)
|
151
|
+
else:
|
152
|
+
f_model = scripted
|
153
|
+
else:
|
154
|
+
f_model = pt_module
|
155
|
+
|
156
|
+
self._input_signature = input_signature
|
157
|
+
return f_model
|
158
|
+
|
159
|
+
def inputs(self) -> list:
|
160
|
+
return [x.unique() for x in self.raw_inputs]
|
161
|
+
|
162
|
+
def get_input(self, index: int):
|
163
|
+
return self.inputs()[index]
|
164
|
+
|
165
|
+
def get_input_debug_name(self, index: int) -> str:
|
166
|
+
return self._raw_input(index).debugName()
|
167
|
+
|
168
|
+
def get_input_signature_name(self, index: int) -> str:
|
169
|
+
if self._input_signature is not None and index < len(self._input_signature):
|
170
|
+
return self._input_signature[index]
|
171
|
+
return self.get_input_debug_name(index)
|
172
|
+
|
173
|
+
def get_input_shape(self, index: int):
|
174
|
+
raw_input = self._raw_input(index)
|
175
|
+
return self.get_shape_for_value(raw_input)
|
176
|
+
|
177
|
+
def get_input_strides(self, index: int) -> typing.List[int]:
|
178
|
+
raw_input = self._raw_input(index)
|
179
|
+
if isinstance(raw_input, torch.Value):
|
180
|
+
inp_type = raw_input.type()
|
181
|
+
if isinstance(inp_type, torch.TensorType):
|
182
|
+
strides = inp_type.strides()
|
183
|
+
if strides:
|
184
|
+
return strides
|
185
|
+
return []
|
186
|
+
|
187
|
+
def get_input_type(self, index: int):
|
188
|
+
raw_input = self._raw_input(index)
|
189
|
+
return self.get_type_for_value(raw_input)
|
190
|
+
|
191
|
+
def get_output_debug_name(self, index: int) -> str:
|
192
|
+
if index in self.out_debug_name_overwrites:
|
193
|
+
return self.out_debug_name_overwrites[index]
|
194
|
+
return self._raw_output(index).debugName()
|
195
|
+
|
196
|
+
def get_output_shape(self, index: int):
|
197
|
+
output = self._raw_output(index)
|
198
|
+
return self.get_shape_for_value(output)
|
199
|
+
|
200
|
+
def get_output_type(self, index: int):
|
201
|
+
output = self._raw_output(index)
|
202
|
+
return self.get_type_for_value(output)
|
203
|
+
|
204
|
+
def _get_known_type_for_value(self, pt_type):
|
205
|
+
"""Returns known/unknown types wrapped as OVAny."""
|
206
|
+
# Check for simple scalar types first
|
207
|
+
if pt_type is None:
|
208
|
+
return OVAny(OVType.dynamic)
|
209
|
+
# TODO: Don't use str, use native types
|
210
|
+
if str(pt_type) in ["int", "float", "bool"]:
|
211
|
+
return OVAny(DecoderType.PyScalar(OVAny(pt_to_ov_type_map[str(pt_type)])))
|
212
|
+
elif str(pt_type) in pt_to_ov_type_map:
|
213
|
+
return OVAny(pt_to_ov_type_map[str(pt_type)])
|
214
|
+
elif isinstance(pt_type, torch.TensorType):
|
215
|
+
# Tensor type, parse element type
|
216
|
+
return OVAny(DecoderType.Tensor(self._get_known_type_for_value(pt_type.dtype())))
|
217
|
+
elif isinstance(pt_type, torch.ListType):
|
218
|
+
element_type = pt_type.getElementType()
|
219
|
+
return OVAny(DecoderType.List(self._get_known_type_for_value(element_type)))
|
220
|
+
elif isinstance(pt_type, (torch.StringType, torch.DeviceObjType)):
|
221
|
+
return OVAny(DecoderType.Str())
|
222
|
+
elif isinstance(pt_type, torch.NoneType):
|
223
|
+
return OVAny(DecoderType.PyNone())
|
224
|
+
else:
|
225
|
+
# Not yet recognized
|
226
|
+
return OVAny(OVType.dynamic)
|
227
|
+
|
228
|
+
def get_shape_for_value(self, value: torch.Value):
|
229
|
+
if value.isCompleteTensor():
|
230
|
+
# We avoid static shapes, they don't generalize on other inputs
|
231
|
+
ps = PartialShape([-1] * len(value.type().sizes()))
|
232
|
+
return ps
|
233
|
+
else:
|
234
|
+
# TODO: Recognize types that we can represent as a nested constructs with objects from DecoderType
|
235
|
+
# If recognized, return scalar instead of dynamic. Scalar means a single value of that custom type.
|
236
|
+
# See get_type_for_value for reference
|
237
|
+
pass
|
238
|
+
return PartialShape.dynamic()
|
239
|
+
|
240
|
+
def get_type_for_value(self, value: torch.Value):
|
241
|
+
full_type = self._get_known_type_for_value(value.type())
|
242
|
+
return full_type
|
243
|
+
|
244
|
+
def get_subgraph_size(self) -> int:
|
245
|
+
if isinstance(self.graph_element, torch.Node):
|
246
|
+
return len(self.get_subgraphs())
|
247
|
+
else:
|
248
|
+
return 1
|
249
|
+
|
250
|
+
def visit_subgraph(self, node_visitor) -> None:
|
251
|
+
# make sure topological order is satisfied
|
252
|
+
for node in self.graph_element.nodes():
|
253
|
+
decoder = TorchScriptPythonDecoder(self.pt_module,
|
254
|
+
node,
|
255
|
+
alias_db=self.alias_db,
|
256
|
+
shared_memory=self._shared_memory,
|
257
|
+
constant_cache=self.constant_cache,
|
258
|
+
module_extensions=self.module_extensions)
|
259
|
+
self.m_decoders.append(decoder)
|
260
|
+
node_visitor(decoder)
|
261
|
+
|
262
|
+
def decoder_type_name(self) -> str:
|
263
|
+
return "ts"
|
264
|
+
|
265
|
+
def get_subgraphs(self) -> list:
|
266
|
+
if self.graph_element.kind() == "prim::PythonOp":
|
267
|
+
if "Subgraph" in self.graph_element.attributeNames():
|
268
|
+
assert isinstance(
|
269
|
+
self.graph_element, torch.Node), "Graph element must be of type torch.Node."
|
270
|
+
return [getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")]
|
271
|
+
else:
|
272
|
+
# Attribute "Subgraph" is only available if Graph was created using tracing.
|
273
|
+
# TODO Find way to extract subgraph for scripted Graph.
|
274
|
+
return []
|
275
|
+
return list(self.graph_element.blocks())
|
276
|
+
|
277
|
+
def get_subgraph_decoder(self, index: int):
|
278
|
+
decoder = TorchScriptPythonDecoder(self.pt_module,
|
279
|
+
self.get_subgraphs()[index],
|
280
|
+
alias_db=self.alias_db,
|
281
|
+
shared_memory=self._shared_memory,
|
282
|
+
module_extensions=self.module_extensions)
|
283
|
+
self.m_decoders.append(decoder)
|
284
|
+
return decoder
|
285
|
+
|
286
|
+
def get_op_type(self) -> str:
|
287
|
+
assert isinstance(
|
288
|
+
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
|
289
|
+
if self.graph_element.kind() == "prim::PythonOp":
|
290
|
+
if hasattr(self.graph_element, 'pyobj') and callable(self.graph_element.pyobj) and hasattr(self.graph_element.pyobj(), '__self__'):
|
291
|
+
trampoline = self.graph_element.pyobj().__self__
|
292
|
+
if hasattr(trampoline, 'target_extension') and isinstance(trampoline.target_extension, ModuleExtension):
|
293
|
+
target_op = trampoline.target_extension.target_op
|
294
|
+
if callable(target_op):
|
295
|
+
target = target_op(trampoline.original_module)
|
296
|
+
elif isinstance(target_op, str):
|
297
|
+
target = target_op
|
298
|
+
# TODO: Support target as a callable that will play a role of ConversionExtension for an entire module instead of a single op.
|
299
|
+
# Without supporting target as a callable here, ConversionExtension functionality is still possible to implement
|
300
|
+
# by combining two extensions: ModuleExtension that use temporary name as a target op and another extension of type ConversionExtension
|
301
|
+
# that translates that particular temporary name to custom graph. But providing conversion code as a callable `target` is more convenient.
|
302
|
+
return target
|
303
|
+
return self.graph_element.kind()
|
304
|
+
|
305
|
+
def get_schema(self) -> str:
|
306
|
+
return self.graph_element.schema()
|
307
|
+
|
308
|
+
def outputs(self) -> list:
|
309
|
+
return [x.unique() for x in self.raw_outputs]
|
310
|
+
|
311
|
+
def _raw_output(self, index: int):
|
312
|
+
return self.raw_outputs[index]
|
313
|
+
|
314
|
+
def _raw_input(self, index: int):
|
315
|
+
return self.raw_inputs[index]
|
316
|
+
|
317
|
+
def num_of_outputs(self):
|
318
|
+
return len(self.raw_outputs)
|
319
|
+
|
320
|
+
def output(self, index: int):
|
321
|
+
return self.outputs()[index]
|
322
|
+
|
323
|
+
def mark_node(self, node):
|
324
|
+
name = self.graph_element.kind()
|
325
|
+
if "FrameworkNode" not in node.get_type_name():
|
326
|
+
name += "/" + node.get_type_name()
|
327
|
+
if self.graph_element.scopeName():
|
328
|
+
node.set_friendly_name(
|
329
|
+
self.graph_element.scopeName().split("/")[-1] + "/" + name)
|
330
|
+
else:
|
331
|
+
node.set_friendly_name(name)
|
332
|
+
return node
|
333
|
+
|
334
|
+
def _add_name_to_const_and_cache(self, outputs, name):
|
335
|
+
if len(outputs) == 1:
|
336
|
+
# set name corresponding to state_dict name
|
337
|
+
outputs[0].get_node().set_friendly_name(name)
|
338
|
+
self.out_debug_name_overwrites[0] = name
|
339
|
+
self.constant_cache[name] = outputs
|
340
|
+
|
341
|
+
def try_decode_get_attr(self):
|
342
|
+
pt_value, name = get_value_from_getattr(self.graph_element,
|
343
|
+
self.pt_module)
|
344
|
+
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
|
345
|
+
if isinstance(pt_value, torch.ScriptObject):
|
346
|
+
# We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase
|
347
|
+
# TODO: but can be anything. Figure a better way to distinguish
|
348
|
+
weight, bias = pt_value.unpack()
|
349
|
+
w_name = name + ".weight"
|
350
|
+
if w_name in self.constant_cache:
|
351
|
+
res = self.constant_cache[w_name]
|
352
|
+
else:
|
353
|
+
res = convert_quantized_tensor(weight, self._shared_memory)
|
354
|
+
self._add_name_to_const_and_cache(res, w_name)
|
355
|
+
|
356
|
+
if isinstance(bias, torch.Tensor):
|
357
|
+
b_name = name + ".bias"
|
358
|
+
if b_name in self.constant_cache:
|
359
|
+
res += self.constant_cache[b_name]
|
360
|
+
else:
|
361
|
+
b_res = ivalue_to_constant(bias)
|
362
|
+
self._add_name_to_const_and_cache(b_res, b_name)
|
363
|
+
res += b_res
|
364
|
+
else:
|
365
|
+
res += ops.convert_like(ivalue_to_constant(torch.zeros(1))
|
366
|
+
[0], res[0]).outputs()
|
367
|
+
try:
|
368
|
+
# these params exist only for conv params
|
369
|
+
stride = pt_value.stride()
|
370
|
+
padding = pt_value.padding()
|
371
|
+
dilation = pt_value.dilation()
|
372
|
+
groups = pt_value.groups()
|
373
|
+
res += ivalue_to_constant(stride,
|
374
|
+
shared_memory=self._shared_memory)
|
375
|
+
res += ivalue_to_constant(padding,
|
376
|
+
shared_memory=self._shared_memory)
|
377
|
+
res += ivalue_to_constant(dilation,
|
378
|
+
shared_memory=self._shared_memory)
|
379
|
+
res += ivalue_to_constant(groups,
|
380
|
+
shared_memory=self._shared_memory)
|
381
|
+
except:
|
382
|
+
pass
|
383
|
+
return res
|
384
|
+
elif not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
|
385
|
+
# this tensor can be used multiple times in the model, so we have to reuse constants
|
386
|
+
if name in self.constant_cache:
|
387
|
+
const = self.constant_cache[name]
|
388
|
+
else:
|
389
|
+
const = ivalue_to_constant(pt_value,
|
390
|
+
shared_memory=self._shared_memory)
|
391
|
+
self._add_name_to_const_and_cache(const, name)
|
392
|
+
return const
|
393
|
+
else:
|
394
|
+
return []
|
395
|
+
|
396
|
+
def as_constant(self):
|
397
|
+
if not isinstance(self.graph_element, torch.Node):
|
398
|
+
return None
|
399
|
+
if not self.get_op_type() == "prim::Constant":
|
400
|
+
return None
|
401
|
+
pt_value = self._raw_output(0)
|
402
|
+
pt_type = pt_value.type()
|
403
|
+
if isinstance(pt_type, torch.TensorType):
|
404
|
+
return ivalue_to_constant(pt_value.toIValue(),
|
405
|
+
shared_memory=self._shared_memory)
|
406
|
+
if isinstance(pt_type, torch.ListType):
|
407
|
+
return self._as_constant_list(pt_value)
|
408
|
+
const = ivalue_to_constant(pt_value.toIValue(),
|
409
|
+
shared_memory=self._shared_memory)
|
410
|
+
if len(const) > 0:
|
411
|
+
# set name corresponding to state_dict name
|
412
|
+
const[0].get_node().set_friendly_name(
|
413
|
+
self.get_output_debug_name(0))
|
414
|
+
return const
|
415
|
+
|
416
|
+
def as_string(self):
|
417
|
+
if self.get_op_type() == "prim::Constant":
|
418
|
+
pt_value = self._raw_output(0)
|
419
|
+
if str(pt_value.type()) in ["torch.StringType", "str"]:
|
420
|
+
return pt_value.toIValue()
|
421
|
+
elif str(pt_value.type()) == "Device":
|
422
|
+
return pt_value.toIValue().type
|
423
|
+
elif self.get_op_type() == "prim::device":
|
424
|
+
return self._get_device_string()
|
425
|
+
return None
|
426
|
+
|
427
|
+
@staticmethod
|
428
|
+
def _as_constant_list(pt_value: torch.Value):
|
429
|
+
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
|
430
|
+
# rewrite them in that part where constant attributes are queried
|
431
|
+
pt_element_type = str(pt_value.type().getElementType())
|
432
|
+
ivalue = pt_value.toIValue()
|
433
|
+
is_known_type = pt_element_type in pt_to_ov_type_map
|
434
|
+
|
435
|
+
if is_known_type:
|
436
|
+
ovtype = pt_to_ov_type_map[pt_element_type]
|
437
|
+
ovshape = PartialShape([len(ivalue)])
|
438
|
+
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
|
439
|
+
return ov_const.outputs()
|
440
|
+
|
441
|
+
def _get_device_string(self) -> str:
|
442
|
+
assert self.graph_element.kind(
|
443
|
+
) == "prim::device", "This function can be called for prim::device node."
|
444
|
+
value = self.raw_inputs[0]
|
445
|
+
if value.type().isSubtypeOf(torch.TensorType.get()):
|
446
|
+
tensor = typing.cast(torch.TensorType, value.type())
|
447
|
+
device = tensor.device()
|
448
|
+
if device:
|
449
|
+
return str(device)
|
450
|
+
# Device cannot be statically determined.
|
451
|
+
return "cpu"
|
452
|
+
|
453
|
+
def input_is_none(self, index: int) -> bool:
|
454
|
+
if index >= len(self.inputs()) or self._raw_input(index) is None:
|
455
|
+
return True
|
456
|
+
else:
|
457
|
+
r_input = self._raw_input(index)
|
458
|
+
if str(r_input.type()) in ["torch.NoneType", "NoneType"]:
|
459
|
+
return True
|
460
|
+
else:
|
461
|
+
in_node = r_input.node()
|
462
|
+
if in_node.kind() == "prim::GetAttr":
|
463
|
+
pt_value, _ = get_value_from_getattr(in_node,
|
464
|
+
self.pt_module)
|
465
|
+
return pt_value is None
|
466
|
+
return False
|
467
|
+
|
468
|
+
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
|
469
|
+
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul", "aten::clone"]:
|
470
|
+
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
|
471
|
+
return False
|
472
|
+
try:
|
473
|
+
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))
|
474
|
+
except:
|
475
|
+
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
|
476
|
+
return False
|
477
|
+
|
478
|
+
def inlined_input(self, index):
|
479
|
+
return []
|
480
|
+
|
481
|
+
def is_input_inlined(self, index):
|
482
|
+
return False
|
483
|
+
|
484
|
+
def get_attribute(self, name):
|
485
|
+
return OVAny(None)
|
486
|
+
|
487
|
+
def get_named_input(self, name):
|
488
|
+
raise RuntimeError("There is no named inputs in TS graph")
|
489
|
+
|
490
|
+
@staticmethod
|
491
|
+
def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph):
|
492
|
+
# Function replaces prim::Constant containing List of Tensors with
|
493
|
+
# prim::ListConstruct containing prim::Constant Tensors.
|
494
|
+
assert isinstance(
|
495
|
+
graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
|
496
|
+
for node in graph.nodes():
|
497
|
+
if node.kind() != "prim::Constant":
|
498
|
+
continue
|
499
|
+
output_type = node.output().type()
|
500
|
+
allowed_types = [
|
501
|
+
output_type.isSubtypeOf(torch.ListType.ofTensors()),
|
502
|
+
output_type.isSubtypeOf(torch.ListType(
|
503
|
+
torch.OptionalType.ofTensor())),
|
504
|
+
]
|
505
|
+
if not any(allowed_types):
|
506
|
+
continue
|
507
|
+
const_inputs = []
|
508
|
+
for val in node.output().toIValue():
|
509
|
+
const_input = graph.insertConstant(val)
|
510
|
+
const_input.node().moveBefore(node)
|
511
|
+
const_input.node().copyMetadata(node)
|
512
|
+
const_inputs.append(const_input)
|
513
|
+
|
514
|
+
replacement = graph.create("prim::ListConstruct", const_inputs)
|
515
|
+
replacement.insertBefore(node)
|
516
|
+
replacement.output().setType(torch.ListType.ofTensors())
|
517
|
+
replacement.copyMetadata(node)
|
518
|
+
node.output().replaceAllUsesWith(replacement.output())
|
519
|
+
|
520
|
+
@staticmethod
|
521
|
+
def _transform_optional_constants(graph: torch.Graph):
|
522
|
+
# Function replaces prim::Constant containing torch.OptionalType with
|
523
|
+
# prim::Constant containing torch.NoneType or type of IValue.
|
524
|
+
assert isinstance(
|
525
|
+
graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
|
526
|
+
for node in graph.nodes():
|
527
|
+
if node.kind() != "prim::Constant":
|
528
|
+
continue
|
529
|
+
output_type = node.output().type()
|
530
|
+
if not isinstance(output_type, torch.OptionalType):
|
531
|
+
continue
|
532
|
+
value = node.output().toIValue()
|
533
|
+
const_input = graph.insertConstant(value)
|
534
|
+
const_input.node().moveBefore(node)
|
535
|
+
const_input.node().copyMetadata(node)
|
536
|
+
node.output().replaceAllUsesWith(const_input)
|