gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -34,6 +34,17 @@
34
34
  #ifndef GUMATH_H
35
35
  #define GUMATH_H
36
36
 
37
+
38
+ #ifdef __cplusplus
39
+ extern "C" {
40
+ #endif
41
+
42
+ #ifdef __cplusplus
43
+ #include <cstdint>
44
+ #else
45
+ #include <stdint.h>
46
+ #endif
47
+
37
48
  #include "ndtypes.h"
38
49
  #include "xnd.h"
39
50
 
@@ -65,7 +76,8 @@
65
76
  #endif
66
77
 
67
78
 
68
- #define GM_MAX_KERNELS 512
79
+ #define GM_MAX_KERNELS 8192
80
+ #define GM_THREAD_CUTOFF 1000000
69
81
 
70
82
  typedef float float32_t;
71
83
  typedef double float64_t;
@@ -74,15 +86,25 @@ typedef double float64_t;
74
86
  typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
75
87
  typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
76
88
 
77
- /* Collection of specialized kernels for a single function signature. */
89
+ /*
90
+ * Collection of specialized kernels for a single function signature.
91
+ *
92
+ * NOTE: The specialized kernel lookup scheme is transitional and may
93
+ * be replaced by something else.
94
+ *
95
+ * This should be considered as a first version of a kernel request
96
+ * protocol.
97
+ */
78
98
  typedef struct {
79
- ndt_t *sig;
99
+ const ndt_t *sig;
80
100
  const ndt_constraint_t *constraint;
81
101
 
82
102
  /* Xnd signatures */
83
- gm_xnd_kernel_t Opt; /* dispatch ensures elementwise, at least 1D, contiguous in last dimensions */
84
- gm_xnd_kernel_t C; /* dispatch ensures c-contiguous in inner dimensions */
85
- gm_xnd_kernel_t Fortran; /* dispatch ensures f-contiguous in inner dimensions */
103
+ gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
104
+ gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
105
+ gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
106
+ gm_xnd_kernel_t C; /* C in inner dimensions */
107
+ gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
86
108
  gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
87
109
 
88
110
  /* NumPy signature */
@@ -99,11 +121,17 @@ typedef struct {
99
121
  const char *name;
100
122
  const char *sig;
101
123
  const ndt_constraint_t *constraint;
124
+ uint32_t cap;
102
125
 
103
- gm_xnd_kernel_t Opt;
126
+ /* Xnd signatures */
127
+ gm_xnd_kernel_t OptC;
128
+ gm_xnd_kernel_t OptZ;
129
+ gm_xnd_kernel_t OptS;
104
130
  gm_xnd_kernel_t C;
105
131
  gm_xnd_kernel_t Fortran;
106
132
  gm_xnd_kernel_t Xnd;
133
+
134
+ /* NumPy signature */
107
135
  gm_strided_kernel_t Strided;
108
136
  } gm_kernel_init_t;
109
137
 
@@ -115,7 +143,10 @@ typedef struct {
115
143
 
116
144
  /* Multimethod with associated kernels */
117
145
  typedef struct gm_func gm_func_t;
118
- typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *in[], int nin, ndt_context_t *ctx);
146
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
147
+ const ndt_t *in[], const int64_t li[],
148
+ int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
119
150
  struct gm_func {
120
151
  char *name;
121
152
  gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
139
170
  GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
140
171
 
141
172
  GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
142
- const ndt_t *in_types[], int nin, const xnd_t args[],
143
- ndt_context_t *ctx);
173
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
174
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
144
175
  GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
145
- GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, uint32_t flags, const int64_t nthreads, ndt_context_t *ctx);
176
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
146
177
 
147
178
 
148
179
  /******************************************************************************/
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
171
202
  /* Xnd loops */
172
203
  /******************************************************************************/
173
204
 
205
+ GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
174
206
  GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
175
207
  const int outer_dims, ndt_context_t *ctx);
176
208
 
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
191
223
  /******************************************************************************/
192
224
 
193
225
  GM_API void gm_init(void);
194
- GM_API int gm_init_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
195
- GM_API int gm_init_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
226
+ GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
227
+ GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
228
+ GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
229
+
230
+ GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
231
+ GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
232
+
196
233
  GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
197
- GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
198
234
  GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
199
235
  GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
200
236
  GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
202
238
  GM_API void gm_finalize(void);
203
239
 
204
240
 
241
+ #ifdef __cplusplus
242
+ } /* END extern "C" */
243
+ #endif
244
+
245
+
205
246
  #endif /* GUMATH_H */
@@ -0,0 +1,513 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include <math.h>
38
+ #include <complex.h>
39
+ #include <inttypes.h>
40
+ #include "ndtypes.h"
41
+ #include "xnd.h"
42
+ #include "gumath.h"
43
+ #include "common.h"
44
+
45
+
46
+ /****************************************************************************/
47
+ /* Unary bitmap kernels */
48
+ /****************************************************************************/
49
+
50
+ void
51
+ unary_update_bitmap_1D_S(xnd_t stack[])
52
+ {
53
+ const int64_t N = xnd_fixed_shape(&stack[0]);
54
+ const int64_t li0 = stack[0].index;
55
+ const int64_t li1 = stack[1].index;
56
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
57
+ const int64_t s1 = xnd_fixed_step(&stack[1]);
58
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
59
+ uint8_t *b1 = get_bitmap1D(&stack[1]);
60
+ int64_t i, k0, k1;
61
+
62
+ assert(b0 != NULL);
63
+ assert(b1 != NULL);
64
+
65
+ for (i=0, k0=li0, k1=li1; i<N; i++, k0+=s0, k1+=s1) {
66
+ bool x = is_valid(b0, k0);
67
+ set_bit(b1, k1, x);
68
+ }
69
+ }
70
+
71
+ void
72
+ unary_reduce_bitmap_1D_S(xnd_t stack[])
73
+ {
74
+ const int64_t N = xnd_fixed_shape(&stack[0]);
75
+ const int64_t li0 = stack[0].index;
76
+ const int64_t li1 = stack[1].index;
77
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
78
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
79
+ uint8_t *b1 = get_bitmap(&stack[1]);
80
+ int64_t i, k0;
81
+
82
+ assert(b0 != NULL);
83
+ assert(b1 != NULL);
84
+
85
+ for (i=0, k0=li0; i<N; i++, k0+=s0) {
86
+ bool x = is_valid(b0, k0) && is_valid(b1, li1);
87
+ set_bit(b1, li1, x);
88
+ }
89
+ }
90
+
91
+ void
92
+ unary_update_bitmap_0D(xnd_t stack[])
93
+ {
94
+ const int64_t li0 = stack[0].index;
95
+ const int64_t li1 = stack[1].index;
96
+ const uint8_t *b0 = get_bitmap(&stack[0]);
97
+ uint8_t *b1 = get_bitmap(&stack[1]);
98
+
99
+ assert(b0 != NULL);
100
+ assert(b1 != NULL);
101
+
102
+ bool x = is_valid(b0, li0);
103
+ set_bit(b1, li1, x);
104
+ }
105
+
106
+
107
+ /****************************************************************************/
108
+ /* Binary bitmap kernels */
109
+ /****************************************************************************/
110
+
111
+ void
112
+ binary_update_bitmap_1D_S(xnd_t stack[])
113
+ {
114
+ const int64_t N = xnd_fixed_shape(&stack[0]);
115
+ const int64_t li0 = stack[0].index;
116
+ const int64_t li1 = stack[1].index;
117
+ const int64_t li2 = stack[2].index;
118
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
119
+ const int64_t s1 = xnd_fixed_step(&stack[1]);
120
+ const int64_t s2 = xnd_fixed_step(&stack[2]);
121
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
122
+ const uint8_t *b1 = get_bitmap1D(&stack[1]);
123
+ uint8_t *b2 = get_bitmap1D(&stack[2]);
124
+ int64_t i, k0, k1, k2;
125
+
126
+ if (b0 && b1) {
127
+ for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
128
+ bool x = is_valid(b0, k0) && is_valid(b1, k1);
129
+ set_bit(b2, k2, x);
130
+ }
131
+ }
132
+ else if (b0) {
133
+ for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
134
+ bool x = is_valid(b0, k0);
135
+ set_bit(b2, k2, x);
136
+ }
137
+ }
138
+ else if (b1) {
139
+ for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
140
+ bool x = is_valid(b1, k1);
141
+ set_bit(b2, k2, x);
142
+ }
143
+ }
144
+ }
145
+
146
+ void
147
+ binary_update_bitmap_0D(xnd_t stack[])
148
+ {
149
+ const int64_t li0 = stack[0].index;
150
+ const int64_t li1 = stack[1].index;
151
+ const int64_t li2 = stack[2].index;
152
+ const uint8_t *b0 = get_bitmap(&stack[0]);
153
+ const uint8_t *b1 = get_bitmap(&stack[1]);
154
+ uint8_t *b2 = get_bitmap(&stack[2]);
155
+
156
+ assert(b2 != NULL);
157
+
158
+ if (b0 && b1) {
159
+ bool x = is_valid(b0, li0) && is_valid(b1, li1);
160
+ set_bit(b2, li2, x);
161
+ }
162
+ else if (b0) {
163
+ bool x = is_valid(b0, li0);
164
+ set_bit(b2, li2, x);
165
+ }
166
+ else if (b1) {
167
+ bool x = is_valid(b1, li1);
168
+ set_bit(b2, li2, x);
169
+ }
170
+ }
171
+
172
+ void
173
+ binary_update_bitmap_1D_S_bool(xnd_t stack[])
174
+ {
175
+ const int64_t N = xnd_fixed_shape(&stack[0]);
176
+ const int64_t li0 = stack[0].index;
177
+ const int64_t li1 = stack[1].index;
178
+ const int64_t li2 = stack[2].index;
179
+ const int64_t s0 = xnd_fixed_step(&stack[0]);
180
+ const int64_t s1 = xnd_fixed_step(&stack[1]);
181
+ const int64_t s2 = xnd_fixed_step(&stack[2]);
182
+ const uint8_t *b0 = get_bitmap1D(&stack[0]);
183
+ const uint8_t *b1 = get_bitmap1D(&stack[1]);
184
+ bool *x2 = (bool *)apply_index(&stack[2]);
185
+ int64_t i, k0, k1, k2;
186
+
187
+ assert(!ndt_is_optional(stack[2].type));
188
+
189
+ if (b0 && b1) {
190
+ for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
191
+ bool x = is_valid(b0, k0);
192
+ bool y = is_valid(b1, k1);
193
+ bool z = x2[k2];
194
+ z = x && y ? z : !x && !y;
195
+ x2[k2] = z;
196
+ }
197
+ }
198
+ else if (b0) {
199
+ for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
200
+ bool x = is_valid(b0, k0);
201
+ bool z = x2[k2];
202
+ z = x ? z : x;
203
+ x2[k2] = z;
204
+ }
205
+ }
206
+ else if (b1) {
207
+ for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
208
+ bool x = is_valid(b1, k1);
209
+ bool z = x2[k2];
210
+ z = x ? z : x;
211
+ x2[k2] = z;
212
+ }
213
+ }
214
+ }
215
+
216
+ void
217
+ binary_update_bitmap_0D_bool(xnd_t stack[])
218
+ {
219
+ const int64_t li0 = stack[0].index;
220
+ const int64_t li1 = stack[1].index;
221
+ const int64_t li2 = stack[2].index;
222
+ const uint8_t *b0 = get_bitmap(&stack[0]);
223
+ const uint8_t *b1 = get_bitmap(&stack[1]);
224
+ bool *x2 = (bool *)stack[2].ptr;
225
+
226
+ assert(!ndt_is_optional(stack[2].type));
227
+
228
+ if (b0 && b1) {
229
+ bool x = is_valid(b0, li0);
230
+ bool y = is_valid(b1, li1);
231
+ bool z = x2[li2];
232
+ z = x && y ? z : !x && !y;
233
+ x2[li2] = z;
234
+ }
235
+ else if (b0) {
236
+ bool x = is_valid(b0, li0);
237
+ bool z = x2[li2];
238
+ z = x ? z : x;
239
+ x2[li2] = z;
240
+ }
241
+ else if (b1) {
242
+ bool x = is_valid(b1, li1);
243
+ bool z = x2[li2];
244
+ z = x ? z : x;
245
+ x2[li2] = z;
246
+ }
247
+ }
248
+
249
+
250
+ /****************************************************************************/
251
+ /* Optimized unary typecheck */
252
+ /****************************************************************************/
253
+
254
+ const gm_kernel_set_t *
255
+ cpu_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
256
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
257
+ const int64_t li[], int nin, int nout, bool check_broadcast,
258
+ ndt_context_t *ctx)
259
+ {
260
+ const ndt_t *t;
261
+ const ndt_t *u;
262
+ int n;
263
+
264
+ assert(spec->flags == 0);
265
+ assert(spec->outer_dims == 0);
266
+ assert(spec->nin == 0);
267
+ assert(spec->nout == 0);
268
+ assert(spec->nargs == 0);
269
+
270
+ if (nin != 1) {
271
+ ndt_err_format(ctx, NDT_ValueError,
272
+ "invalid number of arguments for %s(x): expected 1, got %d",
273
+ f->name, nin);
274
+ return NULL;
275
+
276
+ }
277
+
278
+ t = types[0];
279
+
280
+ if (nout) {
281
+ if (nout != 1) {
282
+ ndt_err_format(ctx, NDT_ValueError,
283
+ "%s(x) expects at most one 'out' argument, got %d",
284
+ f->name, nout);
285
+ return NULL;
286
+ }
287
+ u = types[1];
288
+ }
289
+ else {
290
+ u = types[0];
291
+ }
292
+
293
+ assert(ndt_is_concrete(t));
294
+ assert(ndt_is_concrete(u));
295
+
296
+ n = kernel_location(t, u, ctx);
297
+ if (n < 0) {
298
+ return NULL;
299
+ }
300
+ if (ndt_is_optional(ndt_dtype(t))) {
301
+ n++;
302
+ }
303
+
304
+ if (t->tag == VarDim || t->tag == VarDimElem) {
305
+ const gm_kernel_set_t *set = &f->kernels[n+2];
306
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
307
+ check_broadcast, NULL, NULL, ctx) < 0) {
308
+ return NULL;
309
+ }
310
+ return set;
311
+ }
312
+
313
+ if (t->tag == Array) {
314
+ const gm_kernel_set_t *set = &f->kernels[n+4];
315
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
316
+ check_broadcast, NULL, NULL, ctx) < 0) {
317
+ return NULL;
318
+ }
319
+ return set;
320
+ }
321
+
322
+ const gm_kernel_set_t *set = &f->kernels[n];
323
+
324
+ if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
325
+ check_broadcast, ctx) < 0) {
326
+ return NULL;
327
+ }
328
+
329
+ return set;
330
+ }
331
+
332
+ const gm_kernel_set_t *
333
+ cuda_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
334
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
335
+ const int64_t li[], int nin, int nout, bool check_broadcast,
336
+ ndt_context_t *ctx)
337
+ {
338
+ const ndt_t *t;
339
+ const ndt_t *u;
340
+ int n;
341
+ (void)li;
342
+
343
+ assert(spec->flags == 0);
344
+ assert(spec->outer_dims == 0);
345
+ assert(spec->nin == 0);
346
+ assert(spec->nout == 0);
347
+ assert(spec->nargs == 0);
348
+
349
+ if (nin != 1) {
350
+ ndt_err_format(ctx, NDT_ValueError,
351
+ "invalid number of arguments for %s(x): expected 1, got %d",
352
+ f->name, nin);
353
+ return NULL;
354
+ }
355
+
356
+ t = types[0];
357
+
358
+ if (nout) {
359
+ if (nout != 1) {
360
+ ndt_err_format(ctx, NDT_ValueError,
361
+ "%s(x) expects at most one 'out' argument, got %d",
362
+ f->name, nout);
363
+ return NULL;
364
+ }
365
+ u = types[1];
366
+ }
367
+ else {
368
+ u = types[0];
369
+ }
370
+
371
+ assert(ndt_is_concrete(t));
372
+ assert(ndt_is_concrete(u));
373
+
374
+ n = kernel_location(t, u, ctx);
375
+ if (n < 0) {
376
+ return NULL;
377
+ }
378
+ if (ndt_is_optional(ndt_dtype(t))) {
379
+ n++;
380
+ }
381
+
382
+ const gm_kernel_set_t *set = &f->kernels[n];
383
+
384
+ if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
385
+ check_broadcast, ctx) < 0) {
386
+ return NULL;
387
+ }
388
+
389
+ return set;
390
+ }
391
+
392
+
393
+ /****************************************************************************/
394
+ /* Optimized binary typecheck */
395
+ /****************************************************************************/
396
+
397
+ const gm_kernel_set_t *
398
+ cpu_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
399
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
400
+ const int64_t li[], int nin, int nout, bool check_broadcast,
401
+ ndt_context_t *ctx)
402
+ {
403
+ const ndt_t *t0;
404
+ const ndt_t *t1;
405
+ int n;
406
+
407
+ assert(spec->flags == 0);
408
+ assert(spec->outer_dims == 0);
409
+ assert(spec->nin == 0);
410
+ assert(spec->nout == 0);
411
+ assert(spec->nargs == 0);
412
+
413
+ if (nin != 2) {
414
+ ndt_err_format(ctx, NDT_ValueError,
415
+ "invalid number of arguments for %s(x, y): expected 2, got %d",
416
+ f->name, nin);
417
+ return NULL;
418
+ }
419
+
420
+ t0 = types[0];
421
+ t1 = types[1];
422
+ assert(ndt_is_concrete(t0));
423
+ assert(ndt_is_concrete(t1));
424
+
425
+ n = kernel_location(t0, t1, ctx);
426
+ if (n < 0) {
427
+ return NULL;
428
+ }
429
+ if (ndt_is_optional(ndt_dtype(t0))) {
430
+ n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
431
+ }
432
+ else if (ndt_is_optional(ndt_dtype(t1))) {
433
+ n = n+2;
434
+ }
435
+
436
+ if (t0->tag == VarDim || t0->tag == VarDimElem ||
437
+ t1->tag == VarDim || t1->tag == VarDimElem) {
438
+ const gm_kernel_set_t *set = &f->kernels[n+4];
439
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
440
+ check_broadcast, NULL, NULL, ctx) < 0) {
441
+ return NULL;
442
+ }
443
+ return set;
444
+ }
445
+
446
+ if (t0->tag == Array || t1->tag == Array) {
447
+ const gm_kernel_set_t *set = &f->kernels[n+8];
448
+ if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
449
+ check_broadcast, NULL, NULL, ctx) < 0) {
450
+ return NULL;
451
+ }
452
+ return set;
453
+ }
454
+
455
+ const gm_kernel_set_t *set = &f->kernels[n];
456
+
457
+ if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
458
+ check_broadcast, ctx) < 0) {
459
+ return NULL;
460
+ }
461
+
462
+ return set;
463
+ }
464
+
465
+ const gm_kernel_set_t *
466
+ cuda_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
467
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
468
+ const int64_t li[], int nin, int nout, bool check_broadcast,
469
+ ndt_context_t *ctx)
470
+ {
471
+ const ndt_t *t0;
472
+ const ndt_t *t1;
473
+ int n;
474
+ (void)li;
475
+
476
+ assert(spec->flags == 0);
477
+ assert(spec->outer_dims == 0);
478
+ assert(spec->nin == 0);
479
+ assert(spec->nout == 0);
480
+ assert(spec->nargs == 0);
481
+
482
+ if (nin != 2) {
483
+ ndt_err_format(ctx, NDT_ValueError,
484
+ "invalid number of arguments for %s(x, y): expected 2, got %d",
485
+ f->name, nin);
486
+ return NULL;
487
+ }
488
+
489
+ t0 = types[0];
490
+ t1 = types[1];
491
+ assert(ndt_is_concrete(t0));
492
+ assert(ndt_is_concrete(t1));
493
+
494
+ n = kernel_location(t0, t1, ctx);
495
+ if (n < 0) {
496
+ return NULL;
497
+ }
498
+ if (ndt_is_optional(ndt_dtype(t0))) {
499
+ n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
500
+ }
501
+ else if (ndt_is_optional(ndt_dtype(t1))) {
502
+ n = n+2;
503
+ }
504
+
505
+ const gm_kernel_set_t *set = &f->kernels[n];
506
+
507
+ if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
508
+ check_broadcast, ctx) < 0) {
509
+ return NULL;
510
+ }
511
+
512
+ return set;
513
+ }