gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -27,10 +27,14 @@ LIBXNDDIR = ..\xnd\libxnd
27
27
 
28
28
  OPT = /MT /Ox /GS /EHsc /fp:strict
29
29
  OPT_SHARED = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /fp:strict /Fo.objs^\
30
+ OPT_NOFP = /MT /Ox /GS /EHsc
31
+ OPT_SHARED_NOFP = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /Fo.objs^\
30
32
 
31
33
  COMMON_CFLAGS = /nologo /W4 /wd4200 /wd4201 /wd4204
32
34
  CFLAGS = $(COMMON_CFLAGS) $(OPT)
33
35
  CFLAGS_SHARED = $(COMMON_CFLAGS) $(OPT_SHARED)
36
+ CFLAGS_NOFP = $(COMMON_CFLAGS) $(OPT_NOFP)
37
+ CFLAGS_SHARED_NOFP = $(COMMON_CFLAGS) $(OPT_SHARED_NOFP)
34
38
 
35
39
 
36
40
  default: $(LIBSTATIC) $(LIBSHARED)
@@ -40,11 +44,14 @@ default: $(LIBSTATIC) $(LIBSHARED)
40
44
  copy /y $(LIBSHARED) ..\python\gumath
41
45
 
42
46
 
43
- OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj \
44
- unary.obj binary.obj examples.obj graph.obj pdist.obj
47
+ OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj cpu_host_unary.obj \
48
+ cpu_device_unary.obj cpu_host_binary.obj cpu_device_binary.obj cpu_device_msvc.obj \
49
+ common.obj examples.obj graph.obj pdist.obj
45
50
 
46
51
  SHARED_OBJS = .objs/apply.obj .objs/func.obj .objs/nploops.obj .objs/tbl.obj .objs/xndloops.obj \
47
- .objs/unary.obj .objs/binary.obj .objs/examples.obj .objs/graph.obj .objs/pdist.obj
52
+ .objs/cpu_host_unary.obj .objs/cpu_device_unary.obj .objs/cpu_host_binary.obj \
53
+ .objs/cpu_device_binary.obj .objs/cpu_device_msvc.obj .objs/common.obj \
54
+ .objs/examples.obj .objs/graph.obj .objs/pdist.obj
48
55
 
49
56
 
50
57
  $(LIBSTATIC):\
@@ -101,21 +108,53 @@ Makefile xndloops.c gumath.h
101
108
  Makefile xndloops.c gumath.h
102
109
  $(CC) "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c xndloops.c
103
110
 
104
- unary.obj:\
105
- Makefile kernels\unary.c gumath.h
106
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\unary.c
111
+ cpu_host_unary.obj:\
112
+ Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
113
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_unary.c
107
114
 
108
- .objs\unary.obj:\
109
- Makefile kernels\unary.c gumath.h
110
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\unary.c
115
+ .objs\cpu_host_unary.obj:\
116
+ Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
117
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_unary.c
111
118
 
112
- binary.obj:\
113
- Makefile kernels\binary.c gumath.h
114
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\binary.c
119
+ cpu_device_unary.obj:\
120
+ Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
121
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_unary.cc
115
122
 
116
- .objs\binary.obj:\
117
- Makefile kernels\binary.c gumath.h
118
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\binary.c
123
+ .objs\cpu_device_unary.obj:\
124
+ Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
125
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_unary.cc
126
+
127
+ cpu_host_binary.obj:\
128
+ Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
129
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_binary.c
130
+
131
+ .objs\cpu_host_binary.obj:\
132
+ Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
133
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_binary.c
134
+
135
+ cpu_device_binary.obj:\
136
+ Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
137
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_binary.cc
138
+
139
+ .objs\cpu_device_binary.obj:\
140
+ Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
141
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_binary.cc
142
+
143
+ cpu_device_msvc.obj:\
144
+ Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
145
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_NOFP) -c kernels\cpu_device_msvc.cc
146
+
147
+ .objs\cpu_device_msvc.obj:\
148
+ Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
149
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED_NOFP) -c kernels\cpu_device_msvc.cc
150
+
151
+ common.obj:\
152
+ Makefile kernels\common.c kernels\common.h gumath.h
153
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\common.c
154
+
155
+ .objs\common.obj:\
156
+ Makefile kernels\common.c kernels\common.h gumath.h
157
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\common.c
119
158
 
120
159
  examples.obj:\
121
160
  Makefile extending\examples.c gumath.h
@@ -40,6 +40,29 @@
40
40
  #include "gumath.h"
41
41
 
42
42
 
43
+ /* flags that apply to all arguments */
44
+ #define OPT_Z (NDT_EXT_ZERO|NDT_INNER_C)
45
+ #define OPT_C (NDT_EXT_C|NDT_INNER_C)
46
+ #define OPT_S (NDT_EXT_STRIDED|NDT_INNER_STRIDED)
47
+ #define OPT_SC (NDT_EXT_STRIDED|NDT_INNER_C)
48
+ #define OPT_SF (NDT_EXT_STRIDED|NDT_INNER_F)
49
+
50
+ #define INNER_C (NDT_INNER_C)
51
+ #define INNER_F (NDT_INNER_F)
52
+ #define INNER_S (NDT_INNER_STRIDED)
53
+ #define INNER_X (NDT_INNER_XND)
54
+
55
+ /* kernel requests */
56
+ #define REQ_LOOP_C(flags) ((flags&OPT_C) == OPT_C)
57
+ #define REQ_LOOP_Z(flags) ((flags&OPT_C) == OPT_C || (flags&OPT_Z) == OPT_Z)
58
+ #define REQ_LOOP_S(flags) ((flags&OPT_S) == OPT_S)
59
+
60
+ #define REQ_INNER_C(flags) ((flags&INNER_C) == INNER_C)
61
+ #define REQ_INNER_F(flags) ((flags&INNER_F) == INNER_F)
62
+ #define REQ_INNER_S(flags) ((flags&INNER_S) == INNER_S)
63
+ #define REQ_INNER_X(flags) ((flags&INNER_X) == INNER_X)
64
+
65
+
43
66
  static int
44
67
  sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
45
68
  {
@@ -55,6 +78,18 @@ sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
55
78
  return sum;
56
79
  }
57
80
 
81
+ static inline bool
82
+ opt_safe(int outer, ndt_context_t *ctx)
83
+ {
84
+ if (outer == 0) {
85
+ ndt_err_format(ctx, NDT_RuntimeError,
86
+ "internal error: optimized kernel called with outer_dims==0");
87
+ return false;
88
+ }
89
+
90
+ return true;
91
+ }
92
+
58
93
  int
59
94
  gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
60
95
  ndt_context_t *ctx)
@@ -62,30 +97,43 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
62
97
  const int nargs = (int)kernel->set->sig->Function.nargs;
63
98
 
64
99
  switch (kernel->flag) {
65
- case NDT_ELEMWISE_1D: {
66
- if (outer_dims == 0) {
67
- ndt_err_format(ctx, NDT_RuntimeError,
68
- "gm_xnd_map: optimized elementwise kernel called with "
69
- "outer_dims==0");
100
+ case OPT_C: {
101
+ if (!opt_safe(outer_dims, ctx)) {
70
102
  return -1;
71
103
  }
72
104
 
73
- return gm_xnd_map(kernel->set->Opt, stack, nargs, outer_dims-1, ctx);
105
+ return gm_xnd_map(kernel->set->OptC, stack, nargs, outer_dims-1, ctx);
74
106
  }
75
107
 
76
- case NDT_C: {
108
+ case OPT_Z: {
109
+ if (!opt_safe(outer_dims, ctx)) {
110
+ return -1;
111
+ }
112
+
113
+ return gm_xnd_map(kernel->set->OptZ, stack, nargs, outer_dims-1, ctx);
114
+ }
115
+
116
+ case OPT_S: {
117
+ if (!opt_safe(outer_dims, ctx)) {
118
+ return -1;
119
+ }
120
+
121
+ return gm_xnd_map(kernel->set->OptS, stack, nargs, outer_dims-1, ctx);
122
+ }
123
+
124
+ case INNER_C: {
77
125
  return gm_xnd_map(kernel->set->C, stack, nargs, outer_dims, ctx);
78
126
  }
79
127
 
80
- case NDT_FORTRAN: {
128
+ case INNER_F: {
81
129
  return gm_xnd_map(kernel->set->Fortran, stack, nargs, outer_dims, ctx);
82
130
  }
83
131
 
84
- case NDT_XND: {
132
+ case INNER_X: {
85
133
  return gm_xnd_map(kernel->set->Xnd, stack, nargs, outer_dims, ctx);
86
134
  }
87
135
 
88
- case NDT_STRIDED: {
136
+ case INNER_S: {
89
137
  const int sum_inner = sum_inner_dimensions(stack, nargs, outer_dims);
90
138
  const int dims_size = outer_dims + sum_inner;
91
139
  const int steps_size = nargs * outer_dims + sum_inner;
@@ -105,8 +153,9 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
105
153
  }
106
154
  }
107
155
 
108
- /* NOT REACHED: tags should be exhaustive. */
109
- ndt_internal_error("invalid tag");
156
+ ndt_err_format(ctx, NDT_RuntimeError,
157
+ "internal error: kernel selection failed");
158
+ return -1;
110
159
  }
111
160
 
112
161
  static gm_kernel_t
@@ -117,35 +166,48 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
117
166
 
118
167
  kernel.set = set;
119
168
 
120
- if (set->Opt != NULL && (spec->flags&NDT_ELEMWISE_1D)) {
121
- kernel.flag = NDT_ELEMWISE_1D;
169
+ if (REQ_LOOP_C(spec->flags) && set->OptC != NULL) {
170
+ kernel.flag = OPT_C;
171
+ return kernel;
172
+ }
173
+
174
+ if (REQ_LOOP_Z(spec->flags) && set->OptZ != NULL) {
175
+ kernel.flag = OPT_Z;
176
+ return kernel;
177
+ }
178
+
179
+ if (REQ_LOOP_S(spec->flags) && set->OptS != NULL) {
180
+ kernel.flag = OPT_S;
122
181
  return kernel;
123
182
  }
124
183
 
125
- if (set->C != NULL && (spec->flags&NDT_C)) {
126
- kernel.flag = NDT_C;
184
+ if (REQ_INNER_C(spec->flags) && set->C != NULL) {
185
+ kernel.flag = INNER_C;
127
186
  return kernel;
128
187
  }
129
188
 
130
- if (set->Fortran != NULL && (spec->flags&NDT_FORTRAN)) {
131
- kernel.flag = NDT_FORTRAN;
189
+ if (REQ_INNER_F(spec->flags) && set->Fortran != NULL) {
190
+ kernel.flag = INNER_F;
132
191
  return kernel;
133
192
  }
134
193
 
135
- if (set->Strided != NULL && (spec->flags&NDT_STRIDED)) {
136
- kernel.flag = NDT_STRIDED;
194
+ if (REQ_INNER_S(spec->flags) && set->Strided != NULL) {
195
+ kernel.flag = INNER_S;
137
196
  return kernel;
138
197
  }
139
198
 
140
- if (set->Xnd != NULL && (spec->flags&NDT_XND)) {
141
- kernel.flag = NDT_XND;
199
+ if (REQ_INNER_X(spec->flags) && set->Xnd != NULL) {
200
+ kernel.flag = INNER_X;
142
201
  return kernel;
143
202
  }
144
203
 
145
204
  kernel.set = NULL;
146
205
  ndt_err_format(ctx, NDT_RuntimeError,
147
- "could not find specialized kernel for '%s' input (available: %s, %s, %s, %s)",
206
+ "could not find specialized kernel for '%s' input (available: %s, %s, %s, %s, %s, %s, %s)",
148
207
  ndt_apply_flags_as_string(spec),
208
+ set->OptC ? "OptC" : "_",
209
+ set->OptZ ? "OptZ" : "_",
210
+ set->OptS ? "OptS" : "_",
149
211
  set->C ? "C" : "_",
150
212
  set->Fortran ? "Fortran" : "_",
151
213
  set->Xnd ? "Xnd" : "_",
@@ -157,8 +219,8 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
157
219
  /* Look up a multimethod by name and select a kernel. */
158
220
  gm_kernel_t
159
221
  gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
160
- const ndt_t *in_types[], int nin, const xnd_t args[],
161
- ndt_context_t *ctx)
222
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
223
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx)
162
224
  {
163
225
  gm_kernel_t empty_kernel = {0U, NULL};
164
226
  const gm_func_t *f;
@@ -171,7 +233,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
171
233
  }
172
234
 
173
235
  if (f->typecheck != NULL) {
174
- const gm_kernel_set_t *set = f->typecheck(spec, f, in_types, nin, ctx);
236
+ const gm_kernel_set_t *set = f->typecheck(spec, f, types, li, nin, nout,
237
+ check_broadcast, ctx);
175
238
  if (set == NULL) {
176
239
  return empty_kernel;
177
240
  }
@@ -180,7 +243,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
180
243
 
181
244
  for (i = 0; i < f->nkernels; i++) {
182
245
  const gm_kernel_set_t *set = &f->kernels[i];
183
- if (ndt_typecheck(spec, set->sig, in_types, nin, set->constraint, args,
246
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
247
+ check_broadcast, set->constraint, args,
184
248
  ctx) < 0) {
185
249
  ndt_err_clear(ctx);
186
250
  continue;
@@ -188,7 +252,7 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
188
252
  return select_kernel(spec, set, ctx);
189
253
  }
190
254
 
191
- s = ndt_list_as_string(in_types, nin, ctx);
255
+ s = ndt_list_as_string(types, nin, ctx);
192
256
  if (s == NULL) {
193
257
  return empty_kernel;
194
258
  }
@@ -181,9 +181,10 @@ static xnd_t
181
181
  mk_return_array(int32_t p[], const int64_t N, const int32_t u,
182
182
  ndt_context_t *ctx)
183
183
  {
184
- int32_t *ndim2_offsets = NULL;
185
- int32_t *ndim1_offsets = NULL;
186
- ndt_t *t;
184
+ ndt_offsets_t *ndim2_offsets = NULL;
185
+ ndt_offsets_t *ndim1_offsets = NULL;
186
+ int32_t *ptr;
187
+ const ndt_t *t, *type;
187
188
  int64_t sum;
188
189
  int32_t v;
189
190
 
@@ -191,45 +192,51 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
191
192
  goto offset_overflow;
192
193
  }
193
194
 
194
- ndim2_offsets = ndt_alloc(2, sizeof *ndim2_offsets);
195
+ ndim2_offsets = ndt_offsets_new(2, ctx);
195
196
  if (ndim2_offsets == NULL) {
196
- (void)ndt_memory_error(ctx);
197
197
  return xnd_error;
198
198
  }
199
- ndim2_offsets[0] = 0;
200
- ndim2_offsets[1] = (int32_t)N;
199
+ ptr = (int32_t *)ndim2_offsets->v;
200
+ ptr[0] = 0;
201
+ ptr[1] = (int32_t)N;
201
202
 
202
203
 
203
- ndim1_offsets = ndt_alloc(N+1, sizeof *ndim1_offsets);
204
+ ndim1_offsets = ndt_offsets_new((int32_t)(N+1), ctx);
204
205
  if (ndim1_offsets == NULL) {
205
- (void)ndt_memory_error(ctx);
206
+ ndt_decref_offsets(ndim2_offsets);
206
207
  return xnd_error;
207
208
  }
208
209
 
209
210
  sum = 0;
211
+ ptr = (int32_t *)ndim1_offsets->v;
210
212
  for (v = 0; v < N; v++) {
211
- ndim1_offsets[v] = (int32_t)sum;
213
+ ptr[v] = (int32_t)sum;
212
214
  int64_t n = write_path(NULL, 0, p, N, u, v);
213
215
  sum += n;
214
216
  if (sum > INT32_MAX) {
215
217
  goto offset_overflow;
216
218
  }
217
219
  }
218
- ndim1_offsets[v] = (int32_t)sum;
220
+ ptr[v] = (int32_t)sum;
219
221
 
220
222
 
221
- t = ndt_from_string("node", ctx);
222
- if (t == NULL) {
223
+ type = ndt_from_string("node", ctx);
224
+ if (type == NULL) {
223
225
  goto error;
224
226
  }
225
227
 
226
- t = ndt_var_dim(t, InternalOffsets, (int32_t)(N+1), ndim1_offsets, 0, NULL, ctx);
228
+ t = ndt_var_dim(type, ndim1_offsets, 0, NULL, false, ctx);
229
+ ndt_decref(type);
230
+ ndt_decref_offsets(ndim1_offsets);
227
231
  ndim1_offsets = NULL;
228
232
  if (t == NULL) {
229
233
  goto error;
230
234
  }
235
+ type = t;
231
236
 
232
- t = ndt_var_dim(t, InternalOffsets, 2, ndim2_offsets, 0, NULL, ctx);
237
+ t = ndt_var_dim(type, ndim2_offsets, 0, NULL, false, ctx);
238
+ ndt_decref(type);
239
+ ndt_decref_offsets(ndim2_offsets);
233
240
  ndim2_offsets = NULL;
234
241
  if (t == NULL) {
235
242
  goto error;
@@ -244,16 +251,16 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
244
251
 
245
252
  t = out.type->VarDim.type;
246
253
  for (v = 0; v < N; v++) {
247
- int32_t shape = t->Concrete.VarDim.offsets[v+1]-t->Concrete.VarDim.offsets[v];
248
- char *ptr = out.ptr + t->Concrete.VarDim.offsets[v] * t->Concrete.VarDim.itemsize;
249
- (void)write_path((int32_t *)ptr, shape, p, N, u, v);
254
+ int32_t shape = t->Concrete.VarDim.offsets->v[v+1]-t->Concrete.VarDim.offsets->v[v];
255
+ char *cp = out.ptr + t->Concrete.VarDim.offsets->v[v] * t->Concrete.VarDim.itemsize;
256
+ (void)write_path((int32_t *)cp, shape, p, N, u, v);
250
257
  }
251
258
 
252
259
  return out;
253
260
 
254
261
  error:
255
- ndt_free(ndim2_offsets);
256
- ndt_free(ndim1_offsets);
262
+ ndt_decref_offsets(ndim2_offsets);
263
+ ndt_decref_offsets(ndim1_offsets);
257
264
  return xnd_error;
258
265
 
259
266
  offset_overflow:
@@ -82,7 +82,7 @@ pdist(xnd_t stack[], ndt_context_t *ctx)
82
82
  * Validate N, M and compute unknown output dimension P.
83
83
  *
84
84
  * shape[0] = N
85
- * shape[1] = N
85
+ * shape[1] = M
86
86
  * shape[2] = P
87
87
  *
88
88
  * 'args' is unused here. Other functions may inspect the incoming xnd
@@ -73,7 +73,7 @@ gm_func_del(gm_func_t *f)
73
73
  ndt_free(f->name);
74
74
 
75
75
  for (int i = 0; i < f->nkernels; i++) {
76
- ndt_del(f->kernels[i].sig);
76
+ ndt_decref(f->kernels[i].sig);
77
77
  }
78
78
 
79
79
  ndt_free(f);
@@ -101,7 +101,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
101
101
  {
102
102
  gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
103
103
  gm_kernel_set_t kernel;
104
- ndt_t *t;
104
+ const ndt_t *t;
105
105
 
106
106
  if (f == NULL) {
107
107
  ndt_err_clear(ctx);
@@ -117,7 +117,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
117
117
  }
118
118
 
119
119
  if (f->nkernels == GM_MAX_KERNELS) {
120
- ndt_del(t);
120
+ ndt_decref(t);
121
121
  ndt_err_format(ctx, NDT_RuntimeError,
122
122
  "%s: maximum number of kernels reached for", f->name);
123
123
  return -1;
@@ -125,11 +125,13 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
125
125
 
126
126
  kernel.sig = t;
127
127
  kernel.constraint = k->constraint;
128
- kernel.Opt = k->Opt;
128
+ kernel.OptC = k->OptC;
129
+ kernel.OptZ = k->OptZ;
130
+ kernel.OptS = k->OptS;
129
131
  kernel.C = k->C;
130
132
  kernel.Fortran = k->Fortran;
131
- kernel.Strided = k->Strided;
132
133
  kernel.Xnd = k->Xnd;
134
+ kernel.Strided = k->Strided;
133
135
 
134
136
  f->kernels[f->nkernels++] = kernel;
135
137
  return 0;
@@ -141,7 +143,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
141
143
  {
142
144
  gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
143
145
  gm_kernel_set_t kernel;
144
- ndt_t *t;
146
+ const ndt_t *t;
145
147
 
146
148
  if (f == NULL) {
147
149
  ndt_err_clear(ctx);
@@ -158,7 +160,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
158
160
  }
159
161
 
160
162
  if (f->nkernels == GM_MAX_KERNELS) {
161
- ndt_del(t);
163
+ ndt_decref(t);
162
164
  ndt_err_format(ctx, NDT_RuntimeError,
163
165
  "%s: maximum number of kernels reached for", f->name);
164
166
  return -1;
@@ -166,11 +168,13 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
166
168
 
167
169
  kernel.sig = t;
168
170
  kernel.constraint = k->constraint;
169
- kernel.Opt = k->Opt;
171
+ kernel.OptC = k->OptC;
172
+ kernel.OptZ = k->OptZ;
173
+ kernel.OptS = k->OptS;
170
174
  kernel.C = k->C;
171
175
  kernel.Fortran = k->Fortran;
172
- kernel.Strided = k->Strided;
173
176
  kernel.Xnd = k->Xnd;
177
+ kernel.Strided = k->Strided;
174
178
 
175
179
  f->kernels[f->nkernels++] = kernel;
176
180
  return 0;