gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -34,6 +34,17 @@
34
34
  #ifndef GUMATH_H
35
35
  #define GUMATH_H
36
36
 
37
+
38
+ #ifdef __cplusplus
39
+ extern "C" {
40
+ #endif
41
+
42
+ #ifdef __cplusplus
43
+ #include <cstdint>
44
+ #else
45
+ #include <stdint.h>
46
+ #endif
47
+
37
48
  #include "ndtypes.h"
38
49
  #include "xnd.h"
39
50
 
@@ -65,7 +76,8 @@
65
76
  #endif
66
77
 
67
78
 
68
- #define GM_MAX_KERNELS 512
79
+ #define GM_MAX_KERNELS 8192
80
+ #define GM_THREAD_CUTOFF 1000000
69
81
 
70
82
  typedef float float32_t;
71
83
  typedef double float64_t;
@@ -74,15 +86,25 @@ typedef double float64_t;
74
86
  typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
75
87
  typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
76
88
 
77
- /* Collection of specialized kernels for a single function signature. */
89
+ /*
90
+ * Collection of specialized kernels for a single function signature.
91
+ *
92
+ * NOTE: The specialized kernel lookup scheme is transitional and may
93
+ * be replaced by something else.
94
+ *
95
+ * This should be considered as a first version of a kernel request
96
+ * protocol.
97
+ */
78
98
  typedef struct {
79
- ndt_t *sig;
99
+ const ndt_t *sig;
80
100
  const ndt_constraint_t *constraint;
81
101
 
82
102
  /* Xnd signatures */
83
- gm_xnd_kernel_t Opt; /* dispatch ensures elementwise, at least 1D, contiguous in last dimensions */
84
- gm_xnd_kernel_t C; /* dispatch ensures c-contiguous in inner dimensions */
85
- gm_xnd_kernel_t Fortran; /* dispatch ensures f-contiguous in inner dimensions */
103
+ gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
104
+ gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
105
+ gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
106
+ gm_xnd_kernel_t C; /* C in inner dimensions */
107
+ gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
86
108
  gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
87
109
 
88
110
  /* NumPy signature */
@@ -99,11 +121,17 @@ typedef struct {
99
121
  const char *name;
100
122
  const char *sig;
101
123
  const ndt_constraint_t *constraint;
124
+ uint32_t cap;
102
125
 
103
- gm_xnd_kernel_t Opt;
126
+ /* Xnd signatures */
127
+ gm_xnd_kernel_t OptC;
128
+ gm_xnd_kernel_t OptZ;
129
+ gm_xnd_kernel_t OptS;
104
130
  gm_xnd_kernel_t C;
105
131
  gm_xnd_kernel_t Fortran;
106
132
  gm_xnd_kernel_t Xnd;
133
+
134
+ /* NumPy signature */
107
135
  gm_strided_kernel_t Strided;
108
136
  } gm_kernel_init_t;
109
137
 
@@ -115,7 +143,10 @@ typedef struct {
115
143
 
116
144
  /* Multimethod with associated kernels */
117
145
  typedef struct gm_func gm_func_t;
118
- typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *in[], int nin, ndt_context_t *ctx);
146
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
147
+ const ndt_t *in[], const int64_t li[],
148
+ int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
119
150
  struct gm_func {
120
151
  char *name;
121
152
  gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
139
170
  GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
140
171
 
141
172
  GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
142
- const ndt_t *in_types[], int nin, const xnd_t args[],
143
- ndt_context_t *ctx);
173
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
174
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
144
175
  GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
145
- GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, uint32_t flags, const int64_t nthreads, ndt_context_t *ctx);
176
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
146
177
 
147
178
 
148
179
  /******************************************************************************/
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
171
202
  /* Xnd loops */
172
203
  /******************************************************************************/
173
204
 
205
+ GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
174
206
  GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
175
207
  const int outer_dims, ndt_context_t *ctx);
176
208
 
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
191
223
  /******************************************************************************/
192
224
 
193
225
  GM_API void gm_init(void);
194
- GM_API int gm_init_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
195
- GM_API int gm_init_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
226
+ GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
227
+ GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
228
+ GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
229
+
230
+ GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
231
+ GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
232
+
196
233
  GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
197
- GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
198
234
  GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
199
235
  GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
200
236
  GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
202
238
  GM_API void gm_finalize(void);
203
239
 
204
240
 
241
+ #ifdef __cplusplus
242
+ } /* END extern "C" */
243
+ #endif
244
+
245
+
205
246
  #endif /* GUMATH_H */
@@ -0,0 +1,513 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include <math.h>
38
+ #include <complex.h>
39
+ #include <inttypes.h>
40
+ #include "ndtypes.h"
41
+ #include "xnd.h"
42
+ #include "gumath.h"
43
+ #include "common.h"
44
+
45
+
46
+ /****************************************************************************/
47
+ /* Unary bitmap kernels */
48
+ /****************************************************************************/
49
+
50
+ void
51
+ unary_update_bitmap_1D_S(xnd_t stack[])
52
+ {
53
+ const int64_t N = xnd_fixed_shape(&stack[0]);
54
+ const int64_t li0 = stack[0].index;
55
+ const int64_t li1 = stack[1].index;
56
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
57
+ const int64_t s1 = xnd_fixed_step(&stack[1]);
58
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
59
+ uint8_t *b1 = get_bitmap1D(&stack[1]);
60
+ int64_t i, k0, k1;
61
+
62
+ assert(b0 != NULL);
63
+ assert(b1 != NULL);
64
+
65
+ for (i=0, k0=li0, k1=li1; i<N; i++, k0+=s0, k1+=s1) {
66
+ bool x = is_valid(b0, k0);
67
+ set_bit(b1, k1, x);
68
+ }
69
+ }
70
+
71
+ void
72
+ unary_reduce_bitmap_1D_S(xnd_t stack[])
73
+ {
74
+ const int64_t N = xnd_fixed_shape(&stack[0]);
75
+ const int64_t li0 = stack[0].index;
76
+ const int64_t li1 = stack[1].index;
77
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
78
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
79
+ uint8_t *b1 = get_bitmap(&stack[1]);
80
+ int64_t i, k0;
81
+
82
+ assert(b0 != NULL);
83
+ assert(b1 != NULL);
84
+
85
+ for (i=0, k0=li0; i<N; i++, k0+=s0) {
86
+ bool x = is_valid(b0, k0) && is_valid(b1, li1);
87
+ set_bit(b1, li1, x);
88
+ }
89
+ }
90
+
91
+ void
92
+ unary_update_bitmap_0D(xnd_t stack[])
93
+ {
94
+ const int64_t li0 = stack[0].index;
95
+ const int64_t li1 = stack[1].index;
96
+ const uint8_t *b0 = get_bitmap(&stack[0]);
97
+ uint8_t *b1 = get_bitmap(&stack[1]);
98
+
99
+ assert(b0 != NULL);
100
+ assert(b1 != NULL);
101
+
102
+ bool x = is_valid(b0, li0);
103
+ set_bit(b1, li1, x);
104
+ }
105
+
106
+
107
+ /****************************************************************************/
108
+ /* Binary bitmap kernels */
109
+ /****************************************************************************/
110
+
111
+ void
112
+ binary_update_bitmap_1D_S(xnd_t stack[])
113
+ {
114
+ const int64_t N = xnd_fixed_shape(&stack[0]);
115
+ const int64_t li0 = stack[0].index;
116
+ const int64_t li1 = stack[1].index;
117
+ const int64_t li2 = stack[2].index;
118
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
119
+ const int64_t s1 = xnd_fixed_step(&stack[1]);
120
+ const int64_t s2 = xnd_fixed_step(&stack[2]);
121
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
122
+ const uint8_t *b1 = get_bitmap1D(&stack[1]);
123
+ uint8_t *b2 = get_bitmap1D(&stack[2]);
124
+ int64_t i, k0, k1, k2;
125
+
126
+ if (b0 && b1) {
127
+ for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
128
+ bool x = is_valid(b0, k0) && is_valid(b1, k1);
129
+ set_bit(b2, k2, x);
130
+ }
131
+ }
132
+ else if (b0) {
133
+ for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
134
+ bool x = is_valid(b0, k0);
135
+ set_bit(b2, k2, x);
136
+ }
137
+ }
138
+ else if (b1) {
139
+ for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
140
+ bool x = is_valid(b1, k1);
141
+ set_bit(b2, k2, x);
142
+ }
143
+ }
144
+ }
145
+
146
+ void
147
+ binary_update_bitmap_0D(xnd_t stack[])
148
+ {
149
+ const int64_t li0 = stack[0].index;
150
+ const int64_t li1 = stack[1].index;
151
+ const int64_t li2 = stack[2].index;
152
+ const uint8_t *b0 = get_bitmap(&stack[0]);
153
+ const uint8_t *b1 = get_bitmap(&stack[1]);
154
+ uint8_t *b2 = get_bitmap(&stack[2]);
155
+
156
+ assert(b2 != NULL);
157
+
158
+ if (b0 && b1) {
159
+ bool x = is_valid(b0, li0) && is_valid(b1, li1);
160
+ set_bit(b2, li2, x);
161
+ }
162
+ else if (b0) {
163
+ bool x = is_valid(b0, li0);
164
+ set_bit(b2, li2, x);
165
+ }
166
+ else if (b1) {
167
+ bool x = is_valid(b1, li1);
168
+ set_bit(b2, li2, x);
169
+ }
170
+ }
171
+
172
+ void
173
+ binary_update_bitmap_1D_S_bool(xnd_t stack[])
174
+ {
175
+ const int64_t N = xnd_fixed_shape(&stack[0]);
176
+ const int64_t li0 = stack[0].index;
177
+ const int64_t li1 = stack[1].index;
178
+ const int64_t li2 = stack[2].index;
179
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
180
+ const int64_t s1 = xnd_fixed_step(&stack[1]);
181
+ const int64_t s2 = xnd_fixed_step(&stack[2]);
182
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
183
+ const uint8_t *b1 = get_bitmap1D(&stack[1]);
184
+ bool *x2 = (bool *)apply_index(&stack[2]);
185
+ int64_t i, k0, k1, k2;
186
+
187
+ assert(!ndt_is_optional(stack[2].type));
188
+
189
+ if (b0 && b1) {
190
+ for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
191
+ bool x = is_valid(b0, k0);
192
+ bool y = is_valid(b1, k1);
193
+ bool z = x2[k2];
194
+ z = x && y ? z : !x && !y;
195
+ x2[k2] = z;
196
+ }
197
+ }
198
+ else if (b0) {
199
+ for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
200
+ bool x = is_valid(b0, k0);
201
+ bool z = x2[k2];
202
+ z = x ? z : x;
203
+ x2[k2] = z;
204
+ }
205
+ }
206
+ else if (b1) {
207
+ for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
208
+ bool x = is_valid(b1, k1);
209
+ bool z = x2[k2];
210
+ z = x ? z : x;
211
+ x2[k2] = z;
212
+ }
213
+ }
214
+ }
215
+
216
+ void
217
+ binary_update_bitmap_0D_bool(xnd_t stack[])
218
+ {
219
+ const int64_t li0 = stack[0].index;
220
+ const int64_t li1 = stack[1].index;
221
+ const int64_t li2 = stack[2].index;
222
+ const uint8_t *b0 = get_bitmap(&stack[0]);
223
+ const uint8_t *b1 = get_bitmap(&stack[1]);
224
+ bool *x2 = (bool *)stack[2].ptr;
225
+
226
+ assert(!ndt_is_optional(stack[2].type));
227
+
228
+ if (b0 && b1) {
229
+ bool x = is_valid(b0, li0);
230
+ bool y = is_valid(b1, li1);
231
+ bool z = x2[li2];
232
+ z = x && y ? z : !x && !y;
233
+ x2[li2] = z;
234
+ }
235
+ else if (b0) {
236
+ bool x = is_valid(b0, li0);
237
+ bool z = x2[li2];
238
+ z = x ? z : x;
239
+ x2[li2] = z;
240
+ }
241
+ else if (b1) {
242
+ bool x = is_valid(b1, li1);
243
+ bool z = x2[li2];
244
+ z = x ? z : x;
245
+ x2[li2] = z;
246
+ }
247
+ }
248
+
249
+
250
+ /****************************************************************************/
251
+ /* Optimized unary typecheck */
252
+ /****************************************************************************/
253
+
254
+ const gm_kernel_set_t *
255
+ cpu_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
256
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
257
+ const int64_t li[], int nin, int nout, bool check_broadcast,
258
+ ndt_context_t *ctx)
259
+ {
260
+ const ndt_t *t;
261
+ const ndt_t *u;
262
+ int n;
263
+
264
+ assert(spec->flags == 0);
265
+ assert(spec->outer_dims == 0);
266
+ assert(spec->nin == 0);
267
+ assert(spec->nout == 0);
268
+ assert(spec->nargs == 0);
269
+
270
+ if (nin != 1) {
271
+ ndt_err_format(ctx, NDT_ValueError,
272
+ "invalid number of arguments for %s(x): expected 1, got %d",
273
+ f->name, nin);
274
+ return NULL;
275
+
276
+ }
277
+
278
+ t = types[0];
279
+
280
+ if (nout) {
281
+ if (nout != 1) {
282
+ ndt_err_format(ctx, NDT_ValueError,
283
+ "%s(x) expects at most one 'out' argument, got %d",
284
+ f->name, nout);
285
+ return NULL;
286
+ }
287
+ u = types[1];
288
+ }
289
+ else {
290
+ u = types[0];
291
+ }
292
+
293
+ assert(ndt_is_concrete(t));
294
+ assert(ndt_is_concrete(u));
295
+
296
+ n = kernel_location(t, u, ctx);
297
+ if (n < 0) {
298
+ return NULL;
299
+ }
300
+ if (ndt_is_optional(ndt_dtype(t))) {
301
+ n++;
302
+ }
303
+
304
+ if (t->tag == VarDim || t->tag == VarDimElem) {
305
+ const gm_kernel_set_t *set = &f->kernels[n+2];
306
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
307
+ check_broadcast, NULL, NULL, ctx) < 0) {
308
+ return NULL;
309
+ }
310
+ return set;
311
+ }
312
+
313
+ if (t->tag == Array) {
314
+ const gm_kernel_set_t *set = &f->kernels[n+4];
315
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
316
+ check_broadcast, NULL, NULL, ctx) < 0) {
317
+ return NULL;
318
+ }
319
+ return set;
320
+ }
321
+
322
+ const gm_kernel_set_t *set = &f->kernels[n];
323
+
324
+ if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
325
+ check_broadcast, ctx) < 0) {
326
+ return NULL;
327
+ }
328
+
329
+ return set;
330
+ }
331
+
332
+ const gm_kernel_set_t *
333
+ cuda_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
334
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
335
+ const int64_t li[], int nin, int nout, bool check_broadcast,
336
+ ndt_context_t *ctx)
337
+ {
338
+ const ndt_t *t;
339
+ const ndt_t *u;
340
+ int n;
341
+ (void)li;
342
+
343
+ assert(spec->flags == 0);
344
+ assert(spec->outer_dims == 0);
345
+ assert(spec->nin == 0);
346
+ assert(spec->nout == 0);
347
+ assert(spec->nargs == 0);
348
+
349
+ if (nin != 1) {
350
+ ndt_err_format(ctx, NDT_ValueError,
351
+ "invalid number of arguments for %s(x): expected 1, got %d",
352
+ f->name, nin);
353
+ return NULL;
354
+ }
355
+
356
+ t = types[0];
357
+
358
+ if (nout) {
359
+ if (nout != 1) {
360
+ ndt_err_format(ctx, NDT_ValueError,
361
+ "%s(x) expects at most one 'out' argument, got %d",
362
+ f->name, nout);
363
+ return NULL;
364
+ }
365
+ u = types[1];
366
+ }
367
+ else {
368
+ u = types[0];
369
+ }
370
+
371
+ assert(ndt_is_concrete(t));
372
+ assert(ndt_is_concrete(u));
373
+
374
+ n = kernel_location(t, u, ctx);
375
+ if (n < 0) {
376
+ return NULL;
377
+ }
378
+ if (ndt_is_optional(ndt_dtype(t))) {
379
+ n++;
380
+ }
381
+
382
+ const gm_kernel_set_t *set = &f->kernels[n];
383
+
384
+ if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
385
+ check_broadcast, ctx) < 0) {
386
+ return NULL;
387
+ }
388
+
389
+ return set;
390
+ }
391
+
392
+
393
+ /****************************************************************************/
394
+ /* Optimized binary typecheck */
395
+ /****************************************************************************/
396
+
397
+ const gm_kernel_set_t *
398
+ cpu_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
399
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
400
+ const int64_t li[], int nin, int nout, bool check_broadcast,
401
+ ndt_context_t *ctx)
402
+ {
403
+ const ndt_t *t0;
404
+ const ndt_t *t1;
405
+ int n;
406
+
407
+ assert(spec->flags == 0);
408
+ assert(spec->outer_dims == 0);
409
+ assert(spec->nin == 0);
410
+ assert(spec->nout == 0);
411
+ assert(spec->nargs == 0);
412
+
413
+ if (nin != 2) {
414
+ ndt_err_format(ctx, NDT_ValueError,
415
+ "invalid number of arguments for %s(x, y): expected 2, got %d",
416
+ f->name, nin);
417
+ return NULL;
418
+ }
419
+
420
+ t0 = types[0];
421
+ t1 = types[1];
422
+ assert(ndt_is_concrete(t0));
423
+ assert(ndt_is_concrete(t1));
424
+
425
+ n = kernel_location(t0, t1, ctx);
426
+ if (n < 0) {
427
+ return NULL;
428
+ }
429
+ if (ndt_is_optional(ndt_dtype(t0))) {
430
+ n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
431
+ }
432
+ else if (ndt_is_optional(ndt_dtype(t1))) {
433
+ n = n+2;
434
+ }
435
+
436
+ if (t0->tag == VarDim || t0->tag == VarDimElem ||
437
+ t1->tag == VarDim || t1->tag == VarDimElem) {
438
+ const gm_kernel_set_t *set = &f->kernels[n+4];
439
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
440
+ check_broadcast, NULL, NULL, ctx) < 0) {
441
+ return NULL;
442
+ }
443
+ return set;
444
+ }
445
+
446
+ if (t0->tag == Array || t1->tag == Array) {
447
+ const gm_kernel_set_t *set = &f->kernels[n+8];
448
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
449
+ check_broadcast, NULL, NULL, ctx) < 0) {
450
+ return NULL;
451
+ }
452
+ return set;
453
+ }
454
+
455
+ const gm_kernel_set_t *set = &f->kernels[n];
456
+
457
+ if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
458
+ check_broadcast, ctx) < 0) {
459
+ return NULL;
460
+ }
461
+
462
+ return set;
463
+ }
464
+
465
+ const gm_kernel_set_t *
466
+ cuda_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
467
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
468
+ const int64_t li[], int nin, int nout, bool check_broadcast,
469
+ ndt_context_t *ctx)
470
+ {
471
+ const ndt_t *t0;
472
+ const ndt_t *t1;
473
+ int n;
474
+ (void)li;
475
+
476
+ assert(spec->flags == 0);
477
+ assert(spec->outer_dims == 0);
478
+ assert(spec->nin == 0);
479
+ assert(spec->nout == 0);
480
+ assert(spec->nargs == 0);
481
+
482
+ if (nin != 2) {
483
+ ndt_err_format(ctx, NDT_ValueError,
484
+ "invalid number of arguments for %s(x, y): expected 2, got %d",
485
+ f->name, nin);
486
+ return NULL;
487
+ }
488
+
489
+ t0 = types[0];
490
+ t1 = types[1];
491
+ assert(ndt_is_concrete(t0));
492
+ assert(ndt_is_concrete(t1));
493
+
494
+ n = kernel_location(t0, t1, ctx);
495
+ if (n < 0) {
496
+ return NULL;
497
+ }
498
+ if (ndt_is_optional(ndt_dtype(t0))) {
499
+ n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
500
+ }
501
+ else if (ndt_is_optional(ndt_dtype(t1))) {
502
+ n = n+2;
503
+ }
504
+
505
+ const gm_kernel_set_t *set = &f->kernels[n];
506
+
507
+ if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
508
+ check_broadcast, ctx) < 0) {
509
+ return NULL;
510
+ }
511
+
512
+ return set;
513
+ }