numba-cuda 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +0 -8
  3. numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
  4. numba_cuda/numba/cuda/api_util.py +6 -0
  5. numba_cuda/numba/cuda/cgutils.py +1291 -0
  6. numba_cuda/numba/cuda/codegen.py +32 -14
  7. numba_cuda/numba/cuda/compiler.py +113 -10
  8. numba_cuda/numba/cuda/core/caching.py +741 -0
  9. numba_cuda/numba/cuda/core/callconv.py +338 -0
  10. numba_cuda/numba/cuda/core/codegen.py +168 -0
  11. numba_cuda/numba/cuda/core/compiler.py +205 -0
  12. numba_cuda/numba/cuda/core/typed_passes.py +139 -0
  13. numba_cuda/numba/cuda/cudadecl.py +0 -268
  14. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  15. numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
  16. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
  17. numba_cuda/numba/cuda/cudaimpl.py +4 -178
  18. numba_cuda/numba/cuda/debuginfo.py +469 -3
  19. numba_cuda/numba/cuda/device_init.py +0 -1
  20. numba_cuda/numba/cuda/dispatcher.py +309 -11
  21. numba_cuda/numba/cuda/extending.py +2 -1
  22. numba_cuda/numba/cuda/fp16.py +348 -0
  23. numba_cuda/numba/cuda/intrinsics.py +1 -1
  24. numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
  25. numba_cuda/numba/cuda/lowering.py +1833 -8
  26. numba_cuda/numba/cuda/mathimpl.py +2 -90
  27. numba_cuda/numba/cuda/nvvmutils.py +2 -1
  28. numba_cuda/numba/cuda/printimpl.py +2 -1
  29. numba_cuda/numba/cuda/serialize.py +264 -0
  30. numba_cuda/numba/cuda/simulator/__init__.py +2 -0
  31. numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
  32. numba_cuda/numba/cuda/stubs.py +0 -308
  33. numba_cuda/numba/cuda/target.py +13 -5
  34. numba_cuda/numba/cuda/testing.py +156 -5
  35. numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
  36. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
  37. numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
  38. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
  39. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
  40. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
  41. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  42. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
  43. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  44. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
  45. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
  46. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  47. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +94 -24
  48. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
  49. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
  50. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
  51. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
  52. numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
  53. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
  54. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
  55. numba_cuda/numba/cuda/utils.py +785 -0
  56. numba_cuda/numba/cuda/vector_types.py +1 -1
  57. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/METADATA +18 -4
  58. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/RECORD +61 -48
  59. numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
  60. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/WHEEL +0 -0
  61. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/licenses/LICENSE +0 -0
  62. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ import numba.core.types as types
2
+ from numba.cuda._internal.cuda_fp16 import (
3
+ typing_registry,
4
+ target_registry,
5
+ __half,
6
+ __double2half,
7
+ __float2half,
8
+ __float2half_rd,
9
+ __float2half_rn,
10
+ __float2half_ru,
11
+ __float2half_rz,
12
+ __int2half_rd,
13
+ __int2half_rn,
14
+ __int2half_ru,
15
+ __int2half_rz,
16
+ __ll2half_rd,
17
+ __ll2half_rn,
18
+ __ll2half_ru,
19
+ __ll2half_rz,
20
+ __short2half_rd,
21
+ __short2half_rn,
22
+ __short2half_ru,
23
+ __short2half_rz,
24
+ __uint2half_rd,
25
+ __uint2half_rn,
26
+ __uint2half_ru,
27
+ __uint2half_rz,
28
+ __ull2half_rd,
29
+ __ull2half_rn,
30
+ __ull2half_ru,
31
+ __ull2half_rz,
32
+ __ushort2half_rd,
33
+ __ushort2half_rn,
34
+ __ushort2half_ru,
35
+ __ushort2half_rz,
36
+ __half2char_rz,
37
+ __half2float,
38
+ __half2int_rd,
39
+ __half2int_rn,
40
+ __half2int_ru,
41
+ __half2int_rz,
42
+ __half2ll_rd,
43
+ __half2ll_rn,
44
+ __half2ll_ru,
45
+ __half2ll_rz,
46
+ __half2short_rd,
47
+ __half2short_rn,
48
+ __half2short_ru,
49
+ __half2short_rz,
50
+ __half2uchar_rz,
51
+ __half2uint_rd,
52
+ __half2uint_rn,
53
+ __half2uint_ru,
54
+ __half2uint_rz,
55
+ __half2ull_rd,
56
+ __half2ull_rn,
57
+ __half2ull_ru,
58
+ __half2ull_rz,
59
+ __half2ushort_rd,
60
+ __half2ushort_rn,
61
+ __half2ushort_ru,
62
+ __half2ushort_rz,
63
+ __short_as_half,
64
+ __ushort_as_half,
65
+ __half_as_short,
66
+ __half_as_ushort,
67
+ __habs as habs,
68
+ __habs,
69
+ __hadd as hadd,
70
+ __hadd,
71
+ __hadd_rn,
72
+ __hadd_sat,
73
+ __hcmadd,
74
+ __hdiv as hdiv,
75
+ __hdiv,
76
+ __heq as heq,
77
+ __heq,
78
+ __hequ,
79
+ __hfma as hfma,
80
+ __hfma,
81
+ __hfma_relu,
82
+ __hfma_sat,
83
+ __hge as hge,
84
+ __hge,
85
+ __hgeu,
86
+ __hgt as hgt,
87
+ __hgt,
88
+ __hgtu,
89
+ __hisinf,
90
+ __hisnan,
91
+ __hle as hle,
92
+ __hle,
93
+ __hleu,
94
+ __hlt as hlt,
95
+ __hlt,
96
+ __hltu,
97
+ __hmax as hmax,
98
+ __hmax,
99
+ __hmax_nan,
100
+ __hmin as hmin,
101
+ __hmin,
102
+ __hmin_nan,
103
+ __hmul as hmul,
104
+ __hmul,
105
+ __hmul_rn,
106
+ __hmul_sat,
107
+ __hne as hne,
108
+ __hne,
109
+ __hneg as hneg,
110
+ __hneg,
111
+ __hneu,
112
+ __hsub as hsub,
113
+ __hsub,
114
+ __hsub_rn,
115
+ __hsub_sat,
116
+ atomicAdd,
117
+ hceil,
118
+ hcos,
119
+ hexp,
120
+ hexp10,
121
+ hexp2,
122
+ hfloor,
123
+ hlog,
124
+ hlog10,
125
+ hlog2,
126
+ hrcp,
127
+ hrint,
128
+ hrsqrt,
129
+ hsin,
130
+ hsqrt,
131
+ htanh,
132
+ htanh_approx,
133
+ htrunc,
134
+ )
135
+
136
+ from numba.extending import overload
137
+ import math
138
+
139
+
140
+ def _make_unary(a, func):
141
+ if isinstance(a, types.Float) and a.bitwidth == 16:
142
+ return lambda a: func(a)
143
+
144
+
145
+ # Bind low++ bindings to math APIs
146
+ @overload(math.trunc, target="cuda")
147
+ def trunc_ol(a):
148
+ return _make_unary(a, htrunc)
149
+
150
+
151
+ @overload(math.ceil, target="cuda")
152
+ def ceil_ol(a):
153
+ return _make_unary(a, hceil)
154
+
155
+
156
+ @overload(math.floor, target="cuda")
157
+ def floor_ol(a):
158
+ return _make_unary(a, hfloor)
159
+
160
+
161
+ @overload(math.fabs, target="cuda")
162
+ def fabs_ol(a):
163
+ return _make_unary(a, habs)
164
+
165
+
166
+ @overload(math.sqrt, target="cuda")
167
+ def sqrt_ol(a):
168
+ return _make_unary(a, hsqrt)
169
+
170
+
171
+ @overload(math.log, target="cuda")
172
+ def log_ol(a):
173
+ return _make_unary(a, hlog)
174
+
175
+
176
+ @overload(math.log2, target="cuda")
177
+ def log2_ol(a):
178
+ return _make_unary(a, hlog2)
179
+
180
+
181
+ @overload(math.log10, target="cuda")
182
+ def log10_ol(a):
183
+ return _make_unary(a, hlog10)
184
+
185
+
186
+ @overload(math.exp, target="cuda")
187
+ def exp_ol(a):
188
+ return _make_unary(a, hexp)
189
+
190
+
191
+ @overload(math.tanh, target="cuda")
192
+ def tanh_ol(a):
193
+ return _make_unary(a, htanh)
194
+
195
+
196
+ @overload(math.cos, target="cuda")
197
+ def cos_ol(a):
198
+ return _make_unary(a, hcos)
199
+
200
+
201
+ @overload(math.sin, target="cuda")
202
+ def sin_ol(a):
203
+ return _make_unary(a, hsin)
204
+
205
+
206
+ try:
207
+ from math import exp2
208
+
209
+ @overload(exp2, target="cuda")
210
+ def exp2_ol(a):
211
+ return _make_unary(a, hexp2)
212
+ except ImportError:
213
+ pass
214
+
215
+
216
+ __all__ = [
217
+ "typing_registry",
218
+ "target_registry",
219
+ "__half",
220
+ "__double2half",
221
+ "__float2half",
222
+ "__float2half_rd",
223
+ "__float2half_rn",
224
+ "__float2half_ru",
225
+ "__float2half_rz",
226
+ "__int2half_rd",
227
+ "__int2half_rn",
228
+ "__int2half_ru",
229
+ "__int2half_rz",
230
+ "__ll2half_rd",
231
+ "__ll2half_rn",
232
+ "__ll2half_ru",
233
+ "__ll2half_rz",
234
+ "__short2half_rd",
235
+ "__short2half_rn",
236
+ "__short2half_ru",
237
+ "__short2half_rz",
238
+ "__uint2half_rd",
239
+ "__uint2half_rn",
240
+ "__uint2half_ru",
241
+ "__uint2half_rz",
242
+ "__ull2half_rd",
243
+ "__ull2half_rn",
244
+ "__ull2half_ru",
245
+ "__ull2half_rz",
246
+ "__ushort2half_rd",
247
+ "__ushort2half_rn",
248
+ "__ushort2half_ru",
249
+ "__ushort2half_rz",
250
+ "__half2char_rz",
251
+ "__half2float",
252
+ "__half2int_rd",
253
+ "__half2int_rn",
254
+ "__half2int_ru",
255
+ "__half2int_rz",
256
+ "__half2ll_rd",
257
+ "__half2ll_rn",
258
+ "__half2ll_ru",
259
+ "__half2ll_rz",
260
+ "__half2short_rd",
261
+ "__half2short_rn",
262
+ "__half2short_ru",
263
+ "__half2short_rz",
264
+ "__half2uchar_rz",
265
+ "__half2uint_rd",
266
+ "__half2uint_rn",
267
+ "__half2uint_ru",
268
+ "__half2uint_rz",
269
+ "__half2ull_rd",
270
+ "__half2ull_rn",
271
+ "__half2ull_ru",
272
+ "__half2ull_rz",
273
+ "__half2ushort_rd",
274
+ "__half2ushort_rn",
275
+ "__half2ushort_ru",
276
+ "__half2ushort_rz",
277
+ "__short_as_half",
278
+ "__ushort_as_half",
279
+ "__half_as_short",
280
+ "__half_as_ushort",
281
+ "habs",
282
+ "__habs",
283
+ "hadd",
284
+ "__hadd",
285
+ "__hadd_rn",
286
+ "__hadd_sat",
287
+ "__hcmadd",
288
+ "hdiv",
289
+ "__hdiv",
290
+ "heq",
291
+ "__heq",
292
+ "__hequ",
293
+ "hfma",
294
+ "__hfma",
295
+ "__hfma_relu",
296
+ "__hfma_sat",
297
+ "hge",
298
+ "__hge",
299
+ "__hgeu",
300
+ "hgt",
301
+ "__hgt",
302
+ "__hgtu",
303
+ "__hisinf",
304
+ "__hisnan",
305
+ "hle",
306
+ "__hle",
307
+ "__hleu",
308
+ "hlt",
309
+ "__hlt",
310
+ "__hltu",
311
+ "hmax",
312
+ "__hmax",
313
+ "__hmax_nan",
314
+ "hmin",
315
+ "__hmin",
316
+ "__hmin_nan",
317
+ "hmul",
318
+ "__hmul",
319
+ "__hmul_rn",
320
+ "__hmul_sat",
321
+ "hne",
322
+ "__hne",
323
+ "hneg",
324
+ "__hneg",
325
+ "__hneu",
326
+ "hsub",
327
+ "__hsub",
328
+ "__hsub_rn",
329
+ "__hsub_sat",
330
+ "atomicAdd",
331
+ "hceil",
332
+ "hcos",
333
+ "hexp",
334
+ "hexp10",
335
+ "hexp2",
336
+ "hfloor",
337
+ "hlog",
338
+ "hlog10",
339
+ "hlog2",
340
+ "hrcp",
341
+ "hrint",
342
+ "hrsqrt",
343
+ "hsin",
344
+ "hsqrt",
345
+ "htanh",
346
+ "htanh_approx",
347
+ "htrunc",
348
+ ]
@@ -1,7 +1,7 @@
1
1
  from llvmlite import ir
2
2
 
3
3
  from numba import cuda, types
4
- from numba.core import cgutils
4
+ from numba.cuda import cgutils
5
5
  from numba.core.errors import RequireLiteralValue, TypingError
6
6
  from numba.core.typing import signature
7
7
  from numba.core.extending import overload_attribute, overload_method
@@ -1,5 +1,6 @@
1
1
  from llvmlite import ir
2
- from numba.core import cgutils, types
2
+ from numba.core import types
3
+ from numba.cuda import cgutils
3
4
  from numba.core.imputils import Registry
4
5
  from numba.cuda import libdevice, libdevicefuncs
5
6