vkdispatch-core 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vkdispatch/__init__.py +49 -0
- vkdispatch/__main__.py +4 -0
- vkdispatch/_compat/__init__.py +2 -0
- vkdispatch/_compat/numpy_compat.py +583 -0
- vkdispatch/backends/__init__.py +1 -0
- vkdispatch/backends/dummy_native.py +819 -0
- vkdispatch/backends/pycuda_native.py +1371 -0
- vkdispatch/base/__init__.py +0 -0
- vkdispatch/base/backend.py +103 -0
- vkdispatch/base/brython_utils.py +4 -0
- vkdispatch/base/buffer.py +300 -0
- vkdispatch/base/command_list.py +114 -0
- vkdispatch/base/compute_plan.py +45 -0
- vkdispatch/base/context.py +603 -0
- vkdispatch/base/descriptor_set.py +56 -0
- vkdispatch/base/dtype.py +407 -0
- vkdispatch/base/errors.py +45 -0
- vkdispatch/base/image.py +429 -0
- vkdispatch/base/init.py +676 -0
- vkdispatch/cli.py +25 -0
- vkdispatch/codegen/__init__.py +74 -0
- vkdispatch/codegen/abreviations.py +26 -0
- vkdispatch/codegen/arguments.py +43 -0
- vkdispatch/codegen/backends/__init__.py +3 -0
- vkdispatch/codegen/backends/base.py +204 -0
- vkdispatch/codegen/backends/cuda.py +1603 -0
- vkdispatch/codegen/backends/glsl.py +168 -0
- vkdispatch/codegen/builder.py +371 -0
- vkdispatch/codegen/functions/__init__.py +0 -0
- vkdispatch/codegen/functions/atomic_memory.py +20 -0
- vkdispatch/codegen/functions/base_functions/__init__.py +0 -0
- vkdispatch/codegen/functions/base_functions/arithmetic.py +340 -0
- vkdispatch/codegen/functions/base_functions/arithmetic_comparisons.py +47 -0
- vkdispatch/codegen/functions/base_functions/base_utils.py +123 -0
- vkdispatch/codegen/functions/base_functions/bitwise.py +185 -0
- vkdispatch/codegen/functions/block_synchonization.py +27 -0
- vkdispatch/codegen/functions/builtin_constants.py +98 -0
- vkdispatch/codegen/functions/common_builtins.py +430 -0
- vkdispatch/codegen/functions/complex_numbers.py +38 -0
- vkdispatch/codegen/functions/control_flow.py +91 -0
- vkdispatch/codegen/functions/exponential.py +113 -0
- vkdispatch/codegen/functions/geometric.py +83 -0
- vkdispatch/codegen/functions/index_raveling.py +83 -0
- vkdispatch/codegen/functions/matrix.py +83 -0
- vkdispatch/codegen/functions/printing.py +29 -0
- vkdispatch/codegen/functions/registers.py +83 -0
- vkdispatch/codegen/functions/subgroups.py +31 -0
- vkdispatch/codegen/functions/trigonometry.py +191 -0
- vkdispatch/codegen/functions/type_casting.py +80 -0
- vkdispatch/codegen/functions/utils.py +34 -0
- vkdispatch/codegen/global_builder.py +89 -0
- vkdispatch/codegen/shader_writer.py +93 -0
- vkdispatch/codegen/struct_builder.py +48 -0
- vkdispatch/codegen/variables/__init__.py +0 -0
- vkdispatch/codegen/variables/base_variable.py +82 -0
- vkdispatch/codegen/variables/bound_variables.py +133 -0
- vkdispatch/codegen/variables/variables.py +394 -0
- vkdispatch/execution_pipeline/__init__.py +0 -0
- vkdispatch/execution_pipeline/buffer_builder.py +278 -0
- vkdispatch/execution_pipeline/command_graph.py +282 -0
- vkdispatch/fft/__init__.py +36 -0
- vkdispatch/fft/config.py +172 -0
- vkdispatch/fft/context.py +185 -0
- vkdispatch/fft/cooley_tukey.py +174 -0
- vkdispatch/fft/functions.py +247 -0
- vkdispatch/fft/global_memory_iterators.py +323 -0
- vkdispatch/fft/grid_manager.py +259 -0
- vkdispatch/fft/io_manager.py +166 -0
- vkdispatch/fft/io_proxy.py +51 -0
- vkdispatch/fft/memory_iterators.py +90 -0
- vkdispatch/fft/prime_utils.py +66 -0
- vkdispatch/fft/registers.py +112 -0
- vkdispatch/fft/resources.py +147 -0
- vkdispatch/fft/sdata_manager.py +104 -0
- vkdispatch/fft/shader_factories.py +169 -0
- vkdispatch/fft/src_functions.py +342 -0
- vkdispatch/reduce/__init__.py +8 -0
- vkdispatch/reduce/decorator.py +64 -0
- vkdispatch/reduce/operations.py +64 -0
- vkdispatch/reduce/reduce_function.py +163 -0
- vkdispatch/reduce/stage.py +165 -0
- vkdispatch/shader/__init__.py +0 -0
- vkdispatch/shader/context.py +46 -0
- vkdispatch/shader/decorator.py +54 -0
- vkdispatch/shader/map.py +71 -0
- vkdispatch/shader/shader_function.py +397 -0
- vkdispatch/shader/signature.py +166 -0
- vkdispatch/vkfft/__init__.py +9 -0
- vkdispatch/vkfft/vkfft_dispatcher.py +399 -0
- vkdispatch/vkfft/vkfft_plan.py +112 -0
- vkdispatch_core-0.0.32.dist-info/METADATA +101 -0
- vkdispatch_core-0.0.32.dist-info/RECORD +95 -0
- vkdispatch_core-0.0.32.dist-info/WHEEL +5 -0
- vkdispatch_core-0.0.32.dist-info/licenses/LICENSE +201 -0
- vkdispatch_core-0.0.32.dist-info/top_level.txt +1 -0
vkdispatch/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .base.init import DeviceInfo
|
|
2
|
+
from .base.init import LogLevel
|
|
3
|
+
from .base.init import get_devices
|
|
4
|
+
from .base.init import get_backend
|
|
5
|
+
from .base.init import initialize
|
|
6
|
+
from .base.init import is_initialized
|
|
7
|
+
from .base.init import log, log_error, log_warning, log_info, log_verbose, set_log_level
|
|
8
|
+
|
|
9
|
+
from .base.dtype import dtype
|
|
10
|
+
from .base.dtype import float32, int32, uint32, complex64
|
|
11
|
+
from .base.dtype import vec2, vec3, vec4, ivec2, ivec3, ivec4, uvec2, uvec3, uvec4
|
|
12
|
+
from .base.dtype import mat2, mat3, mat4
|
|
13
|
+
|
|
14
|
+
from .base.context import get_context, queue_wait_idle, Signal
|
|
15
|
+
from .base.context import get_context_handle
|
|
16
|
+
from .base.context import make_context, select_queue_families, set_dummy_context_params
|
|
17
|
+
from .base.context import is_context_initialized
|
|
18
|
+
|
|
19
|
+
from .base.buffer import asbuffer
|
|
20
|
+
from .base.buffer import Buffer, buffer_u32, buffer_i32, buffer_f32, buffer_c64
|
|
21
|
+
from .base.buffer import asrfftbuffer
|
|
22
|
+
from .base.buffer import RFFTBuffer
|
|
23
|
+
|
|
24
|
+
from .base.image import image_format
|
|
25
|
+
from .base.image import image_type
|
|
26
|
+
from .base.image import image_view_type
|
|
27
|
+
from .base.image import Image
|
|
28
|
+
from .base.image import Image1D
|
|
29
|
+
from .base.image import Image2D
|
|
30
|
+
from .base.image import Image2DArray
|
|
31
|
+
from .base.image import Image3D
|
|
32
|
+
from .base.image import Sampler
|
|
33
|
+
from .base.image import Filter
|
|
34
|
+
from .base.image import AddressMode
|
|
35
|
+
from .base.image import BorderColor
|
|
36
|
+
|
|
37
|
+
from .execution_pipeline.command_graph import CommandGraph, BufferBindInfo, ImageBindInfo
|
|
38
|
+
from .execution_pipeline.command_graph import global_graph, set_global_graph, default_graph
|
|
39
|
+
|
|
40
|
+
from .shader.shader_function import ShaderFunction, ShaderSource
|
|
41
|
+
from .shader.context import ShaderContext, shader_context
|
|
42
|
+
from .shader.map import map, MappingFunction
|
|
43
|
+
from .shader.decorator import shader
|
|
44
|
+
|
|
45
|
+
import vkdispatch.vkfft as vkfft
|
|
46
|
+
import vkdispatch.fft as fft
|
|
47
|
+
import vkdispatch.reduce as reduce
|
|
48
|
+
|
|
49
|
+
__version__ = "0.0.32"
|
vkdispatch/__main__.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
import cmath
|
|
5
|
+
import math
|
|
6
|
+
import struct
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Iterable, List, Sequence, Tuple
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import numpy as _np
|
|
13
|
+
except Exception: # pragma: no cover - intentionally broad for optional dependency import
|
|
14
|
+
_np = None
|
|
15
|
+
|
|
16
|
+
HAS_NUMPY = _np is not None
|
|
17
|
+
pi = math.pi
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def require_numpy(feature_name: str) -> None:
|
|
21
|
+
if HAS_NUMPY:
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
raise RuntimeError(
|
|
25
|
+
f"{feature_name} requires numpy, but numpy is not available. "
|
|
26
|
+
"Install numpy or use the bytes-based API."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def numpy_module():
|
|
31
|
+
return _np
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def prod(values: Iterable[int]) -> int:
|
|
35
|
+
values_tuple = tuple(values)
|
|
36
|
+
|
|
37
|
+
if HAS_NUMPY:
|
|
38
|
+
return int(_np.prod(values_tuple))
|
|
39
|
+
|
|
40
|
+
result = 1
|
|
41
|
+
for value in values_tuple:
|
|
42
|
+
result *= int(value)
|
|
43
|
+
return result
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def ceil(value: float) -> float:
|
|
47
|
+
if HAS_NUMPY:
|
|
48
|
+
return float(_np.ceil(value))
|
|
49
|
+
return float(math.ceil(value))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def floor(value: float) -> float:
|
|
53
|
+
if HAS_NUMPY:
|
|
54
|
+
return float(_np.floor(value))
|
|
55
|
+
return float(math.floor(value))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def trunc(value: float) -> float:
|
|
59
|
+
if HAS_NUMPY:
|
|
60
|
+
return float(_np.trunc(value))
|
|
61
|
+
return float(math.trunc(value))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def round(value: float) -> float:
|
|
65
|
+
if HAS_NUMPY:
|
|
66
|
+
return float(_np.round(value))
|
|
67
|
+
return float(builtins.round(value))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def sign(value: float) -> float:
|
|
71
|
+
if HAS_NUMPY:
|
|
72
|
+
return float(_np.sign(value))
|
|
73
|
+
|
|
74
|
+
if value > 0:
|
|
75
|
+
return 1.0
|
|
76
|
+
if value < 0:
|
|
77
|
+
return -1.0
|
|
78
|
+
return 0.0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def abs_value(value: Any) -> float:
|
|
82
|
+
if HAS_NUMPY:
|
|
83
|
+
return float(_np.abs(value))
|
|
84
|
+
return float(abs(value))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def minimum(x: float, y: float) -> float:
|
|
88
|
+
if HAS_NUMPY:
|
|
89
|
+
return float(_np.minimum(x, y))
|
|
90
|
+
return float(x if x <= y else y)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def maximum(x: float, y: float) -> float:
|
|
94
|
+
if HAS_NUMPY:
|
|
95
|
+
return float(_np.maximum(x, y))
|
|
96
|
+
return float(x if x >= y else y)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def clip(x: float, min_value: float, max_value: float) -> float:
|
|
100
|
+
if HAS_NUMPY:
|
|
101
|
+
return float(_np.clip(x, min_value, max_value))
|
|
102
|
+
return float(min(max(x, min_value), max_value))
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def mod(x: float, y: float) -> float:
|
|
106
|
+
if HAS_NUMPY:
|
|
107
|
+
return float(_np.mod(x, y))
|
|
108
|
+
return float(x % y)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def modf(x: float, _unused: Any = None) -> Tuple[float, float]:
|
|
112
|
+
if HAS_NUMPY:
|
|
113
|
+
frac, whole = _np.modf(x)
|
|
114
|
+
return float(frac), float(whole)
|
|
115
|
+
|
|
116
|
+
frac, whole = math.modf(x)
|
|
117
|
+
return float(frac), float(whole)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float:
|
|
121
|
+
if HAS_NUMPY:
|
|
122
|
+
return float(_np.interp(x, xp, fp))
|
|
123
|
+
|
|
124
|
+
if len(xp) != len(fp):
|
|
125
|
+
raise ValueError("xp and fp must have the same length")
|
|
126
|
+
if len(xp) == 0:
|
|
127
|
+
raise ValueError("xp and fp must be non-empty")
|
|
128
|
+
if len(xp) == 1:
|
|
129
|
+
return float(fp[0])
|
|
130
|
+
|
|
131
|
+
if x <= xp[0]:
|
|
132
|
+
return float(fp[0])
|
|
133
|
+
if x >= xp[-1]:
|
|
134
|
+
return float(fp[-1])
|
|
135
|
+
|
|
136
|
+
for index in range(1, len(xp)):
|
|
137
|
+
if x <= xp[index]:
|
|
138
|
+
x0 = xp[index - 1]
|
|
139
|
+
x1 = xp[index]
|
|
140
|
+
y0 = fp[index - 1]
|
|
141
|
+
y1 = fp[index]
|
|
142
|
+
|
|
143
|
+
if x1 == x0:
|
|
144
|
+
return float(y0)
|
|
145
|
+
|
|
146
|
+
t = (x - x0) / (x1 - x0)
|
|
147
|
+
return float(y0 + t * (y1 - y0))
|
|
148
|
+
|
|
149
|
+
return float(fp[-1])
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def isnan(value: float) -> bool:
|
|
153
|
+
if HAS_NUMPY:
|
|
154
|
+
return bool(_np.isnan(value))
|
|
155
|
+
return math.isnan(value)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def isinf(value: float) -> bool:
|
|
159
|
+
if HAS_NUMPY:
|
|
160
|
+
return bool(_np.isinf(value))
|
|
161
|
+
return math.isinf(value)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def power(x: float, y: float) -> float:
|
|
165
|
+
if HAS_NUMPY:
|
|
166
|
+
return float(_np.power(x, y))
|
|
167
|
+
return float(math.pow(x, y))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def exp(value: float) -> float:
|
|
171
|
+
if HAS_NUMPY:
|
|
172
|
+
return float(_np.exp(value))
|
|
173
|
+
return float(math.exp(value))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def exp2(value: float) -> float:
|
|
177
|
+
if HAS_NUMPY:
|
|
178
|
+
return float(_np.exp2(value))
|
|
179
|
+
if hasattr(math, "exp2"):
|
|
180
|
+
return float(math.exp2(value))
|
|
181
|
+
return float(math.pow(2.0, value))
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def log(value: float) -> float:
|
|
185
|
+
if HAS_NUMPY:
|
|
186
|
+
return float(_np.log(value))
|
|
187
|
+
return float(math.log(value))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def log2(value: float) -> float:
|
|
191
|
+
if HAS_NUMPY:
|
|
192
|
+
return float(_np.log2(value))
|
|
193
|
+
return float(math.log2(value))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def sqrt(value: float) -> float:
|
|
197
|
+
if HAS_NUMPY:
|
|
198
|
+
return float(_np.sqrt(value))
|
|
199
|
+
return float(math.sqrt(value))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def sin(value: float) -> float:
|
|
203
|
+
if HAS_NUMPY:
|
|
204
|
+
return float(_np.sin(value))
|
|
205
|
+
return float(math.sin(value))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def cos(value: float) -> float:
|
|
209
|
+
if HAS_NUMPY:
|
|
210
|
+
return float(_np.cos(value))
|
|
211
|
+
return float(math.cos(value))
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def tan(value: float) -> float:
|
|
215
|
+
if HAS_NUMPY:
|
|
216
|
+
return float(_np.tan(value))
|
|
217
|
+
return float(math.tan(value))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def arcsin(value: float) -> float:
|
|
221
|
+
if HAS_NUMPY:
|
|
222
|
+
return float(_np.arcsin(value))
|
|
223
|
+
return float(math.asin(value))
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def arccos(value: float) -> float:
|
|
227
|
+
if HAS_NUMPY:
|
|
228
|
+
return float(_np.arccos(value))
|
|
229
|
+
return float(math.acos(value))
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def arctan(value: float) -> float:
|
|
233
|
+
if HAS_NUMPY:
|
|
234
|
+
return float(_np.arctan(value))
|
|
235
|
+
return float(math.atan(value))
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def arctan2(y: float, x: float) -> float:
|
|
239
|
+
if HAS_NUMPY:
|
|
240
|
+
return float(_np.arctan2(y, x))
|
|
241
|
+
return float(math.atan2(y, x))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def sinh(value: float) -> float:
|
|
245
|
+
if HAS_NUMPY:
|
|
246
|
+
return float(_np.sinh(value))
|
|
247
|
+
return float(math.sinh(value))
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def cosh(value: float) -> float:
|
|
251
|
+
if HAS_NUMPY:
|
|
252
|
+
return float(_np.cosh(value))
|
|
253
|
+
return float(math.cosh(value))
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def tanh(value: float) -> float:
|
|
257
|
+
if HAS_NUMPY:
|
|
258
|
+
return float(_np.tanh(value))
|
|
259
|
+
return float(math.tanh(value))
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def arcsinh(value: float) -> float:
|
|
263
|
+
if HAS_NUMPY:
|
|
264
|
+
return float(_np.arcsinh(value))
|
|
265
|
+
return float(math.asinh(value))
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def arccosh(value: float) -> float:
|
|
269
|
+
if HAS_NUMPY:
|
|
270
|
+
return float(_np.arccosh(value))
|
|
271
|
+
return float(math.acosh(value))
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def arctanh(value: float) -> float:
|
|
275
|
+
if HAS_NUMPY:
|
|
276
|
+
return float(_np.arctanh(value))
|
|
277
|
+
return float(math.atanh(value))
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def dot(x: Any, y: Any) -> float:
|
|
281
|
+
if HAS_NUMPY:
|
|
282
|
+
return float(_np.dot(x, y))
|
|
283
|
+
|
|
284
|
+
if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)):
|
|
285
|
+
return float(x * y)
|
|
286
|
+
|
|
287
|
+
return float(sum(a * b for a, b in zip(x, y)))
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def angle(value: complex) -> float:
|
|
291
|
+
if HAS_NUMPY:
|
|
292
|
+
return float(_np.angle(value))
|
|
293
|
+
return float(cmath.phase(value))
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def exp_complex(value: complex) -> complex:
|
|
297
|
+
if HAS_NUMPY:
|
|
298
|
+
return complex(_np.exp(value))
|
|
299
|
+
return cmath.exp(value)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def is_numpy_integer_scalar(value: Any) -> bool:
|
|
303
|
+
return bool(HAS_NUMPY and _np.issubdtype(type(value), _np.integer))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def is_integer_scalar(value: Any) -> bool:
|
|
307
|
+
return isinstance(value, int) or is_numpy_integer_scalar(value)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def is_numpy_floating_instance(value: Any) -> bool:
|
|
311
|
+
return bool(HAS_NUMPY and isinstance(value, _np.floating))
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@dataclass(frozen=True)
|
|
315
|
+
class HostDType:
|
|
316
|
+
name: str
|
|
317
|
+
itemsize: int
|
|
318
|
+
struct_format: str
|
|
319
|
+
kind: str
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
INT32 = HostDType("int32", 4, "i", "int")
|
|
323
|
+
UINT32 = HostDType("uint32", 4, "I", "uint")
|
|
324
|
+
FLOAT32 = HostDType("float32", 4, "f", "float")
|
|
325
|
+
COMPLEX64 = HostDType("complex64", 8, "ff", "complex")
|
|
326
|
+
|
|
327
|
+
_HOST_DTYPES = {
|
|
328
|
+
"int32": INT32,
|
|
329
|
+
"uint32": UINT32,
|
|
330
|
+
"float32": FLOAT32,
|
|
331
|
+
"complex64": COMPLEX64,
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def host_dtype(name: str) -> HostDType:
|
|
336
|
+
if name not in _HOST_DTYPES:
|
|
337
|
+
raise ValueError(f"Unsupported dtype ({name})!")
|
|
338
|
+
return _HOST_DTYPES[name]
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def is_host_dtype(value: Any) -> bool:
|
|
342
|
+
return isinstance(value, HostDType)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def host_dtype_name(dtype: Any) -> str:
|
|
346
|
+
if isinstance(dtype, HostDType):
|
|
347
|
+
return dtype.name
|
|
348
|
+
|
|
349
|
+
if isinstance(dtype, str):
|
|
350
|
+
return dtype
|
|
351
|
+
|
|
352
|
+
if HAS_NUMPY:
|
|
353
|
+
return str(_np.dtype(dtype).name)
|
|
354
|
+
|
|
355
|
+
raise ValueError(f"Unsupported dtype ({dtype})!")
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def dtype_itemsize(dtype: Any) -> int:
|
|
359
|
+
if isinstance(dtype, HostDType):
|
|
360
|
+
return dtype.itemsize
|
|
361
|
+
|
|
362
|
+
if HAS_NUMPY:
|
|
363
|
+
return int(_np.dtype(dtype).itemsize)
|
|
364
|
+
|
|
365
|
+
return host_dtype(host_dtype_name(dtype)).itemsize
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def dtype_kind(dtype: Any) -> str:
|
|
369
|
+
if isinstance(dtype, HostDType):
|
|
370
|
+
return dtype.kind
|
|
371
|
+
|
|
372
|
+
if HAS_NUMPY:
|
|
373
|
+
dtype_obj = _np.dtype(dtype)
|
|
374
|
+
if _np.issubdtype(dtype_obj, _np.complexfloating):
|
|
375
|
+
return "complex"
|
|
376
|
+
if _np.issubdtype(dtype_obj, _np.unsignedinteger):
|
|
377
|
+
return "uint"
|
|
378
|
+
if _np.issubdtype(dtype_obj, _np.integer):
|
|
379
|
+
return "int"
|
|
380
|
+
if _np.issubdtype(dtype_obj, _np.floating):
|
|
381
|
+
return "float"
|
|
382
|
+
|
|
383
|
+
return host_dtype(host_dtype_name(dtype)).kind
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def dtype_struct_format(dtype: Any) -> str:
|
|
387
|
+
if isinstance(dtype, HostDType):
|
|
388
|
+
return dtype.struct_format
|
|
389
|
+
return host_dtype(host_dtype_name(dtype)).struct_format
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class CompatArray:
|
|
393
|
+
def __init__(self, buffer: bytes, dtype: HostDType, shape: Tuple[int, ...]):
|
|
394
|
+
self._buffer = bytes(buffer)
|
|
395
|
+
self.dtype = dtype
|
|
396
|
+
self.shape = tuple(shape)
|
|
397
|
+
self.size = prod(self.shape)
|
|
398
|
+
|
|
399
|
+
def reshape(self, shape: Tuple[int, ...]) -> "CompatArray":
|
|
400
|
+
shape = tuple(shape)
|
|
401
|
+
if prod(shape) != self.size:
|
|
402
|
+
raise ValueError("Cannot reshape array with mismatched element count")
|
|
403
|
+
return CompatArray(self._buffer, self.dtype, shape)
|
|
404
|
+
|
|
405
|
+
def tobytes(self) -> bytes:
|
|
406
|
+
return bytes(self._buffer)
|
|
407
|
+
|
|
408
|
+
@property
|
|
409
|
+
def nbytes(self) -> int:
|
|
410
|
+
return len(self._buffer)
|
|
411
|
+
|
|
412
|
+
def __repr__(self) -> str:
|
|
413
|
+
return f"CompatArray(shape={self.shape}, dtype={self.dtype.name}, nbytes={len(self._buffer)})"
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def is_array_like(value: Any) -> bool:
|
|
417
|
+
if HAS_NUMPY and isinstance(value, _np.ndarray):
|
|
418
|
+
return True
|
|
419
|
+
return isinstance(value, CompatArray)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def array_shape(value: Any) -> Tuple[int, ...]:
|
|
423
|
+
if HAS_NUMPY and isinstance(value, _np.ndarray):
|
|
424
|
+
return tuple(value.shape)
|
|
425
|
+
if isinstance(value, CompatArray):
|
|
426
|
+
return tuple(value.shape)
|
|
427
|
+
raise TypeError(f"Unsupported array-like value ({type(value)})")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def array_dtype(value: Any) -> Any:
|
|
431
|
+
if HAS_NUMPY and isinstance(value, _np.ndarray):
|
|
432
|
+
return value.dtype
|
|
433
|
+
if isinstance(value, CompatArray):
|
|
434
|
+
return value.dtype
|
|
435
|
+
raise TypeError(f"Unsupported array-like value ({type(value)})")
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def array_nbytes(value: Any) -> int:
|
|
439
|
+
if HAS_NUMPY and isinstance(value, _np.ndarray):
|
|
440
|
+
return int(value.size * value.dtype.itemsize)
|
|
441
|
+
if isinstance(value, CompatArray):
|
|
442
|
+
return value.nbytes
|
|
443
|
+
raise TypeError(f"Unsupported array-like value ({type(value)})")
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def as_contiguous_bytes(value: Any) -> bytes:
|
|
447
|
+
if HAS_NUMPY and isinstance(value, _np.ndarray):
|
|
448
|
+
return _np.ascontiguousarray(value).tobytes()
|
|
449
|
+
if isinstance(value, CompatArray):
|
|
450
|
+
return value.tobytes()
|
|
451
|
+
raise TypeError(f"Unsupported array-like value ({type(value)})")
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def from_buffer(buffer: bytes, dtype: Any, shape: Tuple[int, ...]):
|
|
455
|
+
dtype_name = host_dtype_name(dtype)
|
|
456
|
+
|
|
457
|
+
if HAS_NUMPY:
|
|
458
|
+
return _np.frombuffer(buffer, dtype=_np.dtype(dtype_name)).reshape(shape)
|
|
459
|
+
|
|
460
|
+
return CompatArray(buffer, host_dtype(dtype_name), tuple(shape))
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def ensure_bytes(value: Any) -> bytes:
|
|
464
|
+
if isinstance(value, bytes):
|
|
465
|
+
return value
|
|
466
|
+
if isinstance(value, bytearray):
|
|
467
|
+
return bytes(value)
|
|
468
|
+
if isinstance(value, memoryview):
|
|
469
|
+
return value.tobytes()
|
|
470
|
+
raise TypeError(f"Unsupported bytes-like object ({type(value)})")
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def is_bytes_like(value: Any) -> bool:
|
|
474
|
+
return isinstance(value, (bytes, bytearray, memoryview))
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def flatten(value: Any) -> List[Any]:
|
|
478
|
+
if isinstance(value, CompatArray):
|
|
479
|
+
return unpack_values(value.tobytes(), value.dtype)
|
|
480
|
+
|
|
481
|
+
if HAS_NUMPY and isinstance(value, _np.ndarray):
|
|
482
|
+
return value.reshape(-1).tolist()
|
|
483
|
+
|
|
484
|
+
if isinstance(value, (list, tuple)):
|
|
485
|
+
out: List[Any] = []
|
|
486
|
+
for element in value:
|
|
487
|
+
out.extend(flatten(element))
|
|
488
|
+
return out
|
|
489
|
+
|
|
490
|
+
return [value]
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def _coerce_scalar(value: Any, dtype: Any):
|
|
494
|
+
kind = dtype_kind(dtype)
|
|
495
|
+
|
|
496
|
+
if kind == "complex":
|
|
497
|
+
if isinstance(value, complex):
|
|
498
|
+
return value
|
|
499
|
+
if isinstance(value, (list, tuple)):
|
|
500
|
+
if len(value) != 2:
|
|
501
|
+
raise ValueError("Complex values must be complex scalars or pairs")
|
|
502
|
+
return complex(float(value[0]), float(value[1]))
|
|
503
|
+
return complex(value)
|
|
504
|
+
|
|
505
|
+
if kind == "float":
|
|
506
|
+
return float(value)
|
|
507
|
+
|
|
508
|
+
if kind in ("int", "uint"):
|
|
509
|
+
return int(value)
|
|
510
|
+
|
|
511
|
+
raise ValueError(f"Unsupported dtype kind ({kind})")
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def pack_values(values: Sequence[Any], dtype: Any) -> bytes:
|
|
515
|
+
values_list = list(values)
|
|
516
|
+
dtype_name = host_dtype_name(dtype)
|
|
517
|
+
|
|
518
|
+
if HAS_NUMPY:
|
|
519
|
+
array = _np.asarray(values_list, dtype=_np.dtype(dtype_name))
|
|
520
|
+
return array.tobytes()
|
|
521
|
+
|
|
522
|
+
host = host_dtype(dtype_name)
|
|
523
|
+
|
|
524
|
+
if host.kind == "complex":
|
|
525
|
+
output = bytearray()
|
|
526
|
+
for value in values_list:
|
|
527
|
+
coerced = _coerce_scalar(value, host)
|
|
528
|
+
output.extend(struct.pack("=ff", float(coerced.real), float(coerced.imag)))
|
|
529
|
+
return bytes(output)
|
|
530
|
+
|
|
531
|
+
pack_fmt = "=" + host.struct_format
|
|
532
|
+
output = bytearray()
|
|
533
|
+
for value in values_list:
|
|
534
|
+
output.extend(struct.pack(pack_fmt, _coerce_scalar(value, host)))
|
|
535
|
+
return bytes(output)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def unpack_values(data: bytes, dtype: Any) -> List[Any]:
|
|
539
|
+
dtype_name = host_dtype_name(dtype)
|
|
540
|
+
|
|
541
|
+
if HAS_NUMPY:
|
|
542
|
+
return _np.frombuffer(data, dtype=_np.dtype(dtype_name)).tolist()
|
|
543
|
+
|
|
544
|
+
host = host_dtype(dtype_name)
|
|
545
|
+
|
|
546
|
+
if host.kind == "complex":
|
|
547
|
+
values: List[Any] = []
|
|
548
|
+
for real, imag in struct.iter_unpack("=ff", data):
|
|
549
|
+
values.append(complex(real, imag))
|
|
550
|
+
return values
|
|
551
|
+
|
|
552
|
+
unpack_fmt = "=" + host.struct_format
|
|
553
|
+
stride = struct.calcsize(unpack_fmt)
|
|
554
|
+
values = []
|
|
555
|
+
|
|
556
|
+
for offset in range(0, len(data), stride):
|
|
557
|
+
values.append(struct.unpack(unpack_fmt, data[offset: offset + stride])[0])
|
|
558
|
+
|
|
559
|
+
return values
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def float_bits_to_int(value: float) -> int:
|
|
563
|
+
if HAS_NUMPY:
|
|
564
|
+
return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.int32)[0])
|
|
565
|
+
return int(struct.unpack("=i", struct.pack("=f", float(value)))[0])
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def float_bits_to_uint(value: float) -> int:
|
|
569
|
+
if HAS_NUMPY:
|
|
570
|
+
return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.uint32)[0])
|
|
571
|
+
return int(struct.unpack("=I", struct.pack("=f", float(value)))[0])
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def int_bits_to_float(value: int) -> float:
|
|
575
|
+
if HAS_NUMPY:
|
|
576
|
+
return float(_np.frombuffer(_np.int32(value).tobytes(), dtype=_np.float32)[0])
|
|
577
|
+
return float(struct.unpack("=f", struct.pack("=i", int(value)))[0])
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def uint_bits_to_float(value: int) -> float:
|
|
581
|
+
if HAS_NUMPY:
|
|
582
|
+
return float(_np.frombuffer(_np.uint32(value).tobytes(), dtype=_np.float32)[0])
|
|
583
|
+
return float(struct.unpack("=f", struct.pack("=I", int(value)))[0])
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__all__ = []
|