bigdl-core-npu 2.6.0b20250114__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. bigdl-core-npu/__init__.py +0 -0
  2. bigdl-core-npu/include/common.h +96 -0
  3. bigdl-core-npu/include/npu_llm.h +74 -0
  4. bigdl-core-npu/npu_llm.dll +0 -0
  5. bigdl-core-npu/npu_llm.lib +0 -0
  6. bigdl_core_npu-2.6.0b20250114.dist-info/METADATA +44 -0
  7. bigdl_core_npu-2.6.0b20250114.dist-info/RECORD +234 -0
  8. bigdl_core_npu-2.6.0b20250114.dist-info/WHEEL +5 -0
  9. bigdl_core_npu-2.6.0b20250114.dist-info/top_level.txt +2 -0
  10. intel_npu_acceleration_library/__init__.py +24 -0
  11. intel_npu_acceleration_library/_version.py +6 -0
  12. intel_npu_acceleration_library/backend/__init__.py +37 -0
  13. intel_npu_acceleration_library/backend/base.py +250 -0
  14. intel_npu_acceleration_library/backend/bindings.py +383 -0
  15. intel_npu_acceleration_library/backend/compression.py +24 -0
  16. intel_npu_acceleration_library/backend/convolution.py +58 -0
  17. intel_npu_acceleration_library/backend/factory.py +1161 -0
  18. intel_npu_acceleration_library/backend/linear.py +60 -0
  19. intel_npu_acceleration_library/backend/matmul.py +59 -0
  20. intel_npu_acceleration_library/backend/mlp.py +58 -0
  21. intel_npu_acceleration_library/backend/ops.py +142 -0
  22. intel_npu_acceleration_library/backend/qlinear.py +75 -0
  23. intel_npu_acceleration_library/backend/qmatmul.py +66 -0
  24. intel_npu_acceleration_library/backend/runtime.py +215 -0
  25. intel_npu_acceleration_library/backend/sdpa.py +107 -0
  26. intel_npu_acceleration_library/backend/tensor.py +1120 -0
  27. intel_npu_acceleration_library/backend/utils.py +70 -0
  28. intel_npu_acceleration_library/compiler.py +194 -0
  29. intel_npu_acceleration_library/device.py +230 -0
  30. intel_npu_acceleration_library/dtypes.py +155 -0
  31. intel_npu_acceleration_library/external/openvino/__init__.py +72 -0
  32. intel_npu_acceleration_library/external/openvino/_offline_transformations/__init__.py +21 -0
  33. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp310-win_amd64.pyd +0 -0
  34. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp311-win_amd64.pyd +0 -0
  35. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp312-win_amd64.pyd +0 -0
  36. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp38-win_amd64.pyd +0 -0
  37. intel_npu_acceleration_library/external/openvino/_pyopenvino.cp39-win_amd64.pyd +0 -0
  38. intel_npu_acceleration_library/external/openvino/experimental/__init__.py +14 -0
  39. intel_npu_acceleration_library/external/openvino/frontend/__init__.py +34 -0
  40. intel_npu_acceleration_library/external/openvino/frontend/frontend.py +44 -0
  41. intel_npu_acceleration_library/external/openvino/frontend/jax/__init__.py +15 -0
  42. intel_npu_acceleration_library/external/openvino/frontend/jax/jaxpr_decoder.py +293 -0
  43. intel_npu_acceleration_library/external/openvino/frontend/jax/passes.py +65 -0
  44. intel_npu_acceleration_library/external/openvino/frontend/jax/utils.py +182 -0
  45. intel_npu_acceleration_library/external/openvino/frontend/onnx/__init__.py +15 -0
  46. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp310-win_amd64.pyd +0 -0
  47. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp311-win_amd64.pyd +0 -0
  48. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp312-win_amd64.pyd +0 -0
  49. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp38-win_amd64.pyd +0 -0
  50. intel_npu_acceleration_library/external/openvino/frontend/onnx/py_onnx_frontend.cp39-win_amd64.pyd +0 -0
  51. intel_npu_acceleration_library/external/openvino/frontend/paddle/__init__.py +15 -0
  52. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp310-win_amd64.pyd +0 -0
  53. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp311-win_amd64.pyd +0 -0
  54. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp312-win_amd64.pyd +0 -0
  55. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp38-win_amd64.pyd +0 -0
  56. intel_npu_acceleration_library/external/openvino/frontend/paddle/py_paddle_frontend.cp39-win_amd64.pyd +0 -0
  57. intel_npu_acceleration_library/external/openvino/frontend/pytorch/__init__.py +19 -0
  58. intel_npu_acceleration_library/external/openvino/frontend/pytorch/fx_decoder.py +370 -0
  59. intel_npu_acceleration_library/external/openvino/frontend/pytorch/gptq.py +180 -0
  60. intel_npu_acceleration_library/external/openvino/frontend/pytorch/module_extension.py +39 -0
  61. intel_npu_acceleration_library/external/openvino/frontend/pytorch/patch_model.py +118 -0
  62. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp310-win_amd64.pyd +0 -0
  63. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp311-win_amd64.pyd +0 -0
  64. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp312-win_amd64.pyd +0 -0
  65. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp38-win_amd64.pyd +0 -0
  66. intel_npu_acceleration_library/external/openvino/frontend/pytorch/py_pytorch_frontend.cp39-win_amd64.pyd +0 -0
  67. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend.py +131 -0
  68. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/backend_utils.py +85 -0
  69. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/compile.py +141 -0
  70. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/decompositions.py +116 -0
  71. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/execute.py +189 -0
  72. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/op_support.py +290 -0
  73. intel_npu_acceleration_library/external/openvino/frontend/pytorch/torchdynamo/partition.py +126 -0
  74. intel_npu_acceleration_library/external/openvino/frontend/pytorch/ts_decoder.py +568 -0
  75. intel_npu_acceleration_library/external/openvino/frontend/pytorch/utils.py +258 -0
  76. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/__init__.py +16 -0
  77. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/graph_iterator.py +116 -0
  78. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/node_decoder.py +219 -0
  79. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp310-win_amd64.pyd +0 -0
  80. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp311-win_amd64.pyd +0 -0
  81. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp312-win_amd64.pyd +0 -0
  82. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp38-win_amd64.pyd +0 -0
  83. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/py_tensorflow_frontend.cp39-win_amd64.pyd +0 -0
  84. intel_npu_acceleration_library/external/openvino/frontend/tensorflow/utils.py +481 -0
  85. intel_npu_acceleration_library/external/openvino/helpers/__init__.py +6 -0
  86. intel_npu_acceleration_library/external/openvino/helpers/packing.py +87 -0
  87. intel_npu_acceleration_library/external/openvino/preprocess/README.md +60 -0
  88. intel_npu_acceleration_library/external/openvino/preprocess/__init__.py +28 -0
  89. intel_npu_acceleration_library/external/openvino/preprocess/torchvision/__init__.py +15 -0
  90. intel_npu_acceleration_library/external/openvino/preprocess/torchvision/preprocess_converter.py +47 -0
  91. intel_npu_acceleration_library/external/openvino/preprocess/torchvision/requirements.txt +5 -0
  92. intel_npu_acceleration_library/external/openvino/preprocess/torchvision/torchvision_preprocessing.py +347 -0
  93. intel_npu_acceleration_library/external/openvino/properties/__init__.py +22 -0
  94. intel_npu_acceleration_library/external/openvino/properties/_properties.py +55 -0
  95. intel_npu_acceleration_library/external/openvino/properties/device/__init__.py +14 -0
  96. intel_npu_acceleration_library/external/openvino/properties/hint/__init__.py +15 -0
  97. intel_npu_acceleration_library/external/openvino/properties/intel_auto/__init__.py +12 -0
  98. intel_npu_acceleration_library/external/openvino/properties/intel_cpu/__init__.py +8 -0
  99. intel_npu_acceleration_library/external/openvino/properties/intel_gpu/__init__.py +12 -0
  100. intel_npu_acceleration_library/external/openvino/properties/intel_gpu/hint/__init__.py +11 -0
  101. intel_npu_acceleration_library/external/openvino/properties/log/__init__.py +11 -0
  102. intel_npu_acceleration_library/external/openvino/properties/streams/__init__.py +11 -0
  103. intel_npu_acceleration_library/external/openvino/runtime/__init__.py +85 -0
  104. intel_npu_acceleration_library/external/openvino/runtime/exceptions.py +17 -0
  105. intel_npu_acceleration_library/external/openvino/runtime/ie_api.py +631 -0
  106. intel_npu_acceleration_library/external/openvino/runtime/op/__init__.py +19 -0
  107. intel_npu_acceleration_library/external/openvino/runtime/op/util/__init__.py +22 -0
  108. intel_npu_acceleration_library/external/openvino/runtime/opset1/__init__.py +112 -0
  109. intel_npu_acceleration_library/external/openvino/runtime/opset1/ops.py +3068 -0
  110. intel_npu_acceleration_library/external/openvino/runtime/opset10/__init__.py +179 -0
  111. intel_npu_acceleration_library/external/openvino/runtime/opset10/ops.py +173 -0
  112. intel_npu_acceleration_library/external/openvino/runtime/opset11/__init__.py +179 -0
  113. intel_npu_acceleration_library/external/openvino/runtime/opset11/ops.py +107 -0
  114. intel_npu_acceleration_library/external/openvino/runtime/opset12/__init__.py +180 -0
  115. intel_npu_acceleration_library/external/openvino/runtime/opset12/ops.py +120 -0
  116. intel_npu_acceleration_library/external/openvino/runtime/opset13/__init__.py +188 -0
  117. intel_npu_acceleration_library/external/openvino/runtime/opset13/ops.py +398 -0
  118. intel_npu_acceleration_library/external/openvino/runtime/opset14/__init__.py +190 -0
  119. intel_npu_acceleration_library/external/openvino/runtime/opset14/ops.py +171 -0
  120. intel_npu_acceleration_library/external/openvino/runtime/opset15/__init__.py +17 -0
  121. intel_npu_acceleration_library/external/openvino/runtime/opset15/ops.py +276 -0
  122. intel_npu_acceleration_library/external/openvino/runtime/opset2/__init__.py +118 -0
  123. intel_npu_acceleration_library/external/openvino/runtime/opset2/ops.py +216 -0
  124. intel_npu_acceleration_library/external/openvino/runtime/opset3/__init__.py +134 -0
  125. intel_npu_acceleration_library/external/openvino/runtime/opset3/ops.py +638 -0
  126. intel_npu_acceleration_library/external/openvino/runtime/opset4/__init__.py +145 -0
  127. intel_npu_acceleration_library/external/openvino/runtime/opset4/ops.py +464 -0
  128. intel_npu_acceleration_library/external/openvino/runtime/opset5/__init__.py +152 -0
  129. intel_npu_acceleration_library/external/openvino/runtime/opset5/ops.py +372 -0
  130. intel_npu_acceleration_library/external/openvino/runtime/opset6/__init__.py +154 -0
  131. intel_npu_acceleration_library/external/openvino/runtime/opset6/ops.py +215 -0
  132. intel_npu_acceleration_library/external/openvino/runtime/opset7/__init__.py +158 -0
  133. intel_npu_acceleration_library/external/openvino/runtime/opset7/ops.py +169 -0
  134. intel_npu_acceleration_library/external/openvino/runtime/opset8/__init__.py +169 -0
  135. intel_npu_acceleration_library/external/openvino/runtime/opset8/ops.py +787 -0
  136. intel_npu_acceleration_library/external/openvino/runtime/opset9/__init__.py +175 -0
  137. intel_npu_acceleration_library/external/openvino/runtime/opset9/ops.py +341 -0
  138. intel_npu_acceleration_library/external/openvino/runtime/opset_utils.py +22 -0
  139. intel_npu_acceleration_library/external/openvino/runtime/passes/__init__.py +19 -0
  140. intel_npu_acceleration_library/external/openvino/runtime/passes/graph_rewrite.py +33 -0
  141. intel_npu_acceleration_library/external/openvino/runtime/passes/manager.py +26 -0
  142. intel_npu_acceleration_library/external/openvino/runtime/properties/__init__.py +40 -0
  143. intel_npu_acceleration_library/external/openvino/runtime/properties/hint/__init__.py +25 -0
  144. intel_npu_acceleration_library/external/openvino/runtime/utils/__init__.py +7 -0
  145. intel_npu_acceleration_library/external/openvino/runtime/utils/broadcasting.py +44 -0
  146. intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/__init__.py +8 -0
  147. intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/data_dispatcher.py +447 -0
  148. intel_npu_acceleration_library/external/openvino/runtime/utils/data_helpers/wrappers.py +148 -0
  149. intel_npu_acceleration_library/external/openvino/runtime/utils/decorators.py +156 -0
  150. intel_npu_acceleration_library/external/openvino/runtime/utils/input_validation.py +133 -0
  151. intel_npu_acceleration_library/external/openvino/runtime/utils/node_factory.py +127 -0
  152. intel_npu_acceleration_library/external/openvino/runtime/utils/reduction.py +25 -0
  153. intel_npu_acceleration_library/external/openvino/runtime/utils/types.py +175 -0
  154. intel_npu_acceleration_library/external/openvino/tools/__init__.py +4 -0
  155. intel_npu_acceleration_library/external/openvino/tools/benchmark/__init__.py +3 -0
  156. intel_npu_acceleration_library/external/openvino/tools/benchmark/benchmark.py +186 -0
  157. intel_npu_acceleration_library/external/openvino/tools/benchmark/main.py +695 -0
  158. intel_npu_acceleration_library/external/openvino/tools/benchmark/parameters.py +199 -0
  159. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/__init__.py +3 -0
  160. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/constants.py +26 -0
  161. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/inputs_filling.py +482 -0
  162. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/logging.py +8 -0
  163. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/statistics_report.py +296 -0
  164. intel_npu_acceleration_library/external/openvino/tools/benchmark/utils/utils.py +836 -0
  165. intel_npu_acceleration_library/external/openvino/tools/ovc/__init__.py +20 -0
  166. intel_npu_acceleration_library/external/openvino/tools/ovc/__main__.py +10 -0
  167. intel_npu_acceleration_library/external/openvino/tools/ovc/cli_parser.py +633 -0
  168. intel_npu_acceleration_library/external/openvino/tools/ovc/convert.py +102 -0
  169. intel_npu_acceleration_library/external/openvino/tools/ovc/convert_data_type.py +82 -0
  170. intel_npu_acceleration_library/external/openvino/tools/ovc/convert_impl.py +550 -0
  171. intel_npu_acceleration_library/external/openvino/tools/ovc/environment_setup_utils.py +50 -0
  172. intel_npu_acceleration_library/external/openvino/tools/ovc/error.py +49 -0
  173. intel_npu_acceleration_library/external/openvino/tools/ovc/get_ov_update_message.py +16 -0
  174. intel_npu_acceleration_library/external/openvino/tools/ovc/help.py +45 -0
  175. intel_npu_acceleration_library/external/openvino/tools/ovc/logger.py +91 -0
  176. intel_npu_acceleration_library/external/openvino/tools/ovc/main.py +40 -0
  177. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/__init__.py +2 -0
  178. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/analysis.py +46 -0
  179. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/check_config.py +57 -0
  180. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/extractor.py +447 -0
  181. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/jax_frontend_utils.py +19 -0
  182. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/layout_utils.py +73 -0
  183. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/moc_emit_ir.py +32 -0
  184. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/offline_transformations.py +107 -0
  185. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/paddle_frontend_utils.py +83 -0
  186. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pipeline.py +298 -0
  187. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/preprocessing.py +220 -0
  188. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +214 -0
  189. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/shape_utils.py +109 -0
  190. intel_npu_acceleration_library/external/openvino/tools/ovc/moc_frontend/type_utils.py +82 -0
  191. intel_npu_acceleration_library/external/openvino/tools/ovc/ovc.py +13 -0
  192. intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_params.py +6 -0
  193. intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_stub.py +28 -0
  194. intel_npu_acceleration_library/external/openvino/tools/ovc/telemetry_utils.py +118 -0
  195. intel_npu_acceleration_library/external/openvino/tools/ovc/utils.py +196 -0
  196. intel_npu_acceleration_library/external/openvino/tools/ovc/version.py +80 -0
  197. intel_npu_acceleration_library/external/openvino/torch/__init__.py +5 -0
  198. intel_npu_acceleration_library/external/openvino/utils.py +115 -0
  199. intel_npu_acceleration_library/functional/__init__.py +8 -0
  200. intel_npu_acceleration_library/functional/scaled_dot_product_attention.py +47 -0
  201. intel_npu_acceleration_library/lib/Release/cache.json +113732 -0
  202. intel_npu_acceleration_library/lib/Release/intel_npu_acceleration_library.dll +0 -0
  203. intel_npu_acceleration_library/lib/Release/openvino.dll +0 -0
  204. intel_npu_acceleration_library/lib/Release/openvino_auto_batch_plugin.dll +0 -0
  205. intel_npu_acceleration_library/lib/Release/openvino_auto_plugin.dll +0 -0
  206. intel_npu_acceleration_library/lib/Release/openvino_c.dll +0 -0
  207. intel_npu_acceleration_library/lib/Release/openvino_hetero_plugin.dll +0 -0
  208. intel_npu_acceleration_library/lib/Release/openvino_intel_cpu_plugin.dll +0 -0
  209. intel_npu_acceleration_library/lib/Release/openvino_intel_gpu_plugin.dll +0 -0
  210. intel_npu_acceleration_library/lib/Release/openvino_intel_npu_plugin.dll +0 -0
  211. intel_npu_acceleration_library/lib/Release/openvino_ir_frontend.dll +0 -0
  212. intel_npu_acceleration_library/lib/Release/openvino_onnx_frontend.dll +0 -0
  213. intel_npu_acceleration_library/lib/Release/openvino_paddle_frontend.dll +0 -0
  214. intel_npu_acceleration_library/lib/Release/openvino_pytorch_frontend.dll +0 -0
  215. intel_npu_acceleration_library/lib/Release/openvino_tensorflow_frontend.dll +0 -0
  216. intel_npu_acceleration_library/lib/Release/openvino_tensorflow_lite_frontend.dll +0 -0
  217. intel_npu_acceleration_library/lib/Release/tbb12.dll +0 -0
  218. intel_npu_acceleration_library/lib/Release/tbb12_debug.dll +0 -0
  219. intel_npu_acceleration_library/lib/Release/tbbbind_2_5.dll +0 -0
  220. intel_npu_acceleration_library/lib/Release/tbbbind_2_5_debug.dll +0 -0
  221. intel_npu_acceleration_library/lib/Release/tbbmalloc.dll +0 -0
  222. intel_npu_acceleration_library/lib/Release/tbbmalloc_debug.dll +0 -0
  223. intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy.dll +0 -0
  224. intel_npu_acceleration_library/lib/Release/tbbmalloc_proxy_debug.dll +0 -0
  225. intel_npu_acceleration_library/modelling.py +150 -0
  226. intel_npu_acceleration_library/nn/__init__.py +20 -0
  227. intel_npu_acceleration_library/nn/autograd.py +68 -0
  228. intel_npu_acceleration_library/nn/conv.py +257 -0
  229. intel_npu_acceleration_library/nn/functional.py +1207 -0
  230. intel_npu_acceleration_library/nn/linear.py +162 -0
  231. intel_npu_acceleration_library/nn/llm.py +417 -0
  232. intel_npu_acceleration_library/nn/module.py +393 -0
  233. intel_npu_acceleration_library/optimizations.py +157 -0
  234. intel_npu_acceleration_library/quantization.py +174 -0
@@ -0,0 +1,1120 @@
1
+ #
2
+ # Copyright © 2024 Intel Corporation
3
+ # SPDX-License-Identifier: Apache 2.0
4
+ #
5
+
6
+ from intel_npu_acceleration_library.backend import lib as backend_lib
7
+ from typing import Sequence, Any, Optional, MutableMapping, Union
8
+ from intel_npu_acceleration_library.dtypes import (
9
+ float16,
10
+ bfloat16,
11
+ float32,
12
+ float64,
13
+ int4,
14
+ int8,
15
+ int16,
16
+ int32,
17
+ int64,
18
+ NPUDtype,
19
+ get_backend_dtype,
20
+ )
21
+ from dataclasses import dataclass
22
+ import functools
23
+ from math import prod
24
+ import numpy as np
25
+ import ctypes
26
+ import torch
27
+
28
+
29
+ class RemoteTensor(torch.Tensor):
30
+ """
31
+ Represent a remote tensor object.
32
+ Attrs:
33
+ _remote_tensor (ctypes._Pointer): The pointer to the underlying remote tensor.
34
+ Methods:
35
+ from_torch(x: torch.Tensor): Create a remote tensor from a torch tensor.
36
+ """
37
+
38
+ _remote_tensor = None
39
+
40
+ @staticmethod
41
+ def __new__(cls, x: Any, remote_tensor: ctypes._Pointer, *args: Any, **kwargs: Any):
42
+ """
43
+ Create a new remote tensor object.
44
+ Args:
45
+ x (Any): tensor input
46
+ remote_tensor (ctypes._Pointer): remote tensor pointer
47
+ args (Any): additional arguments
48
+ kwargs (Any): additional keyword arguments
49
+ Returns:
50
+ RemoteTensor: a RemoteTensor object
51
+ """
52
+ return super().__new__(cls, x, *args, **kwargs)
53
+
54
+ def __init__(self, x: Any, remote_tensor: ctypes._Pointer):
55
+ """
56
+ Initialize the remote tensor object.
57
+ Args:
58
+ x (Any): tensor input
59
+ remote_tensor (ctypes._Pointer): remote tensor pointer
60
+ """
61
+ self._remote_tensor = remote_tensor
62
+
63
+ # def __del__(self):
64
+ # if self._remote_tensor and backend_lib:
65
+ # backend_lib.del_remote_tensor(self._remote_tensor)
66
+
67
+ @staticmethod
68
+ def from_torch(x: torch.Tensor) -> "RemoteTensor":
69
+ """
70
+ Create a remote tensor from a torch tensor.
71
+ Args:
72
+ x (torch.Tensor): The torch tensor.
73
+ Returns:
74
+ RemoteTensor: The remote tensor.
75
+ """
76
+ shape_arr = np.array(x.shape, dtype=np.uint32)
77
+ dtype_str = get_backend_dtype(x.dtype)
78
+ p = ctypes.cast(x.data_ptr(), ctypes.c_void_p)
79
+
80
+ rt = backend_lib.to_npu(shape_arr.size, shape_arr, dtype_str, p)
81
+
82
+ pointer = ctypes.cast(
83
+ backend_lib.remote_tensor_data(rt),
84
+ ctypes.POINTER(ctypes.c_uint8),
85
+ )
86
+
87
+ arr = (pointer._type_ * prod(x.shape) * x.element_size()).from_address(
88
+ ctypes.addressof(pointer.contents)
89
+ )
90
+
91
+ pt_tensor = torch.frombuffer(arr, dtype=x.dtype).view(*x.shape)
92
+
93
+ return RemoteTensor(pt_tensor, rt)
94
+
95
+
96
+ @dataclass
97
+ class Tensor:
98
+ """
99
+ Represents a tensor object.
100
+
101
+ Attrs:
102
+ factory (NNFactory): The factory object used to create the tensor.
103
+ node (ctypes._Pointer): The pointer to the underlying tensor node.
104
+ shape (Sequence[int]): The shape of the tensor.
105
+ dtype (NPUDtype): The data type of the tensor.
106
+ T (Tensor): The transpose of the tensor.
107
+
108
+ Methods:
109
+ __add__(self, other): Adds two tensors element-wise.
110
+ __sub__(self, other): Subtracts two tensors element-wise.
111
+ __mul__(self, other): Multiplies two tensors element-wise.
112
+ __truediv__(self, other): Divides two tensors element-wise.
113
+ __neg__(self): Negates the tensor.
114
+ __repr__(self): Returns a string representation of the tensor.
115
+ __str__(self): Returns a string representation of the tensor.
116
+ __len__(self): Returns the total number of elements in the tensor.
117
+ T: Returns the transpose of the tensor.
118
+ squeeze(self): Removes dimensions of size 1 from the tensor.
119
+ unsqueeze(self, axis): Adds a dimension of size 1 to the tensor.
120
+ __matmul__(self, other): Performs matrix multiplication between two tensors.
121
+ acos(self): Applies acos function to the tensor.
122
+ asin(self): Applies asin function to the tensor.
123
+ atan(self): Applies atan function to the tensor.
124
+ acosh(self): Applies acosh function to the tensor.
125
+ asinh(self): Applies asinh function to the tensor.
126
+ atanh(self): Applies atanh function to the tensor.
127
+ cosh(self): Applies cosh function to the tensor.
128
+ sinh(self): Applies sinh function to the tensor.
129
+ tanh(self): Applies tanh function to the tensor.
130
+ cos(self): Applies cos function to the tensor.
131
+ sin(self): Applies sin function to the tensor.
132
+ tan(self): Applies tan function to the tensor.
133
+ ceiling(self): Applies ceil function to the tensor.
134
+ clamp(self, min, max): Applies clamp function to the tensor.
135
+ elu(self, alpha): Applies elu function to the tensor.
136
+ erf(self): Applies erf function to the tensor.
137
+ exp(self): Applies exponental function to the tensor.
138
+ floor(self): Applies floor function to the tensor.
139
+ grn(self, bias): Applies grn function to the tensor.
140
+ hsigmoid(self): Applies hsigmoid function to the tensor.
141
+ hswish(self): Applies hswish function to the tensor.
142
+ log(self): Applies log function to the tensor.
143
+ mish(self): Applies mish function to the tensor.
144
+ relu(self, bias): Applies relu function to the tensor.
145
+ round(self): Applies round function to the tensor.
146
+ sigmoid(self): Applies sigmoid function to the tensor.
147
+ sign(self): Applies sign function to the tensor.
148
+ softmax(self, dim): Applies softmax function to the tensor.
149
+ softplus(self): Applies softplus function to the tensor.
150
+ sqrt(self): Applies sqrt function to the tensor.
151
+ max(self, dim, keep_dims): Returns the reduced max tensor.
152
+ mean(self, dim, keep_dims, dtype): Returns the reduced mean tensor.
153
+ min(self, dim, keep_dims): Returns the reduced min tensor.
154
+ prod(self, dim, keep_dims, dtype): Returns the reduced product tensor.
155
+ sum(self, dim, keep_dims, dtype): Returns the reduced sum tensor.
156
+ """
157
+
158
+ factory: "NNFactory" # type: ignore # noqa: F821
159
+ node: ctypes._Pointer
160
+ output_idx: int
161
+
162
+ @property
163
+ def shape(self) -> Sequence[int]:
164
+ """
165
+ Returns the shape of the tensor.
166
+
167
+ Returns:
168
+ Sequence[int]: The shape of the tensor.
169
+ """
170
+ shape_size = backend_lib.op_shape_size(self.node, self.output_idx)
171
+ return [backend_lib.op_shape(self.node, i, self.output_idx) for i in range(shape_size)]
172
+
173
+ @property
174
+ def dtype(self) -> NPUDtype:
175
+ """
176
+ Returns the data type of the tensor.
177
+
178
+ Returns:
179
+ type: The data type of the tensor.
180
+ """
181
+ dtype_int = backend_lib.op_dtype(self.node, self.output_idx)
182
+
183
+ if dtype_int == 2:
184
+ return np.bool
185
+ elif dtype_int == 3:
186
+ return bfloat16
187
+ elif dtype_int == 4:
188
+ return float16
189
+ elif dtype_int == 5:
190
+ return float32
191
+ elif dtype_int == 6:
192
+ return float64
193
+ elif dtype_int == 7:
194
+ return int4
195
+ elif dtype_int == 8:
196
+ return int8
197
+ elif dtype_int == 9:
198
+ return int16
199
+ elif dtype_int == 10:
200
+ return int32
201
+ elif dtype_int == 11:
202
+ return int64
203
+ else:
204
+ raise RuntimeError("Unsupported dtype")
205
+
206
+ def dim(self) -> int:
207
+ """
208
+ Return the number of dimensions of the tensor.
209
+
210
+ Returns:
211
+ int: The number of dimensions of the tensor.
212
+ """
213
+ return len(self.shape)
214
+
215
+ def size(self, dim=None) -> Union[int, Sequence[int]]:
216
+ """
217
+ Return the size of the tensor.
218
+
219
+ Args:
220
+ dim (int, optional): The dimension to return the size of. Defaults to None.
221
+
222
+ Returns:
223
+ Union[int, Sequence[int]]: The size of the tensor.
224
+ """
225
+ if dim is None:
226
+ return torch.Size(self.shape)
227
+ return self.shape[dim]
228
+
229
+ def __add__(self, other) -> "Tensor":
230
+ """
231
+ Add two tensors element-wise.
232
+
233
+ Args:
234
+ other (Tensor): The tensor to be added.
235
+
236
+ Returns:
237
+ Tensor: The result of the addition.
238
+ """
239
+ if isinstance(other, (int, float)):
240
+ other = self.factory.constant(
241
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
242
+ )
243
+ return generate_op([self, other], "eltwise_add")
244
+
245
+ def __sub__(self, other) -> "Tensor":
246
+ """
247
+ Subtract two tensors element-wise.
248
+
249
+ Args:
250
+ other (Tensor): The tensor to be subtracted.
251
+
252
+ Returns:
253
+ Tensor: The result of the subtraction.
254
+ """
255
+ if isinstance(other, (int, float)):
256
+ other = self.factory.constant(
257
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
258
+ )
259
+ return generate_op([self, -other], "eltwise_add")
260
+
261
+ def __mul__(self, other) -> "Tensor":
262
+ """
263
+ Multiply two tensors element-wise.
264
+
265
+ Args:
266
+ other (Tensor): The tensor to be multiplied.
267
+
268
+ Returns:
269
+ Tensor: The result of the multiplication.
270
+ """
271
+ if isinstance(other, (int, float)):
272
+ other = self.factory.constant(
273
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
274
+ )
275
+ return generate_op([self, other], "eltwise_mul")
276
+
277
+ def __truediv__(self, other) -> "Tensor":
278
+ """
279
+ Divide two tensors element-wise.
280
+
281
+ Args:
282
+ other (Tensor): The tensor to be divided.
283
+
284
+ Returns:
285
+ Tensor: The result of the division.
286
+ """
287
+ if isinstance(other, (int, float)):
288
+ other = self.factory.constant(
289
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
290
+ )
291
+ return generate_op([self, other], "eltwise_div")
292
+
293
+ def __radd__(self, other) -> "Tensor":
294
+ """
295
+ Add two tensors element-wise.
296
+
297
+ Args:
298
+ other (Tensor): The tensor to be added.
299
+
300
+ Returns:
301
+ Tensor: The result of the addition.
302
+ """
303
+ if isinstance(other, (int, float)):
304
+ other = self.factory.constant(
305
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
306
+ )
307
+ return generate_op([other, self], "eltwise_add")
308
+
309
+ def __rsub__(self, other) -> "Tensor":
310
+ """
311
+ Subtract two tensors element-wise.
312
+
313
+ Args:
314
+ other (Tensor): The tensor to be subtracted.
315
+
316
+ Returns:
317
+ Tensor: The result of the subtraction.
318
+ """
319
+ if isinstance(other, (int, float)):
320
+ other = self.factory.constant(
321
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
322
+ )
323
+ return generate_op([other, -self], "eltwise_add")
324
+
325
+ def __rmul__(self, other) -> "Tensor":
326
+ """
327
+ Multiply two tensors element-wise.
328
+
329
+ Args:
330
+ other (Tensor): The tensor to be multiplied.
331
+
332
+ Returns:
333
+ Tensor: The result of the multiplication.
334
+ """
335
+ if isinstance(other, (int, float)):
336
+ other = self.factory.constant(
337
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
338
+ )
339
+ return generate_op([other, self], "eltwise_mul")
340
+
341
+ def __rtruediv__(self, other) -> "Tensor":
342
+ """
343
+ Divide two tensors element-wise.
344
+
345
+ Args:
346
+ other (Tensor): The tensor to be divided.
347
+
348
+ Returns:
349
+ Tensor: The result of the division.
350
+ """
351
+ if isinstance(other, (int, float)):
352
+ other = self.factory.constant(
353
+ torch.tensor([other], dtype=self.dtype.torch_dtype)
354
+ )
355
+ return generate_op([other, self], "eltwise_div")
356
+
357
+ def __neg__(self) -> "Tensor":
358
+ """
359
+ Negate the tensor.
360
+
361
+ Returns:
362
+ Tensor: The negated tensor.
363
+ """
364
+ return generate_op([self], "negative")
365
+
366
+ def __repr__(self) -> str:
367
+ """
368
+ Return a string representation of the tensor.
369
+
370
+ Returns:
371
+ str: The string representation of the tensor.
372
+ """
373
+ return f"Tensor({self.shape}, {self.dtype})"
374
+
375
+ def __str__(self) -> str:
376
+ """
377
+ Return a string representation of the tensor.
378
+
379
+ Returns:
380
+ str: The string representation of the tensor.
381
+ """
382
+ return f"Tensor({self.shape}, {self.dtype})"
383
+
384
+ def __len__(self) -> int:
385
+ """
386
+ Return the total number of elements in the tensor.
387
+
388
+ Returns:
389
+ int: The total number of elements in the tensor.
390
+ """
391
+ return np.product(self.shape)
392
+
393
+ def __getitem__(self, key) -> "Tensor":
394
+ """
395
+ Return a slice of the tensor.
396
+
397
+ Args:
398
+ key: The slice key.
399
+
400
+ Raises:
401
+ ValueError: If the slice key is invalid.
402
+
403
+ Returns:
404
+ Tensor: The sliced tensor.
405
+ """
406
+ shape_len = len(self.shape)
407
+
408
+ begin, end, stride = [], [], []
409
+ if isinstance(key, slice):
410
+ key = (key,)
411
+ if not isinstance(key, tuple):
412
+ raise ValueError(
413
+ f"Invalid slice key: must be a tuple instead of {type(key)}"
414
+ )
415
+
416
+ if any(k is Ellipsis for k in key):
417
+ # if ellispis is at the start
418
+ if key[0] is Ellipsis:
419
+ key = tuple([slice(None)] * (shape_len - len(key) + 1)) + key[1:]
420
+ # if ellispis is at the end
421
+ if key[-1] is Ellipsis:
422
+ key = key[:-1] + tuple([slice(None)] * (shape_len - len(key) + 1))
423
+ # if ellispis is in the middle
424
+ if any(k is Ellipsis for k in key):
425
+ raise ValueError("Ellipsis must be at the start or end of the slice")
426
+
427
+ if len(key) != shape_len or len(key) < 1:
428
+ raise ValueError(f"Invalid slice key: {key}")
429
+
430
+ def get_index(idx: int, shape: int) -> int:
431
+ """
432
+ Get the index of the slice.
433
+
434
+ Args:
435
+ idx (int): The index of the slice.
436
+ shape (int): The shape of the tensor.
437
+
438
+ Raises:
439
+ IndexError: If the index is out of bounds.
440
+
441
+ Returns:
442
+ int: The index of the slice.
443
+ """
444
+ if idx < 0:
445
+ idx += shape
446
+ if idx < 0 or idx > shape:
447
+ raise IndexError(f"Index {idx} out of bounds for shape {shape}")
448
+ return idx
449
+
450
+ for i, k in enumerate(key):
451
+ if isinstance(k, slice):
452
+ begin.append(get_index(k.start or 0, self.shape[i]))
453
+ end.append(get_index(k.stop or self.shape[i], self.shape[i]))
454
+ stride.append(k.step or 1)
455
+ elif k is None:
456
+ begin.append(0)
457
+ end.append(self.shape[i])
458
+ stride.append(1)
459
+ else:
460
+ begin.append(k)
461
+ end.append(k + 1)
462
+ stride.append(1)
463
+
464
+ if any(s <= 0 for s in stride):
465
+ raise ValueError("Stride must be positive")
466
+
467
+ return generate_op([self], "slice", begin, end, stride)
468
+
469
+ @property
470
+ def T(self) -> "Tensor":
471
+ """
472
+ Return the transpose of the tensor.
473
+
474
+ Returns:
475
+ Tensor: The transposed tensor.
476
+ """
477
+ input_order = list(range(len(self.shape)))
478
+ input_order[-1], input_order[-2] = input_order[-2], input_order[-1]
479
+ return generate_op([self], "transpose", input_order)
480
+
481
+ def transpose(self, dim0: int, dim1: int) -> "Tensor":
482
+ """
483
+ Return the transpose of the tensor.
484
+
485
+ Args:
486
+ dim0 (int): The first dimension to transpose.
487
+ dim1 (int): The second dimension to transpose.
488
+
489
+ Returns:
490
+ Tensor: The transposed tensor.
491
+ """
492
+ input_order = list(range(len(self.shape)))
493
+ input_order[dim0], input_order[dim1] = input_order[dim1], input_order[dim0]
494
+
495
+ return generate_op([self], "transpose", input_order)
496
+
497
+ def permute(self, *input_order: int) -> "Tensor":
498
+ """
499
+ Return the transpose of the tensor.
500
+
501
+ Args:
502
+ input_order (Sequence[int]): The order of the dimensions in the transposed tensor.
503
+
504
+ Returns:
505
+ Tensor: The transposed tensor.
506
+ """
507
+ return generate_op([self], "transpose", input_order)
508
+
509
+ def reshape(self, *shape: Union[int, Sequence[int]]) -> "Tensor":
510
+ """
511
+ Return the transpose of the tensor.
512
+
513
+ Args:
514
+ shape (Union[int, Sequence[int]]): The new shape of the tensor.
515
+
516
+ Returns:
517
+ Tensor: The transposed tensor.
518
+ """
519
+ if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
520
+ shape = shape[0] # type: ignore
521
+ return generate_op([self], "reshape", shape)
522
+
523
+ def view(self, *shape: Union[Sequence[int], int]) -> "Tensor":
524
+ """
525
+ Return the transpose of the tensor.
526
+
527
+ Args:
528
+ shape (Union[Sequence[int], int]): The new shape of the tensor.
529
+
530
+ Returns:
531
+ Tensor: The transposed tensor.
532
+ """
533
+ if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
534
+ shape = shape[0] # type: ignore
535
+
536
+ return self.reshape(*shape)
537
+
538
+ def flatten(self, start_dim=0, end_dim=-1) -> "Tensor":
539
+ """
540
+ Flatten the tensor.
541
+
542
+ Args:
543
+ start_dim (int): The first dim to flatten. Defaults to 0.
544
+ end_dim (int): The last dim to flatten. Defaults to -1.
545
+
546
+ Returns:
547
+ Tensor: The flattened tensor.
548
+ """
549
+ if end_dim < 0:
550
+ end_dim = len(self.shape) + end_dim + 1
551
+
552
+ flattened_dim = self.shape[start_dim:end_dim]
553
+ size = int(np.prod(flattened_dim))
554
+ new_shape = list(self.shape[:start_dim]) + [size] + list(self.shape[end_dim:])
555
+
556
+ return self.reshape(*new_shape)
557
+
558
+ def squeeze(self) -> "Tensor":
559
+ """
560
+ Remove dimensions of size 1 from the tensor.
561
+
562
+ Returns:
563
+ Tensor: The squeezed tensor.
564
+ """
565
+ return generate_op([self], "squeeze")
566
+
567
+ def unsqueeze(self, axis) -> "Tensor":
568
+ """
569
+ Add a dimension of size 1 to the tensor.
570
+
571
+ Args:
572
+ axis (int): The axis along which to add the dimension.
573
+
574
+ Returns:
575
+ Tensor: The unsqueezed tensor.
576
+ """
577
+ return generate_op([self], "unsqueeze", axis)
578
+
579
+ def __matmul__(self, other) -> "Tensor":
580
+ """
581
+ Perform matrix multiplication between two tensors.
582
+
583
+ Args:
584
+ other (Tensor): The tensor to be multiplied.
585
+
586
+ Returns:
587
+ Tensor: The result of the matrix multiplication.
588
+ """
589
+ return generate_op([self, other], "matmul", False, False)
590
+
591
+ def acos(self) -> "Tensor":
592
+ """
593
+ Apply the acos function to the tensor.
594
+
595
+ Returns:
596
+ Tensor: The result of applying the acos function.
597
+ """
598
+ return torch.acos(self)
599
+
600
+ def asin(self) -> "Tensor":
601
+ """
602
+ Apply the asin function to the tensor.
603
+
604
+ Returns:
605
+ Tensor: The result of applying the asin function.
606
+ """
607
+ return torch.asin(self)
608
+
609
+ def atan(self) -> "Tensor":
610
+ """
611
+ Apply the atan function to the tensor.
612
+
613
+ Returns:
614
+ Tensor: The result of applying the atan function.
615
+ """
616
+ return torch.atan(self)
617
+
618
+ def acosh(self) -> "Tensor":
619
+ """
620
+ Apply the acosh function to the tensor.
621
+
622
+ Returns:
623
+ Tensor: The result of applying the acosh function.
624
+ """
625
+ return torch.acosh(self)
626
+
627
+ def asinh(self) -> "Tensor":
628
+ """
629
+ Apply the asinh function to the tensor.
630
+
631
+ Returns:
632
+ Tensor: The result of applying the asinh function.
633
+ """
634
+ return torch.asinh(self)
635
+
636
+ def atanh(self) -> "Tensor":
637
+ """
638
+ Apply the atanh function to the tensor.
639
+
640
+ Returns:
641
+ Tensor: The result of applying the atanh function.
642
+ """
643
+ return torch.atanh(self)
644
+
645
+ def cosh(self) -> "Tensor":
646
+ """
647
+ Apply the cosh function to the tensor.
648
+
649
+ Returns:
650
+ Tensor: The result of applying the cosh function.
651
+ """
652
+ return torch.cosh(self)
653
+
654
+ def sinh(self) -> "Tensor":
655
+ """
656
+ Apply the sinh function to the tensor.
657
+
658
+ Returns:
659
+ Tensor: The result of applying the sinh function.
660
+ """
661
+ return torch.sinh(self)
662
+
663
+ def tanh(self) -> "Tensor":
664
+ """
665
+ Apply the tanh function to the tensor.
666
+
667
+ Returns:
668
+ Tensor: The result of applying the tanh function.
669
+ """
670
+ return torch.tanh(self)
671
+
672
+ def cos(self) -> "Tensor":
673
+ """
674
+ Apply the cos function to the tensor.
675
+
676
+ Returns:
677
+ Tensor: The result of applying the cos function.
678
+ """
679
+ return torch.cos(self)
680
+
681
+ def sin(self) -> "Tensor":
682
+ """
683
+ Apply the sin function to the tensor.
684
+
685
+ Returns:
686
+ Tensor: The result of applying the sin function.
687
+ """
688
+ return torch.sin(self)
689
+
690
+ def tan(self) -> "Tensor":
691
+ """
692
+ Apply the tan function to the tensor.
693
+
694
+ Returns:
695
+ Tensor: The result of applying the tan function.
696
+ """
697
+ return torch.tan(self)
698
+
699
+ def ceiling(self) -> "Tensor":
700
+ """
701
+ Apply the ceiling function to the tensor.
702
+
703
+ Returns:
704
+ Tensor: The result of applying the ceiling function.
705
+ """
706
+ return generate_op([self], "ceiling")
707
+
708
+ def clamp(self, min=None, max=None) -> "Tensor":
709
+ """
710
+ Apply the clamp function to the tensor.
711
+
712
+ Args:
713
+ min (int, float): The lower-bound of the range to be clamped
714
+ max (int, float): The upper-bound of the range to be clamped
715
+
716
+ Returns:
717
+ Tensor: The result of applying the ceil function.
718
+ """
719
+ return torch.clamp(self, min=min, max=max)
720
+
721
+ def elu(self, alpha: float = 1.0) -> "Tensor":
722
+ """
723
+ Apply the elu function to the tensor.
724
+
725
+ Args:
726
+ alpha (float): The alpha value. Defaults to 1.0.
727
+
728
+ Returns:
729
+ Tensor: The result of applying the elu function.
730
+ """
731
+ return generate_op([self], "elu", alpha)
732
+
733
+ def erf(self) -> "Tensor":
734
+ """
735
+ Apply the erf function to the tensor.
736
+
737
+ Returns:
738
+ Tensor: The result of applying the erf function.
739
+ """
740
+ return torch.erf(self)
741
+
742
+ def exp(self) -> "Tensor":
743
+ """
744
+ Apply the exp function to the tensor.
745
+
746
+ Returns:
747
+ Tensor: The result of applying the exp function.
748
+ """
749
+ return torch.exp(self)
750
+
751
+ def floor(self) -> "Tensor":
752
+ """
753
+ Apply the floor function to the tensor.
754
+
755
+ Returns:
756
+ Tensor: The result of applying the floor function.
757
+ """
758
+ return torch.floor(self)
759
+
760
+ def grn(self, bias: float = 1e-12) -> "Tensor":
761
+ """
762
+ Apply the grn function to the tensor.
763
+
764
+ Args:
765
+ bias (float): The bias value. Defaults to 1e-12.
766
+
767
+ Returns:
768
+ Tensor: The result of applying the grn function.
769
+ """
770
+ return generate_op([self], "grn", bias)
771
+
772
+ def hsigmoid(self) -> "Tensor":
773
+ """
774
+ Apply the hsigmoid function to the tensor.
775
+
776
+ Returns:
777
+ Tensor: The result of applying the hsigmoid function.
778
+ """
779
+ return generate_op([self], "hsigmoid")
780
+
781
+ def hswish(self) -> "Tensor":
782
+ """
783
+ Apply the hswish function to the tensor.
784
+
785
+ Returns:
786
+ Tensor: The result of applying the hswish function.
787
+ """
788
+ return generate_op([self], "hswish")
789
+
790
+ def log(self) -> "Tensor":
791
+ """
792
+ Apply the log function to the tensor.
793
+
794
+ Returns:
795
+ Tensor: The result of applying the log function.
796
+ """
797
+ return torch.log(self)
798
+
799
+ def mish(self) -> "Tensor":
800
+ """
801
+ Apply the mish function to the tensor.
802
+
803
+ Returns:
804
+ Tensor: The result of applying the mish function.
805
+ """
806
+ return generate_op([self], "mish")
807
+
808
+ def relu(self) -> "Tensor":
809
+ """
810
+ Apply the relu function to the tensor.
811
+
812
+ Returns:
813
+ Tensor: The result of applying the relu function.
814
+ """
815
+ return generate_op([self], "relu")
816
+
817
+ def round(self) -> "Tensor":
818
+ """
819
+ Apply the round function to the tensor.
820
+
821
+ Returns:
822
+ Tensor: The result of applying the round function.
823
+ """
824
+ return torch.round(self)
825
+
826
+ def sigmoid(self) -> "Tensor":
827
+ """
828
+ Apply the sigmoid function to the tensor.
829
+
830
+ Returns:
831
+ Tensor: The result of applying the sigmoid function.
832
+ """
833
+ return generate_op([self], "sigmoid")
834
+
835
+ def sign(self) -> "Tensor":
836
+ """
837
+ Apply the sign function to the tensor.
838
+
839
+ Returns:
840
+ Tensor: The result of applying the sign function.
841
+ """
842
+ return torch.sign(self)
843
+
844
+ def softmax(self, dim) -> "Tensor":
845
+ """
846
+ Apply the softmax function to the tensor.
847
+
848
+ Args:
849
+ dim (int): The dimension to apply softmax.
850
+
851
+ Returns:
852
+ Tensor: The result of applying the softmax function.
853
+ """
854
+ return torch.nn.functional.softmax(self, dim=dim)
855
+
856
+ def softplus(self) -> "Tensor":
857
+ """
858
+ Apply the softplus function to the tensor.
859
+
860
+ Returns:
861
+ Tensor: The result of applying the softplus function.
862
+ """
863
+ return generate_op([self], "softplus")
864
+
865
+ def sqrt(self) -> "Tensor":
866
+ """
867
+ Apply the sqrt function to the tensor.
868
+
869
+ Returns:
870
+ Tensor: The result of applying the sqrt function.
871
+ """
872
+ return torch.sqrt(self)
873
+
874
+ def max(
875
+ self, dim: Optional[int] = None, keep_dims: Optional[bool] = False
876
+ ) -> "Tensor":
877
+ """
878
+ Return the reduced max tensor.
879
+
880
+ Args:
881
+ dim (Optional[int], optional): The dim to reduce. Default is None, and all dimensions are reduced.
882
+ keep_dims (Optional[bool], optional): If set to 1 it holds axes that are used for reduction. Defaults to False.
883
+
884
+ Returns:
885
+ Tensor: The result of max reducing operation.
886
+ """
887
+ return generate_op(self, "reduce_max", reduction_axes=dim, keep_dims=keep_dims)
888
+
889
+ def mean(
890
+ self,
891
+ dim: Optional[Union[int, Sequence[int]]] = None,
892
+ keep_dims: Optional[bool] = False,
893
+ dtype: Optional[torch.dtype] = None,
894
+ ) -> "Tensor":
895
+ """
896
+ Return the reduced mean tensor.
897
+
898
+ Args:
899
+ dim (Optional[Union[int, Sequence[int]]], optional): The dim(s) to reduce. Default is None, and all dimensions are reduced.
900
+ keep_dims (Optional[bool], optional): If set to 1 it holds axes that are used for reduction. Defaults to False.
901
+ dtype (Optional[torch.dtype], optional): The data type. Defaults to None.
902
+
903
+ Returns:
904
+ Tensor: The result of mean reducing operation.
905
+ """
906
+ mean = generate_op(self, "reduce_mean", reduction_axes=dim, keep_dims=keep_dims)
907
+ if dtype:
908
+ mean = mean.to(dtype)
909
+ return mean
910
+
911
+ def min(
912
+ self,
913
+ dim: Optional[int] = None,
914
+ keep_dims: Optional[bool] = False,
915
+ ) -> "Tensor":
916
+ """
917
+ Return the reduced min tensor.
918
+
919
+ Args:
920
+ dim (Optional[int], optional): The dim to reduce. Default is None, and all dimensions are reduced.
921
+ keep_dims (Optional[bool], optional): If set to 1 it holds axes that are used for reduction. Defaults to False.
922
+
923
+ Returns:
924
+ Tensor: The result of min reducing operation.
925
+ """
926
+ return generate_op(self, "reduce_min", reduction_axes=dim, keep_dims=keep_dims)
927
+
928
+ def prod(
929
+ self,
930
+ dim: Optional[int] = None,
931
+ keep_dims: Optional[bool] = False,
932
+ dtype: Optional[torch.dtype] = None,
933
+ ) -> "Tensor":
934
+ """
935
+ Return the reduced product tensor.
936
+
937
+ Args:
938
+ dim (Optional[int], optional): The dim to reduce. Default is None, and all dimensions are reduced.
939
+ keep_dims (Optional[bool], optional): If set to 1 it holds axes that are used for reduction. Defaults to False.
940
+ dtype (Optional[torch.dtype], optional): The data type. Defaults to None.
941
+
942
+ Returns:
943
+ Tensor: The result of product reducing operation.
944
+ """
945
+ prod = generate_op(self, "reduce_prod", reduction_axes=dim, keep_dims=keep_dims)
946
+ if dtype:
947
+ prod = prod.to(dtype)
948
+ return prod
949
+
950
+ def sum(
951
+ self,
952
+ dim: Optional[Union[int, Sequence[int]]] = None,
953
+ keep_dims: Optional[bool] = False,
954
+ dtype: Optional[torch.dtype] = None,
955
+ ) -> "Tensor":
956
+ """
957
+ Return the reduced sum tensor.
958
+
959
+ Args:
960
+ dim (Optional[Union[int, Sequence[int]]], optional): The dim(s) to reduce. Default is None, and all dimensions are reduced.
961
+ keep_dims (Optional[bool], optional): If set to 1 it holds axes that are used for reduction. Defaults to False.
962
+ dtype (Optional[torch.dtype], optional): The data type. Defaults to None.
963
+
964
+ Returns:
965
+ Tensor: The result of sum reducing operation.
966
+ """
967
+ sum = generate_op(self, "reduce_sum", reduction_axes=dim, keep_dims=keep_dims)
968
+ if dtype:
969
+ sum = sum.to(dtype)
970
+ return sum
971
+
972
+ def chunk(
973
+ self,
974
+ chunks: int,
975
+ dim: int = 0,
976
+ ) -> Union["Tensor", list]:
977
+ """
978
+ Return the list of tensor chunks.
979
+
980
+ Args:
981
+ chunks (int): The number of chunks to return.
982
+ dim (int): The dimension along which to split the tensor. Default is 0.
983
+
984
+ Returns:
985
+ Union["Tensor", list]: The resulting list of split tensors or a single tensor.
986
+
987
+ Raises:
988
+ ValueError: The input chunks value is not valid.
989
+ """
990
+ if chunks <= 0:
991
+ raise ValueError("The input chunks value is not valid.")
992
+ if chunks == 1:
993
+ return self
994
+ tensors = []
995
+ remainder = self.shape[dim] % chunks
996
+ chunk_size = self.shape[dim] // chunks + (1 if remainder > 0 else 0)
997
+ num_dims = self.dim()
998
+
999
+ start_idx = 0
1000
+ for _ in range(chunks):
1001
+ indexes = [slice(None)] * num_dims
1002
+ end_idx = start_idx + chunk_size
1003
+ end_idx = end_idx if end_idx < self.shape[dim] else self.shape[dim]
1004
+ indexes[dim] = slice(start_idx, end_idx)
1005
+ tensors.append(self.__getitem__(tuple(indexes)))
1006
+ start_idx = end_idx
1007
+ return tensors
1008
+
1009
+ def to(self, dtype: NPUDtype) -> "Tensor":
1010
+ """
1011
+ Convert the tensor to the specified data type.
1012
+
1013
+ Args:
1014
+ dtype (NPUDtype): The data type to convert the tensor to.
1015
+
1016
+ Returns:
1017
+ Tensor: The converted tensor.
1018
+ """
1019
+ return generate_op([self], "to", dtype)
1020
+
1021
+ @classmethod
1022
+ def __torch_function__(
1023
+ cls: Any,
1024
+ func: Any,
1025
+ types: Any,
1026
+ args: Sequence[Any] = (),
1027
+ kwargs: Optional[MutableMapping[Any, Any]] = None,
1028
+ ) -> Any:
1029
+ """Python function to override torch functions for Tensor class.
1030
+
1031
+ Args:
1032
+ func (Any): the function to override.
1033
+ types (Any): the types of the arguments.
1034
+ args (Sequence[Any], optional): the arguments. Defaults to ().
1035
+ kwargs (Optional[MutableMapping[Any, Any]], optional): the keyword arguments. Defaults to None.
1036
+
1037
+ Returns:
1038
+ Any: the result of the function.
1039
+ """
1040
+ if kwargs is None:
1041
+ kwargs = {}
1042
+ if func not in HANDLED_FUNCTIONS or not all(
1043
+ issubclass(t, (torch.Tensor, Tensor)) for t in types
1044
+ ):
1045
+ return NotImplemented
1046
+ return HANDLED_FUNCTIONS[func](*args, **kwargs)
1047
+
1048
+
1049
+ HANDLED_FUNCTIONS: MutableMapping[Any, Any] = {}
1050
+
1051
+
1052
+ def implements(torch_function: Any) -> Any:
1053
+ """Implement a decorator to override torch functions for Tensor class.
1054
+
1055
+ Args:
1056
+ torch_function (Any): the function to override.
1057
+
1058
+ Returns:
1059
+ Any: the result of the function.
1060
+ """
1061
+
1062
+ def decorator(func: Any) -> Any:
1063
+ """Implement a decorator to override torch functions for Tensor class.
1064
+
1065
+ Args:
1066
+ func (Any): the function to override.
1067
+
1068
+ Returns:
1069
+ Any: the result of the function.
1070
+ """
1071
+ functools.update_wrapper(func, torch_function)
1072
+ HANDLED_FUNCTIONS[torch_function] = func
1073
+ return func
1074
+
1075
+ return decorator
1076
+
1077
+
1078
+ def generate_op(
1079
+ tensors: Union[Sequence[Union[Tensor, torch.Tensor]], Union[Tensor, torch.Tensor]],
1080
+ op: str,
1081
+ *args: Any,
1082
+ **kwargs: Any,
1083
+ ) -> "Tensor":
1084
+ """
1085
+ Generate a new tensor by applying the specified operation to a sequence of tensors.
1086
+
1087
+ Args:
1088
+ tensors (Union[Sequence[Union[Tensor, torch.Tensor]], Union[Tensor, torch.Tensor]]): A sequence or a single tensor.
1089
+ op (str): The name of the operation to apply.
1090
+ args (Any): Variable length argument list.
1091
+ kwargs (Any): Arbitrary keyword arguments.
1092
+
1093
+ Returns:
1094
+ Tensor: A new tensor generated by applying the operation to the input tensors.
1095
+
1096
+ Raises:
1097
+ ValueError: If the tensors are not from the same factory.
1098
+
1099
+ """
1100
+ if not isinstance(tensors, (list, tuple)):
1101
+ tensors = [tensors]
1102
+
1103
+ # Check that all tensors are from the same factory
1104
+ if (
1105
+ not len({tensor.factory for tensor in tensors if isinstance(tensor, Tensor)})
1106
+ == 1
1107
+ ):
1108
+ raise ValueError("All tensors must be from the same factory")
1109
+
1110
+ # Get the first factory from the tensors
1111
+ factory = [t for t in tensors if isinstance(t, Tensor)][0].factory
1112
+
1113
+ # Replace the tensors that are not from the factory with constant tensors if they are coming from pytorch
1114
+ tensors = [
1115
+ tensor if isinstance(tensor, Tensor) else factory.constant(tensor)
1116
+ for tensor in tensors
1117
+ ]
1118
+
1119
+ # Create the operation
1120
+ return factory.__getattribute__(op)(*tensors, *args, **kwargs)