vkdispatch-core 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. vkdispatch/__init__.py +49 -0
  2. vkdispatch/__main__.py +4 -0
  3. vkdispatch/_compat/__init__.py +2 -0
  4. vkdispatch/_compat/numpy_compat.py +583 -0
  5. vkdispatch/backends/__init__.py +1 -0
  6. vkdispatch/backends/dummy_native.py +819 -0
  7. vkdispatch/backends/pycuda_native.py +1371 -0
  8. vkdispatch/base/__init__.py +0 -0
  9. vkdispatch/base/backend.py +103 -0
  10. vkdispatch/base/brython_utils.py +4 -0
  11. vkdispatch/base/buffer.py +300 -0
  12. vkdispatch/base/command_list.py +114 -0
  13. vkdispatch/base/compute_plan.py +45 -0
  14. vkdispatch/base/context.py +603 -0
  15. vkdispatch/base/descriptor_set.py +56 -0
  16. vkdispatch/base/dtype.py +407 -0
  17. vkdispatch/base/errors.py +45 -0
  18. vkdispatch/base/image.py +429 -0
  19. vkdispatch/base/init.py +676 -0
  20. vkdispatch/cli.py +25 -0
  21. vkdispatch/codegen/__init__.py +74 -0
  22. vkdispatch/codegen/abreviations.py +26 -0
  23. vkdispatch/codegen/arguments.py +43 -0
  24. vkdispatch/codegen/backends/__init__.py +3 -0
  25. vkdispatch/codegen/backends/base.py +204 -0
  26. vkdispatch/codegen/backends/cuda.py +1603 -0
  27. vkdispatch/codegen/backends/glsl.py +168 -0
  28. vkdispatch/codegen/builder.py +371 -0
  29. vkdispatch/codegen/functions/__init__.py +0 -0
  30. vkdispatch/codegen/functions/atomic_memory.py +20 -0
  31. vkdispatch/codegen/functions/base_functions/__init__.py +0 -0
  32. vkdispatch/codegen/functions/base_functions/arithmetic.py +340 -0
  33. vkdispatch/codegen/functions/base_functions/arithmetic_comparisons.py +47 -0
  34. vkdispatch/codegen/functions/base_functions/base_utils.py +123 -0
  35. vkdispatch/codegen/functions/base_functions/bitwise.py +185 -0
  36. vkdispatch/codegen/functions/block_synchonization.py +27 -0
  37. vkdispatch/codegen/functions/builtin_constants.py +98 -0
  38. vkdispatch/codegen/functions/common_builtins.py +430 -0
  39. vkdispatch/codegen/functions/complex_numbers.py +38 -0
  40. vkdispatch/codegen/functions/control_flow.py +91 -0
  41. vkdispatch/codegen/functions/exponential.py +113 -0
  42. vkdispatch/codegen/functions/geometric.py +83 -0
  43. vkdispatch/codegen/functions/index_raveling.py +83 -0
  44. vkdispatch/codegen/functions/matrix.py +83 -0
  45. vkdispatch/codegen/functions/printing.py +29 -0
  46. vkdispatch/codegen/functions/registers.py +83 -0
  47. vkdispatch/codegen/functions/subgroups.py +31 -0
  48. vkdispatch/codegen/functions/trigonometry.py +191 -0
  49. vkdispatch/codegen/functions/type_casting.py +80 -0
  50. vkdispatch/codegen/functions/utils.py +34 -0
  51. vkdispatch/codegen/global_builder.py +89 -0
  52. vkdispatch/codegen/shader_writer.py +93 -0
  53. vkdispatch/codegen/struct_builder.py +48 -0
  54. vkdispatch/codegen/variables/__init__.py +0 -0
  55. vkdispatch/codegen/variables/base_variable.py +82 -0
  56. vkdispatch/codegen/variables/bound_variables.py +133 -0
  57. vkdispatch/codegen/variables/variables.py +394 -0
  58. vkdispatch/execution_pipeline/__init__.py +0 -0
  59. vkdispatch/execution_pipeline/buffer_builder.py +278 -0
  60. vkdispatch/execution_pipeline/command_graph.py +282 -0
  61. vkdispatch/fft/__init__.py +36 -0
  62. vkdispatch/fft/config.py +172 -0
  63. vkdispatch/fft/context.py +185 -0
  64. vkdispatch/fft/cooley_tukey.py +174 -0
  65. vkdispatch/fft/functions.py +247 -0
  66. vkdispatch/fft/global_memory_iterators.py +323 -0
  67. vkdispatch/fft/grid_manager.py +259 -0
  68. vkdispatch/fft/io_manager.py +166 -0
  69. vkdispatch/fft/io_proxy.py +51 -0
  70. vkdispatch/fft/memory_iterators.py +90 -0
  71. vkdispatch/fft/prime_utils.py +66 -0
  72. vkdispatch/fft/registers.py +112 -0
  73. vkdispatch/fft/resources.py +147 -0
  74. vkdispatch/fft/sdata_manager.py +104 -0
  75. vkdispatch/fft/shader_factories.py +169 -0
  76. vkdispatch/fft/src_functions.py +342 -0
  77. vkdispatch/reduce/__init__.py +8 -0
  78. vkdispatch/reduce/decorator.py +64 -0
  79. vkdispatch/reduce/operations.py +64 -0
  80. vkdispatch/reduce/reduce_function.py +163 -0
  81. vkdispatch/reduce/stage.py +165 -0
  82. vkdispatch/shader/__init__.py +0 -0
  83. vkdispatch/shader/context.py +46 -0
  84. vkdispatch/shader/decorator.py +54 -0
  85. vkdispatch/shader/map.py +71 -0
  86. vkdispatch/shader/shader_function.py +397 -0
  87. vkdispatch/shader/signature.py +166 -0
  88. vkdispatch/vkfft/__init__.py +9 -0
  89. vkdispatch/vkfft/vkfft_dispatcher.py +399 -0
  90. vkdispatch/vkfft/vkfft_plan.py +112 -0
  91. vkdispatch_core-0.0.32.dist-info/METADATA +101 -0
  92. vkdispatch_core-0.0.32.dist-info/RECORD +95 -0
  93. vkdispatch_core-0.0.32.dist-info/WHEEL +5 -0
  94. vkdispatch_core-0.0.32.dist-info/licenses/LICENSE +201 -0
  95. vkdispatch_core-0.0.32.dist-info/top_level.txt +1 -0
vkdispatch/__init__.py ADDED
@@ -0,0 +1,49 @@
1
+ from .base.init import DeviceInfo
2
+ from .base.init import LogLevel
3
+ from .base.init import get_devices
4
+ from .base.init import get_backend
5
+ from .base.init import initialize
6
+ from .base.init import is_initialized
7
+ from .base.init import log, log_error, log_warning, log_info, log_verbose, set_log_level
8
+
9
+ from .base.dtype import dtype
10
+ from .base.dtype import float32, int32, uint32, complex64
11
+ from .base.dtype import vec2, vec3, vec4, ivec2, ivec3, ivec4, uvec2, uvec3, uvec4
12
+ from .base.dtype import mat2, mat3, mat4
13
+
14
+ from .base.context import get_context, queue_wait_idle, Signal
15
+ from .base.context import get_context_handle
16
+ from .base.context import make_context, select_queue_families, set_dummy_context_params
17
+ from .base.context import is_context_initialized
18
+
19
+ from .base.buffer import asbuffer
20
+ from .base.buffer import Buffer, buffer_u32, buffer_i32, buffer_f32, buffer_c64
21
+ from .base.buffer import asrfftbuffer
22
+ from .base.buffer import RFFTBuffer
23
+
24
+ from .base.image import image_format
25
+ from .base.image import image_type
26
+ from .base.image import image_view_type
27
+ from .base.image import Image
28
+ from .base.image import Image1D
29
+ from .base.image import Image2D
30
+ from .base.image import Image2DArray
31
+ from .base.image import Image3D
32
+ from .base.image import Sampler
33
+ from .base.image import Filter
34
+ from .base.image import AddressMode
35
+ from .base.image import BorderColor
36
+
37
+ from .execution_pipeline.command_graph import CommandGraph, BufferBindInfo, ImageBindInfo
38
+ from .execution_pipeline.command_graph import global_graph, set_global_graph, default_graph
39
+
40
+ from .shader.shader_function import ShaderFunction, ShaderSource
41
+ from .shader.context import ShaderContext, shader_context
42
+ from .shader.map import map, MappingFunction
43
+ from .shader.decorator import shader
44
+
45
+ import vkdispatch.vkfft as vkfft
46
+ import vkdispatch.fft as fft
47
+ import vkdispatch.reduce as reduce
48
+
49
+ __version__ = "0.0.32"
vkdispatch/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from . import cli
2
+
3
+ if __name__ == "__main__":
4
+ cli.cli_entrypoint()
@@ -0,0 +1,2 @@
1
+ """Compatibility helpers for optional runtime dependencies."""
2
+
@@ -0,0 +1,583 @@
1
+ from __future__ import annotations
2
+
3
+ import builtins
4
+ import cmath
5
+ import math
6
+ import struct
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any, Iterable, List, Sequence, Tuple
10
+
11
+ try:
12
+ import numpy as _np
13
+ except Exception: # pragma: no cover - intentionally broad for optional dependency import
14
+ _np = None
15
+
16
+ HAS_NUMPY = _np is not None
17
+ pi = math.pi
18
+
19
+
20
+ def require_numpy(feature_name: str) -> None:
21
+ if HAS_NUMPY:
22
+ return
23
+
24
+ raise RuntimeError(
25
+ f"{feature_name} requires numpy, but numpy is not available. "
26
+ "Install numpy or use the bytes-based API."
27
+ )
28
+
29
+
30
+ def numpy_module():
31
+ return _np
32
+
33
+
34
+ def prod(values: Iterable[int]) -> int:
35
+ values_tuple = tuple(values)
36
+
37
+ if HAS_NUMPY:
38
+ return int(_np.prod(values_tuple))
39
+
40
+ result = 1
41
+ for value in values_tuple:
42
+ result *= int(value)
43
+ return result
44
+
45
+
46
+ def ceil(value: float) -> float:
47
+ if HAS_NUMPY:
48
+ return float(_np.ceil(value))
49
+ return float(math.ceil(value))
50
+
51
+
52
+ def floor(value: float) -> float:
53
+ if HAS_NUMPY:
54
+ return float(_np.floor(value))
55
+ return float(math.floor(value))
56
+
57
+
58
+ def trunc(value: float) -> float:
59
+ if HAS_NUMPY:
60
+ return float(_np.trunc(value))
61
+ return float(math.trunc(value))
62
+
63
+
64
+ def round(value: float) -> float:
65
+ if HAS_NUMPY:
66
+ return float(_np.round(value))
67
+ return float(builtins.round(value))
68
+
69
+
70
+ def sign(value: float) -> float:
71
+ if HAS_NUMPY:
72
+ return float(_np.sign(value))
73
+
74
+ if value > 0:
75
+ return 1.0
76
+ if value < 0:
77
+ return -1.0
78
+ return 0.0
79
+
80
+
81
+ def abs_value(value: Any) -> float:
82
+ if HAS_NUMPY:
83
+ return float(_np.abs(value))
84
+ return float(abs(value))
85
+
86
+
87
+ def minimum(x: float, y: float) -> float:
88
+ if HAS_NUMPY:
89
+ return float(_np.minimum(x, y))
90
+ return float(x if x <= y else y)
91
+
92
+
93
+ def maximum(x: float, y: float) -> float:
94
+ if HAS_NUMPY:
95
+ return float(_np.maximum(x, y))
96
+ return float(x if x >= y else y)
97
+
98
+
99
+ def clip(x: float, min_value: float, max_value: float) -> float:
100
+ if HAS_NUMPY:
101
+ return float(_np.clip(x, min_value, max_value))
102
+ return float(min(max(x, min_value), max_value))
103
+
104
+
105
+ def mod(x: float, y: float) -> float:
106
+ if HAS_NUMPY:
107
+ return float(_np.mod(x, y))
108
+ return float(x % y)
109
+
110
+
111
+ def modf(x: float, _unused: Any = None) -> Tuple[float, float]:
112
+ if HAS_NUMPY:
113
+ frac, whole = _np.modf(x)
114
+ return float(frac), float(whole)
115
+
116
+ frac, whole = math.modf(x)
117
+ return float(frac), float(whole)
118
+
119
+
120
+ def interp(x: float, xp: Sequence[float], fp: Sequence[float]) -> float:
121
+ if HAS_NUMPY:
122
+ return float(_np.interp(x, xp, fp))
123
+
124
+ if len(xp) != len(fp):
125
+ raise ValueError("xp and fp must have the same length")
126
+ if len(xp) == 0:
127
+ raise ValueError("xp and fp must be non-empty")
128
+ if len(xp) == 1:
129
+ return float(fp[0])
130
+
131
+ if x <= xp[0]:
132
+ return float(fp[0])
133
+ if x >= xp[-1]:
134
+ return float(fp[-1])
135
+
136
+ for index in range(1, len(xp)):
137
+ if x <= xp[index]:
138
+ x0 = xp[index - 1]
139
+ x1 = xp[index]
140
+ y0 = fp[index - 1]
141
+ y1 = fp[index]
142
+
143
+ if x1 == x0:
144
+ return float(y0)
145
+
146
+ t = (x - x0) / (x1 - x0)
147
+ return float(y0 + t * (y1 - y0))
148
+
149
+ return float(fp[-1])
150
+
151
+
152
+ def isnan(value: float) -> bool:
153
+ if HAS_NUMPY:
154
+ return bool(_np.isnan(value))
155
+ return math.isnan(value)
156
+
157
+
158
+ def isinf(value: float) -> bool:
159
+ if HAS_NUMPY:
160
+ return bool(_np.isinf(value))
161
+ return math.isinf(value)
162
+
163
+
164
+ def power(x: float, y: float) -> float:
165
+ if HAS_NUMPY:
166
+ return float(_np.power(x, y))
167
+ return float(math.pow(x, y))
168
+
169
+
170
+ def exp(value: float) -> float:
171
+ if HAS_NUMPY:
172
+ return float(_np.exp(value))
173
+ return float(math.exp(value))
174
+
175
+
176
+ def exp2(value: float) -> float:
177
+ if HAS_NUMPY:
178
+ return float(_np.exp2(value))
179
+ if hasattr(math, "exp2"):
180
+ return float(math.exp2(value))
181
+ return float(math.pow(2.0, value))
182
+
183
+
184
+ def log(value: float) -> float:
185
+ if HAS_NUMPY:
186
+ return float(_np.log(value))
187
+ return float(math.log(value))
188
+
189
+
190
+ def log2(value: float) -> float:
191
+ if HAS_NUMPY:
192
+ return float(_np.log2(value))
193
+ return float(math.log2(value))
194
+
195
+
196
+ def sqrt(value: float) -> float:
197
+ if HAS_NUMPY:
198
+ return float(_np.sqrt(value))
199
+ return float(math.sqrt(value))
200
+
201
+
202
+ def sin(value: float) -> float:
203
+ if HAS_NUMPY:
204
+ return float(_np.sin(value))
205
+ return float(math.sin(value))
206
+
207
+
208
+ def cos(value: float) -> float:
209
+ if HAS_NUMPY:
210
+ return float(_np.cos(value))
211
+ return float(math.cos(value))
212
+
213
+
214
+ def tan(value: float) -> float:
215
+ if HAS_NUMPY:
216
+ return float(_np.tan(value))
217
+ return float(math.tan(value))
218
+
219
+
220
+ def arcsin(value: float) -> float:
221
+ if HAS_NUMPY:
222
+ return float(_np.arcsin(value))
223
+ return float(math.asin(value))
224
+
225
+
226
+ def arccos(value: float) -> float:
227
+ if HAS_NUMPY:
228
+ return float(_np.arccos(value))
229
+ return float(math.acos(value))
230
+
231
+
232
+ def arctan(value: float) -> float:
233
+ if HAS_NUMPY:
234
+ return float(_np.arctan(value))
235
+ return float(math.atan(value))
236
+
237
+
238
+ def arctan2(y: float, x: float) -> float:
239
+ if HAS_NUMPY:
240
+ return float(_np.arctan2(y, x))
241
+ return float(math.atan2(y, x))
242
+
243
+
244
+ def sinh(value: float) -> float:
245
+ if HAS_NUMPY:
246
+ return float(_np.sinh(value))
247
+ return float(math.sinh(value))
248
+
249
+
250
+ def cosh(value: float) -> float:
251
+ if HAS_NUMPY:
252
+ return float(_np.cosh(value))
253
+ return float(math.cosh(value))
254
+
255
+
256
+ def tanh(value: float) -> float:
257
+ if HAS_NUMPY:
258
+ return float(_np.tanh(value))
259
+ return float(math.tanh(value))
260
+
261
+
262
+ def arcsinh(value: float) -> float:
263
+ if HAS_NUMPY:
264
+ return float(_np.arcsinh(value))
265
+ return float(math.asinh(value))
266
+
267
+
268
+ def arccosh(value: float) -> float:
269
+ if HAS_NUMPY:
270
+ return float(_np.arccosh(value))
271
+ return float(math.acosh(value))
272
+
273
+
274
+ def arctanh(value: float) -> float:
275
+ if HAS_NUMPY:
276
+ return float(_np.arctanh(value))
277
+ return float(math.atanh(value))
278
+
279
+
280
+ def dot(x: Any, y: Any) -> float:
281
+ if HAS_NUMPY:
282
+ return float(_np.dot(x, y))
283
+
284
+ if isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)):
285
+ return float(x * y)
286
+
287
+ return float(sum(a * b for a, b in zip(x, y)))
288
+
289
+
290
+ def angle(value: complex) -> float:
291
+ if HAS_NUMPY:
292
+ return float(_np.angle(value))
293
+ return float(cmath.phase(value))
294
+
295
+
296
+ def exp_complex(value: complex) -> complex:
297
+ if HAS_NUMPY:
298
+ return complex(_np.exp(value))
299
+ return cmath.exp(value)
300
+
301
+
302
+ def is_numpy_integer_scalar(value: Any) -> bool:
303
+ return bool(HAS_NUMPY and _np.issubdtype(type(value), _np.integer))
304
+
305
+
306
+ def is_integer_scalar(value: Any) -> bool:
307
+ return isinstance(value, int) or is_numpy_integer_scalar(value)
308
+
309
+
310
+ def is_numpy_floating_instance(value: Any) -> bool:
311
+ return bool(HAS_NUMPY and isinstance(value, _np.floating))
312
+
313
+
314
+ @dataclass(frozen=True)
315
+ class HostDType:
316
+ name: str
317
+ itemsize: int
318
+ struct_format: str
319
+ kind: str
320
+
321
+
322
+ INT32 = HostDType("int32", 4, "i", "int")
323
+ UINT32 = HostDType("uint32", 4, "I", "uint")
324
+ FLOAT32 = HostDType("float32", 4, "f", "float")
325
+ COMPLEX64 = HostDType("complex64", 8, "ff", "complex")
326
+
327
+ _HOST_DTYPES = {
328
+ "int32": INT32,
329
+ "uint32": UINT32,
330
+ "float32": FLOAT32,
331
+ "complex64": COMPLEX64,
332
+ }
333
+
334
+
335
+ def host_dtype(name: str) -> HostDType:
336
+ if name not in _HOST_DTYPES:
337
+ raise ValueError(f"Unsupported dtype ({name})!")
338
+ return _HOST_DTYPES[name]
339
+
340
+
341
+ def is_host_dtype(value: Any) -> bool:
342
+ return isinstance(value, HostDType)
343
+
344
+
345
+ def host_dtype_name(dtype: Any) -> str:
346
+ if isinstance(dtype, HostDType):
347
+ return dtype.name
348
+
349
+ if isinstance(dtype, str):
350
+ return dtype
351
+
352
+ if HAS_NUMPY:
353
+ return str(_np.dtype(dtype).name)
354
+
355
+ raise ValueError(f"Unsupported dtype ({dtype})!")
356
+
357
+
358
+ def dtype_itemsize(dtype: Any) -> int:
359
+ if isinstance(dtype, HostDType):
360
+ return dtype.itemsize
361
+
362
+ if HAS_NUMPY:
363
+ return int(_np.dtype(dtype).itemsize)
364
+
365
+ return host_dtype(host_dtype_name(dtype)).itemsize
366
+
367
+
368
+ def dtype_kind(dtype: Any) -> str:
369
+ if isinstance(dtype, HostDType):
370
+ return dtype.kind
371
+
372
+ if HAS_NUMPY:
373
+ dtype_obj = _np.dtype(dtype)
374
+ if _np.issubdtype(dtype_obj, _np.complexfloating):
375
+ return "complex"
376
+ if _np.issubdtype(dtype_obj, _np.unsignedinteger):
377
+ return "uint"
378
+ if _np.issubdtype(dtype_obj, _np.integer):
379
+ return "int"
380
+ if _np.issubdtype(dtype_obj, _np.floating):
381
+ return "float"
382
+
383
+ return host_dtype(host_dtype_name(dtype)).kind
384
+
385
+
386
+ def dtype_struct_format(dtype: Any) -> str:
387
+ if isinstance(dtype, HostDType):
388
+ return dtype.struct_format
389
+ return host_dtype(host_dtype_name(dtype)).struct_format
390
+
391
+
392
+ class CompatArray:
393
+ def __init__(self, buffer: bytes, dtype: HostDType, shape: Tuple[int, ...]):
394
+ self._buffer = bytes(buffer)
395
+ self.dtype = dtype
396
+ self.shape = tuple(shape)
397
+ self.size = prod(self.shape)
398
+
399
+ def reshape(self, shape: Tuple[int, ...]) -> "CompatArray":
400
+ shape = tuple(shape)
401
+ if prod(shape) != self.size:
402
+ raise ValueError("Cannot reshape array with mismatched element count")
403
+ return CompatArray(self._buffer, self.dtype, shape)
404
+
405
+ def tobytes(self) -> bytes:
406
+ return bytes(self._buffer)
407
+
408
+ @property
409
+ def nbytes(self) -> int:
410
+ return len(self._buffer)
411
+
412
+ def __repr__(self) -> str:
413
+ return f"CompatArray(shape={self.shape}, dtype={self.dtype.name}, nbytes={len(self._buffer)})"
414
+
415
+
416
+ def is_array_like(value: Any) -> bool:
417
+ if HAS_NUMPY and isinstance(value, _np.ndarray):
418
+ return True
419
+ return isinstance(value, CompatArray)
420
+
421
+
422
+ def array_shape(value: Any) -> Tuple[int, ...]:
423
+ if HAS_NUMPY and isinstance(value, _np.ndarray):
424
+ return tuple(value.shape)
425
+ if isinstance(value, CompatArray):
426
+ return tuple(value.shape)
427
+ raise TypeError(f"Unsupported array-like value ({type(value)})")
428
+
429
+
430
+ def array_dtype(value: Any) -> Any:
431
+ if HAS_NUMPY and isinstance(value, _np.ndarray):
432
+ return value.dtype
433
+ if isinstance(value, CompatArray):
434
+ return value.dtype
435
+ raise TypeError(f"Unsupported array-like value ({type(value)})")
436
+
437
+
438
+ def array_nbytes(value: Any) -> int:
439
+ if HAS_NUMPY and isinstance(value, _np.ndarray):
440
+ return int(value.size * value.dtype.itemsize)
441
+ if isinstance(value, CompatArray):
442
+ return value.nbytes
443
+ raise TypeError(f"Unsupported array-like value ({type(value)})")
444
+
445
+
446
+ def as_contiguous_bytes(value: Any) -> bytes:
447
+ if HAS_NUMPY and isinstance(value, _np.ndarray):
448
+ return _np.ascontiguousarray(value).tobytes()
449
+ if isinstance(value, CompatArray):
450
+ return value.tobytes()
451
+ raise TypeError(f"Unsupported array-like value ({type(value)})")
452
+
453
+
454
+ def from_buffer(buffer: bytes, dtype: Any, shape: Tuple[int, ...]):
455
+ dtype_name = host_dtype_name(dtype)
456
+
457
+ if HAS_NUMPY:
458
+ return _np.frombuffer(buffer, dtype=_np.dtype(dtype_name)).reshape(shape)
459
+
460
+ return CompatArray(buffer, host_dtype(dtype_name), tuple(shape))
461
+
462
+
463
+ def ensure_bytes(value: Any) -> bytes:
464
+ if isinstance(value, bytes):
465
+ return value
466
+ if isinstance(value, bytearray):
467
+ return bytes(value)
468
+ if isinstance(value, memoryview):
469
+ return value.tobytes()
470
+ raise TypeError(f"Unsupported bytes-like object ({type(value)})")
471
+
472
+
473
+ def is_bytes_like(value: Any) -> bool:
474
+ return isinstance(value, (bytes, bytearray, memoryview))
475
+
476
+
477
+ def flatten(value: Any) -> List[Any]:
478
+ if isinstance(value, CompatArray):
479
+ return unpack_values(value.tobytes(), value.dtype)
480
+
481
+ if HAS_NUMPY and isinstance(value, _np.ndarray):
482
+ return value.reshape(-1).tolist()
483
+
484
+ if isinstance(value, (list, tuple)):
485
+ out: List[Any] = []
486
+ for element in value:
487
+ out.extend(flatten(element))
488
+ return out
489
+
490
+ return [value]
491
+
492
+
493
+ def _coerce_scalar(value: Any, dtype: Any):
494
+ kind = dtype_kind(dtype)
495
+
496
+ if kind == "complex":
497
+ if isinstance(value, complex):
498
+ return value
499
+ if isinstance(value, (list, tuple)):
500
+ if len(value) != 2:
501
+ raise ValueError("Complex values must be complex scalars or pairs")
502
+ return complex(float(value[0]), float(value[1]))
503
+ return complex(value)
504
+
505
+ if kind == "float":
506
+ return float(value)
507
+
508
+ if kind in ("int", "uint"):
509
+ return int(value)
510
+
511
+ raise ValueError(f"Unsupported dtype kind ({kind})")
512
+
513
+
514
+ def pack_values(values: Sequence[Any], dtype: Any) -> bytes:
515
+ values_list = list(values)
516
+ dtype_name = host_dtype_name(dtype)
517
+
518
+ if HAS_NUMPY:
519
+ array = _np.asarray(values_list, dtype=_np.dtype(dtype_name))
520
+ return array.tobytes()
521
+
522
+ host = host_dtype(dtype_name)
523
+
524
+ if host.kind == "complex":
525
+ output = bytearray()
526
+ for value in values_list:
527
+ coerced = _coerce_scalar(value, host)
528
+ output.extend(struct.pack("=ff", float(coerced.real), float(coerced.imag)))
529
+ return bytes(output)
530
+
531
+ pack_fmt = "=" + host.struct_format
532
+ output = bytearray()
533
+ for value in values_list:
534
+ output.extend(struct.pack(pack_fmt, _coerce_scalar(value, host)))
535
+ return bytes(output)
536
+
537
+
538
+ def unpack_values(data: bytes, dtype: Any) -> List[Any]:
539
+ dtype_name = host_dtype_name(dtype)
540
+
541
+ if HAS_NUMPY:
542
+ return _np.frombuffer(data, dtype=_np.dtype(dtype_name)).tolist()
543
+
544
+ host = host_dtype(dtype_name)
545
+
546
+ if host.kind == "complex":
547
+ values: List[Any] = []
548
+ for real, imag in struct.iter_unpack("=ff", data):
549
+ values.append(complex(real, imag))
550
+ return values
551
+
552
+ unpack_fmt = "=" + host.struct_format
553
+ stride = struct.calcsize(unpack_fmt)
554
+ values = []
555
+
556
+ for offset in range(0, len(data), stride):
557
+ values.append(struct.unpack(unpack_fmt, data[offset: offset + stride])[0])
558
+
559
+ return values
560
+
561
+
562
+ def float_bits_to_int(value: float) -> int:
563
+ if HAS_NUMPY:
564
+ return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.int32)[0])
565
+ return int(struct.unpack("=i", struct.pack("=f", float(value)))[0])
566
+
567
+
568
+ def float_bits_to_uint(value: float) -> int:
569
+ if HAS_NUMPY:
570
+ return int(_np.frombuffer(_np.float32(value).tobytes(), dtype=_np.uint32)[0])
571
+ return int(struct.unpack("=I", struct.pack("=f", float(value)))[0])
572
+
573
+
574
+ def int_bits_to_float(value: int) -> float:
575
+ if HAS_NUMPY:
576
+ return float(_np.frombuffer(_np.int32(value).tobytes(), dtype=_np.float32)[0])
577
+ return float(struct.unpack("=f", struct.pack("=i", int(value)))[0])
578
+
579
+
580
+ def uint_bits_to_float(value: int) -> float:
581
+ if HAS_NUMPY:
582
+ return float(_np.frombuffer(_np.uint32(value).tobytes(), dtype=_np.float32)[0])
583
+ return float(struct.unpack("=f", struct.pack("=I", int(value)))[0])
@@ -0,0 +1 @@
1
+ __all__ = []