gumath 0.2.0dev5 → 0.2.0dev8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
|
Binary file
|
|
Binary file
|
|
@@ -34,6 +34,17 @@
|
|
|
34
34
|
#ifndef GUMATH_H
|
|
35
35
|
#define GUMATH_H
|
|
36
36
|
|
|
37
|
+
|
|
38
|
+
#ifdef __cplusplus
|
|
39
|
+
extern "C" {
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#ifdef __cplusplus
|
|
43
|
+
#include <cstdint>
|
|
44
|
+
#else
|
|
45
|
+
#include <stdint.h>
|
|
46
|
+
#endif
|
|
47
|
+
|
|
37
48
|
#include "ndtypes.h"
|
|
38
49
|
#include "xnd.h"
|
|
39
50
|
|
|
@@ -65,7 +76,8 @@
|
|
|
65
76
|
#endif
|
|
66
77
|
|
|
67
78
|
|
|
68
|
-
#define GM_MAX_KERNELS
|
|
79
|
+
#define GM_MAX_KERNELS 8192
|
|
80
|
+
#define GM_THREAD_CUTOFF 1000000
|
|
69
81
|
|
|
70
82
|
typedef float float32_t;
|
|
71
83
|
typedef double float64_t;
|
|
@@ -74,15 +86,25 @@ typedef double float64_t;
|
|
|
74
86
|
typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
|
|
75
87
|
typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
|
|
76
88
|
|
|
77
|
-
/*
|
|
89
|
+
/*
|
|
90
|
+
* Collection of specialized kernels for a single function signature.
|
|
91
|
+
*
|
|
92
|
+
* NOTE: The specialized kernel lookup scheme is transitional and may
|
|
93
|
+
* be replaced by something else.
|
|
94
|
+
*
|
|
95
|
+
* This should be considered as a first version of a kernel request
|
|
96
|
+
* protocol.
|
|
97
|
+
*/
|
|
78
98
|
typedef struct {
|
|
79
|
-
ndt_t *sig;
|
|
99
|
+
const ndt_t *sig;
|
|
80
100
|
const ndt_constraint_t *constraint;
|
|
81
101
|
|
|
82
102
|
/* Xnd signatures */
|
|
83
|
-
gm_xnd_kernel_t
|
|
84
|
-
gm_xnd_kernel_t
|
|
85
|
-
gm_xnd_kernel_t
|
|
103
|
+
gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
|
|
104
|
+
gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
|
|
105
|
+
gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
|
|
106
|
+
gm_xnd_kernel_t C; /* C in inner dimensions */
|
|
107
|
+
gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
|
|
86
108
|
gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
|
|
87
109
|
|
|
88
110
|
/* NumPy signature */
|
|
@@ -99,11 +121,17 @@ typedef struct {
|
|
|
99
121
|
const char *name;
|
|
100
122
|
const char *sig;
|
|
101
123
|
const ndt_constraint_t *constraint;
|
|
124
|
+
uint32_t cap;
|
|
102
125
|
|
|
103
|
-
|
|
126
|
+
/* Xnd signatures */
|
|
127
|
+
gm_xnd_kernel_t OptC;
|
|
128
|
+
gm_xnd_kernel_t OptZ;
|
|
129
|
+
gm_xnd_kernel_t OptS;
|
|
104
130
|
gm_xnd_kernel_t C;
|
|
105
131
|
gm_xnd_kernel_t Fortran;
|
|
106
132
|
gm_xnd_kernel_t Xnd;
|
|
133
|
+
|
|
134
|
+
/* NumPy signature */
|
|
107
135
|
gm_strided_kernel_t Strided;
|
|
108
136
|
} gm_kernel_init_t;
|
|
109
137
|
|
|
@@ -115,7 +143,10 @@ typedef struct {
|
|
|
115
143
|
|
|
116
144
|
/* Multimethod with associated kernels */
|
|
117
145
|
typedef struct gm_func gm_func_t;
|
|
118
|
-
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
|
146
|
+
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
|
147
|
+
const ndt_t *in[], const int64_t li[],
|
|
148
|
+
int nin, int nout, bool check_broadcast,
|
|
149
|
+
ndt_context_t *ctx);
|
|
119
150
|
struct gm_func {
|
|
120
151
|
char *name;
|
|
121
152
|
gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
|
|
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
|
|
|
139
170
|
GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
|
|
140
171
|
|
|
141
172
|
GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
142
|
-
const ndt_t *
|
|
143
|
-
ndt_context_t *ctx);
|
|
173
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
|
174
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
|
|
144
175
|
GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
|
|
145
|
-
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
176
|
+
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
|
|
146
177
|
|
|
147
178
|
|
|
148
179
|
/******************************************************************************/
|
|
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
|
|
|
171
202
|
/* Xnd loops */
|
|
172
203
|
/******************************************************************************/
|
|
173
204
|
|
|
205
|
+
GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
|
|
174
206
|
GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
175
207
|
const int outer_dims, ndt_context_t *ctx);
|
|
176
208
|
|
|
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
|
|
|
191
223
|
/******************************************************************************/
|
|
192
224
|
|
|
193
225
|
GM_API void gm_init(void);
|
|
194
|
-
GM_API int
|
|
195
|
-
GM_API int
|
|
226
|
+
GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
227
|
+
GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
228
|
+
GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
229
|
+
|
|
230
|
+
GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
231
|
+
GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
232
|
+
|
|
196
233
|
GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
197
|
-
GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
198
234
|
GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
199
235
|
GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
200
236
|
GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
|
202
238
|
GM_API void gm_finalize(void);
|
|
203
239
|
|
|
204
240
|
|
|
241
|
+
#ifdef __cplusplus
|
|
242
|
+
} /* END extern "C" */
|
|
243
|
+
#endif
|
|
244
|
+
|
|
245
|
+
|
|
205
246
|
#endif /* GUMATH_H */
|
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* BSD 3-Clause License
|
|
3
|
+
*
|
|
4
|
+
* Copyright (c) 2017-2018, plures
|
|
5
|
+
* All rights reserved.
|
|
6
|
+
*
|
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
|
9
|
+
*
|
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
|
11
|
+
* this list of conditions and the following disclaimer.
|
|
12
|
+
*
|
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
|
15
|
+
* and/or other materials provided with the distribution.
|
|
16
|
+
*
|
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
|
19
|
+
* this software without specific prior written permission.
|
|
20
|
+
*
|
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
31
|
+
*/
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
#include <stdlib.h>
|
|
35
|
+
#include <stdint.h>
|
|
36
|
+
#include <string.h>
|
|
37
|
+
#include <math.h>
|
|
38
|
+
#include <complex.h>
|
|
39
|
+
#include <inttypes.h>
|
|
40
|
+
#include "ndtypes.h"
|
|
41
|
+
#include "xnd.h"
|
|
42
|
+
#include "gumath.h"
|
|
43
|
+
#include "common.h"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
/****************************************************************************/
|
|
47
|
+
/* Unary bitmap kernels */
|
|
48
|
+
/****************************************************************************/
|
|
49
|
+
|
|
50
|
+
void
|
|
51
|
+
unary_update_bitmap_1D_S(xnd_t stack[])
|
|
52
|
+
{
|
|
53
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
|
54
|
+
const int64_t li0 = stack[0].index;
|
|
55
|
+
const int64_t li1 = stack[1].index;
|
|
56
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
|
57
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]);
|
|
58
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
|
59
|
+
uint8_t *b1 = get_bitmap1D(&stack[1]);
|
|
60
|
+
int64_t i, k0, k1;
|
|
61
|
+
|
|
62
|
+
assert(b0 != NULL);
|
|
63
|
+
assert(b1 != NULL);
|
|
64
|
+
|
|
65
|
+
for (i=0, k0=li0, k1=li1; i<N; i++, k0+=s0, k1+=s1) {
|
|
66
|
+
bool x = is_valid(b0, k0);
|
|
67
|
+
set_bit(b1, k1, x);
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
void
|
|
72
|
+
unary_reduce_bitmap_1D_S(xnd_t stack[])
|
|
73
|
+
{
|
|
74
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
|
75
|
+
const int64_t li0 = stack[0].index;
|
|
76
|
+
const int64_t li1 = stack[1].index;
|
|
77
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
|
78
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
|
79
|
+
uint8_t *b1 = get_bitmap(&stack[1]);
|
|
80
|
+
int64_t i, k0;
|
|
81
|
+
|
|
82
|
+
assert(b0 != NULL);
|
|
83
|
+
assert(b1 != NULL);
|
|
84
|
+
|
|
85
|
+
for (i=0, k0=li0; i<N; i++, k0+=s0) {
|
|
86
|
+
bool x = is_valid(b0, k0) && is_valid(b1, li1);
|
|
87
|
+
set_bit(b1, li1, x);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
void
|
|
92
|
+
unary_update_bitmap_0D(xnd_t stack[])
|
|
93
|
+
{
|
|
94
|
+
const int64_t li0 = stack[0].index;
|
|
95
|
+
const int64_t li1 = stack[1].index;
|
|
96
|
+
const uint8_t *b0 = get_bitmap(&stack[0]);
|
|
97
|
+
uint8_t *b1 = get_bitmap(&stack[1]);
|
|
98
|
+
|
|
99
|
+
assert(b0 != NULL);
|
|
100
|
+
assert(b1 != NULL);
|
|
101
|
+
|
|
102
|
+
bool x = is_valid(b0, li0);
|
|
103
|
+
set_bit(b1, li1, x);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
/****************************************************************************/
|
|
108
|
+
/* Binary bitmap kernels */
|
|
109
|
+
/****************************************************************************/
|
|
110
|
+
|
|
111
|
+
void
|
|
112
|
+
binary_update_bitmap_1D_S(xnd_t stack[])
|
|
113
|
+
{
|
|
114
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
|
115
|
+
const int64_t li0 = stack[0].index;
|
|
116
|
+
const int64_t li1 = stack[1].index;
|
|
117
|
+
const int64_t li2 = stack[2].index;
|
|
118
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
|
119
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]);
|
|
120
|
+
const int64_t s2 = xnd_fixed_step(&stack[2]);
|
|
121
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
|
122
|
+
const uint8_t *b1 = get_bitmap1D(&stack[1]);
|
|
123
|
+
uint8_t *b2 = get_bitmap1D(&stack[2]);
|
|
124
|
+
int64_t i, k0, k1, k2;
|
|
125
|
+
|
|
126
|
+
if (b0 && b1) {
|
|
127
|
+
for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
|
|
128
|
+
bool x = is_valid(b0, k0) && is_valid(b1, k1);
|
|
129
|
+
set_bit(b2, k2, x);
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
else if (b0) {
|
|
133
|
+
for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
|
|
134
|
+
bool x = is_valid(b0, k0);
|
|
135
|
+
set_bit(b2, k2, x);
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
else if (b1) {
|
|
139
|
+
for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
|
|
140
|
+
bool x = is_valid(b1, k1);
|
|
141
|
+
set_bit(b2, k2, x);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
void
|
|
147
|
+
binary_update_bitmap_0D(xnd_t stack[])
|
|
148
|
+
{
|
|
149
|
+
const int64_t li0 = stack[0].index;
|
|
150
|
+
const int64_t li1 = stack[1].index;
|
|
151
|
+
const int64_t li2 = stack[2].index;
|
|
152
|
+
const uint8_t *b0 = get_bitmap(&stack[0]);
|
|
153
|
+
const uint8_t *b1 = get_bitmap(&stack[1]);
|
|
154
|
+
uint8_t *b2 = get_bitmap(&stack[2]);
|
|
155
|
+
|
|
156
|
+
assert(b2 != NULL);
|
|
157
|
+
|
|
158
|
+
if (b0 && b1) {
|
|
159
|
+
bool x = is_valid(b0, li0) && is_valid(b1, li1);
|
|
160
|
+
set_bit(b2, li2, x);
|
|
161
|
+
}
|
|
162
|
+
else if (b0) {
|
|
163
|
+
bool x = is_valid(b0, li0);
|
|
164
|
+
set_bit(b2, li2, x);
|
|
165
|
+
}
|
|
166
|
+
else if (b1) {
|
|
167
|
+
bool x = is_valid(b1, li1);
|
|
168
|
+
set_bit(b2, li2, x);
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
void
|
|
173
|
+
binary_update_bitmap_1D_S_bool(xnd_t stack[])
|
|
174
|
+
{
|
|
175
|
+
const int64_t N = xnd_fixed_shape(&stack[0]);
|
|
176
|
+
const int64_t li0 = stack[0].index;
|
|
177
|
+
const int64_t li1 = stack[1].index;
|
|
178
|
+
const int64_t li2 = stack[2].index;
|
|
179
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]);
|
|
180
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]);
|
|
181
|
+
const int64_t s2 = xnd_fixed_step(&stack[2]);
|
|
182
|
+
const uint8_t *b0 = get_bitmap1D(&stack[0]);
|
|
183
|
+
const uint8_t *b1 = get_bitmap1D(&stack[1]);
|
|
184
|
+
bool *x2 = (bool *)apply_index(&stack[2]);
|
|
185
|
+
int64_t i, k0, k1, k2;
|
|
186
|
+
|
|
187
|
+
assert(!ndt_is_optional(stack[2].type));
|
|
188
|
+
|
|
189
|
+
if (b0 && b1) {
|
|
190
|
+
for (i=0, k0=li0, k1=li1, k2=li2; i<N; i++, k0+=s0, k1+=s1, k2+=s2) {
|
|
191
|
+
bool x = is_valid(b0, k0);
|
|
192
|
+
bool y = is_valid(b1, k1);
|
|
193
|
+
bool z = x2[k2];
|
|
194
|
+
z = x && y ? z : !x && !y;
|
|
195
|
+
x2[k2] = z;
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
else if (b0) {
|
|
199
|
+
for (i=0, k0=li0, k2=li2; i<N; i++, k0+=s0, k2+=s2) {
|
|
200
|
+
bool x = is_valid(b0, k0);
|
|
201
|
+
bool z = x2[k2];
|
|
202
|
+
z = x ? z : x;
|
|
203
|
+
x2[k2] = z;
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
else if (b1) {
|
|
207
|
+
for (i=0, k1=li1, k2=li2; i<N; i++, k1+=s1, k2+=s2) {
|
|
208
|
+
bool x = is_valid(b1, k1);
|
|
209
|
+
bool z = x2[k2];
|
|
210
|
+
z = x ? z : x;
|
|
211
|
+
x2[k2] = z;
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
void
|
|
217
|
+
binary_update_bitmap_0D_bool(xnd_t stack[])
|
|
218
|
+
{
|
|
219
|
+
const int64_t li0 = stack[0].index;
|
|
220
|
+
const int64_t li1 = stack[1].index;
|
|
221
|
+
const int64_t li2 = stack[2].index;
|
|
222
|
+
const uint8_t *b0 = get_bitmap(&stack[0]);
|
|
223
|
+
const uint8_t *b1 = get_bitmap(&stack[1]);
|
|
224
|
+
bool *x2 = (bool *)stack[2].ptr;
|
|
225
|
+
|
|
226
|
+
assert(!ndt_is_optional(stack[2].type));
|
|
227
|
+
|
|
228
|
+
if (b0 && b1) {
|
|
229
|
+
bool x = is_valid(b0, li0);
|
|
230
|
+
bool y = is_valid(b1, li1);
|
|
231
|
+
bool z = x2[li2];
|
|
232
|
+
z = x && y ? z : !x && !y;
|
|
233
|
+
x2[li2] = z;
|
|
234
|
+
}
|
|
235
|
+
else if (b0) {
|
|
236
|
+
bool x = is_valid(b0, li0);
|
|
237
|
+
bool z = x2[li2];
|
|
238
|
+
z = x ? z : x;
|
|
239
|
+
x2[li2] = z;
|
|
240
|
+
}
|
|
241
|
+
else if (b1) {
|
|
242
|
+
bool x = is_valid(b1, li1);
|
|
243
|
+
bool z = x2[li2];
|
|
244
|
+
z = x ? z : x;
|
|
245
|
+
x2[li2] = z;
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
/****************************************************************************/
|
|
251
|
+
/* Optimized unary typecheck */
|
|
252
|
+
/****************************************************************************/
|
|
253
|
+
|
|
254
|
+
const gm_kernel_set_t *
|
|
255
|
+
cpu_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
|
|
256
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
257
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
258
|
+
ndt_context_t *ctx)
|
|
259
|
+
{
|
|
260
|
+
const ndt_t *t;
|
|
261
|
+
const ndt_t *u;
|
|
262
|
+
int n;
|
|
263
|
+
|
|
264
|
+
assert(spec->flags == 0);
|
|
265
|
+
assert(spec->outer_dims == 0);
|
|
266
|
+
assert(spec->nin == 0);
|
|
267
|
+
assert(spec->nout == 0);
|
|
268
|
+
assert(spec->nargs == 0);
|
|
269
|
+
|
|
270
|
+
if (nin != 1) {
|
|
271
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
272
|
+
"invalid number of arguments for %s(x): expected 1, got %d",
|
|
273
|
+
f->name, nin);
|
|
274
|
+
return NULL;
|
|
275
|
+
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
t = types[0];
|
|
279
|
+
|
|
280
|
+
if (nout) {
|
|
281
|
+
if (nout != 1) {
|
|
282
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
283
|
+
"%s(x) expects at most one 'out' argument, got %d",
|
|
284
|
+
f->name, nout);
|
|
285
|
+
return NULL;
|
|
286
|
+
}
|
|
287
|
+
u = types[1];
|
|
288
|
+
}
|
|
289
|
+
else {
|
|
290
|
+
u = types[0];
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
assert(ndt_is_concrete(t));
|
|
294
|
+
assert(ndt_is_concrete(u));
|
|
295
|
+
|
|
296
|
+
n = kernel_location(t, u, ctx);
|
|
297
|
+
if (n < 0) {
|
|
298
|
+
return NULL;
|
|
299
|
+
}
|
|
300
|
+
if (ndt_is_optional(ndt_dtype(t))) {
|
|
301
|
+
n++;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
if (t->tag == VarDim || t->tag == VarDimElem) {
|
|
305
|
+
const gm_kernel_set_t *set = &f->kernels[n+2];
|
|
306
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
|
307
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
|
308
|
+
return NULL;
|
|
309
|
+
}
|
|
310
|
+
return set;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
if (t->tag == Array) {
|
|
314
|
+
const gm_kernel_set_t *set = &f->kernels[n+4];
|
|
315
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
|
316
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
|
317
|
+
return NULL;
|
|
318
|
+
}
|
|
319
|
+
return set;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
|
323
|
+
|
|
324
|
+
if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
|
325
|
+
check_broadcast, ctx) < 0) {
|
|
326
|
+
return NULL;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
return set;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
const gm_kernel_set_t *
|
|
333
|
+
cuda_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
|
|
334
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
335
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
336
|
+
ndt_context_t *ctx)
|
|
337
|
+
{
|
|
338
|
+
const ndt_t *t;
|
|
339
|
+
const ndt_t *u;
|
|
340
|
+
int n;
|
|
341
|
+
(void)li;
|
|
342
|
+
|
|
343
|
+
assert(spec->flags == 0);
|
|
344
|
+
assert(spec->outer_dims == 0);
|
|
345
|
+
assert(spec->nin == 0);
|
|
346
|
+
assert(spec->nout == 0);
|
|
347
|
+
assert(spec->nargs == 0);
|
|
348
|
+
|
|
349
|
+
if (nin != 1) {
|
|
350
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
351
|
+
"invalid number of arguments for %s(x): expected 1, got %d",
|
|
352
|
+
f->name, nin);
|
|
353
|
+
return NULL;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
t = types[0];
|
|
357
|
+
|
|
358
|
+
if (nout) {
|
|
359
|
+
if (nout != 1) {
|
|
360
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
361
|
+
"%s(x) expects at most one 'out' argument, got %d",
|
|
362
|
+
f->name, nout);
|
|
363
|
+
return NULL;
|
|
364
|
+
}
|
|
365
|
+
u = types[1];
|
|
366
|
+
}
|
|
367
|
+
else {
|
|
368
|
+
u = types[0];
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
assert(ndt_is_concrete(t));
|
|
372
|
+
assert(ndt_is_concrete(u));
|
|
373
|
+
|
|
374
|
+
n = kernel_location(t, u, ctx);
|
|
375
|
+
if (n < 0) {
|
|
376
|
+
return NULL;
|
|
377
|
+
}
|
|
378
|
+
if (ndt_is_optional(ndt_dtype(t))) {
|
|
379
|
+
n++;
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
|
383
|
+
|
|
384
|
+
if (ndt_fast_unary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
|
385
|
+
check_broadcast, ctx) < 0) {
|
|
386
|
+
return NULL;
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
return set;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
/****************************************************************************/
|
|
394
|
+
/* Optimized binary typecheck */
|
|
395
|
+
/****************************************************************************/
|
|
396
|
+
|
|
397
|
+
const gm_kernel_set_t *
|
|
398
|
+
cpu_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
|
|
399
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
400
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
401
|
+
ndt_context_t *ctx)
|
|
402
|
+
{
|
|
403
|
+
const ndt_t *t0;
|
|
404
|
+
const ndt_t *t1;
|
|
405
|
+
int n;
|
|
406
|
+
|
|
407
|
+
assert(spec->flags == 0);
|
|
408
|
+
assert(spec->outer_dims == 0);
|
|
409
|
+
assert(spec->nin == 0);
|
|
410
|
+
assert(spec->nout == 0);
|
|
411
|
+
assert(spec->nargs == 0);
|
|
412
|
+
|
|
413
|
+
if (nin != 2) {
|
|
414
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
415
|
+
"invalid number of arguments for %s(x, y): expected 2, got %d",
|
|
416
|
+
f->name, nin);
|
|
417
|
+
return NULL;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
t0 = types[0];
|
|
421
|
+
t1 = types[1];
|
|
422
|
+
assert(ndt_is_concrete(t0));
|
|
423
|
+
assert(ndt_is_concrete(t1));
|
|
424
|
+
|
|
425
|
+
n = kernel_location(t0, t1, ctx);
|
|
426
|
+
if (n < 0) {
|
|
427
|
+
return NULL;
|
|
428
|
+
}
|
|
429
|
+
if (ndt_is_optional(ndt_dtype(t0))) {
|
|
430
|
+
n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
|
|
431
|
+
}
|
|
432
|
+
else if (ndt_is_optional(ndt_dtype(t1))) {
|
|
433
|
+
n = n+2;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
if (t0->tag == VarDim || t0->tag == VarDimElem ||
|
|
437
|
+
t1->tag == VarDim || t1->tag == VarDimElem) {
|
|
438
|
+
const gm_kernel_set_t *set = &f->kernels[n+4];
|
|
439
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
|
440
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
|
441
|
+
return NULL;
|
|
442
|
+
}
|
|
443
|
+
return set;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
if (t0->tag == Array || t1->tag == Array) {
|
|
447
|
+
const gm_kernel_set_t *set = &f->kernels[n+8];
|
|
448
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
|
449
|
+
check_broadcast, NULL, NULL, ctx) < 0) {
|
|
450
|
+
return NULL;
|
|
451
|
+
}
|
|
452
|
+
return set;
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
|
456
|
+
|
|
457
|
+
if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
|
458
|
+
check_broadcast, ctx) < 0) {
|
|
459
|
+
return NULL;
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
return set;
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
const gm_kernel_set_t *
|
|
466
|
+
cuda_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
|
|
467
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
468
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
469
|
+
ndt_context_t *ctx)
|
|
470
|
+
{
|
|
471
|
+
const ndt_t *t0;
|
|
472
|
+
const ndt_t *t1;
|
|
473
|
+
int n;
|
|
474
|
+
(void)li;
|
|
475
|
+
|
|
476
|
+
assert(spec->flags == 0);
|
|
477
|
+
assert(spec->outer_dims == 0);
|
|
478
|
+
assert(spec->nin == 0);
|
|
479
|
+
assert(spec->nout == 0);
|
|
480
|
+
assert(spec->nargs == 0);
|
|
481
|
+
|
|
482
|
+
if (nin != 2) {
|
|
483
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
484
|
+
"invalid number of arguments for %s(x, y): expected 2, got %d",
|
|
485
|
+
f->name, nin);
|
|
486
|
+
return NULL;
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
t0 = types[0];
|
|
490
|
+
t1 = types[1];
|
|
491
|
+
assert(ndt_is_concrete(t0));
|
|
492
|
+
assert(ndt_is_concrete(t1));
|
|
493
|
+
|
|
494
|
+
n = kernel_location(t0, t1, ctx);
|
|
495
|
+
if (n < 0) {
|
|
496
|
+
return NULL;
|
|
497
|
+
}
|
|
498
|
+
if (ndt_is_optional(ndt_dtype(t0))) {
|
|
499
|
+
n = ndt_is_optional(ndt_dtype(t1)) ? n+3 : n+1;
|
|
500
|
+
}
|
|
501
|
+
else if (ndt_is_optional(ndt_dtype(t1))) {
|
|
502
|
+
n = n+2;
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
const gm_kernel_set_t *set = &f->kernels[n];
|
|
506
|
+
|
|
507
|
+
if (ndt_fast_binary_fixed_typecheck(spec, set->sig, types, nin, nout,
|
|
508
|
+
check_broadcast, ctx) < 0) {
|
|
509
|
+
return NULL;
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
return set;
|
|
513
|
+
}
|