bigdl-core-npu 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (146) hide show
  1. bigdl-core-npu/__init__.py +0 -0
  2. bigdl-core-npu/common.lib +0 -0
  3. bigdl-core-npu/ggml.dll +0 -0
  4. bigdl-core-npu/ggml.lib +0 -0
  5. bigdl-core-npu/include/llamacpp/arg.h +77 -0
  6. bigdl-core-npu/include/llamacpp/common.h +563 -0
  7. bigdl-core-npu/include/llamacpp/ggml-alloc.h +76 -0
  8. bigdl-core-npu/include/llamacpp/ggml-backend.h +241 -0
  9. bigdl-core-npu/include/llamacpp/ggml.h +2679 -0
  10. bigdl-core-npu/include/llamacpp/llama.h +1234 -0
  11. bigdl-core-npu/include/llamacpp/log.h +92 -0
  12. bigdl-core-npu/include/npu/npu_common.h +119 -0
  13. bigdl-core-npu/include/npu/npu_llm.h +77 -0
  14. bigdl-core-npu/llama-cli-npu.exe +0 -0
  15. bigdl-core-npu/llama.dll +0 -0
  16. bigdl-core-npu/llama.lib +0 -0
  17. bigdl-core-npu/llm-cli.exe +0 -0
  18. bigdl-core-npu/npu_llm.dll +0 -0
  19. bigdl-core-npu/npu_llm.lib +0 -0
  20. bigdl-core-npu/zlib1.dll +0 -0
  21. bigdl_core_npu-2.6.0.data/scripts/init-llama-cpp.bat +29 -0
  22. {bigdl_core_npu-2.5.0.dist-info → bigdl_core_npu-2.6.0.dist-info}/METADATA +12 -3
  23. {bigdl_core_npu-2.5.0.dist-info → bigdl_core_npu-2.6.0.dist-info}/RECORD +146 -96
  24. {bigdl_core_npu-2.5.0.dist-info → bigdl_core_npu-2.6.0.dist-info}/WHEEL +1 -1
  25. {bigdl_core_npu-2.5.0.dist-info → bigdl_core_npu-2.6.0.dist-info}/top_level.txt +1 -0
  26. intel_npu_acceleration_library/_version.py +1 -1
  27. intel_npu_acceleration_library/backend/base.py +39 -4
  28. intel_npu_acceleration_library/backend/bindings.py +109 -5
  29. intel_npu_acceleration_library/backend/factory.py +264 -47
  30. intel_npu_acceleration_library/backend/ops.py +2 -1
  31. intel_npu_acceleration_library/backend/qlinear.py +8 -4
  32. intel_npu_acceleration_library/backend/runtime.py +7 -2
  33. intel_npu_acceleration_library/backend/tensor.py +73 -3
  34. intel_npu_acceleration_library/bigdl-core-npu/cache.json +113732 -0
  35. intel_npu_acceleration_library/bigdl-core-npu/openvino.dll +0 -0
  36. intel_npu_acceleration_library/bigdl-core-npu/openvino_auto_batch_plugin.dll +0 -0
  37. intel_npu_acceleration_library/bigdl-core-npu/openvino_auto_plugin.dll +0 -0
  38. intel_npu_acceleration_library/bigdl-core-npu/openvino_c.dll +0 -0
  39. intel_npu_acceleration_library/bigdl-core-npu/openvino_hetero_plugin.dll +0 -0
  40. intel_npu_acceleration_library/bigdl-core-npu/openvino_intel_cpu_plugin.dll +0 -0
  41. intel_npu_acceleration_library/bigdl-core-npu/openvino_intel_gpu_plugin.dll +0 -0
  42. intel_npu_acceleration_library/bigdl-core-npu/openvino_intel_npu_plugin.dll +0 -0
  43. intel_npu_acceleration_library/bigdl-core-npu/openvino_ir_frontend.dll +0 -0
  44. intel_npu_acceleration_library/bigdl-core-npu/openvino_onnx_frontend.dll +0 -0
  45. intel_npu_acceleration_library/bigdl-core-npu/openvino_paddle_frontend.dll +0 -0
  46. intel_npu_acceleration_library/bigdl-core-npu/openvino_pytorch_frontend.dll +0 -0
  47. intel_npu_acceleration_library/bigdl-core-npu/openvino_tensorflow_frontend.dll +0 -0
  48. intel_npu_acceleration_library/bigdl-core-npu/openvino_tensorflow_lite_frontend.dll +0 -0
  49. intel_npu_acceleration_library/bigdl-core-npu/tbb12.dll +0 -0
  50. intel_npu_acceleration_library/bigdl-core-npu/tbb12_debug.dll +0 -0
  51. intel_npu_acceleration_library/bigdl-core-npu/tbbbind_2_5.dll +0 -0
  52. intel_npu_acceleration_library/bigdl-core-npu/tbbbind_2_5_debug.dll +0 -0
  53. intel_npu_acceleration_library/bigdl-core-npu/tbbmalloc.dll +0 -0
  54. intel_npu_acceleration_library/bigdl-core-npu/tbbmalloc_debug.dll +0 -0
  55. intel_npu_acceleration_library/bigdl-core-npu/tbbmalloc_proxy.dll +0 -0
  56. intel_npu_acceleration_library/bigdl-core-npu/tbbmalloc_proxy_debug.dll +0 -0
  57. intel_npu_acceleration_library/device.py +2 -2
  58. intel_npu_acceleration_library/dtypes.py +34 -1
  59. intel_npu_acceleration_library/external/openvino/__init__.py +1 -0
  60. intel_npu_acceleration_library/external/openvino/_offline_transformations/__init__.py +1 -0
  61. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp310-win_amd64.pyd +0 -0
  62. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp311-win_amd64.pyd +0 -0
  63. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp312-win_amd64.pyd +0 -0
  64. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp38-win_amd64.pyd +0 -0
  65. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp39-win_amd64.pyd +0 -0
  66. intel_npu_acceleration_library/external/openvino/experimental/__init__.py +14 -0
  67. intel_npu_acceleration_library/external/openvino/frontend/jax/__init__.py +15 -0
  68. intel_npu_acceleration_library/external/openvino/frontend/jax/jaxpr_decoder.py +293 -0
  69. intel_npu_acceleration_library/external/openvino/frontend/jax/passes.py +65 -0
  70. intel_npu_acceleration_library/external/openvino/frontend/jax/utils.py +182 -0
  71. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp310-win_amd64.pyd +0 -0
  72. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp311-win_amd64.pyd +0 -0
  73. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp312-win_amd64.pyd +0 -0
  74. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp38-win_amd64.pyd +0 -0
  75. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp39-win_amd64.pyd +0 -0
  76. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp310-win_amd64.pyd +0 -0
  77. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp311-win_amd64.pyd +0 -0
  78. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp312-win_amd64.pyd +0 -0
  79. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp38-win_amd64.pyd +0 -0
  80. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp39-win_amd64.pyd +0 -0
  81. intel_npu_acceleration_library/external/openvino/frontend/pytorch/fx_decoder.py +37 -19
  82. intel_npu_acceleration_library/external/openvino/frontend/pytorch/gptq.py +47 -6
  83. intel_npu_acceleration_library/external/openvino/frontend/pytorch/patch_model.py +28 -8
  84. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp310-win_amd64.pyd +0 -0
  85. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp311-win_amd64.pyd +0 -0
  86. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp312-win_amd64.pyd +0 -0
  87. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp38-win_amd64.pyd +0 -0
  88. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp39-win_amd64.pyd +0 -0
  89. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend.py +17 -5
  90. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/op_support.py +1 -0
  91. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/partition.py +55 -47
  92. intel_npu_acceleration_library/external/openvino/frontend/pytorch/ts_decoder.py +95 -63
  93. intel_npu_acceleration_library/external/openvino/frontend/pytorch/utils.py +12 -10
  94. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp310-win_amd64.pyd +0 -0
  95. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp311-win_amd64.pyd +0 -0
  96. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp312-win_amd64.pyd +0 -0
  97. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp38-win_amd64.pyd +0 -0
  98. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp39-win_amd64.pyd +0 -0
  99. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/utils.py +31 -10
  100. intel_npu_acceleration_library/external/openvino/helpers/packing.py +4 -4
  101. intel_npu_acceleration_library/external/openvino/preprocess/__init__.py +2 -0
  102. intel_npu_acceleration_library/external/openvino/preprocess/torchvision/requirements.txt +1 -0
  103. intel_npu_acceleration_library/external/openvino/properties/__init__.py +1 -0
  104. intel_npu_acceleration_library/external/openvino/runtime/ie_api.py +1 -1
  105. intel_npu_acceleration_library/external/openvino/runtime/op/__init__.py +1 -0
  106. intel_npu_acceleration_library/external/openvino/runtime/opset1/ops.py +2 -1
  107. intel_npu_acceleration_library/external/openvino/runtime/opset13/ops.py +5 -6
  108. intel_npu_acceleration_library/external/openvino/runtime/opset15/__init__.py +7 -0
  109. intel_npu_acceleration_library/external/openvino/runtime/opset15/ops.py +193 -2
  110. intel_npu_acceleration_library/external/openvino/runtime/opset6/ops.py +69 -43
  111. intel_npu_acceleration_library/external/openvino/runtime/opset8/ops.py +4 -0
  112. intel_npu_acceleration_library/external/openvino/runtime/properties/__init__.py +2 -0
  113. intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/data_dispatcher.py +21 -3
  114. intel_npu_acceleration_library/external/openvino/runtime/utils/decorators.py +88 -2
  115. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/inputs_filling.py +9 -9
  116. intel_npu_acceleration_library/external/openvino/tools/ovc/convert_impl.py +16 -2
  117. intel_npu_acceleration_library/external/openvino/tools/ovc/main.py +5 -0
  118. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/jax_frontend_utils.py +19 -0
  119. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pipeline.py +68 -16
  120. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +69 -60
  121. intel_npu_acceleration_library/external/openvino/tools/ovc/utils.py +90 -3
  122. intel_npu_acceleration_library/external/openvino/utils.py +17 -0
  123. intel_npu_acceleration_library/lib/Release/intel_npu_acceleration_library.dll +0 -0
  124. intel_npu_acceleration_library/lib/Release/openvino.dll +0 -0
  125. intel_npu_acceleration_library/lib/Release/openvino_auto_batch_plugin.dll +0 -0
  126. intel_npu_acceleration_library/lib/Release/openvino_auto_plugin.dll +0 -0
  127. intel_npu_acceleration_library/lib/Release/openvino_c.dll +0 -0
  128. intel_npu_acceleration_library/lib/Release/openvino_hetero_plugin.dll +0 -0
  129. intel_npu_acceleration_library/lib/Release/openvino_intel_cpu_plugin.dll +0 -0
  130. intel_npu_acceleration_library/lib/Release/openvino_intel_gpu_plugin.dll +0 -0
  131. intel_npu_acceleration_library/lib/Release/openvino_intel_npu_plugin.dll +0 -0
  132. intel_npu_acceleration_library/lib/Release/openvino_ir_frontend.dll +0 -0
  133. intel_npu_acceleration_library/lib/Release/openvino_onnx_frontend.dll +0 -0
  134. intel_npu_acceleration_library/lib/Release/openvino_paddle_frontend.dll +0 -0
  135. intel_npu_acceleration_library/lib/Release/openvino_pytorch_frontend.dll +0 -0
  136. intel_npu_acceleration_library/lib/Release/openvino_tensorflow_frontend.dll +0 -0
  137. intel_npu_acceleration_library/lib/Release/openvino_tensorflow_lite_frontend.dll +0 -0
  138. intel_npu_acceleration_library/lib/Release/tbb12.dll +0 -0
  139. intel_npu_acceleration_library/lib/Release/tbb12_debug.dll +0 -0
  140. intel_npu_acceleration_library/lib/Release/tbbbind_2_5.dll +0 -0
  141. intel_npu_acceleration_library/lib/Release/tbbbind_2_5_debug.dll +0 -0
  142. intel_npu_acceleration_library/lib/Release/tbbmalloc.dll +0 -0
  143. intel_npu_acceleration_library/lib/Release/tbbmalloc_debug.dll +0 -0
  144. intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy.dll +0 -0
  145. intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy_debug.dll +0 -0
  146. intel_npu_acceleration_library/nn/module.py +17 -17
@@ -4,6 +4,7 @@
4
4
  #
5
5
 
6
6
  from intel_npu_acceleration_library.nn.module import convert_to_npu_module
7
+ from intel_npu_acceleration_library.backend.tensor import RemoteTensor
7
8
  from torch.overrides import TorchFunctionMode
8
9
  from functools import lru_cache
9
10
  from typing import Any, MutableMapping
@@ -165,8 +166,7 @@ def to(super_fn: Any, self: Any, *args: Any, **kwargs: Any):
165
166
  """
166
167
  npu_device, args, kwargs = parse_to_arguments(*args, **kwargs)
167
168
  if npu_device:
168
- # None for now, once the remote tensor feature lands, it can be converted to a remote tensor
169
- pass
169
+ return super_fn(RemoteTensor.from_torch(self), *args, **kwargs)
170
170
  return super_fn(self, *args, **kwargs)
171
171
 
172
172
 
@@ -7,7 +7,7 @@ from dataclasses import dataclass
7
7
  from typing import Union
8
8
  import numpy as np
9
9
  import torch
10
-
10
+ import ctypes
11
11
 
12
12
  @dataclass(frozen=True)
13
13
  class NPUDtype:
@@ -81,6 +81,39 @@ class NPUDtype:
81
81
  return self.name
82
82
 
83
83
 
84
+ def get_backend_dtype(dtype) -> ctypes.c_char_p:
85
+ """Get the string representation of the dtype.
86
+ Args:
87
+ dtype: numpy dtype
88
+ Raises:
89
+ RuntimeError: Unsupported datatype
90
+ Returns:
91
+ ctypes.c_char_p: string representation of the dtype
92
+ """
93
+ if dtype in [np.int8, torch.int8]:
94
+ str_dtype = "int8"
95
+ elif dtype in [np.uint8, int4, torch.uint8]:
96
+ # u8 represents packed i4 dtypes
97
+ str_dtype = "int4"
98
+ elif dtype in [np.int16, torch.int16]:
99
+ str_dtype = "int16"
100
+ elif dtype in [np.int32, torch.int32]:
101
+ str_dtype = "int32"
102
+ elif dtype in [np.int64, torch.int64]:
103
+ str_dtype = "int64"
104
+ elif dtype in [np.float16, torch.float16]:
105
+ str_dtype = "float16"
106
+ elif dtype in [np.float32, torch.float32]:
107
+ str_dtype = "float32"
108
+ elif dtype in [np.float64, torch.float64]:
109
+ str_dtype = "float64"
110
+ elif dtype in [bfloat16, torch.bfloat16]:
111
+ str_dtype = "bfloat16"
112
+ else:
113
+ raise RuntimeError(f"DType is not supported {dtype}")
114
+ return ctypes.c_char_p(str_dtype.encode())
115
+
116
+
84
117
  float16 = NPUDtype(
85
118
  "fp16",
86
119
  16,
@@ -21,6 +21,7 @@ except ImportError:
21
21
  from openvino import runtime as runtime
22
22
  from openvino import frontend as frontend
23
23
  from openvino import helpers as helpers
24
+ from openvino import experimental as experimental
24
25
  from openvino import preprocess as preprocess
25
26
  from openvino import utils as utils
26
27
  from openvino import properties as properties
@@ -18,3 +18,4 @@ from openvino._pyopenvino._offline_transformations import compress_model_transfo
18
18
  from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
19
19
  from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
20
20
  from openvino._pyopenvino._offline_transformations import paged_attention_transformation
21
+ from openvino._pyopenvino._offline_transformations import stateful_to_stateless_transformation
@@ -0,0 +1,14 @@
1
+ # Copyright (C) 2018-2024 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Package: openvino
6
+ This module provides access to experimental functionality that is subject to change without prior notice.
7
+ """
8
+
9
+ # flake8: noqa
10
+
11
+ from openvino._pyopenvino.experimental import evaluate_as_partial_shape
12
+ from openvino._pyopenvino.experimental import evaluate_both_bounds
13
+ from openvino._pyopenvino.experimental import set_element_type
14
+ from openvino._pyopenvino.experimental import set_tensor_type
@@ -0,0 +1,15 @@
1
+ # Copyright (C) 2018-2024 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Package: openvino
6
+ Low level wrappers for the FrontEnd C++ API.
7
+ """
8
+
9
+ # flake8: noqa
10
+
11
+ try:
12
+ from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
13
+ except ImportError as err:
14
+ raise ImportError("OpenVINO JAX frontend is not available, please make sure the frontend is built."
15
+ "{}".format(err))
@@ -0,0 +1,293 @@
1
+ # Copyright (C) 2018-2024 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # flake8: noqa
5
+ # mypy: ignore-errors
6
+
7
+ import jax.core
8
+ from openvino.frontend.jax.py_jax_frontend import _FrontEndJaxDecoder as Decoder
9
+ from openvino.runtime import PartialShape, Type as OVType, OVAny
10
+ from openvino.frontend.jax.utils import jax_array_to_ov_const, get_ov_type_for_value, \
11
+ ivalue_to_constant, param_to_constants
12
+
13
+ import jax
14
+ import numpy as np
15
+
16
+ from typing import List
17
+ import logging
18
+ logger = logging.getLogger(__name__)
19
+ logger.setLevel(logging.WARNING)
20
+
21
+ class JaxprPythonDecoder (Decoder):
22
+ '''
23
+ The jaxpr decoder uses Jaxpr to get graph information from a jax module.
24
+ It takes use of the following parts.
25
+
26
+ - `ClosedJaxpr`: the jaxpr object that contains the jaxpr and literals.
27
+ - `Jaxpr`: the jaxpr object that contains the invars, outvars, and eqns.
28
+ - `JaxEqns`: A list of jaxpr equations, which contains the information of the operation.
29
+ - `Primitive`: the operation that is used in the equation.
30
+ - `invars`: the input variables of the equation.
31
+ - `aval`: the abstract value.
32
+ - `outvars`: the output variables of the equation.
33
+ - `aval`: the abstract value.
34
+ - `params`: the named params of this equation.
35
+ - `invars`: the inputs of the model (traced graph).
36
+ - `aval`: the abstract value.
37
+ - `outvars`: the outputs of the model (traced graph).
38
+ - `aval`: the abstract value.
39
+ - `constvars`: the constant variables used in this model.
40
+ - `aval`: the abstract value.
41
+ - `Literal`: the literal object that contains the value of the constants.
42
+ '''
43
+
44
+ def __init__(self, jaxpr, name=None, literals=None):
45
+ '''
46
+ Inputs:
47
+ - jaxpr: for users, `ClosedJaxpr` is expected here. See https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L197
48
+ - name: the name for the model.
49
+ - literals: the literals (constants) that are used in the model.
50
+ '''
51
+ Decoder.__init__(self)
52
+
53
+ if isinstance(jaxpr, (jax.core.JaxprEqn, jax.core.Jaxpr)):
54
+ self.jaxpr = jaxpr
55
+ elif isinstance(jaxpr, jax.core.ClosedJaxpr):
56
+ # Take the `Jaxpr` from `ClosedJaxpr`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
57
+ self.jaxpr = jaxpr.jaxpr
58
+ # Literal should be a `Jax.core.Var`, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L85
59
+ self.literals = jaxpr.literals
60
+ else:
61
+ raise ValueError(f"Unexpected type of jaxpr: {type(jaxpr)}")
62
+ self.name = name
63
+ if self.name is None:
64
+ self.name = "jax_module"
65
+ if literals is not None:
66
+ self.literals = literals
67
+
68
+ self.params = {}
69
+ if hasattr(self.jaxpr, 'params') and isinstance(self.jaxpr.params, dict):
70
+ for k in self.jaxpr.params.keys():
71
+ converted = self.convert_param_to_constant_node(self.jaxpr, k)
72
+ if converted is not None:
73
+ self.params.update(converted)
74
+
75
+ # TODO: this implementation may lead to memory increasing. Any better solution?
76
+ self.m_decoders = []
77
+
78
+ def inputs(self) -> List[int]:
79
+ if isinstance(self.jaxpr, jax.core.JaxprEqn):
80
+ idx = 0
81
+ res = []
82
+ for inp in self.jaxpr.invars:
83
+ if isinstance(inp, jax.core.Literal):
84
+ res.append(self.literals[idx].output(0))
85
+ idx += 1
86
+ else:
87
+ res.append(id(inp))
88
+ return res
89
+ else:
90
+ return [id(v) for v in self.jaxpr.invars]
91
+
92
+ def input(self, idx: int) -> int:
93
+ return id(self.jaxpr.invars[idx])
94
+
95
+ def get_input_shape(self, index):
96
+ return PartialShape(self.jaxpr.invars[index].aval.shape)
97
+
98
+ def get_input_signature_name(self, index) -> str:
99
+ return "jaxpr_invar_" + str(index)
100
+
101
+ def get_input_type(self, index) -> OVType:
102
+ return get_ov_type_for_value(self.jaxpr.invars[index])
103
+
104
+ def get_named_param(self, name):
105
+ '''
106
+ Get the object id of the named parameter by the name.
107
+ '''
108
+ return self.params[name].output(0)
109
+
110
+ def get_named_param_as_constant(self, name):
111
+ '''
112
+ The named parameter in JAX is a python object but we want to use its value in cpp.
113
+ Therefore this API is used to get the named parameter as a constant, which can be used
114
+ to extract the value of it in cpp-level.
115
+ '''
116
+ return self.params[name].as_constant()
117
+
118
+ def get_param_names(self):
119
+ '''
120
+ In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
121
+ For example, the `jax.lax.cat` operation has a named parameter `dim`,
122
+ which is used to indicate the dimension to concatenate the tensors.
123
+
124
+ Here we return the names of all the named params that appear in the model for the current `JaxEqn`.
125
+ '''
126
+ return list(self.params.keys())
127
+
128
+ def get_output_type(self, index) -> OVType:
129
+ return get_ov_type_for_value(self.jaxpr.outvars[index])
130
+
131
+ def get_output_name(self, index) -> str:
132
+ return "jaxpr_outvar_" + str(index)
133
+
134
+ def get_output_shape(self, index):
135
+ return PartialShape(self.jaxpr.outvars[index].aval.shape)
136
+
137
+ def visit_subgraph(self, node_visitor) -> None:
138
+ if isinstance(self.jaxpr, jax.core.JaxprEqn):
139
+ return
140
+ for _, decoder in self.params.items():
141
+ self.m_decoders.append(decoder)
142
+ node_visitor(decoder)
143
+ for idx, node in enumerate(self.jaxpr.constvars):
144
+ decoder = self.convert_literal_to_constant_node(
145
+ literal=self.literals[idx],
146
+ name=self.name + "/" + f"const({id(node)})",
147
+ output_id=id(node)
148
+ )
149
+ self.m_decoders.append(decoder)
150
+ node_visitor(decoder)
151
+ # Visit every `JaxEqn` in the jaxpr, see https://github.com/google/jax/blob/jaxlib-v0.4.29/jax/_src/core.py#L285
152
+ for node in self.jaxpr.eqns:
153
+ literal_decoders = []
154
+ for inp in node.invars:
155
+ if isinstance(inp, jax.core.Literal):
156
+ literal_decoder = self.convert_literal_to_constant_node(inp)
157
+ literal_decoders.append(literal_decoder)
158
+ node_visitor(literal_decoder)
159
+ decoder = JaxprPythonDecoder(node, name=self.name + "/" + node.primitive.name, literals=literal_decoders)
160
+ self.m_decoders.append(decoder)
161
+ node_visitor(decoder)
162
+
163
+ def get_op_type(self) -> str:
164
+ if isinstance(self.jaxpr, jax.core.JaxprEqn):
165
+ return self.jaxpr.primitive.name
166
+ else:
167
+ return "root"
168
+
169
+ def outputs(self) -> List[int]:
170
+ return [id(v) for v in self.jaxpr.outvars]
171
+
172
+ def output(self, idx: int) -> int:
173
+ return id(self.jaxpr.outvars[idx])
174
+
175
+ def num_inputs(self) -> int:
176
+ return len(self.jaxpr.invars)
177
+
178
+ def num_outputs(self) -> int:
179
+ return len(self.jaxpr.outvars)
180
+
181
+ def as_constant(self):
182
+ if self.get_op_type() == 'constant':
183
+ value = self.literals
184
+ # TODO: dig out how to share the memory.
185
+ # Currently, using shared_memory will raise `ValueError: array is not writeable``
186
+ ov_const = jax_array_to_ov_const(value, shared_memory=False)
187
+ return ov_const.outputs()
188
+ else:
189
+ raise ValueError("This is not a constant node so it cannot be converted to a constant.")
190
+
191
+ @staticmethod
192
+ def convert_param_to_constant_node(jaxpr, param) -> dict:
193
+ assert hasattr(jaxpr, 'params'), "The jaxpr does not have params."
194
+ if hasattr(jaxpr, 'primitive'):
195
+ param_map = param_to_constants(jaxpr.primitive.name, param, jaxpr, shared_memory=False)
196
+ res = {}
197
+ for name, constant in param_map.items():
198
+ if constant is not None:
199
+ res[name] = _JaxprPythonConstantDecoder(constant=constant)
200
+ else:
201
+ constant = ivalue_to_constant(jaxpr.params[param], shared_memory=False)
202
+ res = {param: _JaxprPythonConstantDecoder(constant=constant)} if constant is not None else {}
203
+ return res
204
+
205
+ @staticmethod
206
+ def convert_literal_to_constant_node(literal, name=None, output_id=None):
207
+ if isinstance(literal, jax.core.Literal):
208
+ constant = ivalue_to_constant(literal.val, shared_memory=False)
209
+ elif isinstance(literal, (jax.Array, np.ndarray)):
210
+ constant = ivalue_to_constant(literal, shared_memory=False)
211
+ else:
212
+ raise TypeError( f"The input should be a literal or jax array, but got {type(literal)}.")
213
+ return _JaxprPythonConstantDecoder(constant=constant, name=name, output_id=output_id)
214
+
215
+ class _JaxprPythonConstantDecoder (Decoder):
216
+ def __init__(self, name=None, constant=None, output_id=None):
217
+ '''
218
+ A decoder specially for constants and named parameters.
219
+
220
+ Inputs:
221
+ - name: the name for the model.
222
+ - literals: the literals (constants) that are used in the model.
223
+ - output_id: the id specified for this decoder's output. If none, use `id(self.constant)`.
224
+ '''
225
+ Decoder.__init__(self)
226
+
227
+ self.name = name
228
+ self.constant = constant
229
+ self.output_id = id(self.constant) if output_id is None else output_id
230
+
231
+ def inputs(self) -> List[int]:
232
+ return []
233
+
234
+ def input(self, idx: int) -> int:
235
+ raise ValueError("This is a constant node so it does not have input.")
236
+
237
+ def get_input_shape(self, index):
238
+ raise ValueError("This is a constant node so it does not have input shape.")
239
+
240
+ def get_input_signature_name(self, index) -> str:
241
+ raise ValueError("This is a constant node so it does not have input signature name.")
242
+
243
+ def get_input_type(self, index) -> OVType:
244
+ raise ValueError("This is a constant node so it does not have input type.")
245
+
246
+ def get_named_param(self, name):
247
+ raise ValueError("This is a constant node so it does not have named param.")
248
+
249
+ def get_named_param_as_constant(self, name):
250
+ raise ValueError("This is a constant node so it does not have named param.")
251
+
252
+ def get_param_names(self):
253
+ '''
254
+ In JAX, the named parameters may exist in `params` attribute of `JaxEqn`.
255
+ For example, the `jax.lax.cat` operation has a named parameter `dim`,
256
+ which is used to indicate the dimension to concatenate the tensors.
257
+
258
+ However, `_JaxprPythonConstantDecoder` is already a named param or a constant.
259
+ So it will never have a named param.
260
+ '''
261
+ return []
262
+
263
+ def get_output_type(self, index) -> OVType:
264
+ assert len(self.constant) == 1
265
+ return OVAny(self.constant[0].element_type)
266
+
267
+ def get_output_name(self, index) -> str:
268
+ return "jaxpr_outvar_" + str(index)
269
+
270
+ def get_output_shape(self, index):
271
+ assert len(self.constant) == 1
272
+ return PartialShape(self.constant[0].shape)
273
+
274
+ def visit_subgraph(self, node_visitor) -> None:
275
+ return
276
+
277
+ def get_op_type(self) -> str:
278
+ return "constant"
279
+
280
+ def outputs(self) -> List[int]:
281
+ return [self.output_id]
282
+
283
+ def output(self, idx: int) -> int:
284
+ return self.output_id
285
+
286
+ def num_inputs(self) -> int:
287
+ return 0
288
+
289
+ def num_outputs(self) -> int:
290
+ return 1
291
+
292
+ def as_constant(self):
293
+ return self.constant
@@ -0,0 +1,65 @@
1
+ # Copyright (C) 2018-2024 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # flake8: noqa
5
+ # mypy: ignore-errors
6
+
7
+ from enum import Enum
8
+ from jax.lax import ConvDimensionNumbers
9
+
10
+ def enum_values_pass(value):
11
+ if isinstance(value, Enum):
12
+ return value.value
13
+ return value
14
+
15
+
16
+ def conv_dimension_numbers_pass(value):
17
+ if isinstance(value, ConvDimensionNumbers):
18
+ return [
19
+ list(value.lhs_spec),
20
+ list(value.rhs_spec),
21
+ list(value.out_spec)
22
+ ]
23
+ return value
24
+
25
+
26
+ def filter_element(value):
27
+ passes = [enum_values_pass]
28
+ for pass_ in passes:
29
+ value = pass_(value)
30
+ return value
31
+
32
+
33
+ def filter_ivalue(value):
34
+ passes = [conv_dimension_numbers_pass]
35
+ for pass_ in passes:
36
+ value = pass_(value)
37
+ return value
38
+
39
+
40
+ def dot_general_param_pass(param_name: str, jax_eqn):
41
+ param = jax_eqn.params[param_name]
42
+ res = {}
43
+ if param_name == 'dimension_numbers':
44
+ contract_dimensions = param[0]
45
+ assert len(contract_dimensions) == 2
46
+ res['contract_dimensions'] = [list(contract_dimensions[0]), list(contract_dimensions[1])]
47
+
48
+ batch_dimensions = param[1]
49
+ assert len(batch_dimensions) == 2
50
+ lhs_length = len(batch_dimensions[0])
51
+ rhs_length = len(batch_dimensions[1])
52
+ assert lhs_length == rhs_length
53
+ if lhs_length > 0:
54
+ res['batch_dimensions'] = [list(batch_dimensions[0]), list(batch_dimensions[1])]
55
+ return res
56
+
57
+ # mapping from primitive to pass
58
+ param_passes = {
59
+ 'dot_general': dot_general_param_pass,
60
+ }
61
+
62
+ def filter_param(primitive: str, param_name: str, jax_eqn):
63
+ if primitive in param_passes:
64
+ return param_passes[primitive](param_name, jax_eqn)
65
+ return {param_name: jax_eqn.params[param_name]}