numba-cuda 0.21.1__cp313-cp313-win_amd64.whl → 0.23.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/api.py +4 -1
  3. numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
  4. numba_cuda/numba/cuda/cext/_dispatcher.cpp +0 -38
  5. numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
  6. numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
  7. numba_cuda/numba/cuda/cext/_typeof.cpp +0 -111
  8. numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
  9. numba_cuda/numba/cuda/codegen.py +42 -10
  10. numba_cuda/numba/cuda/compiler.py +10 -4
  11. numba_cuda/numba/cuda/core/analysis.py +29 -21
  12. numba_cuda/numba/cuda/core/annotations/type_annotations.py +4 -4
  13. numba_cuda/numba/cuda/core/base.py +6 -1
  14. numba_cuda/numba/cuda/core/consts.py +1 -1
  15. numba_cuda/numba/cuda/core/cuda_errors.py +917 -0
  16. numba_cuda/numba/cuda/core/errors.py +4 -912
  17. numba_cuda/numba/cuda/core/inline_closurecall.py +71 -57
  18. numba_cuda/numba/cuda/core/interpreter.py +79 -64
  19. numba_cuda/numba/cuda/core/ir.py +191 -119
  20. numba_cuda/numba/cuda/core/ir_utils.py +142 -112
  21. numba_cuda/numba/cuda/core/postproc.py +8 -8
  22. numba_cuda/numba/cuda/core/rewrites/ir_print.py +6 -3
  23. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +5 -5
  24. numba_cuda/numba/cuda/core/rewrites/static_raise.py +3 -3
  25. numba_cuda/numba/cuda/core/ssa.py +3 -3
  26. numba_cuda/numba/cuda/core/transforms.py +25 -10
  27. numba_cuda/numba/cuda/core/typed_passes.py +9 -9
  28. numba_cuda/numba/cuda/core/typeinfer.py +39 -24
  29. numba_cuda/numba/cuda/core/untyped_passes.py +71 -55
  30. numba_cuda/numba/cuda/cudadecl.py +0 -13
  31. numba_cuda/numba/cuda/cudadrv/devicearray.py +6 -5
  32. numba_cuda/numba/cuda/cudadrv/driver.py +132 -511
  33. numba_cuda/numba/cuda/cudadrv/dummyarray.py +4 -0
  34. numba_cuda/numba/cuda/cudadrv/nvrtc.py +16 -0
  35. numba_cuda/numba/cuda/cudaimpl.py +0 -12
  36. numba_cuda/numba/cuda/debuginfo.py +104 -10
  37. numba_cuda/numba/cuda/descriptor.py +1 -1
  38. numba_cuda/numba/cuda/device_init.py +4 -7
  39. numba_cuda/numba/cuda/dispatcher.py +36 -32
  40. numba_cuda/numba/cuda/intrinsics.py +150 -1
  41. numba_cuda/numba/cuda/lowering.py +64 -29
  42. numba_cuda/numba/cuda/memory_management/nrt.py +10 -14
  43. numba_cuda/numba/cuda/np/arrayobj.py +54 -0
  44. numba_cuda/numba/cuda/np/numpy_support.py +26 -0
  45. numba_cuda/numba/cuda/printimpl.py +20 -0
  46. numba_cuda/numba/cuda/serialize.py +10 -0
  47. numba_cuda/numba/cuda/stubs.py +0 -11
  48. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +21 -4
  49. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +1 -2
  50. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +130 -48
  51. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +6 -2
  52. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +3 -1
  53. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +5 -6
  54. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +11 -12
  55. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +27 -19
  56. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +47 -0
  57. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +10 -0
  58. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +89 -0
  59. numba_cuda/numba/cuda/tests/cudapy/test_device_array_capture.py +243 -0
  60. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +3 -3
  61. numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py +35 -0
  62. numba_cuda/numba/cuda/tests/cudapy/test_print.py +51 -0
  63. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +116 -1
  64. numba_cuda/numba/cuda/tests/doc_examples/test_globals.py +111 -0
  65. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +61 -0
  66. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +31 -0
  67. numba_cuda/numba/cuda/typing/context.py +3 -1
  68. numba_cuda/numba/cuda/typing/typeof.py +56 -0
  69. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/METADATA +1 -1
  70. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/RECORD +74 -74
  71. numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
  72. numba_cuda/numba/cuda/cext/_devicearray.cpp +0 -159
  73. numba_cuda/numba/cuda/cext/_devicearray.h +0 -29
  74. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -41
  75. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/WHEEL +0 -0
  76. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/licenses/LICENSE +0 -0
  77. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/licenses/LICENSE.numba +0 -0
  78. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/top_level.txt +0 -0
numba_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 0.21.1
1
+ 0.23.0
@@ -21,6 +21,7 @@ current_context = devices.get_context
21
21
  gpus = devices.gpus
22
22
 
23
23
 
24
+ @require_context
24
25
  def from_cuda_array_interface(desc, owner=None, sync=True):
25
26
  """Create a DeviceNDArray from a cuda-array-interface description.
26
27
  The ``owner`` is the owner of the underlying memory.
@@ -47,7 +48,9 @@ def from_cuda_array_interface(desc, owner=None, sync=True):
47
48
 
48
49
  cudevptr_class = driver.binding.CUdeviceptr
49
50
  devptr = cudevptr_class(desc["data"][0])
50
- data = driver.MemoryPointer(devptr, size=size, owner=owner)
51
+ data = driver.MemoryPointer(
52
+ current_context(), devptr, size=size, owner=owner
53
+ )
51
54
  stream_ptr = desc.get("stream", None)
52
55
  if stream_ptr is not None:
53
56
  stream = external_stream(stream_ptr)
@@ -12,7 +12,6 @@
12
12
  #include "frameobject.h"
13
13
  #include "traceback.h"
14
14
  #include "typeconv.hpp"
15
- #include "_devicearray.h"
16
15
 
17
16
  /*
18
17
  * Notes on the C_TRACE macro:
@@ -940,37 +939,6 @@ CLEANUP:
940
939
  return retval;
941
940
  }
942
941
 
943
- static int
944
- import_devicearray(void)
945
- {
946
- PyObject *devicearray = PyImport_ImportModule(NUMBA_DEVICEARRAY_IMPORT_NAME);
947
- if (devicearray == NULL) {
948
- return -1;
949
- }
950
-
951
- PyObject *d = PyModule_GetDict(devicearray);
952
- if (d == NULL) {
953
- Py_DECREF(devicearray);
954
- return -1;
955
- }
956
-
957
- PyObject *key = PyUnicode_FromString("_DEVICEARRAY_API");
958
- PyObject *c_api = PyDict_GetItemWithError(d, key);
959
- int retcode = 0;
960
- if (PyCapsule_IsValid(c_api, NUMBA_DEVICEARRAY_IMPORT_NAME "._DEVICEARRAY_API")) {
961
- DeviceArray_API = (void**)PyCapsule_GetPointer(c_api, NUMBA_DEVICEARRAY_IMPORT_NAME "._DEVICEARRAY_API");
962
- if (DeviceArray_API == NULL) {
963
- retcode = -1;
964
- }
965
- } else {
966
- retcode = -1;
967
- }
968
-
969
- Py_DECREF(key);
970
- Py_DECREF(devicearray);
971
- return retcode;
972
- }
973
-
974
942
  static PyMethodDef Dispatcher_methods[] = {
975
943
  { "_clear", (PyCFunction)Dispatcher_clear, METH_NOARGS, NULL },
976
944
  { "_insert", (PyCFunction)Dispatcher_Insert, METH_VARARGS | METH_KEYWORDS,
@@ -1076,12 +1044,6 @@ static PyMethodDef ext_methods[] = {
1076
1044
 
1077
1045
 
1078
1046
  MOD_INIT(_dispatcher) {
1079
- if (import_devicearray() < 0) {
1080
- PyErr_Print();
1081
- PyErr_SetString(PyExc_ImportError, NUMBA_DEVICEARRAY_IMPORT_NAME " failed to import");
1082
- return MOD_ERROR_VAL;
1083
- }
1084
-
1085
1047
  PyObject *m;
1086
1048
  MOD_DEF(m, "_dispatcher", "No docs", ext_methods)
1087
1049
  if (m == NULL)
@@ -9,7 +9,6 @@
9
9
 
10
10
  #include "_typeof.h"
11
11
  #include "_hashtable.h"
12
- #include "_devicearray.h"
13
12
  #include "pyerrors.h"
14
13
 
15
14
  #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -56,9 +55,6 @@ static PyObject *str_typeof_pyval = NULL;
56
55
  static PyObject *str_value = NULL;
57
56
  static PyObject *str_numba_type = NULL;
58
57
 
59
- /* CUDA device array API */
60
- void **DeviceArray_API;
61
-
62
58
  /*
63
59
  * Type fingerprint computation.
64
60
  */
@@ -857,109 +853,6 @@ int typecode_arrayscalar(PyObject *dispatcher, PyObject* aryscalar) {
857
853
  return BASIC_TYPECODES[typecode];
858
854
  }
859
855
 
860
- static
861
- int typecode_devicendarray(PyObject *dispatcher, PyObject *ary)
862
- {
863
- int typecode;
864
- int dtype;
865
- int ndim;
866
- int layout = 0;
867
- PyObject *ndim_obj = nullptr;
868
- PyObject *num_obj = nullptr;
869
- PyObject *dtype_obj = nullptr;
870
- int dtype_num = 0;
871
-
872
- PyObject* flags = PyObject_GetAttrString(ary, "flags");
873
- if (flags == NULL)
874
- {
875
- PyErr_Clear();
876
- goto FALLBACK;
877
- }
878
-
879
- if (PyDict_GetItemString(flags, "C_CONTIGUOUS") == Py_True) {
880
- layout = 1;
881
- } else if (PyDict_GetItemString(flags, "F_CONTIGUOUS") == Py_True) {
882
- layout = 2;
883
- }
884
-
885
- Py_DECREF(flags);
886
-
887
- ndim_obj = PyObject_GetAttrString(ary, "ndim");
888
- if (ndim_obj == NULL) {
889
- /* If there's no ndim, try to proceed by clearing the error and using the
890
- * fallback. */
891
- PyErr_Clear();
892
- goto FALLBACK;
893
- }
894
-
895
- ndim = PyLong_AsLong(ndim_obj);
896
- Py_DECREF(ndim_obj);
897
-
898
- if (PyErr_Occurred()) {
899
- /* ndim wasn't an integer for some reason - unlikely to happen, but try
900
- * the fallback. */
901
- PyErr_Clear();
902
- goto FALLBACK;
903
- }
904
-
905
- if (ndim <= 0 || ndim > N_NDIM)
906
- goto FALLBACK;
907
-
908
- dtype_obj = PyObject_GetAttrString(ary, "dtype");
909
- if (dtype_obj == NULL) {
910
- /* No dtype: try the fallback. */
911
- PyErr_Clear();
912
- goto FALLBACK;
913
- }
914
-
915
- num_obj = PyObject_GetAttrString(dtype_obj, "num");
916
- Py_DECREF(dtype_obj);
917
-
918
- if (num_obj == NULL) {
919
- /* This strange dtype has no num - try the fallback. */
920
- PyErr_Clear();
921
- goto FALLBACK;
922
- }
923
-
924
- dtype_num = PyLong_AsLong(num_obj);
925
- Py_DECREF(num_obj);
926
-
927
- if (PyErr_Occurred()) {
928
- /* num wasn't an integer for some reason - unlikely to happen, but try
929
- * the fallback. */
930
- PyErr_Clear();
931
- goto FALLBACK;
932
- }
933
-
934
- dtype = dtype_num_to_typecode(dtype_num);
935
- if (dtype == -1) {
936
- /* Not a dtype we have in the global lookup table. */
937
- goto FALLBACK;
938
- }
939
-
940
- /* Fast path, using direct table lookup */
941
- assert(layout < N_LAYOUT);
942
- assert(ndim <= N_NDIM);
943
- assert(dtype < N_DTYPES);
944
- typecode = cached_arycode[ndim - 1][layout][dtype];
945
-
946
- if (typecode == -1) {
947
- /* First use of this table entry, so it requires populating */
948
- typecode = typecode_fallback_keep_ref(dispatcher, (PyObject*)ary);
949
- cached_arycode[ndim - 1][layout][dtype] = typecode;
950
- }
951
-
952
- return typecode;
953
-
954
- FALLBACK:
955
- /* Slower path, for non-trivial array types. At present this always uses
956
- the fingerprinting to get the typecode. Future optimization might
957
- implement a cache, but this would require some fast equivalent of
958
- PyArray_DESCR for a device array. */
959
-
960
- return typecode_using_fingerprint(dispatcher, (PyObject *) ary);
961
- }
962
-
963
856
  extern "C" int
964
857
  typeof_typecode(PyObject *dispatcher, PyObject *val)
965
858
  {
@@ -994,10 +887,6 @@ typeof_typecode(PyObject *dispatcher, PyObject *val)
994
887
  else if (tyobj == &PyArray_Type) {
995
888
  return typecode_ndarray(dispatcher, (PyArrayObject*)val);
996
889
  }
997
- /* Subtype of CUDA device array */
998
- else if (PyType_IsSubtype(tyobj, &DeviceArrayType)) {
999
- return typecode_devicendarray(dispatcher, val);
1000
- }
1001
890
  /* Subtypes of Array handling */
1002
891
  else if (PyType_IsSubtype(tyobj, &PyArray_Type)) {
1003
892
  /* By default, Numba will treat all numpy.ndarray subtypes as if they
@@ -12,6 +12,7 @@ from numba.cuda.cudadrv.linkable_code import LinkableCode
12
12
  from numba.cuda.memory_management.nrt import NRT_LIBRARY
13
13
 
14
14
  import os
15
+ import pickle
15
16
  import subprocess
16
17
  import tempfile
17
18
 
@@ -189,6 +190,11 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
189
190
 
190
191
  self.use_cooperative = False
191
192
 
193
+ # Objects that need to be kept alive for the lifetime of the
194
+ # kernels or device functions generated by this code library,
195
+ # e.g., device arrays captured from global scope.
196
+ self.referenced_objects = {}
197
+
192
198
  @property
193
199
  def llvm_strs(self):
194
200
  if self._llvm_strs is None:
@@ -206,6 +212,9 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
206
212
  return device.compute_capability
207
213
 
208
214
  def get_asm_str(self, cc=None):
215
+ return "\n".join(self.get_asm_strs(cc=cc))
216
+
217
+ def get_asm_strs(self, cc=None):
209
218
  cc = self._ensure_cc(cc)
210
219
 
211
220
  ptxes = self._ptx_cache.get(cc, None)
@@ -218,21 +227,25 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
218
227
 
219
228
  irs = self.llvm_strs
220
229
 
221
- ptx = nvvm.compile_ir(irs, **options)
230
+ if "g" in options:
231
+ ptxes = [nvvm.compile_ir(ir, **options) for ir in irs]
232
+ else:
233
+ ptxes = [nvvm.compile_ir(irs, **options)]
222
234
 
223
235
  # Sometimes the result from NVVM contains trailing whitespace and
224
236
  # nulls, which we strip so that the assembly dump looks a little
225
237
  # tidier.
226
- ptx = ptx.decode().strip("\x00").strip()
238
+ ptxes = [ptx.decode().strip("\x00").strip() for ptx in ptxes]
227
239
 
228
240
  if config.DUMP_ASSEMBLY:
229
241
  print(("ASSEMBLY %s" % self._name).center(80, "-"))
230
- print(ptx)
242
+ for ptx in ptxes:
243
+ print(ptx)
231
244
  print("=" * 80)
232
245
 
233
- self._ptx_cache[cc] = ptx
246
+ self._ptx_cache[cc] = ptxes
234
247
 
235
- return ptx
248
+ return ptxes
236
249
 
237
250
  def get_lto_ptx(self, cc=None):
238
251
  """
@@ -247,7 +260,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
247
260
 
248
261
  cc = self._ensure_cc(cc)
249
262
 
250
- linker = driver._Linker.new(
263
+ linker = driver._Linker(
251
264
  max_registers=self._max_registers,
252
265
  cc=cc,
253
266
  additional_flags=["-ptx"],
@@ -284,8 +297,9 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
284
297
  ltoir = self.get_ltoir(cc=cc)
285
298
  linker.add_ltoir(ltoir)
286
299
  else:
287
- ptx = self.get_asm_str(cc=cc)
288
- linker.add_ptx(ptx.encode())
300
+ ptxes = self.get_asm_strs(cc=cc)
301
+ for ptx in ptxes:
302
+ linker.add_ptx(ptx.encode())
289
303
 
290
304
  for path in self._linking_files:
291
305
  linker.add_file_guess_ext(path, ignore_nonlto)
@@ -308,7 +322,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
308
322
  print(ptx)
309
323
  print("=" * 80)
310
324
 
311
- linker = driver._Linker.new(
325
+ linker = driver._Linker(
312
326
  max_registers=self._max_registers, cc=cc, lto=self._lto
313
327
  )
314
328
  self._link_all(linker, cc, ignore_nonlto=False)
@@ -377,6 +391,9 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
377
391
  self._setup_functions.extend(library._setup_functions)
378
392
  self._teardown_functions.extend(library._teardown_functions)
379
393
  self.use_cooperative |= library.use_cooperative
394
+ self.referenced_objects.update(
395
+ getattr(library, "referenced_objects", {})
396
+ )
380
397
 
381
398
  def add_linking_file(self, path_or_obj):
382
399
  if isinstance(path_or_obj, LinkableCode):
@@ -432,7 +449,10 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
432
449
  for mod in library.modules:
433
450
  for fn in mod.functions:
434
451
  if not fn.is_declaration:
435
- fn.linkage = "linkonce_odr"
452
+ if "g" in self._nvvm_options:
453
+ fn.linkage = "weak_odr"
454
+ else:
455
+ fn.linkage = "linkonce_odr"
436
456
 
437
457
  self._finalized = True
438
458
 
@@ -442,6 +462,18 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
442
462
  but loaded functions are discarded. They are recreated when needed
443
463
  after deserialization.
444
464
  """
465
+ # Check for captured device arrays that cannot be safely cached.
466
+ if self.referenced_objects:
467
+ if any(
468
+ getattr(obj, "__cuda_array_interface__", None) is not None
469
+ for obj in self.referenced_objects.values()
470
+ ):
471
+ raise pickle.PicklingError(
472
+ "Cannot serialize kernels or device functions referencing "
473
+ "global device arrays. Pass the array(s) as arguments "
474
+ "to the kernel instead."
475
+ )
476
+
445
477
  nrt = False
446
478
  if self._linking_files:
447
479
  if (
@@ -1023,10 +1023,9 @@ def compile_all(
1023
1023
  )
1024
1024
 
1025
1025
  if lto:
1026
- code = lib.get_ltoir(cc=cc)
1026
+ codes = [lib.get_ltoir(cc=cc)]
1027
1027
  else:
1028
- code = lib.get_asm_str(cc=cc)
1029
- codes = [code]
1028
+ codes = lib.get_asm_strs(cc=cc)
1030
1029
 
1031
1030
  # linking_files
1032
1031
  is_ltoir = output == "ltoir"
@@ -1241,7 +1240,14 @@ def compile(
1241
1240
  if lto:
1242
1241
  code = lib.get_ltoir(cc=cc)
1243
1242
  else:
1244
- code = lib.get_asm_str(cc=cc)
1243
+ codes = lib.get_asm_strs(cc=cc)
1244
+ if len(codes) == 1:
1245
+ code = codes[0]
1246
+ else:
1247
+ raise RuntimeError(
1248
+ "Compiling this function results in multiple "
1249
+ "PTX files. Use compile_all() instead"
1250
+ )
1245
1251
  return code, resty
1246
1252
 
1247
1253
 
@@ -38,13 +38,16 @@ def compute_use_defs(blocks):
38
38
  func = ir_extension_usedefs[type(stmt)]
39
39
  func(stmt, use_set, def_set)
40
40
  continue
41
- if isinstance(stmt, ir.Assign):
42
- if isinstance(stmt.value, ir.Inst):
41
+ if isinstance(stmt, ir.assign_types):
42
+ if isinstance(stmt.value, ir.inst_types):
43
43
  rhs_set = set(var.name for var in stmt.value.list_vars())
44
- elif isinstance(stmt.value, ir.Var):
44
+ elif isinstance(stmt.value, ir.var_types):
45
45
  rhs_set = set([stmt.value.name])
46
- elif isinstance(
47
- stmt.value, (ir.Arg, ir.Const, ir.Global, ir.FreeVar)
46
+ elif (
47
+ isinstance(stmt.value, ir.arg_types)
48
+ or isinstance(stmt.value, ir.const_types)
49
+ or isinstance(stmt.value, ir.global_types)
50
+ or isinstance(stmt.value, ir.freevar_types)
48
51
  ):
49
52
  rhs_set = ()
50
53
  else:
@@ -326,7 +329,7 @@ def rewrite_semantic_constants(func_ir, called_args):
326
329
  if getattr(val, "op", None) == "getattr":
327
330
  if val.attr == "ndim":
328
331
  arg_def = guard(get_definition, func_ir, val.value)
329
- if isinstance(arg_def, ir.Arg):
332
+ if isinstance(arg_def, ir.arg_types):
330
333
  argty = called_args[arg_def.index]
331
334
  if isinstance(argty, types.Array):
332
335
  rewrite_statement(func_ir, stmt, argty.ndim)
@@ -337,17 +340,17 @@ def rewrite_semantic_constants(func_ir, called_args):
337
340
  func = guard(get_definition, func_ir, val.func)
338
341
  if (
339
342
  func is not None
340
- and isinstance(func, ir.Global)
343
+ and isinstance(func, ir.global_types)
341
344
  and getattr(func, "value", None) is len
342
345
  ):
343
346
  (arg,) = val.args
344
347
  arg_def = guard(get_definition, func_ir, arg)
345
- if isinstance(arg_def, ir.Arg):
348
+ if isinstance(arg_def, ir.arg_types):
346
349
  argty = called_args[arg_def.index]
347
350
  if isinstance(argty, types.BaseTuple):
348
351
  rewrite_statement(func_ir, stmt, argty.count)
349
352
  elif (
350
- isinstance(arg_def, ir.Expr)
353
+ isinstance(arg_def, ir.expr_types)
351
354
  and arg_def.op == "typed_getitem"
352
355
  ):
353
356
  argty = arg_def.dtype
@@ -358,9 +361,9 @@ def rewrite_semantic_constants(func_ir, called_args):
358
361
 
359
362
  for blk in func_ir.blocks.values():
360
363
  for stmt in blk.body:
361
- if isinstance(stmt, ir.Assign):
364
+ if isinstance(stmt, ir.assign_types):
362
365
  val = stmt.value
363
- if isinstance(val, ir.Expr):
366
+ if isinstance(val, ir.expr_types):
364
367
  rewrite_array_ndim(val, func_ir, called_args)
365
368
  rewrite_tuple_len(val, func_ir, called_args)
366
369
 
@@ -391,7 +394,7 @@ def find_literally_calls(func_ir, argtypes):
391
394
  for blk in func_ir.blocks.values():
392
395
  for assign in blk.find_exprs(op="call"):
393
396
  var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func)
394
- if isinstance(var, (ir.Global, ir.FreeVar)):
397
+ if isinstance(var, ir.global_types + ir.freevar_types):
395
398
  fnobj = var.value
396
399
  else:
397
400
  fnobj = ir_utils.guard(
@@ -401,7 +404,7 @@ def find_literally_calls(func_ir, argtypes):
401
404
  # Found
402
405
  [arg] = assign.args
403
406
  defarg = func_ir.get_definition(arg)
404
- if isinstance(defarg, ir.Arg):
407
+ if isinstance(defarg, ir.arg_types):
405
408
  argindex = defarg.index
406
409
  marked_args.add(argindex)
407
410
  first_loc.setdefault(argindex, assign.loc)
@@ -473,14 +476,14 @@ def dead_branch_prune(func_ir, called_args):
473
476
  branches = []
474
477
  for blk in func_ir.blocks.values():
475
478
  branch_or_jump = blk.body[-1]
476
- if isinstance(branch_or_jump, ir.Branch):
479
+ if isinstance(branch_or_jump, ir.branch_types):
477
480
  branch = branch_or_jump
478
481
  pred = guard(get_definition, func_ir, branch.cond.name)
479
482
  if pred is not None and getattr(pred, "op", None) == "call":
480
483
  function = guard(get_definition, func_ir, pred.func)
481
484
  if (
482
485
  function is not None
483
- and isinstance(function, ir.Global)
486
+ and isinstance(function, ir.global_types)
484
487
  and function.value is bool
485
488
  ):
486
489
  condition = guard(get_definition, func_ir, pred.args[0])
@@ -539,7 +542,9 @@ def dead_branch_prune(func_ir, called_args):
539
542
  try:
540
543
  # Just to prevent accidents, whilst already guarded, ensure this
541
544
  # is an ir.Const
542
- if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
545
+ if not isinstance(
546
+ pred, ir.const_types + ir.freevar_types + ir.global_types
547
+ ):
543
548
  raise TypeError("Expected constant Numba IR node")
544
549
  take_truebr = bool(pred.value)
545
550
  except TypeError:
@@ -584,8 +589,11 @@ def dead_branch_prune(func_ir, called_args):
584
589
  phi2asgn = dict()
585
590
  for lbl, blk in func_ir.blocks.items():
586
591
  for stmt in blk.body:
587
- if isinstance(stmt, ir.Assign):
588
- if isinstance(stmt.value, ir.Expr) and stmt.value.op == "phi":
592
+ if isinstance(stmt, ir.assign_types):
593
+ if (
594
+ isinstance(stmt.value, ir.expr_types)
595
+ and stmt.value.op == "phi"
596
+ ):
589
597
  phi2lbl[stmt.value] = lbl
590
598
  phi2asgn[stmt.value] = stmt
591
599
 
@@ -599,12 +607,12 @@ def dead_branch_prune(func_ir, called_args):
599
607
 
600
608
  for branch, condition, blk in branch_info:
601
609
  const_conds = []
602
- if isinstance(condition, ir.Expr) and condition.op == "binop":
610
+ if isinstance(condition, ir.expr_types) and condition.op == "binop":
603
611
  prune = prune_by_value
604
612
  for arg in [condition.lhs, condition.rhs]:
605
613
  resolved_const = Unknown()
606
614
  arg_def = guard(get_definition, func_ir, arg)
607
- if isinstance(arg_def, ir.Arg):
615
+ if isinstance(arg_def, ir.arg_types):
608
616
  # it's an e.g. literal argument to the function
609
617
  resolved_const = resolve_input_arg_const(arg_def.index)
610
618
  prune = prune_by_type
@@ -668,7 +676,7 @@ def dead_branch_prune(func_ir, called_args):
668
676
  for _, cond, blk in branch_info:
669
677
  if cond in deadcond:
670
678
  for x in blk.body:
671
- if isinstance(x, ir.Assign) and x.value is cond:
679
+ if isinstance(x, ir.assign_types) and x.value is cond:
672
680
  # rewrite the condition as a true/false bit
673
681
  nullified_info = nullified_conditions[deadcond.index(cond)]
674
682
  # only do a rewrite of conditions, predicates need to retain
@@ -94,16 +94,16 @@ class TypeAnnotation(object):
94
94
  for inst in blk.body:
95
95
  lineno = inst.loc.line
96
96
 
97
- if isinstance(inst, ir.Assign):
97
+ if isinstance(inst, ir.assign_types):
98
98
  if found_lifted_loop:
99
99
  atype = "XXX Lifted Loop XXX"
100
100
  found_lifted_loop = False
101
101
  elif (
102
- isinstance(inst.value, ir.Expr)
102
+ isinstance(inst.value, ir.expr_types)
103
103
  and inst.value.op == "call"
104
104
  ):
105
105
  atype = self.calltypes[inst.value]
106
- elif isinstance(inst.value, ir.Const) and isinstance(
106
+ elif isinstance(inst.value, ir.const_types) and isinstance(
107
107
  inst.value.value, LiftedLoop
108
108
  ):
109
109
  atype = "XXX Lifted Loop XXX"
@@ -113,7 +113,7 @@ class TypeAnnotation(object):
113
113
  atype = self.typemap.get(inst.target.name, "<missing>")
114
114
 
115
115
  aline = "%s = %s :: %s" % (inst.target, inst.value, atype)
116
- elif isinstance(inst, ir.SetItem):
116
+ elif isinstance(inst, ir.setitem_types):
117
117
  atype = self.calltypes[inst]
118
118
  aline = "%s :: %s" % (inst, atype)
119
119
  else:
@@ -933,7 +933,12 @@ class BaseContext(object):
933
933
  If *caching* evaluates True, the function keeps the compiled function
934
934
  for reuse in *.cached_internal_func*.
935
935
  """
936
- cache_key = (impl.__code__, sig, type(self.error_model))
936
+ cache_key = (
937
+ impl.__code__,
938
+ sig,
939
+ type(self.error_model),
940
+ self.enable_nrt,
941
+ )
937
942
  if not caching:
938
943
  cached = None
939
944
  else:
@@ -68,7 +68,7 @@ class ConstantInference(object):
68
68
  try:
69
69
  const = defn.infer_constant()
70
70
  except ConstantInferenceError:
71
- if isinstance(defn, ir.Expr):
71
+ if isinstance(defn, ir.expr_types):
72
72
  return self._infer_expr(defn)
73
73
  self._fail(defn)
74
74
  return const