gumath 0.2.0dev5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/CONTRIBUTING.md +61 -0
  3. data/Gemfile +5 -0
  4. data/History.md +0 -0
  5. data/README.md +5 -0
  6. data/Rakefile +105 -0
  7. data/ext/ruby_gumath/examples.c +126 -0
  8. data/ext/ruby_gumath/extconf.rb +97 -0
  9. data/ext/ruby_gumath/functions.c +106 -0
  10. data/ext/ruby_gumath/gufunc_object.c +79 -0
  11. data/ext/ruby_gumath/gufunc_object.h +55 -0
  12. data/ext/ruby_gumath/gumath/AUTHORS.txt +5 -0
  13. data/ext/ruby_gumath/gumath/INSTALL.txt +42 -0
  14. data/ext/ruby_gumath/gumath/LICENSE.txt +29 -0
  15. data/ext/ruby_gumath/gumath/MANIFEST.in +3 -0
  16. data/ext/ruby_gumath/gumath/Makefile.in +62 -0
  17. data/ext/ruby_gumath/gumath/README.rst +20 -0
  18. data/ext/ruby_gumath/gumath/config.guess +1530 -0
  19. data/ext/ruby_gumath/gumath/config.h.in +52 -0
  20. data/ext/ruby_gumath/gumath/config.sub +1782 -0
  21. data/ext/ruby_gumath/gumath/configure +5049 -0
  22. data/ext/ruby_gumath/gumath/configure.ac +167 -0
  23. data/ext/ruby_gumath/gumath/doc/_static/copybutton.js +66 -0
  24. data/ext/ruby_gumath/gumath/doc/conf.py +26 -0
  25. data/ext/ruby_gumath/gumath/doc/gumath/functions.rst +62 -0
  26. data/ext/ruby_gumath/gumath/doc/gumath/index.rst +26 -0
  27. data/ext/ruby_gumath/gumath/doc/index.rst +45 -0
  28. data/ext/ruby_gumath/gumath/doc/libgumath/data-structures.rst +130 -0
  29. data/ext/ruby_gumath/gumath/doc/libgumath/functions.rst +78 -0
  30. data/ext/ruby_gumath/gumath/doc/libgumath/index.rst +25 -0
  31. data/ext/ruby_gumath/gumath/doc/libgumath/kernels.rst +41 -0
  32. data/ext/ruby_gumath/gumath/doc/releases/index.rst +11 -0
  33. data/ext/ruby_gumath/gumath/install-sh +527 -0
  34. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +170 -0
  35. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +160 -0
  36. data/ext/ruby_gumath/gumath/libgumath/apply.c +201 -0
  37. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +130 -0
  38. data/ext/ruby_gumath/gumath/libgumath/extending/examples.c +176 -0
  39. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +393 -0
  40. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +140 -0
  41. data/ext/ruby_gumath/gumath/libgumath/extending/quaternion.c +156 -0
  42. data/ext/ruby_gumath/gumath/libgumath/func.c +177 -0
  43. data/ext/ruby_gumath/gumath/libgumath/gumath.h +205 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +547 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +449 -0
  46. data/ext/ruby_gumath/gumath/libgumath/nploops.c +219 -0
  47. data/ext/ruby_gumath/gumath/libgumath/tbl.c +223 -0
  48. data/ext/ruby_gumath/gumath/libgumath/thread.c +175 -0
  49. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +130 -0
  50. data/ext/ruby_gumath/gumath/python/extending.py +24 -0
  51. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +74 -0
  52. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +577 -0
  53. data/ext/ruby_gumath/gumath/python/gumath/examples.c +93 -0
  54. data/ext/ruby_gumath/gumath/python/gumath/functions.c +77 -0
  55. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +95 -0
  56. data/ext/ruby_gumath/gumath/python/test_gumath.py +405 -0
  57. data/ext/ruby_gumath/gumath/setup.py +298 -0
  58. data/ext/ruby_gumath/gumath/vcbuild/INSTALL.txt +36 -0
  59. data/ext/ruby_gumath/gumath/vcbuild/vcbuild32.bat +21 -0
  60. data/ext/ruby_gumath/gumath/vcbuild/vcbuild64.bat +21 -0
  61. data/ext/ruby_gumath/gumath/vcbuild/vcclean.bat +10 -0
  62. data/ext/ruby_gumath/gumath/vcbuild/vcdistclean.bat +11 -0
  63. data/ext/ruby_gumath/include/gumath.h +205 -0
  64. data/ext/ruby_gumath/include/ruby_gumath.h +41 -0
  65. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  66. data/ext/ruby_gumath/lib/libgumath.so +1 -0
  67. data/ext/ruby_gumath/lib/libgumath.so.0 +1 -0
  68. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  69. data/ext/ruby_gumath/ruby_gumath.c +295 -0
  70. data/ext/ruby_gumath/ruby_gumath.h +41 -0
  71. data/ext/ruby_gumath/ruby_gumath_internal.h +45 -0
  72. data/ext/ruby_gumath/util.c +68 -0
  73. data/ext/ruby_gumath/util.h +48 -0
  74. data/gumath.gemspec +47 -0
  75. data/lib/gumath.rb +7 -0
  76. data/lib/gumath/version.rb +5 -0
  77. data/lib/ruby_gumath.so +0 -0
  78. metadata +206 -0
@@ -0,0 +1,205 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #ifndef GUMATH_H
35
+ #define GUMATH_H
36
+
37
+ #include "ndtypes.h"
38
+ #include "xnd.h"
39
+
40
+
41
+ #ifdef _MSC_VER
42
+ #if defined (GM_EXPORT)
43
+ #define GM_API __declspec(dllexport)
44
+ #elif defined(GM_IMPORT)
45
+ #define GM_API __declspec(dllimport)
46
+ #else
47
+ #define GM_API
48
+ #endif
49
+
50
+ #ifndef GM_UNUSED
51
+ #define GM_UNUSED
52
+ #endif
53
+
54
+ #include "malloc.h"
55
+ #define ALLOCA(type, name, nmemb) type *name = _alloca(nmemb * sizeof(type))
56
+ #else
57
+ #define GM_API
58
+ #if defined(__GNUC__) && !defined(__INTEL_COMPILER)
59
+ #define GM_UNUSED __attribute__((unused))
60
+ #else
61
+ #define GM_UNUSED
62
+ #endif
63
+
64
+ #define ALLOCA(type, name, nmemb) type name[nmemb]
65
+ #endif
66
+
67
+
68
+ #define GM_MAX_KERNELS 512
69
+
70
+ typedef float float32_t;
71
+ typedef double float64_t;
72
+
73
+
74
+ typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
75
+ typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
76
+
77
+ /* Collection of specialized kernels for a single function signature. */
78
+ typedef struct {
79
+ ndt_t *sig;
80
+ const ndt_constraint_t *constraint;
81
+
82
+ /* Xnd signatures */
83
+ gm_xnd_kernel_t Opt; /* dispatch ensures elementwise, at least 1D, contiguous in last dimensions */
84
+ gm_xnd_kernel_t C; /* dispatch ensures c-contiguous in inner dimensions */
85
+ gm_xnd_kernel_t Fortran; /* dispatch ensures f-contiguous in inner dimensions */
86
+ gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
87
+
88
+ /* NumPy signature */
89
+ gm_strided_kernel_t Strided;
90
+ } gm_kernel_set_t;
91
+
92
+ typedef struct {
93
+ const char *name;
94
+ const char *type;
95
+ const ndt_methods_t *meth;
96
+ } gm_typedef_init_t;
97
+
98
+ typedef struct {
99
+ const char *name;
100
+ const char *sig;
101
+ const ndt_constraint_t *constraint;
102
+
103
+ gm_xnd_kernel_t Opt;
104
+ gm_xnd_kernel_t C;
105
+ gm_xnd_kernel_t Fortran;
106
+ gm_xnd_kernel_t Xnd;
107
+ gm_strided_kernel_t Strided;
108
+ } gm_kernel_init_t;
109
+
110
+ /* Actual kernel selected for application */
111
+ typedef struct {
112
+ uint32_t flag;
113
+ const gm_kernel_set_t *set;
114
+ } gm_kernel_t;
115
+
116
+ /* Multimethod with associated kernels */
117
+ typedef struct gm_func gm_func_t;
118
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *in[], int nin, ndt_context_t *ctx);
119
+ struct gm_func {
120
+ char *name;
121
+ gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
122
+ int nkernels;
123
+ gm_kernel_set_t kernels[GM_MAX_KERNELS];
124
+ };
125
+
126
+
127
+ typedef struct _gm_tbl gm_tbl_t;
128
+
129
+
130
+ /******************************************************************************/
131
+ /* Functions */
132
+ /******************************************************************************/
133
+
134
+ GM_API gm_func_t *gm_func_new(const char *name, ndt_context_t *ctx);
135
+ GM_API void gm_func_del(gm_func_t *f);
136
+
137
+ GM_API gm_func_t *gm_add_func(gm_tbl_t *tbl, const char *name, ndt_context_t *ctx);
138
+ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx);
139
+ GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
140
+
141
+ GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
142
+ const ndt_t *in_types[], int nin, const xnd_t args[],
143
+ ndt_context_t *ctx);
144
+ GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
145
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, uint32_t flags, const int64_t nthreads, ndt_context_t *ctx);
146
+
147
+
148
+ /******************************************************************************/
149
+ /* NumPy loops */
150
+ /******************************************************************************/
151
+
152
+ GM_API int gm_np_flatten(char **args, const int nargs,
153
+ int64_t *dimensions, int64_t *strides, const xnd_t stack[],
154
+ ndt_context_t *ctx);
155
+
156
+ GM_API int gm_np_convert_xnd(char **args, const int nargs,
157
+ intptr_t *dimensions, const int dims_size,
158
+ intptr_t *steps, const int steps_size,
159
+ xnd_t stack[], const int outer_dims,
160
+ ndt_context_t *ctx);
161
+
162
+ GM_API int gm_np_map(const gm_strided_kernel_t f,
163
+ char **args, int nargs,
164
+ intptr_t *dimensions,
165
+ intptr_t *steps,
166
+ void *data,
167
+ int outer_dims);
168
+
169
+
170
+ /******************************************************************************/
171
+ /* Xnd loops */
172
+ /******************************************************************************/
173
+
174
+ GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
175
+ const int outer_dims, ndt_context_t *ctx);
176
+
177
+
178
+ /******************************************************************************/
179
+ /* Gufunc table */
180
+ /******************************************************************************/
181
+ GM_API gm_tbl_t *gm_tbl_new(ndt_context_t *ctx);
182
+ GM_API void gm_tbl_del(gm_tbl_t *t);
183
+
184
+ GM_API int gm_tbl_add(gm_tbl_t *tbl, const char *key, gm_func_t *value, ndt_context_t *ctx);
185
+ GM_API gm_func_t *gm_tbl_find(const gm_tbl_t *tbl, const char *key, ndt_context_t *ctx);
186
+ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *state), void *state);
187
+
188
+
189
+ /******************************************************************************/
190
+ /* Library initialization and tables */
191
+ /******************************************************************************/
192
+
193
+ GM_API void gm_init(void);
194
+ GM_API int gm_init_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
195
+ GM_API int gm_init_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
196
+ GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
197
+ GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
198
+ GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
199
+ GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
200
+ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
201
+
202
+ GM_API void gm_finalize(void);
203
+
204
+
205
+ #endif /* GUMATH_H */
@@ -0,0 +1,547 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include <math.h>
38
+ #include <complex.h>
39
+ #include <inttypes.h>
40
+ #include "ndtypes.h"
41
+ #include "xnd.h"
42
+ #include "gumath.h"
43
+
44
+
45
+ /****************************************************************************/
46
+ /* Optimized dispatch (exact casting) */
47
+ /****************************************************************************/
48
+
49
+ /* Structured kernel locations for fast lookup. */
50
+ static ndt_t *
51
+ infer_return_type(int *base, const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx)
52
+ {
53
+ const ndt_t *t0 = ndt_dtype(in0);
54
+ const ndt_t *t1 = ndt_dtype(in1);
55
+ enum ndt tag;
56
+
57
+ switch (t0->tag) {
58
+ case Int8: {
59
+ switch (t1->tag) {
60
+ case Int8: *base = 0; tag = Int8; break;
61
+ case Int16: *base = 2; tag = Int16; break;
62
+ case Int32: *base = 4; tag = Int32; break;
63
+ case Int64: *base = 6; tag = Int64; break;
64
+ case Uint8: *base = 8; tag = Int16; break;
65
+ case Uint16: *base = 10; tag = Int32; break;
66
+ case Uint32: *base = 12; tag = Int64; break;
67
+ case Float32: *base = 14; tag = Float32; break;
68
+ case Float64: *base = 16; tag = Float64; break;
69
+ default: goto invalid_combination;
70
+ }
71
+ break;
72
+ }
73
+ case Int16: {
74
+ switch (t1->tag) {
75
+ case Int8: *base = 18; tag = Int16; break;
76
+ case Int16: *base = 20; tag = Int16; break;
77
+ case Int32: *base = 22; tag = Int32; break;
78
+ case Int64: *base = 24; tag = Int64; break;
79
+ case Uint8: *base = 26; tag = Int16; break;
80
+ case Uint16: *base = 28; tag = Int32; break;
81
+ case Uint32: *base = 30; tag = Int64; break;
82
+ case Float32: *base = 32; tag = Float32; break;
83
+ case Float64: *base = 34; tag = Float64; break;
84
+ default: goto invalid_combination;
85
+ }
86
+ break;
87
+ }
88
+ case Int32: {
89
+ switch (t1->tag) {
90
+ case Int8: *base = 36; tag = Int32; break;
91
+ case Int16: *base = 38; tag = Int32; break;
92
+ case Int32: *base = 40; tag = Int32; break;
93
+ case Int64: *base = 42; tag = Int64; break;
94
+ case Uint8: *base = 44; tag = Int32; break;
95
+ case Uint16: *base = 46; tag = Int32; break;
96
+ case Uint32: *base = 48; tag = Int64; break;
97
+ case Float64: *base = 50; tag = Float64; break;
98
+ default: goto invalid_combination;
99
+ }
100
+ break;
101
+ }
102
+ case Int64: {
103
+ switch (t1->tag) {
104
+ case Int8: *base = 52; tag = Int64; break;
105
+ case Int16: *base = 54; tag = Int64; break;
106
+ case Int32: *base = 56; tag = Int64; break;
107
+ case Int64: *base = 58; tag = Int64; break;
108
+ case Uint8: *base = 60; tag = Int64; break;
109
+ case Uint16: *base = 62; tag = Int64; break;
110
+ case Uint32: *base = 64; tag = Int64; break;
111
+ default: goto invalid_combination;
112
+ }
113
+ break;
114
+ }
115
+ case Uint8: {
116
+ switch (t1->tag) {
117
+ case Int8: *base = 66; tag = Int16; break;
118
+ case Int16: *base = 68; tag = Int16; break;
119
+ case Int32: *base = 70; tag = Int32; break;
120
+ case Int64: *base = 72; tag = Int64; break;
121
+ case Uint8: *base = 74; tag = Uint8; break;
122
+ case Uint16: *base = 76; tag = Uint16; break;
123
+ case Uint32: *base = 78; tag = Uint32; break;
124
+ case Uint64: *base = 80; tag = Uint64; break;
125
+ case Float32: *base = 82; tag = Float32; break;
126
+ case Float64: *base = 84; tag = Float64; break;
127
+ default: goto invalid_combination;
128
+ }
129
+ break;
130
+ }
131
+ case Uint16: {
132
+ switch (t1->tag) {
133
+ case Int8: *base = 86; tag = Int32; break;
134
+ case Int16: *base = 88; tag = Int32; break;
135
+ case Int32: *base = 90; tag = Int32; break;
136
+ case Int64: *base = 92; tag = Int64; break;
137
+ case Uint8: *base = 94; tag = Uint16; break;
138
+ case Uint16: *base = 96; tag = Uint32; break;
139
+ case Uint32: *base = 98; tag = Uint64; break;
140
+ case Uint64: *base = 100; tag = Uint64; break;
141
+ case Float32: *base = 102; tag = Float32; break;
142
+ case Float64: *base = 104; tag = Float64; break;
143
+ default: goto invalid_combination;
144
+ }
145
+ break;
146
+ }
147
+ case Uint32: {
148
+ switch (t1->tag) {
149
+ case Int8: *base = 106; tag = Int64; break;
150
+ case Int16: *base = 108; tag = Int64; break;
151
+ case Int32: *base = 110; tag = Int64; break;
152
+ case Int64: *base = 112; tag = Int64; break;
153
+ case Uint8: *base = 114; tag = Uint32; break;
154
+ case Uint16: *base = 116; tag = Uint32; break;
155
+ case Uint32: *base = 118; tag = Uint32; break;
156
+ case Uint64: *base = 120; tag = Uint64; break;
157
+ case Float64: *base = 122; tag = Float64; break;
158
+ default: goto invalid_combination;
159
+ }
160
+ break;
161
+ }
162
+ case Uint64: {
163
+ switch (t1->tag) {
164
+ case Uint8: *base = 124; tag = Uint64; break;
165
+ case Uint16: *base = 126; tag = Uint64; break;
166
+ case Uint32: *base = 128; tag = Uint64; break;
167
+ case Uint64: *base = 130; tag = Uint64; break;
168
+ default: goto invalid_combination;
169
+ }
170
+ break;
171
+ }
172
+ case Float32: {
173
+ switch (t1->tag) {
174
+ case Int8: *base = 132; tag = Float32; break;
175
+ case Int16: *base = 134; tag = Float32; break;
176
+ case Uint8: *base = 136; tag = Float32; break;
177
+ case Uint16: *base = 138; tag = Float32; break;
178
+ case Float32: *base = 140; tag = Float32; break;
179
+ case Float64: *base = 142; tag = Float64; break;
180
+ default: goto invalid_combination;
181
+ }
182
+ break;
183
+ }
184
+ case Float64: {
185
+ switch (t1->tag) {
186
+ case Int8: *base = 144; tag = Float64; break;
187
+ case Int16: *base = 146; tag = Float64; break;
188
+ case Int32: *base = 148; tag = Float64; break;
189
+ case Uint8: *base = 150; tag = Float64; break;
190
+ case Uint16: *base = 152; tag = Float64; break;
191
+ case Uint32: *base = 154; tag = Float64; break;
192
+ case Float32: *base = 156; tag = Float64; break;
193
+ case Float64: *base = 158; tag = Float64; break;
194
+ default: goto invalid_combination;
195
+ }
196
+ break;
197
+ }
198
+ default:
199
+ goto invalid_combination;
200
+ }
201
+
202
+ return ndt_primitive(tag, 0, ctx);
203
+
204
+ invalid_combination:
205
+ ndt_err_format(ctx, NDT_RuntimeError, "invalid dtype");
206
+ return NULL;
207
+ }
208
+
209
+
210
+ /****************************************************************************/
211
+ /* Optimized typecheck */
212
+ /****************************************************************************/
213
+
214
+ static const gm_kernel_set_t *
215
+ binary_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f,
216
+ const ndt_t *in[], int nin,
217
+ ndt_context_t *ctx)
218
+ {
219
+ const ndt_t *t0;
220
+ const ndt_t *t1;
221
+ ndt_t *dtype;
222
+ int n;
223
+
224
+ if (nin != 2) {
225
+ ndt_err_format(ctx, NDT_ValueError,
226
+ "invalid number of arguments for %s(x, y): expected 2, got %d",
227
+ f->name, nin);
228
+ return NULL;
229
+ }
230
+ t0 = in[0];
231
+ t1 = in[1];
232
+ assert(ndt_is_concrete(t0));
233
+ assert(ndt_is_concrete(t1));
234
+
235
+ dtype = infer_return_type(&n, t0, t1, ctx);
236
+ if (dtype == NULL) {
237
+ return NULL;
238
+ }
239
+
240
+ if (t0->tag == VarDim || t1->tag == VarDim) {
241
+ const gm_kernel_set_t *set = &f->kernels[n+1];
242
+ ndt_del(dtype); /* temporary hack */
243
+ if (ndt_typecheck(spec, set->sig, in, nin, NULL, NULL, ctx) < 0) {
244
+ return NULL;
245
+ }
246
+ return set;
247
+ }
248
+
249
+ const gm_kernel_set_t *set = &f->kernels[n];
250
+ if (ndt_fast_binary_fixed_typecheck(spec, set->sig, in, nin, dtype, ctx) < 0) {
251
+ return NULL;
252
+ }
253
+
254
+ return set;
255
+ }
256
+
257
+
258
+ /****************************************************************************/
259
+ /* Generated Xnd kernels */
260
+ /****************************************************************************/
261
+
262
+ #define XSTRINGIZE(v) #v
263
+ #define STRINGIZE(v) XSTRINGIZE(v)
264
+
265
+ static inline char *
266
+ apply_index(const xnd_t *x)
267
+ {
268
+ return xnd_fixed_apply_index(x);
269
+ }
270
+
271
+
272
+ #define XND_BINARY(func, t0, t1, t2) \
273
+ static int \
274
+ gm_fixed_##func##_1D_C_##t0##_##t1##_##t2(xnd_t stack[], ndt_context_t *ctx) \
275
+ { \
276
+ const t0##_t *in0 = (const t0##_t *)apply_index(&stack[0]); \
277
+ const t1##_t *in1 = (const t1##_t *)apply_index(&stack[1]); \
278
+ t2##_t *out = (t2##_t *)apply_index(&stack[2]); \
279
+ int64_t N = xnd_fixed_shape(&stack[0]); \
280
+ (void)ctx; \
281
+ int64_t i; \
282
+ \
283
+ for (i = 0; i < N-7; i += 8) { \
284
+ out[i] = func(in0[i], in1[i]); \
285
+ out[i+1] = func(in0[i+1], in1[i+1]); \
286
+ out[i+2] = func(in0[i+2], in1[i+2]); \
287
+ out[i+3] = func(in0[i+3], in1[i+3]); \
288
+ out[i+4] = func(in0[i+4], in1[i+4]); \
289
+ out[i+5] = func(in0[i+5], in1[i+5]); \
290
+ out[i+6] = func(in0[i+6], in1[i+6]); \
291
+ out[i+7] = func(in0[i+7], in1[i+7]); \
292
+ } \
293
+ for (; i < N; i++) { \
294
+ out[i] = func(in0[i], in1[i]); \
295
+ } \
296
+ \
297
+ return 0; \
298
+ } \
299
+ \
300
+ static int \
301
+ gm_##func##_0D_##t0##_##t1##_##t2(xnd_t stack[], ndt_context_t *ctx) \
302
+ { \
303
+ const xnd_t *in0 = &stack[0]; \
304
+ const xnd_t *in1 = &stack[1]; \
305
+ xnd_t *out = &stack[2]; \
306
+ (void)ctx; \
307
+ \
308
+ const t0##_t x = *(const t0##_t *)in0->ptr; \
309
+ const t1##_t y = *(const t1##_t *)in1->ptr; \
310
+ *(t2##_t *)out->ptr = func(x, y); \
311
+ \
312
+ return 0; \
313
+ }
314
+
315
+ #define XND_BINARY_INIT(func, t0, t1, t2) \
316
+ { .name = STRINGIZE(func), \
317
+ .sig = "... * " STRINGIZE(t0) ", ... * " STRINGIZE(t1) " -> ... * " STRINGIZE(t2), \
318
+ .Opt = gm_fixed_##func##_1D_C_##t0##_##t1##_##t2, \
319
+ .C = gm_##func##_0D_##t0##_##t1##_##t2 }, \
320
+ \
321
+ { .name = STRINGIZE(func), \
322
+ .sig = "var... * " STRINGIZE(t0) ", var... * " STRINGIZE(t1) " -> var... * " STRINGIZE(t2), \
323
+ .C = gm_##func##_0D_##t0##_##t1##_##t2 }
324
+
325
+ #define XND_ALL_BINARY(name) \
326
+ XND_BINARY(name, int8, int8, int8) \
327
+ XND_BINARY(name, int8, int16, int16) \
328
+ XND_BINARY(name, int8, int32, int32) \
329
+ XND_BINARY(name, int8, int64, int64) \
330
+ XND_BINARY(name, int8, uint8, int16) \
331
+ XND_BINARY(name, int8, uint16, int32) \
332
+ XND_BINARY(name, int8, uint32, int64) \
333
+ XND_BINARY(name, int8, float32, float32) \
334
+ XND_BINARY(name, int8, float64, float64) \
335
+ \
336
+ XND_BINARY(name, int16, int8, int16) \
337
+ XND_BINARY(name, int16, int16, int16) \
338
+ XND_BINARY(name, int16, int32, int32) \
339
+ XND_BINARY(name, int16, int64, int64) \
340
+ XND_BINARY(name, int16, uint8, int16) \
341
+ XND_BINARY(name, int16, uint16, int32) \
342
+ XND_BINARY(name, int16, uint32, int64) \
343
+ XND_BINARY(name, int16, float32, float32) \
344
+ XND_BINARY(name, int16, float64, float64) \
345
+ \
346
+ XND_BINARY(name, int32, int8, int32) \
347
+ XND_BINARY(name, int32, int16, int32) \
348
+ XND_BINARY(name, int32, int32, int32) \
349
+ XND_BINARY(name, int32, int64, int64) \
350
+ XND_BINARY(name, int32, uint8, int32) \
351
+ XND_BINARY(name, int32, uint16, int32) \
352
+ XND_BINARY(name, int32, uint32, int64) \
353
+ XND_BINARY(name, int32, float64, float64) \
354
+ \
355
+ XND_BINARY(name, int64, int8, int64) \
356
+ XND_BINARY(name, int64, int16, int64) \
357
+ XND_BINARY(name, int64, int32, int64) \
358
+ XND_BINARY(name, int64, int64, int64) \
359
+ XND_BINARY(name, int64, uint8, int64) \
360
+ XND_BINARY(name, int64, uint16, int64) \
361
+ XND_BINARY(name, int64, uint32, int64) \
362
+ \
363
+ XND_BINARY(name, uint8, int8, int16) \
364
+ XND_BINARY(name, uint8, int16, int16) \
365
+ XND_BINARY(name, uint8, int32, int32) \
366
+ XND_BINARY(name, uint8, int64, int64) \
367
+ XND_BINARY(name, uint8, uint8, uint8) \
368
+ XND_BINARY(name, uint8, uint16, uint16) \
369
+ XND_BINARY(name, uint8, uint32, uint32) \
370
+ XND_BINARY(name, uint8, uint64, uint64) \
371
+ XND_BINARY(name, uint8, float32, float32) \
372
+ XND_BINARY(name, uint8, float64, float64) \
373
+ \
374
+ XND_BINARY(name, uint16, int8, int32) \
375
+ XND_BINARY(name, uint16, int16, int32) \
376
+ XND_BINARY(name, uint16, int32, int32) \
377
+ XND_BINARY(name, uint16, int64, int64) \
378
+ XND_BINARY(name, uint16, uint8, uint16) \
379
+ XND_BINARY(name, uint16, uint16, uint16) \
380
+ XND_BINARY(name, uint16, uint32, uint32) \
381
+ XND_BINARY(name, uint16, uint64, uint64) \
382
+ XND_BINARY(name, uint16, float32, float32) \
383
+ XND_BINARY(name, uint16, float64, float64) \
384
+ \
385
+ XND_BINARY(name, uint32, int8, int64) \
386
+ XND_BINARY(name, uint32, int16, int64) \
387
+ XND_BINARY(name, uint32, int32, int64) \
388
+ XND_BINARY(name, uint32, int64, int64) \
389
+ XND_BINARY(name, uint32, uint8, uint32) \
390
+ XND_BINARY(name, uint32, uint16, uint32) \
391
+ XND_BINARY(name, uint32, uint32, uint32) \
392
+ XND_BINARY(name, uint32, uint64, uint64) \
393
+ XND_BINARY(name, uint32, float64, float64) \
394
+ \
395
+ XND_BINARY(name, uint64, uint8, uint64) \
396
+ XND_BINARY(name, uint64, uint16, uint64) \
397
+ XND_BINARY(name, uint64, uint32, uint64) \
398
+ XND_BINARY(name, uint64, uint64, uint64) \
399
+ \
400
+ XND_BINARY(name, float32, int8, float32) \
401
+ XND_BINARY(name, float32, int16, float32) \
402
+ XND_BINARY(name, float32, uint8, float32) \
403
+ XND_BINARY(name, float32, uint16, float32) \
404
+ XND_BINARY(name, float32, float32, float32) \
405
+ XND_BINARY(name, float32, float64, float64) \
406
+ \
407
+ XND_BINARY(name, float64, int8, float64) \
408
+ XND_BINARY(name, float64, int16, float64) \
409
+ XND_BINARY(name, float64, int32, float64) \
410
+ XND_BINARY(name, float64, uint8, float64) \
411
+ XND_BINARY(name, float64, uint16, float64) \
412
+ XND_BINARY(name, float64, uint32, float64) \
413
+ XND_BINARY(name, float64, float32, float64) \
414
+ XND_BINARY(name, float64, float64, float64)
415
+
416
+ #define XND_ALL_BINARY_INIT(name) \
417
+ XND_BINARY_INIT(name, int8, int8, int8), \
418
+ XND_BINARY_INIT(name, int8, int16, int16), \
419
+ XND_BINARY_INIT(name, int8, int32, int32), \
420
+ XND_BINARY_INIT(name, int8, int64, int64), \
421
+ XND_BINARY_INIT(name, int8, uint8, int16), \
422
+ XND_BINARY_INIT(name, int8, uint16, int32), \
423
+ XND_BINARY_INIT(name, int8, uint32, int64), \
424
+ XND_BINARY_INIT(name, int8, float32, float32), \
425
+ XND_BINARY_INIT(name, int8, float64, float64), \
426
+ \
427
+ XND_BINARY_INIT(name, int16, int8, int16), \
428
+ XND_BINARY_INIT(name, int16, int16, int16), \
429
+ XND_BINARY_INIT(name, int16, int32, int32), \
430
+ XND_BINARY_INIT(name, int16, int64, int64), \
431
+ XND_BINARY_INIT(name, int16, uint8, int16), \
432
+ XND_BINARY_INIT(name, int16, uint16, int32), \
433
+ XND_BINARY_INIT(name, int16, uint32, int64), \
434
+ XND_BINARY_INIT(name, int16, float32, float32), \
435
+ XND_BINARY_INIT(name, int16, float64, float64), \
436
+ \
437
+ XND_BINARY_INIT(name, int32, int8, int32), \
438
+ XND_BINARY_INIT(name, int32, int16, int32), \
439
+ XND_BINARY_INIT(name, int32, int32, int32), \
440
+ XND_BINARY_INIT(name, int32, int64, int64), \
441
+ XND_BINARY_INIT(name, int32, uint8, int32), \
442
+ XND_BINARY_INIT(name, int32, uint16, int32), \
443
+ XND_BINARY_INIT(name, int32, uint32, int64), \
444
+ XND_BINARY_INIT(name, int32, float64, float64), \
445
+ \
446
+ XND_BINARY_INIT(name, int64, int8, int64), \
447
+ XND_BINARY_INIT(name, int64, int16, int64), \
448
+ XND_BINARY_INIT(name, int64, int32, int64), \
449
+ XND_BINARY_INIT(name, int64, int64, int64), \
450
+ XND_BINARY_INIT(name, int64, uint8, int64), \
451
+ XND_BINARY_INIT(name, int64, uint16, int64), \
452
+ XND_BINARY_INIT(name, int64, uint32, int64), \
453
+ \
454
+ XND_BINARY_INIT(name, uint8, int8, int16), \
455
+ XND_BINARY_INIT(name, uint8, int16, int16), \
456
+ XND_BINARY_INIT(name, uint8, int32, int32), \
457
+ XND_BINARY_INIT(name, uint8, int64, int64), \
458
+ XND_BINARY_INIT(name, uint8, uint8, uint8), \
459
+ XND_BINARY_INIT(name, uint8, uint16, uint16), \
460
+ XND_BINARY_INIT(name, uint8, uint32, uint32), \
461
+ XND_BINARY_INIT(name, uint8, uint64, uint64), \
462
+ XND_BINARY_INIT(name, uint8, float32, float32), \
463
+ XND_BINARY_INIT(name, uint8, float64, float64), \
464
+ \
465
+ XND_BINARY_INIT(name, uint16, int8, int32), \
466
+ XND_BINARY_INIT(name, uint16, int16, int32), \
467
+ XND_BINARY_INIT(name, uint16, int32, int32), \
468
+ XND_BINARY_INIT(name, uint16, int64, int64), \
469
+ XND_BINARY_INIT(name, uint16, uint8, uint16), \
470
+ XND_BINARY_INIT(name, uint16, uint16, uint16), \
471
+ XND_BINARY_INIT(name, uint16, uint32, uint32), \
472
+ XND_BINARY_INIT(name, uint16, uint64, uint64), \
473
+ XND_BINARY_INIT(name, uint16, float32, float32), \
474
+ XND_BINARY_INIT(name, uint16, float64, float64), \
475
+ \
476
+ XND_BINARY_INIT(name, uint32, int8, int64), \
477
+ XND_BINARY_INIT(name, uint32, int16, int64), \
478
+ XND_BINARY_INIT(name, uint32, int32, int64), \
479
+ XND_BINARY_INIT(name, uint32, int64, int64), \
480
+ XND_BINARY_INIT(name, uint32, uint8, uint32), \
481
+ XND_BINARY_INIT(name, uint32, uint16, uint32), \
482
+ XND_BINARY_INIT(name, uint32, uint32, uint32), \
483
+ XND_BINARY_INIT(name, uint32, uint64, uint64), \
484
+ XND_BINARY_INIT(name, uint32, float64, float64), \
485
+ \
486
+ XND_BINARY_INIT(name, uint64, uint8, uint64), \
487
+ XND_BINARY_INIT(name, uint64, uint16, uint64), \
488
+ XND_BINARY_INIT(name, uint64, uint32, uint64), \
489
+ XND_BINARY_INIT(name, uint64, uint64, uint64), \
490
+ \
491
+ XND_BINARY_INIT(name, float32, int8, float32), \
492
+ XND_BINARY_INIT(name, float32, int16, float32), \
493
+ XND_BINARY_INIT(name, float32, uint8, float32), \
494
+ XND_BINARY_INIT(name, float32, uint16, float32), \
495
+ XND_BINARY_INIT(name, float32, float32, float32), \
496
+ XND_BINARY_INIT(name, float32, float64, float64), \
497
+ \
498
+ XND_BINARY_INIT(name, float64, int8, float64), \
499
+ XND_BINARY_INIT(name, float64, int16, float64), \
500
+ XND_BINARY_INIT(name, float64, int32, float64), \
501
+ XND_BINARY_INIT(name, float64, uint8, float64), \
502
+ XND_BINARY_INIT(name, float64, uint16, float64), \
503
+ XND_BINARY_INIT(name, float64, uint32, float64), \
504
+ XND_BINARY_INIT(name, float64, float32, float64), \
505
+ XND_BINARY_INIT(name, float64, float64, float64)
506
+
507
+
508
+ #define add(x, y) x + y
509
+ XND_ALL_BINARY(add)
510
+
511
+ #define subtract(x, y) x - y
512
+ XND_ALL_BINARY(subtract)
513
+
514
+ #define multiply(x, y) x * y
515
+ XND_ALL_BINARY(multiply)
516
+
517
+ #define divide(x, y) x / y
518
+ XND_ALL_BINARY(divide)
519
+
520
+
521
+ static const gm_kernel_init_t kernels[] = {
522
+ XND_ALL_BINARY_INIT(add),
523
+ XND_ALL_BINARY_INIT(subtract),
524
+ XND_ALL_BINARY_INIT(multiply),
525
+ XND_ALL_BINARY_INIT(divide),
526
+
527
+ { .name = NULL, .sig = NULL }
528
+ };
529
+
530
+
531
+ /****************************************************************************/
532
+ /* Initialize kernel table */
533
+ /****************************************************************************/
534
+
535
+ int
536
+ gm_init_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx)
537
+ {
538
+ const gm_kernel_init_t *k;
539
+
540
+ for (k = kernels; k->name != NULL; k++) {
541
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &binary_typecheck) < 0) {
542
+ return -1;
543
+ }
544
+ }
545
+
546
+ return 0;
547
+ }