gumath 0.2.0dev5 → 0.2.0dev8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
Binary file
|
Binary file
|
@@ -34,6 +34,17 @@
|
|
34
34
|
#ifndef GUMATH_H
|
35
35
|
#define GUMATH_H
|
36
36
|
|
37
|
+
|
38
|
+
#ifdef __cplusplus
|
39
|
+
extern "C" {
|
40
|
+
#endif
|
41
|
+
|
42
|
+
#ifdef __cplusplus
|
43
|
+
#include <cstdint>
|
44
|
+
#else
|
45
|
+
#include <stdint.h>
|
46
|
+
#endif
|
47
|
+
|
37
48
|
#include "ndtypes.h"
|
38
49
|
#include "xnd.h"
|
39
50
|
|
@@ -65,7 +76,8 @@
|
|
65
76
|
#endif
|
66
77
|
|
67
78
|
|
68
|
-
#define GM_MAX_KERNELS
|
79
|
+
#define GM_MAX_KERNELS 8192
|
80
|
+
#define GM_THREAD_CUTOFF 1000000
|
69
81
|
|
70
82
|
typedef float float32_t;
|
71
83
|
typedef double float64_t;
|
@@ -74,15 +86,25 @@ typedef double float64_t;
|
|
74
86
|
typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
|
75
87
|
typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
|
76
88
|
|
77
|
-
/*
|
89
|
+
/*
|
90
|
+
* Collection of specialized kernels for a single function signature.
|
91
|
+
*
|
92
|
+
* NOTE: The specialized kernel lookup scheme is transitional and may
|
93
|
+
* be replaced by something else.
|
94
|
+
*
|
95
|
+
* This should be considered as a first version of a kernel request
|
96
|
+
* protocol.
|
97
|
+
*/
|
78
98
|
typedef struct {
|
79
|
-
ndt_t *sig;
|
99
|
+
const ndt_t *sig;
|
80
100
|
const ndt_constraint_t *constraint;
|
81
101
|
|
82
102
|
/* Xnd signatures */
|
83
|
-
gm_xnd_kernel_t
|
84
|
-
gm_xnd_kernel_t
|
85
|
-
gm_xnd_kernel_t
|
103
|
+
gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
|
104
|
+
gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
|
105
|
+
gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
|
106
|
+
gm_xnd_kernel_t C; /* C in inner dimensions */
|
107
|
+
gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
|
86
108
|
gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
|
87
109
|
|
88
110
|
/* NumPy signature */
|
@@ -99,11 +121,17 @@ typedef struct {
|
|
99
121
|
const char *name;
|
100
122
|
const char *sig;
|
101
123
|
const ndt_constraint_t *constraint;
|
124
|
+
uint32_t cap;
|
102
125
|
|
103
|
-
|
126
|
+
/* Xnd signatures */
|
127
|
+
gm_xnd_kernel_t OptC;
|
128
|
+
gm_xnd_kernel_t OptZ;
|
129
|
+
gm_xnd_kernel_t OptS;
|
104
130
|
gm_xnd_kernel_t C;
|
105
131
|
gm_xnd_kernel_t Fortran;
|
106
132
|
gm_xnd_kernel_t Xnd;
|
133
|
+
|
134
|
+
/* NumPy signature */
|
107
135
|
gm_strided_kernel_t Strided;
|
108
136
|
} gm_kernel_init_t;
|
109
137
|
|
@@ -115,7 +143,10 @@ typedef struct {
|
|
115
143
|
|
116
144
|
/* Multimethod with associated kernels */
|
117
145
|
typedef struct gm_func gm_func_t;
|
118
|
-
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
146
|
+
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
147
|
+
const ndt_t *in[], const int64_t li[],
|
148
|
+
int nin, int nout, bool check_broadcast,
|
149
|
+
ndt_context_t *ctx);
|
119
150
|
struct gm_func {
|
120
151
|
char *name;
|
121
152
|
gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
|
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
|
|
139
170
|
GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
|
140
171
|
|
141
172
|
GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
142
|
-
const ndt_t *
|
143
|
-
ndt_context_t *ctx);
|
173
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
174
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
|
144
175
|
GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
|
145
|
-
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
176
|
+
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
|
146
177
|
|
147
178
|
|
148
179
|
/******************************************************************************/
|
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
|
|
171
202
|
/* Xnd loops */
|
172
203
|
/******************************************************************************/
|
173
204
|
|
205
|
+
GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
|
174
206
|
GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
175
207
|
const int outer_dims, ndt_context_t *ctx);
|
176
208
|
|
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
|
|
191
223
|
/******************************************************************************/
|
192
224
|
|
193
225
|
GM_API void gm_init(void);
|
194
|
-
GM_API int
|
195
|
-
GM_API int
|
226
|
+
GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
227
|
+
GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
228
|
+
GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
229
|
+
|
230
|
+
GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
231
|
+
GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
232
|
+
|
196
233
|
GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
197
|
-
GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
198
234
|
GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
199
235
|
GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
200
236
|
GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
202
238
|
GM_API void gm_finalize(void);
|
203
239
|
|
204
240
|
|
241
|
+
#ifdef __cplusplus
|
242
|
+
} /* END extern "C" */
|
243
|
+
#endif
|
244
|
+
|
245
|
+
|
205
246
|
#endif /* GUMATH_H */
|
@@ -0,0 +1,513 @@
|
|
1
|
+
/*
|
2
|
+
* BSD 3-Clause License
|
3
|
+
*
|
4
|
+
* Copyright (c) 2017-2018, plures
|
5
|
+
* All rights reserved.
|
6
|
+
*
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
9
|
+
*
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
* this list of conditions and the following disclaimer.
|
12
|
+
*
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
15
|
+
* and/or other materials provided with the distribution.
|
16
|
+
*
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
19
|
+
* this software without specific prior written permission.
|
20
|
+
*
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
*/
|
32
|
+
|
33
|
+
|
34
|
+
#include <stdlib.h>
|
35
|
+
#include <stdint.h>
|
36
|
+
#include <string.h>
|
37
|
+
#include <math.h>
|
38
|
+
#include <complex.h>
|
39
|
+
#include <inttypes.h>
|
40
|
+
#include "ndtypes.h"
|
41
|
+
#include "xnd.h"
|
42
|
+
#include "gumath.h"
|
43
|
+
#include "common.h"
|
44
|
+
|
45
|
+
|
46
|
+
/****************************************************************************/
|
47
|
+
/* Unary bitmap kernels */
|
48
|
+
/****************************************************************************/
|
49
|
+
|
50
|
+
void
|
51
|
+
unary_update_bitmap_1D_S(xnd_t stack[])
|
52
|
+
{
|
53
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
54
|
+
const int64_t li0 = stack[0].index;
|
55
|
+
const int64_t li1 = stack[1].index;
|
56
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
57
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]);
|
58
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
59
|
+
uint8_t *b1 = get_bitmap1D(&stack[1]);
|
60
|
+
int64_t i, k0, k1;
|
61
|
+
|
62
|
+
assert(b0 != NULL);
|
63
|
+
assert(b1 != NULL);
|
64
|
+
|
65
|
+
for (i=0, k0=li0, k1=li1; i<N; i++, k0+=s0, k1+=s1) {
|
66
|
+
bool x = is_valid(b0, k0);
|
67
|
+
set_bit(b1, k1, x);
|
68
|
+
}
|
69
|
+
}
|
70
|
+
|
71
|
+
void
|
72
|
+
unary_reduce_bitmap_1D_S(xnd_t stack[])
|
73
|
+
{
|
74
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
75
|
+
const int64_t li0 = stack[0].index;
|
76
|
+
const int64_t li1 = stack[1].index;
|
77
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
78
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
79
|
+
uint8_t *b1 = get_bitmap(&stack[1]);
|
80
|
+
int64_t i, k0;
|
81
|
+
|
82
|
+
assert(b0 != NULL);
|
83
|
+
assert(b1 != NULL);
|
84
|
+
|
85
|
+
for (i=0, k0=li0; i<N; i++, k0+=s0) {
|
86
|
+
bool x = is_valid(b0, k0) && is_valid(b1, li1);
|
87
|
+
set_bit(b1, li1, x);
|
88
|
+
}
|
89
|
+
}
|
90
|
+
|
91
|
+
void
|
92
|
+
unary_update_bitmap_0D(xnd_t stack[])
|
93
|
+
{
|
94
|
+
const int64_t li0 = stack[0].index;
|
95
|
+
const int64_t li1 = stack[1].index;
|
96
|
+
const uint8_t *b0 = get_bitmap(&stack[0]);
|
97
|
+
uint8_t *b1 = get_bitmap(&stack[1]);
|
98
|
+
|
99
|
+
assert(b0 != NULL);
|
100
|
+
assert(b1 != NULL);
|
101
|
+
|
102
|
+
bool x = is_valid(b0, li0);
|
103
|
+
set_bit(b1, li1, x);
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
/****************************************************************************/
|
108
|
+
/* Binary bitmap kernels */
|
109
|
+
/****************************************************************************/
|
110
|
+
|
111
|
+
void
|
112
|
+
binary_update_bitmap_1D_S(xnd_t stack[])
|
113
|
+
{
|
114
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
115
|
+
const int64_t li0 = stack[0].index;
|
116
|
+
const int64_t li1 = stack[1].index;
|
117
|
+
const int64_t li2 = stack[2].index;
|
118
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
119
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]);
|
120
|
+
const int64_t s2 = xnd_fixed_step(&stack[2]);
|
121
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
122
|
+
const uint8_t *b1 = get_bitmap1D(&stack[1]);
|
123
|
+
uint8_t *b2 = get_bitmap1D(&stack[2]);
|
124
|
+
int64_t i, k0, k1, k2;
|
125
|
+
|
126
|
+
if (b0 && b1) {
|
127
|
+
for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
|
128
|
+
bool x = is_valid(b0, k0) && is_valid(b1, k1);
|
129
|
+
set_bit(b2, k2, x);
|
130
|
+
}
|
131
|
+
}
|
132
|
+
else if (b0) {
|
133
|
+
for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
|
134
|
+
bool x = is_valid(b0, k0);
|
135
|
+
set_bit(b2, k2, x);
|
136
|
+
}
|
137
|
+
}
|
138
|
+
else if (b1) {
|
139
|
+
for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
|
140
|
+
bool x = is_valid(b1, k1);
|
141
|
+
set_bit(b2, k2, x);
|
142
|
+
}
|
143
|
+
}
|
144
|
+
}
|
145
|
+
|
146
|
+
void
|
147
|
+
binary_update_bitmap_0D(xnd_t stack[])
|
148
|
+
{
|
149
|
+
const int64_t li0 = stack[0].index;
|
150
|
+
const int64_t li1 = stack[1].index;
|
151
|
+
const int64_t li2 = stack[2].index;
|
152
|
+
const uint8_t *b0 = get_bitmap(&stack[0]);
|
153
|
+
const uint8_t *b1 = get_bitmap(&stack[1]);
|
154
|
+
uint8_t *b2 = get_bitmap(&stack[2]);
|
155
|
+
|
156
|
+
assert(b2 != NULL);
|
157
|
+
|
158
|
+
if (b0 && b1) {
|
159
|
+
bool x = is_valid(b0, li0) && is_valid(b1, li1);
|
160
|
+
set_bit(b2, li2, x);
|
161
|
+
}
|
162
|
+
else if (b0) {
|
163
|
+
bool x = is_valid(b0, li0);
|
164
|
+
set_bit(b2, li2, x);
|
165
|
+
}
|
166
|
+
else if (b1) {
|
167
|
+
bool x = is_valid(b1, li1);
|
168
|
+
set_bit(b2, li2, x);
|
169
|
+
}
|
170
|
+
}
|
171
|
+
|
172
|
+
void
|
173
|
+
binary_update_bitmap_1D_S_bool(xnd_t stack[])
|
174
|
+
{
|
175
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
176
|
+
const int64_t li0 = stack[0].index;
|
177
|
+
const int64_t li1 = stack[1].index;
|
178
|
+
const int64_t li2 = stack[2].index;
|
179
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
180
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]);
|
181
|
+
const int64_t s2 = xnd_fixed_step(&stack[2]);
|
182
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
183
|
+
const uint8_t *b1 = get_bitmap1D(&stack[1]);
|
184
|
+
bool *x2 = (bool *)apply_index(&stack[2]);
|
185
|
+
int64_t i, k0, k1, k2;
|
186
|
+
|
187
|
+
assert(!ndt_is_optional(stack[2].type));
|
188
|
+
|
189
|
+
if (b0 && b1) {
|
190
|
+
for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
|
191
|
+
bool x = is_valid(b0, k0);
|
192
|
+
bool y = is_valid(b1, k1);
|
193
|
+
bool z = x2[k2];
|
194
|
+
z = x && y ? z : !x && !y;
|
195
|
+
x2[k2] = z;
|
196
|
+
}
|
197
|
+
}
|
198
|
+
else if (b0) {
|
199
|
+
for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
|
200
|
+
bool x = is_valid(b0, k0);
|
201
|
+
bool z = x2[k2];
|
202
|
+
z = x ? z : x;
|
203
|
+
x2[k2] = z;
|
204
|
+
}
|
205
|
+
}
|
206
|
+
else if (b1) {
|
207
|
+
for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
|
208
|
+
bool x = is_valid(b1, k1);
|
209
|
+
bool z = x2[k2];
|
210
|
+
z = x ? z : x;
|
211
|
+
x2[k2] = z;
|
212
|
+
}
|
213
|
+
}
|
214
|
+
}
|
215
|
+
|
216
|
+
void
|
217
|
+
binary_update_bitmap_0D_bool(xnd_t stack[])
|
218
|
+
{
|
219
|
+
const int64_t li0 = stack[0].index;
|
220
|
+
const int64_t li1 = stack[1].index;
|
221
|
+
const int64_t li2 = stack[2].index;
|
222
|
+
const uint8_t *b0 = get_bitmap(&stack[0]);
|
223
|
+
const uint8_t *b1 = get_bitmap(&stack[1]);
|
224
|
+
bool *x2 = (bool *)stack[2].ptr;
|
225
|
+
|
226
|
+
assert(!ndt_is_optional(stack[2].type));
|
227
|
+
|
228
|
+
if (b0 && b1) {
|
229
|
+
bool x = is_valid(b0, li0);
|
230
|
+
bool y = is_valid(b1, li1);
|
231
|
+
bool z = x2[li2];
|
232
|
+
z = x && y ? z : !x && !y;
|
233
|
+
x2[li2] = z;
|
234
|
+
}
|
235
|
+
else if (b0) {
|
236
|
+
bool x = is_valid(b0, li0);
|
237
|
+
bool z = x2[li2];
|
238
|
+
z = x ? z : x;
|
239
|
+
x2[li2] = z;
|
240
|
+
}
|
241
|
+
else if (b1) {
|
242
|
+
bool x = is_valid(b1, li1);
|
243
|
+
bool z = x2[li2];
|
244
|
+
z = x ? z : x;
|
245
|
+
x2[li2] = z;
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
|
250
|
+
/****************************************************************************/
|
251
|
+
/* Optimized unary typecheck */
|
252
|
+
/****************************************************************************/
|
253
|
+
|
254
|
+
const gm_kernel_set_t *
|
255
|
+
cpu_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
|
256
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
257
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
258
|
+
ndt_context_t *ctx)
|
259
|
+
{
|
260
|
+
const ndt_t *t;
|
261
|
+
const ndt_t *u;
|
262
|
+
int n;
|
263
|
+
|
264
|
+
assert(spec->flags == 0);
|
265
|
+
assert(spec->outer_dims == 0);
|
266
|
+
assert(spec->nin == 0);
|
267
|
+
assert(spec->nout == 0);
|
268
|
+
assert(spec->nargs == 0);
|
269
|
+
|
270
|
+
if (nin != 1) {
|
271
|
+
ndt_err_format(ctx, NDT_ValueError,
|
272
|
+
"invalid number of arguments for %s(x): expected 1, got %d",
|
273
|
+
f->name, nin);
|
274
|
+
return NULL;
|
275
|
+
|
276
|
+
}
|
277
|
+
|
278
|
+
t = types[0];
|
279
|
+
|
280
|
+
if (nout) {
|
281
|
+
if (nout != 1) {
|
282
|
+
ndt_err_format(ctx, NDT_ValueError,
|
283
|
+
"%s(x) expects at most one 'out' argument, got %d",
|
284
|
+
f->name, nout);
|
285
|
+
return NULL;
|
286
|
+
}
|
287
|
+
u = types[1];
|
288
|
+
}
|
289
|
+
else {
|
290
|
+
u = types[0];
|
291
|
+
}
|
292
|
+
|
293
|
+
assert(ndt_is_concrete(t));
|
294
|
+
assert(ndt_is_concrete(u));
|
295
|
+
|
296
|
+
n = kernel_location(t, u, ctx);
|
297
|
+
if (n < 0) {
|
298
|
+
return NULL;
|
299
|
+
}
|
300
|
+
if (ndt_is_optional(ndt_dtype(t))) {
|
301
|
+
n++;
|
302
|
+
}
|
303
|
+
|
304
|
+
if (t->tag == VarDim || t->tag == VarDimElem) {
|
305
|
+
const gm_kernel_set_t *set = &f->kernels[n+2];
|
306
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
307
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
308
|
+
return NULL;
|
309
|
+
}
|
310
|
+
return set;
|
311
|
+
}
|
312
|
+
|
313
|
+
if (t->tag == Array) {
|
314
|
+
const gm_kernel_set_t *set = &f->kernels[n+4];
|
315
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
316
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
317
|
+
return NULL;
|
318
|
+
}
|
319
|
+
return set;
|
320
|
+
}
|
321
|
+
|
322
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
323
|
+
|
324
|
+
if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
325
|
+
check_broadcast, ctx) < 0) {
|
326
|
+
return NULL;
|
327
|
+
}
|
328
|
+
|
329
|
+
return set;
|
330
|
+
}
|
331
|
+
|
332
|
+
const gm_kernel_set_t *
|
333
|
+
cuda_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
|
334
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
335
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
336
|
+
ndt_context_t *ctx)
|
337
|
+
{
|
338
|
+
const ndt_t *t;
|
339
|
+
const ndt_t *u;
|
340
|
+
int n;
|
341
|
+
(void)li;
|
342
|
+
|
343
|
+
assert(spec->flags == 0);
|
344
|
+
assert(spec->outer_dims == 0);
|
345
|
+
assert(spec->nin == 0);
|
346
|
+
assert(spec->nout == 0);
|
347
|
+
assert(spec->nargs == 0);
|
348
|
+
|
349
|
+
if (nin != 1) {
|
350
|
+
ndt_err_format(ctx, NDT_ValueError,
|
351
|
+
"invalid number of arguments for %s(x): expected 1, got %d",
|
352
|
+
f->name, nin);
|
353
|
+
return NULL;
|
354
|
+
}
|
355
|
+
|
356
|
+
t = types[0];
|
357
|
+
|
358
|
+
if (nout) {
|
359
|
+
if (nout != 1) {
|
360
|
+
ndt_err_format(ctx, NDT_ValueError,
|
361
|
+
"%s(x) expects at most one 'out' argument, got %d",
|
362
|
+
f->name, nout);
|
363
|
+
return NULL;
|
364
|
+
}
|
365
|
+
u = types[1];
|
366
|
+
}
|
367
|
+
else {
|
368
|
+
u = types[0];
|
369
|
+
}
|
370
|
+
|
371
|
+
assert(ndt_is_concrete(t));
|
372
|
+
assert(ndt_is_concrete(u));
|
373
|
+
|
374
|
+
n = kernel_location(t, u, ctx);
|
375
|
+
if (n < 0) {
|
376
|
+
return NULL;
|
377
|
+
}
|
378
|
+
if (ndt_is_optional(ndt_dtype(t))) {
|
379
|
+
n++;
|
380
|
+
}
|
381
|
+
|
382
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
383
|
+
|
384
|
+
if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
385
|
+
check_broadcast, ctx) < 0) {
|
386
|
+
return NULL;
|
387
|
+
}
|
388
|
+
|
389
|
+
return set;
|
390
|
+
}
|
391
|
+
|
392
|
+
|
393
|
+
/****************************************************************************/
|
394
|
+
/* Optimized binary typecheck */
|
395
|
+
/****************************************************************************/
|
396
|
+
|
397
|
+
const gm_kernel_set_t *
|
398
|
+
cpu_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
|
399
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
400
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
401
|
+
ndt_context_t *ctx)
|
402
|
+
{
|
403
|
+
const ndt_t *t0;
|
404
|
+
const ndt_t *t1;
|
405
|
+
int n;
|
406
|
+
|
407
|
+
assert(spec->flags == 0);
|
408
|
+
assert(spec->outer_dims == 0);
|
409
|
+
assert(spec->nin == 0);
|
410
|
+
assert(spec->nout == 0);
|
411
|
+
assert(spec->nargs == 0);
|
412
|
+
|
413
|
+
if (nin != 2) {
|
414
|
+
ndt_err_format(ctx, NDT_ValueError,
|
415
|
+
"invalid number of arguments for %s(x, y): expected 2, got %d",
|
416
|
+
f->name, nin);
|
417
|
+
return NULL;
|
418
|
+
}
|
419
|
+
|
420
|
+
t0 = types[0];
|
421
|
+
t1 = types[1];
|
422
|
+
assert(ndt_is_concrete(t0));
|
423
|
+
assert(ndt_is_concrete(t1));
|
424
|
+
|
425
|
+
n = kernel_location(t0, t1, ctx);
|
426
|
+
if (n < 0) {
|
427
|
+
return NULL;
|
428
|
+
}
|
429
|
+
if (ndt_is_optional(ndt_dtype(t0))) {
|
430
|
+
n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
|
431
|
+
}
|
432
|
+
else if (ndt_is_optional(ndt_dtype(t1))) {
|
433
|
+
n = n+2;
|
434
|
+
}
|
435
|
+
|
436
|
+
if (t0->tag == VarDim || t0->tag == VarDimElem ||
|
437
|
+
t1->tag == VarDim || t1->tag == VarDimElem) {
|
438
|
+
const gm_kernel_set_t *set = &f->kernels[n+4];
|
439
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
440
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
441
|
+
return NULL;
|
442
|
+
}
|
443
|
+
return set;
|
444
|
+
}
|
445
|
+
|
446
|
+
if (t0->tag == Array || t1->tag == Array) {
|
447
|
+
const gm_kernel_set_t *set = &f->kernels[n+8];
|
448
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
449
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
450
|
+
return NULL;
|
451
|
+
}
|
452
|
+
return set;
|
453
|
+
}
|
454
|
+
|
455
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
456
|
+
|
457
|
+
if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
458
|
+
check_broadcast, ctx) < 0) {
|
459
|
+
return NULL;
|
460
|
+
}
|
461
|
+
|
462
|
+
return set;
|
463
|
+
}
|
464
|
+
|
465
|
+
const gm_kernel_set_t *
|
466
|
+
cuda_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
|
467
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
468
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
469
|
+
ndt_context_t *ctx)
|
470
|
+
{
|
471
|
+
const ndt_t *t0;
|
472
|
+
const ndt_t *t1;
|
473
|
+
int n;
|
474
|
+
(void)li;
|
475
|
+
|
476
|
+
assert(spec->flags == 0);
|
477
|
+
assert(spec->outer_dims == 0);
|
478
|
+
assert(spec->nin == 0);
|
479
|
+
assert(spec->nout == 0);
|
480
|
+
assert(spec->nargs == 0);
|
481
|
+
|
482
|
+
if (nin != 2) {
|
483
|
+
ndt_err_format(ctx, NDT_ValueError,
|
484
|
+
"invalid number of arguments for %s(x, y): expected 2, got %d",
|
485
|
+
f->name, nin);
|
486
|
+
return NULL;
|
487
|
+
}
|
488
|
+
|
489
|
+
t0 = types[0];
|
490
|
+
t1 = types[1];
|
491
|
+
assert(ndt_is_concrete(t0));
|
492
|
+
assert(ndt_is_concrete(t1));
|
493
|
+
|
494
|
+
n = kernel_location(t0, t1, ctx);
|
495
|
+
if (n < 0) {
|
496
|
+
return NULL;
|
497
|
+
}
|
498
|
+
if (ndt_is_optional(ndt_dtype(t0))) {
|
499
|
+
n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
|
500
|
+
}
|
501
|
+
else if (ndt_is_optional(ndt_dtype(t1))) {
|
502
|
+
n = n+2;
|
503
|
+
}
|
504
|
+
|
505
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
506
|
+
|
507
|
+
if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
508
|
+
check_broadcast, ctx) < 0) {
|
509
|
+
return NULL;
|
510
|
+
}
|
511
|
+
|
512
|
+
return set;
|
513
|
+
}
|