gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -27,10 +27,14 @@ LIBXNDDIR = ..\xnd\libxnd
27
27
 
28
28
  OPT = /MT /Ox /GS /EHsc /fp:strict
29
29
  OPT_SHARED = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /fp:strict /Fo.objs^\
30
+ OPT_NOFP = /MT /Ox /GS /EHsc
31
+ OPT_SHARED_NOFP = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /Fo.objs^\
30
32
 
31
33
  COMMON_CFLAGS = /nologo /W4 /wd4200 /wd4201 /wd4204
32
34
  CFLAGS = $(COMMON_CFLAGS) $(OPT)
33
35
  CFLAGS_SHARED = $(COMMON_CFLAGS) $(OPT_SHARED)
36
+ CFLAGS_NOFP = $(COMMON_CFLAGS) $(OPT_NOFP)
37
+ CFLAGS_SHARED_NOFP = $(COMMON_CFLAGS) $(OPT_SHARED_NOFP)
34
38
 
35
39
 
36
40
  default: $(LIBSTATIC) $(LIBSHARED)
@@ -40,11 +44,14 @@ default: $(LIBSTATIC) $(LIBSHARED)
40
44
  copy /y $(LIBSHARED) ..\python\gumath
41
45
 
42
46
 
43
- OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj \
44
- unary.obj binary.obj examples.obj graph.obj pdist.obj
47
+ OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj cpu_host_unary.obj \
48
+ cpu_device_unary.obj cpu_host_binary.obj cpu_device_binary.obj cpu_device_msvc.obj \
49
+ common.obj examples.obj graph.obj pdist.obj
45
50
 
46
51
  SHARED_OBJS = .objs/apply.obj .objs/func.obj .objs/nploops.obj .objs/tbl.obj .objs/xndloops.obj \
47
- .objs/unary.obj .objs/binary.obj .objs/examples.obj .objs/graph.obj .objs/pdist.obj
52
+ .objs/cpu_host_unary.obj .objs/cpu_device_unary.obj .objs/cpu_host_binary.obj \
53
+ .objs/cpu_device_binary.obj .objs/cpu_device_msvc.obj .objs/common.obj \
54
+ .objs/examples.obj .objs/graph.obj .objs/pdist.obj
48
55
 
49
56
 
50
57
  $(LIBSTATIC):\
@@ -101,21 +108,53 @@ Makefile xndloops.c gumath.h
101
108
  Makefile xndloops.c gumath.h
102
109
  $(CC) "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c xndloops.c
103
110
 
104
- unary.obj:\
105
- Makefile kernels\unary.c gumath.h
106
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\unary.c
111
+ cpu_host_unary.obj:\
112
+ Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
113
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_unary.c
107
114
 
108
- .objs\unary.obj:\
109
- Makefile kernels\unary.c gumath.h
110
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\unary.c
115
+ .objs\cpu_host_unary.obj:\
116
+ Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
117
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_unary.c
111
118
 
112
- binary.obj:\
113
- Makefile kernels\binary.c gumath.h
114
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\binary.c
119
+ cpu_device_unary.obj:\
120
+ Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
121
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_unary.cc
115
122
 
116
- .objs\binary.obj:\
117
- Makefile kernels\binary.c gumath.h
118
- $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\binary.c
123
+ .objs\cpu_device_unary.obj:\
124
+ Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
125
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_unary.cc
126
+
127
+ cpu_host_binary.obj:\
128
+ Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
129
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_binary.c
130
+
131
+ .objs\cpu_host_binary.obj:\
132
+ Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
133
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_binary.c
134
+
135
+ cpu_device_binary.obj:\
136
+ Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
137
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_binary.cc
138
+
139
+ .objs\cpu_device_binary.obj:\
140
+ Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
141
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_binary.cc
142
+
143
+ cpu_device_msvc.obj:\
144
+ Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
145
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_NOFP) -c kernels\cpu_device_msvc.cc
146
+
147
+ .objs\cpu_device_msvc.obj:\
148
+ Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
149
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED_NOFP) -c kernels\cpu_device_msvc.cc
150
+
151
+ common.obj:\
152
+ Makefile kernels\common.c kernels\common.h gumath.h
153
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\common.c
154
+
155
+ .objs\common.obj:\
156
+ Makefile kernels\common.c kernels\common.h gumath.h
157
+ $(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\common.c
119
158
 
120
159
  examples.obj:\
121
160
  Makefile extending\examples.c gumath.h
@@ -40,6 +40,29 @@
40
40
  #include "gumath.h"
41
41
 
42
42
 
43
+ /* flags that apply to all arguments */
44
+ #define OPT_Z (NDT_EXT_ZERO|NDT_INNER_C)
45
+ #define OPT_C (NDT_EXT_C|NDT_INNER_C)
46
+ #define OPT_S (NDT_EXT_STRIDED|NDT_INNER_STRIDED)
47
+ #define OPT_SC (NDT_EXT_STRIDED|NDT_INNER_C)
48
+ #define OPT_SF (NDT_EXT_STRIDED|NDT_INNER_F)
49
+
50
+ #define INNER_C (NDT_INNER_C)
51
+ #define INNER_F (NDT_INNER_F)
52
+ #define INNER_S (NDT_INNER_STRIDED)
53
+ #define INNER_X (NDT_INNER_XND)
54
+
55
+ /* kernel requests */
56
+ #define REQ_LOOP_C(flags) ((flags&OPT_C) == OPT_C)
57
+ #define REQ_LOOP_Z(flags) ((flags&OPT_C) == OPT_C || (flags&OPT_Z) == OPT_Z)
58
+ #define REQ_LOOP_S(flags) ((flags&OPT_S) == OPT_S)
59
+
60
+ #define REQ_INNER_C(flags) ((flags&INNER_C) == INNER_C)
61
+ #define REQ_INNER_F(flags) ((flags&INNER_F) == INNER_F)
62
+ #define REQ_INNER_S(flags) ((flags&INNER_S) == INNER_S)
63
+ #define REQ_INNER_X(flags) ((flags&INNER_X) == INNER_X)
64
+
65
+
43
66
  static int
44
67
  sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
45
68
  {
@@ -55,6 +78,18 @@ sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
55
78
  return sum;
56
79
  }
57
80
 
81
+ static inline bool
82
+ opt_safe(int outer, ndt_context_t *ctx)
83
+ {
84
+ if (outer == 0) {
85
+ ndt_err_format(ctx, NDT_RuntimeError,
86
+ "internal error: optimized kernel called with outer_dims==0");
87
+ return false;
88
+ }
89
+
90
+ return true;
91
+ }
92
+
58
93
  int
59
94
  gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
60
95
  ndt_context_t *ctx)
@@ -62,30 +97,43 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
62
97
  const int nargs = (int)kernel->set->sig->Function.nargs;
63
98
 
64
99
  switch (kernel->flag) {
65
- case NDT_ELEMWISE_1D: {
66
- if (outer_dims == 0) {
67
- ndt_err_format(ctx, NDT_RuntimeError,
68
- "gm_xnd_map: optimized elementwise kernel called with "
69
- "outer_dims==0");
100
+ case OPT_C: {
101
+ if (!opt_safe(outer_dims, ctx)) {
70
102
  return -1;
71
103
  }
72
104
 
73
- return gm_xnd_map(kernel->set->Opt, stack, nargs, outer_dims-1, ctx);
105
+ return gm_xnd_map(kernel->set->OptC, stack, nargs, outer_dims-1, ctx);
74
106
  }
75
107
 
76
- case NDT_C: {
108
+ case OPT_Z: {
109
+ if (!opt_safe(outer_dims, ctx)) {
110
+ return -1;
111
+ }
112
+
113
+ return gm_xnd_map(kernel->set->OptZ, stack, nargs, outer_dims-1, ctx);
114
+ }
115
+
116
+ case OPT_S: {
117
+ if (!opt_safe(outer_dims, ctx)) {
118
+ return -1;
119
+ }
120
+
121
+ return gm_xnd_map(kernel->set->OptS, stack, nargs, outer_dims-1, ctx);
122
+ }
123
+
124
+ case INNER_C: {
77
125
  return gm_xnd_map(kernel->set->C, stack, nargs, outer_dims, ctx);
78
126
  }
79
127
 
80
- case NDT_FORTRAN: {
128
+ case INNER_F: {
81
129
  return gm_xnd_map(kernel->set->Fortran, stack, nargs, outer_dims, ctx);
82
130
  }
83
131
 
84
- case NDT_XND: {
132
+ case INNER_X: {
85
133
  return gm_xnd_map(kernel->set->Xnd, stack, nargs, outer_dims, ctx);
86
134
  }
87
135
 
88
- case NDT_STRIDED: {
136
+ case INNER_S: {
89
137
  const int sum_inner = sum_inner_dimensions(stack, nargs, outer_dims);
90
138
  const int dims_size = outer_dims + sum_inner;
91
139
  const int steps_size = nargs * outer_dims + sum_inner;
@@ -105,8 +153,9 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
105
153
  }
106
154
  }
107
155
 
108
- /* NOT REACHED: tags should be exhaustive. */
109
- ndt_internal_error("invalid tag");
156
+ ndt_err_format(ctx, NDT_RuntimeError,
157
+ "internal error: kernel selection failed");
158
+ return -1;
110
159
  }
111
160
 
112
161
  static gm_kernel_t
@@ -117,35 +166,48 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
117
166
 
118
167
  kernel.set = set;
119
168
 
120
- if (set->Opt != NULL && (spec->flags&NDT_ELEMWISE_1D)) {
121
- kernel.flag = NDT_ELEMWISE_1D;
169
+ if (REQ_LOOP_C(spec->flags) && set->OptC != NULL) {
170
+ kernel.flag = OPT_C;
171
+ return kernel;
172
+ }
173
+
174
+ if (REQ_LOOP_Z(spec->flags) && set->OptZ != NULL) {
175
+ kernel.flag = OPT_Z;
176
+ return kernel;
177
+ }
178
+
179
+ if (REQ_LOOP_S(spec->flags) && set->OptS != NULL) {
180
+ kernel.flag = OPT_S;
122
181
  return kernel;
123
182
  }
124
183
 
125
- if (set->C != NULL && (spec->flags&NDT_C)) {
126
- kernel.flag = NDT_C;
184
+ if (REQ_INNER_C(spec->flags) && set->C != NULL) {
185
+ kernel.flag = INNER_C;
127
186
  return kernel;
128
187
  }
129
188
 
130
- if (set->Fortran != NULL && (spec->flags&NDT_FORTRAN)) {
131
- kernel.flag = NDT_FORTRAN;
189
+ if (REQ_INNER_F(spec->flags) && set->Fortran != NULL) {
190
+ kernel.flag = INNER_F;
132
191
  return kernel;
133
192
  }
134
193
 
135
- if (set->Strided != NULL && (spec->flags&NDT_STRIDED)) {
136
- kernel.flag = NDT_STRIDED;
194
+ if (REQ_INNER_S(spec->flags) && set->Strided != NULL) {
195
+ kernel.flag = INNER_S;
137
196
  return kernel;
138
197
  }
139
198
 
140
- if (set->Xnd != NULL && (spec->flags&NDT_XND)) {
141
- kernel.flag = NDT_XND;
199
+ if (REQ_INNER_X(spec->flags) && set->Xnd != NULL) {
200
+ kernel.flag = INNER_X;
142
201
  return kernel;
143
202
  }
144
203
 
145
204
  kernel.set = NULL;
146
205
  ndt_err_format(ctx, NDT_RuntimeError,
147
- "could not find specialized kernel for '%s' input (available: %s, %s, %s, %s)",
206
+ "could not find specialized kernel for '%s' input (available: %s, %s, %s, %s, %s, %s, %s)",
148
207
  ndt_apply_flags_as_string(spec),
208
+ set->OptC ? "OptC" : "_",
209
+ set->OptZ ? "OptZ" : "_",
210
+ set->OptS ? "OptS" : "_",
149
211
  set->C ? "C" : "_",
150
212
  set->Fortran ? "Fortran" : "_",
151
213
  set->Xnd ? "Xnd" : "_",
@@ -157,8 +219,8 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
157
219
  /* Look up a multimethod by name and select a kernel. */
158
220
  gm_kernel_t
159
221
  gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
160
- const ndt_t *in_types[], int nin, const xnd_t args[],
161
- ndt_context_t *ctx)
222
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
223
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx)
162
224
  {
163
225
  gm_kernel_t empty_kernel = {0U, NULL};
164
226
  const gm_func_t *f;
@@ -171,7 +233,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
171
233
  }
172
234
 
173
235
  if (f->typecheck != NULL) {
174
- const gm_kernel_set_t *set = f->typecheck(spec, f, in_types, nin, ctx);
236
+ const gm_kernel_set_t *set = f->typecheck(spec, f, types, li, nin, nout,
237
+ check_broadcast, ctx);
175
238
  if (set == NULL) {
176
239
  return empty_kernel;
177
240
  }
@@ -180,7 +243,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
180
243
 
181
244
  for (i = 0; i < f->nkernels; i++) {
182
245
  const gm_kernel_set_t *set = &f->kernels[i];
183
- if (ndt_typecheck(spec, set->sig, in_types, nin, set->constraint, args,
246
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
247
+ check_broadcast, set->constraint, args,
184
248
  ctx) < 0) {
185
249
  ndt_err_clear(ctx);
186
250
  continue;
@@ -188,7 +252,7 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
188
252
  return select_kernel(spec, set, ctx);
189
253
  }
190
254
 
191
- s = ndt_list_as_string(in_types, nin, ctx);
255
+ s = ndt_list_as_string(types, nin, ctx);
192
256
  if (s == NULL) {
193
257
  return empty_kernel;
194
258
  }
@@ -181,9 +181,10 @@ static xnd_t
181
181
  mk_return_array(int32_t p[], const int64_t N, const int32_t u,
182
182
  ndt_context_t *ctx)
183
183
  {
184
- int32_t *ndim2_offsets = NULL;
185
- int32_t *ndim1_offsets = NULL;
186
- ndt_t *t;
184
+ ndt_offsets_t *ndim2_offsets = NULL;
185
+ ndt_offsets_t *ndim1_offsets = NULL;
186
+ int32_t *ptr;
187
+ const ndt_t *t, *type;
187
188
  int64_t sum;
188
189
  int32_t v;
189
190
 
@@ -191,45 +192,51 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
191
192
  goto offset_overflow;
192
193
  }
193
194
 
194
- ndim2_offsets = ndt_alloc(2, sizeof *ndim2_offsets);
195
+ ndim2_offsets = ndt_offsets_new(2, ctx);
195
196
  if (ndim2_offsets == NULL) {
196
- (void)ndt_memory_error(ctx);
197
197
  return xnd_error;
198
198
  }
199
- ndim2_offsets[0] = 0;
200
- ndim2_offsets[1] = (int32_t)N;
199
+ ptr = (int32_t *)ndim2_offsets->v;
200
+ ptr[0] = 0;
201
+ ptr[1] = (int32_t)N;
201
202
 
202
203
 
203
- ndim1_offsets = ndt_alloc(N+1, sizeof *ndim1_offsets);
204
+ ndim1_offsets = ndt_offsets_new((int32_t)(N+1), ctx);
204
205
  if (ndim1_offsets == NULL) {
205
- (void)ndt_memory_error(ctx);
206
+ ndt_decref_offsets(ndim2_offsets);
206
207
  return xnd_error;
207
208
  }
208
209
 
209
210
  sum = 0;
211
+ ptr = (int32_t *)ndim1_offsets->v;
210
212
  for (v = 0; v < N; v++) {
211
- ndim1_offsets[v] = (int32_t)sum;
213
+ ptr[v] = (int32_t)sum;
212
214
  int64_t n = write_path(NULL, 0, p, N, u, v);
213
215
  sum += n;
214
216
  if (sum > INT32_MAX) {
215
217
  goto offset_overflow;
216
218
  }
217
219
  }
218
- ndim1_offsets[v] = (int32_t)sum;
220
+ ptr[v] = (int32_t)sum;
219
221
 
220
222
 
221
- t = ndt_from_string("node", ctx);
222
- if (t == NULL) {
223
+ type = ndt_from_string("node", ctx);
224
+ if (type == NULL) {
223
225
  goto error;
224
226
  }
225
227
 
226
- t = ndt_var_dim(t, InternalOffsets, (int32_t)(N+1), ndim1_offsets, 0, NULL, ctx);
228
+ t = ndt_var_dim(type, ndim1_offsets, 0, NULL, false, ctx);
229
+ ndt_decref(type);
230
+ ndt_decref_offsets(ndim1_offsets);
227
231
  ndim1_offsets = NULL;
228
232
  if (t == NULL) {
229
233
  goto error;
230
234
  }
235
+ type = t;
231
236
 
232
- t = ndt_var_dim(t, InternalOffsets, 2, ndim2_offsets, 0, NULL, ctx);
237
+ t = ndt_var_dim(type, ndim2_offsets, 0, NULL, false, ctx);
238
+ ndt_decref(type);
239
+ ndt_decref_offsets(ndim2_offsets);
233
240
  ndim2_offsets = NULL;
234
241
  if (t == NULL) {
235
242
  goto error;
@@ -244,16 +251,16 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
244
251
 
245
252
  t = out.type->VarDim.type;
246
253
  for (v = 0; v < N; v++) {
247
- int32_t shape = t->Concrete.VarDim.offsets[v+1]-t->Concrete.VarDim.offsets[v];
248
- char *ptr = out.ptr + t->Concrete.VarDim.offsets[v] * t->Concrete.VarDim.itemsize;
249
- (void)write_path((int32_t *)ptr, shape, p, N, u, v);
254
+ int32_t shape = t->Concrete.VarDim.offsets->v[v+1]-t->Concrete.VarDim.offsets->v[v];
255
+ char *cp = out.ptr + t->Concrete.VarDim.offsets->v[v] * t->Concrete.VarDim.itemsize;
256
+ (void)write_path((int32_t *)cp, shape, p, N, u, v);
250
257
  }
251
258
 
252
259
  return out;
253
260
 
254
261
  error:
255
- ndt_free(ndim2_offsets);
256
- ndt_free(ndim1_offsets);
262
+ ndt_decref_offsets(ndim2_offsets);
263
+ ndt_decref_offsets(ndim1_offsets);
257
264
  return xnd_error;
258
265
 
259
266
  offset_overflow:
@@ -82,7 +82,7 @@ pdist(xnd_t stack[], ndt_context_t *ctx)
82
82
  * Validate N, M and compute unknown output dimension P.
83
83
  *
84
84
  * shape[0] = N
85
- * shape[1] = N
85
+ * shape[1] = M
86
86
  * shape[2] = P
87
87
  *
88
88
  * 'args' is unused here. Other functions may inspect the incoming xnd
@@ -73,7 +73,7 @@ gm_func_del(gm_func_t *f)
73
73
  ndt_free(f->name);
74
74
 
75
75
  for (int i = 0; i < f->nkernels; i++) {
76
- ndt_del(f->kernels[i].sig);
76
+ ndt_decref(f->kernels[i].sig);
77
77
  }
78
78
 
79
79
  ndt_free(f);
@@ -101,7 +101,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
101
101
  {
102
102
  gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
103
103
  gm_kernel_set_t kernel;
104
- ndt_t *t;
104
+ const ndt_t *t;
105
105
 
106
106
  if (f == NULL) {
107
107
  ndt_err_clear(ctx);
@@ -117,7 +117,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
117
117
  }
118
118
 
119
119
  if (f->nkernels == GM_MAX_KERNELS) {
120
- ndt_del(t);
120
+ ndt_decref(t);
121
121
  ndt_err_format(ctx, NDT_RuntimeError,
122
122
  "%s: maximum number of kernels reached for", f->name);
123
123
  return -1;
@@ -125,11 +125,13 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
125
125
 
126
126
  kernel.sig = t;
127
127
  kernel.constraint = k->constraint;
128
- kernel.Opt = k->Opt;
128
+ kernel.OptC = k->OptC;
129
+ kernel.OptZ = k->OptZ;
130
+ kernel.OptS = k->OptS;
129
131
  kernel.C = k->C;
130
132
  kernel.Fortran = k->Fortran;
131
- kernel.Strided = k->Strided;
132
133
  kernel.Xnd = k->Xnd;
134
+ kernel.Strided = k->Strided;
133
135
 
134
136
  f->kernels[f->nkernels++] = kernel;
135
137
  return 0;
@@ -141,7 +143,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
141
143
  {
142
144
  gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
143
145
  gm_kernel_set_t kernel;
144
- ndt_t *t;
146
+ const ndt_t *t;
145
147
 
146
148
  if (f == NULL) {
147
149
  ndt_err_clear(ctx);
@@ -158,7 +160,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
158
160
  }
159
161
 
160
162
  if (f->nkernels == GM_MAX_KERNELS) {
161
- ndt_del(t);
163
+ ndt_decref(t);
162
164
  ndt_err_format(ctx, NDT_RuntimeError,
163
165
  "%s: maximum number of kernels reached for", f->name);
164
166
  return -1;
@@ -166,11 +168,13 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
166
168
 
167
169
  kernel.sig = t;
168
170
  kernel.constraint = k->constraint;
169
- kernel.Opt = k->Opt;
171
+ kernel.OptC = k->OptC;
172
+ kernel.OptZ = k->OptZ;
173
+ kernel.OptS = k->OptS;
170
174
  kernel.C = k->C;
171
175
  kernel.Fortran = k->Fortran;
172
- kernel.Strided = k->Strided;
173
176
  kernel.Xnd = k->Xnd;
177
+ kernel.Strided = k->Strided;
174
178
 
175
179
  f->kernels[f->nkernels++] = kernel;
176
180
  return 0;