numba-cuda 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +0 -8
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
- numba_cuda/numba/cuda/api_util.py +6 -0
- numba_cuda/numba/cuda/cgutils.py +1291 -0
- numba_cuda/numba/cuda/codegen.py +32 -14
- numba_cuda/numba/cuda/compiler.py +113 -10
- numba_cuda/numba/cuda/core/caching.py +741 -0
- numba_cuda/numba/cuda/core/callconv.py +338 -0
- numba_cuda/numba/cuda/core/codegen.py +168 -0
- numba_cuda/numba/cuda/core/compiler.py +205 -0
- numba_cuda/numba/cuda/core/typed_passes.py +139 -0
- numba_cuda/numba/cuda/cudadecl.py +0 -268
- numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +4 -178
- numba_cuda/numba/cuda/debuginfo.py +469 -3
- numba_cuda/numba/cuda/device_init.py +0 -1
- numba_cuda/numba/cuda/dispatcher.py +309 -11
- numba_cuda/numba/cuda/extending.py +2 -1
- numba_cuda/numba/cuda/fp16.py +348 -0
- numba_cuda/numba/cuda/intrinsics.py +1 -1
- numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
- numba_cuda/numba/cuda/lowering.py +1833 -8
- numba_cuda/numba/cuda/mathimpl.py +2 -90
- numba_cuda/numba/cuda/nvvmutils.py +2 -1
- numba_cuda/numba/cuda/printimpl.py +2 -1
- numba_cuda/numba/cuda/serialize.py +264 -0
- numba_cuda/numba/cuda/simulator/__init__.py +2 -0
- numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
- numba_cuda/numba/cuda/stubs.py +0 -308
- numba_cuda/numba/cuda/target.py +13 -5
- numba_cuda/numba/cuda/testing.py +156 -5
- numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
- numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +94 -24
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
- numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
- numba_cuda/numba/cuda/utils.py +785 -0
- numba_cuda/numba/cuda/vector_types.py +1 -1
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/METADATA +18 -4
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/RECORD +61 -48
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
import numba.core.types as types
|
|
2
|
+
from numba.cuda._internal.cuda_fp16 import (
|
|
3
|
+
typing_registry,
|
|
4
|
+
target_registry,
|
|
5
|
+
__half,
|
|
6
|
+
__double2half,
|
|
7
|
+
__float2half,
|
|
8
|
+
__float2half_rd,
|
|
9
|
+
__float2half_rn,
|
|
10
|
+
__float2half_ru,
|
|
11
|
+
__float2half_rz,
|
|
12
|
+
__int2half_rd,
|
|
13
|
+
__int2half_rn,
|
|
14
|
+
__int2half_ru,
|
|
15
|
+
__int2half_rz,
|
|
16
|
+
__ll2half_rd,
|
|
17
|
+
__ll2half_rn,
|
|
18
|
+
__ll2half_ru,
|
|
19
|
+
__ll2half_rz,
|
|
20
|
+
__short2half_rd,
|
|
21
|
+
__short2half_rn,
|
|
22
|
+
__short2half_ru,
|
|
23
|
+
__short2half_rz,
|
|
24
|
+
__uint2half_rd,
|
|
25
|
+
__uint2half_rn,
|
|
26
|
+
__uint2half_ru,
|
|
27
|
+
__uint2half_rz,
|
|
28
|
+
__ull2half_rd,
|
|
29
|
+
__ull2half_rn,
|
|
30
|
+
__ull2half_ru,
|
|
31
|
+
__ull2half_rz,
|
|
32
|
+
__ushort2half_rd,
|
|
33
|
+
__ushort2half_rn,
|
|
34
|
+
__ushort2half_ru,
|
|
35
|
+
__ushort2half_rz,
|
|
36
|
+
__half2char_rz,
|
|
37
|
+
__half2float,
|
|
38
|
+
__half2int_rd,
|
|
39
|
+
__half2int_rn,
|
|
40
|
+
__half2int_ru,
|
|
41
|
+
__half2int_rz,
|
|
42
|
+
__half2ll_rd,
|
|
43
|
+
__half2ll_rn,
|
|
44
|
+
__half2ll_ru,
|
|
45
|
+
__half2ll_rz,
|
|
46
|
+
__half2short_rd,
|
|
47
|
+
__half2short_rn,
|
|
48
|
+
__half2short_ru,
|
|
49
|
+
__half2short_rz,
|
|
50
|
+
__half2uchar_rz,
|
|
51
|
+
__half2uint_rd,
|
|
52
|
+
__half2uint_rn,
|
|
53
|
+
__half2uint_ru,
|
|
54
|
+
__half2uint_rz,
|
|
55
|
+
__half2ull_rd,
|
|
56
|
+
__half2ull_rn,
|
|
57
|
+
__half2ull_ru,
|
|
58
|
+
__half2ull_rz,
|
|
59
|
+
__half2ushort_rd,
|
|
60
|
+
__half2ushort_rn,
|
|
61
|
+
__half2ushort_ru,
|
|
62
|
+
__half2ushort_rz,
|
|
63
|
+
__short_as_half,
|
|
64
|
+
__ushort_as_half,
|
|
65
|
+
__half_as_short,
|
|
66
|
+
__half_as_ushort,
|
|
67
|
+
__habs as habs,
|
|
68
|
+
__habs,
|
|
69
|
+
__hadd as hadd,
|
|
70
|
+
__hadd,
|
|
71
|
+
__hadd_rn,
|
|
72
|
+
__hadd_sat,
|
|
73
|
+
__hcmadd,
|
|
74
|
+
__hdiv as hdiv,
|
|
75
|
+
__hdiv,
|
|
76
|
+
__heq as heq,
|
|
77
|
+
__heq,
|
|
78
|
+
__hequ,
|
|
79
|
+
__hfma as hfma,
|
|
80
|
+
__hfma,
|
|
81
|
+
__hfma_relu,
|
|
82
|
+
__hfma_sat,
|
|
83
|
+
__hge as hge,
|
|
84
|
+
__hge,
|
|
85
|
+
__hgeu,
|
|
86
|
+
__hgt as hgt,
|
|
87
|
+
__hgt,
|
|
88
|
+
__hgtu,
|
|
89
|
+
__hisinf,
|
|
90
|
+
__hisnan,
|
|
91
|
+
__hle as hle,
|
|
92
|
+
__hle,
|
|
93
|
+
__hleu,
|
|
94
|
+
__hlt as hlt,
|
|
95
|
+
__hlt,
|
|
96
|
+
__hltu,
|
|
97
|
+
__hmax as hmax,
|
|
98
|
+
__hmax,
|
|
99
|
+
__hmax_nan,
|
|
100
|
+
__hmin as hmin,
|
|
101
|
+
__hmin,
|
|
102
|
+
__hmin_nan,
|
|
103
|
+
__hmul as hmul,
|
|
104
|
+
__hmul,
|
|
105
|
+
__hmul_rn,
|
|
106
|
+
__hmul_sat,
|
|
107
|
+
__hne as hne,
|
|
108
|
+
__hne,
|
|
109
|
+
__hneg as hneg,
|
|
110
|
+
__hneg,
|
|
111
|
+
__hneu,
|
|
112
|
+
__hsub as hsub,
|
|
113
|
+
__hsub,
|
|
114
|
+
__hsub_rn,
|
|
115
|
+
__hsub_sat,
|
|
116
|
+
atomicAdd,
|
|
117
|
+
hceil,
|
|
118
|
+
hcos,
|
|
119
|
+
hexp,
|
|
120
|
+
hexp10,
|
|
121
|
+
hexp2,
|
|
122
|
+
hfloor,
|
|
123
|
+
hlog,
|
|
124
|
+
hlog10,
|
|
125
|
+
hlog2,
|
|
126
|
+
hrcp,
|
|
127
|
+
hrint,
|
|
128
|
+
hrsqrt,
|
|
129
|
+
hsin,
|
|
130
|
+
hsqrt,
|
|
131
|
+
htanh,
|
|
132
|
+
htanh_approx,
|
|
133
|
+
htrunc,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
from numba.extending import overload
|
|
137
|
+
import math
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _make_unary(a, func):
|
|
141
|
+
if isinstance(a, types.Float) and a.bitwidth == 16:
|
|
142
|
+
return lambda a: func(a)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# Bind low++ bindings to math APIs
|
|
146
|
+
@overload(math.trunc, target="cuda")
|
|
147
|
+
def trunc_ol(a):
|
|
148
|
+
return _make_unary(a, htrunc)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@overload(math.ceil, target="cuda")
|
|
152
|
+
def ceil_ol(a):
|
|
153
|
+
return _make_unary(a, hceil)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@overload(math.floor, target="cuda")
|
|
157
|
+
def floor_ol(a):
|
|
158
|
+
return _make_unary(a, hfloor)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@overload(math.fabs, target="cuda")
|
|
162
|
+
def fabs_ol(a):
|
|
163
|
+
return _make_unary(a, habs)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@overload(math.sqrt, target="cuda")
|
|
167
|
+
def sqrt_ol(a):
|
|
168
|
+
return _make_unary(a, hsqrt)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@overload(math.log, target="cuda")
|
|
172
|
+
def log_ol(a):
|
|
173
|
+
return _make_unary(a, hlog)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@overload(math.log2, target="cuda")
|
|
177
|
+
def log2_ol(a):
|
|
178
|
+
return _make_unary(a, hlog2)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@overload(math.log10, target="cuda")
|
|
182
|
+
def log10_ol(a):
|
|
183
|
+
return _make_unary(a, hlog10)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@overload(math.exp, target="cuda")
|
|
187
|
+
def exp_ol(a):
|
|
188
|
+
return _make_unary(a, hexp)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@overload(math.tanh, target="cuda")
|
|
192
|
+
def tanh_ol(a):
|
|
193
|
+
return _make_unary(a, htanh)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@overload(math.cos, target="cuda")
|
|
197
|
+
def cos_ol(a):
|
|
198
|
+
return _make_unary(a, hcos)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@overload(math.sin, target="cuda")
|
|
202
|
+
def sin_ol(a):
|
|
203
|
+
return _make_unary(a, hsin)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
from math import exp2
|
|
208
|
+
|
|
209
|
+
@overload(exp2, target="cuda")
|
|
210
|
+
def exp2_ol(a):
|
|
211
|
+
return _make_unary(a, hexp2)
|
|
212
|
+
except ImportError:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
__all__ = [
|
|
217
|
+
"typing_registry",
|
|
218
|
+
"target_registry",
|
|
219
|
+
"__half",
|
|
220
|
+
"__double2half",
|
|
221
|
+
"__float2half",
|
|
222
|
+
"__float2half_rd",
|
|
223
|
+
"__float2half_rn",
|
|
224
|
+
"__float2half_ru",
|
|
225
|
+
"__float2half_rz",
|
|
226
|
+
"__int2half_rd",
|
|
227
|
+
"__int2half_rn",
|
|
228
|
+
"__int2half_ru",
|
|
229
|
+
"__int2half_rz",
|
|
230
|
+
"__ll2half_rd",
|
|
231
|
+
"__ll2half_rn",
|
|
232
|
+
"__ll2half_ru",
|
|
233
|
+
"__ll2half_rz",
|
|
234
|
+
"__short2half_rd",
|
|
235
|
+
"__short2half_rn",
|
|
236
|
+
"__short2half_ru",
|
|
237
|
+
"__short2half_rz",
|
|
238
|
+
"__uint2half_rd",
|
|
239
|
+
"__uint2half_rn",
|
|
240
|
+
"__uint2half_ru",
|
|
241
|
+
"__uint2half_rz",
|
|
242
|
+
"__ull2half_rd",
|
|
243
|
+
"__ull2half_rn",
|
|
244
|
+
"__ull2half_ru",
|
|
245
|
+
"__ull2half_rz",
|
|
246
|
+
"__ushort2half_rd",
|
|
247
|
+
"__ushort2half_rn",
|
|
248
|
+
"__ushort2half_ru",
|
|
249
|
+
"__ushort2half_rz",
|
|
250
|
+
"__half2char_rz",
|
|
251
|
+
"__half2float",
|
|
252
|
+
"__half2int_rd",
|
|
253
|
+
"__half2int_rn",
|
|
254
|
+
"__half2int_ru",
|
|
255
|
+
"__half2int_rz",
|
|
256
|
+
"__half2ll_rd",
|
|
257
|
+
"__half2ll_rn",
|
|
258
|
+
"__half2ll_ru",
|
|
259
|
+
"__half2ll_rz",
|
|
260
|
+
"__half2short_rd",
|
|
261
|
+
"__half2short_rn",
|
|
262
|
+
"__half2short_ru",
|
|
263
|
+
"__half2short_rz",
|
|
264
|
+
"__half2uchar_rz",
|
|
265
|
+
"__half2uint_rd",
|
|
266
|
+
"__half2uint_rn",
|
|
267
|
+
"__half2uint_ru",
|
|
268
|
+
"__half2uint_rz",
|
|
269
|
+
"__half2ull_rd",
|
|
270
|
+
"__half2ull_rn",
|
|
271
|
+
"__half2ull_ru",
|
|
272
|
+
"__half2ull_rz",
|
|
273
|
+
"__half2ushort_rd",
|
|
274
|
+
"__half2ushort_rn",
|
|
275
|
+
"__half2ushort_ru",
|
|
276
|
+
"__half2ushort_rz",
|
|
277
|
+
"__short_as_half",
|
|
278
|
+
"__ushort_as_half",
|
|
279
|
+
"__half_as_short",
|
|
280
|
+
"__half_as_ushort",
|
|
281
|
+
"habs",
|
|
282
|
+
"__habs",
|
|
283
|
+
"hadd",
|
|
284
|
+
"__hadd",
|
|
285
|
+
"__hadd_rn",
|
|
286
|
+
"__hadd_sat",
|
|
287
|
+
"__hcmadd",
|
|
288
|
+
"hdiv",
|
|
289
|
+
"__hdiv",
|
|
290
|
+
"heq",
|
|
291
|
+
"__heq",
|
|
292
|
+
"__hequ",
|
|
293
|
+
"hfma",
|
|
294
|
+
"__hfma",
|
|
295
|
+
"__hfma_relu",
|
|
296
|
+
"__hfma_sat",
|
|
297
|
+
"hge",
|
|
298
|
+
"__hge",
|
|
299
|
+
"__hgeu",
|
|
300
|
+
"hgt",
|
|
301
|
+
"__hgt",
|
|
302
|
+
"__hgtu",
|
|
303
|
+
"__hisinf",
|
|
304
|
+
"__hisnan",
|
|
305
|
+
"hle",
|
|
306
|
+
"__hle",
|
|
307
|
+
"__hleu",
|
|
308
|
+
"hlt",
|
|
309
|
+
"__hlt",
|
|
310
|
+
"__hltu",
|
|
311
|
+
"hmax",
|
|
312
|
+
"__hmax",
|
|
313
|
+
"__hmax_nan",
|
|
314
|
+
"hmin",
|
|
315
|
+
"__hmin",
|
|
316
|
+
"__hmin_nan",
|
|
317
|
+
"hmul",
|
|
318
|
+
"__hmul",
|
|
319
|
+
"__hmul_rn",
|
|
320
|
+
"__hmul_sat",
|
|
321
|
+
"hne",
|
|
322
|
+
"__hne",
|
|
323
|
+
"hneg",
|
|
324
|
+
"__hneg",
|
|
325
|
+
"__hneu",
|
|
326
|
+
"hsub",
|
|
327
|
+
"__hsub",
|
|
328
|
+
"__hsub_rn",
|
|
329
|
+
"__hsub_sat",
|
|
330
|
+
"atomicAdd",
|
|
331
|
+
"hceil",
|
|
332
|
+
"hcos",
|
|
333
|
+
"hexp",
|
|
334
|
+
"hexp10",
|
|
335
|
+
"hexp2",
|
|
336
|
+
"hfloor",
|
|
337
|
+
"hlog",
|
|
338
|
+
"hlog10",
|
|
339
|
+
"hlog2",
|
|
340
|
+
"hrcp",
|
|
341
|
+
"hrint",
|
|
342
|
+
"hrsqrt",
|
|
343
|
+
"hsin",
|
|
344
|
+
"hsqrt",
|
|
345
|
+
"htanh",
|
|
346
|
+
"htanh_approx",
|
|
347
|
+
"htrunc",
|
|
348
|
+
]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from llvmlite import ir
|
|
2
2
|
|
|
3
3
|
from numba import cuda, types
|
|
4
|
-
from numba.
|
|
4
|
+
from numba.cuda import cgutils
|
|
5
5
|
from numba.core.errors import RequireLiteralValue, TypingError
|
|
6
6
|
from numba.core.typing import signature
|
|
7
7
|
from numba.core.extending import overload_attribute, overload_method
|