gumath 0.2.0dev5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CONTRIBUTING.md +61 -0
- data/Gemfile +5 -0
- data/History.md +0 -0
- data/README.md +5 -0
- data/Rakefile +105 -0
- data/ext/ruby_gumath/examples.c +126 -0
- data/ext/ruby_gumath/extconf.rb +97 -0
- data/ext/ruby_gumath/functions.c +106 -0
- data/ext/ruby_gumath/gufunc_object.c +79 -0
- data/ext/ruby_gumath/gufunc_object.h +55 -0
- data/ext/ruby_gumath/gumath/AUTHORS.txt +5 -0
- data/ext/ruby_gumath/gumath/INSTALL.txt +42 -0
- data/ext/ruby_gumath/gumath/LICENSE.txt +29 -0
- data/ext/ruby_gumath/gumath/MANIFEST.in +3 -0
- data/ext/ruby_gumath/gumath/Makefile.in +62 -0
- data/ext/ruby_gumath/gumath/README.rst +20 -0
- data/ext/ruby_gumath/gumath/config.guess +1530 -0
- data/ext/ruby_gumath/gumath/config.h.in +52 -0
- data/ext/ruby_gumath/gumath/config.sub +1782 -0
- data/ext/ruby_gumath/gumath/configure +5049 -0
- data/ext/ruby_gumath/gumath/configure.ac +167 -0
- data/ext/ruby_gumath/gumath/doc/_static/copybutton.js +66 -0
- data/ext/ruby_gumath/gumath/doc/conf.py +26 -0
- data/ext/ruby_gumath/gumath/doc/gumath/functions.rst +62 -0
- data/ext/ruby_gumath/gumath/doc/gumath/index.rst +26 -0
- data/ext/ruby_gumath/gumath/doc/index.rst +45 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/data-structures.rst +130 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/functions.rst +78 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/index.rst +25 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/kernels.rst +41 -0
- data/ext/ruby_gumath/gumath/doc/releases/index.rst +11 -0
- data/ext/ruby_gumath/gumath/install-sh +527 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +170 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +160 -0
- data/ext/ruby_gumath/gumath/libgumath/apply.c +201 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +130 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/examples.c +176 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +393 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +140 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/quaternion.c +156 -0
- data/ext/ruby_gumath/gumath/libgumath/func.c +177 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +205 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +547 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +449 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.c +219 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.c +223 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +175 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +130 -0
- data/ext/ruby_gumath/gumath/python/extending.py +24 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +74 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +577 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +93 -0
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +77 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +95 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +405 -0
- data/ext/ruby_gumath/gumath/setup.py +298 -0
- data/ext/ruby_gumath/gumath/vcbuild/INSTALL.txt +36 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcbuild32.bat +21 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcbuild64.bat +21 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcclean.bat +10 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcdistclean.bat +11 -0
- data/ext/ruby_gumath/include/gumath.h +205 -0
- data/ext/ruby_gumath/include/ruby_gumath.h +41 -0
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so +1 -0
- data/ext/ruby_gumath/lib/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +295 -0
- data/ext/ruby_gumath/ruby_gumath.h +41 -0
- data/ext/ruby_gumath/ruby_gumath_internal.h +45 -0
- data/ext/ruby_gumath/util.c +68 -0
- data/ext/ruby_gumath/util.h +48 -0
- data/gumath.gemspec +47 -0
- data/lib/gumath.rb +7 -0
- data/lib/gumath/version.rb +5 -0
- data/lib/ruby_gumath.so +0 -0
- metadata +206 -0
@@ -0,0 +1,93 @@
|
|
1
|
+
#include <Python.h>
|
2
|
+
#include "ndtypes.h"
|
3
|
+
#include "pyndtypes.h"
|
4
|
+
#include "gumath.h"
|
5
|
+
#include "pygumath.h"
|
6
|
+
|
7
|
+
|
8
|
+
/****************************************************************************/
|
9
|
+
/* Module globals */
|
10
|
+
/****************************************************************************/
|
11
|
+
|
12
|
+
/* Function table */
|
13
|
+
static gm_tbl_t *table = NULL;
|
14
|
+
|
15
|
+
|
16
|
+
/****************************************************************************/
|
17
|
+
/* Module */
|
18
|
+
/****************************************************************************/
|
19
|
+
|
20
|
+
static struct PyModuleDef examples_module = {
|
21
|
+
PyModuleDef_HEAD_INIT, /* m_base */
|
22
|
+
"examples", /* m_name */
|
23
|
+
NULL, /* m_doc */
|
24
|
+
-1, /* m_size */
|
25
|
+
NULL, /* m_methods */
|
26
|
+
NULL, /* m_slots */
|
27
|
+
NULL, /* m_traverse */
|
28
|
+
NULL, /* m_clear */
|
29
|
+
NULL /* m_free */
|
30
|
+
};
|
31
|
+
|
32
|
+
|
33
|
+
PyMODINIT_FUNC
|
34
|
+
PyInit_examples(void)
|
35
|
+
{
|
36
|
+
NDT_STATIC_CONTEXT(ctx);
|
37
|
+
PyObject *m = NULL;
|
38
|
+
static int initialized = 0;
|
39
|
+
|
40
|
+
if (!initialized) {
|
41
|
+
if (import_ndtypes() < 0) {
|
42
|
+
return NULL;
|
43
|
+
}
|
44
|
+
if (import_gumath() < 0) {
|
45
|
+
return NULL;
|
46
|
+
}
|
47
|
+
|
48
|
+
table = gm_tbl_new(&ctx);
|
49
|
+
if (table == NULL) {
|
50
|
+
return Ndt_SetError(&ctx);
|
51
|
+
}
|
52
|
+
|
53
|
+
/* custom examples */
|
54
|
+
if (gm_init_example_kernels(table, &ctx) < 0) {
|
55
|
+
return Ndt_SetError(&ctx);
|
56
|
+
}
|
57
|
+
|
58
|
+
/* extending examples */
|
59
|
+
#ifndef _MSC_VER
|
60
|
+
if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
|
61
|
+
return Ndt_SetError(&ctx);
|
62
|
+
}
|
63
|
+
#endif
|
64
|
+
if (gm_init_graph_kernels(table, &ctx) < 0) {
|
65
|
+
return Ndt_SetError(&ctx);
|
66
|
+
}
|
67
|
+
#ifndef _MSC_VER
|
68
|
+
if (gm_init_quaternion_kernels(table, &ctx) < 0) {
|
69
|
+
return Ndt_SetError(&ctx);
|
70
|
+
}
|
71
|
+
#endif
|
72
|
+
if (gm_init_pdist_kernels(table, &ctx) < 0) {
|
73
|
+
return Ndt_SetError(&ctx);
|
74
|
+
}
|
75
|
+
|
76
|
+
initialized = 1;
|
77
|
+
}
|
78
|
+
|
79
|
+
m = PyModule_Create(&examples_module);
|
80
|
+
if (m == NULL) {
|
81
|
+
goto error;
|
82
|
+
}
|
83
|
+
|
84
|
+
if (Gumath_AddFunctions(m, table) < 0) {
|
85
|
+
goto error;
|
86
|
+
}
|
87
|
+
|
88
|
+
return m;
|
89
|
+
|
90
|
+
error:
|
91
|
+
Py_CLEAR(m);
|
92
|
+
return NULL;
|
93
|
+
}
|
@@ -0,0 +1,77 @@
|
|
1
|
+
#include <Python.h>
|
2
|
+
#include "ndtypes.h"
|
3
|
+
#include "pyndtypes.h"
|
4
|
+
#include "gumath.h"
|
5
|
+
#include "pygumath.h"
|
6
|
+
|
7
|
+
|
8
|
+
/****************************************************************************/
|
9
|
+
/* Module globals */
|
10
|
+
/****************************************************************************/
|
11
|
+
|
12
|
+
/* Function table */
|
13
|
+
static gm_tbl_t *table = NULL;
|
14
|
+
|
15
|
+
|
16
|
+
/****************************************************************************/
|
17
|
+
/* Module */
|
18
|
+
/****************************************************************************/
|
19
|
+
|
20
|
+
static struct PyModuleDef functions_module = {
|
21
|
+
PyModuleDef_HEAD_INIT, /* m_base */
|
22
|
+
"functions", /* m_name */
|
23
|
+
NULL, /* m_doc */
|
24
|
+
-1, /* m_size */
|
25
|
+
NULL, /* m_methods */
|
26
|
+
NULL, /* m_slots */
|
27
|
+
NULL, /* m_traverse */
|
28
|
+
NULL, /* m_clear */
|
29
|
+
NULL /* m_free */
|
30
|
+
};
|
31
|
+
|
32
|
+
|
33
|
+
PyMODINIT_FUNC
|
34
|
+
PyInit_functions(void)
|
35
|
+
{
|
36
|
+
NDT_STATIC_CONTEXT(ctx);
|
37
|
+
PyObject *m = NULL;
|
38
|
+
static int initialized = 0;
|
39
|
+
|
40
|
+
if (!initialized) {
|
41
|
+
if (import_ndtypes() < 0) {
|
42
|
+
return NULL;
|
43
|
+
}
|
44
|
+
if (import_gumath() < 0) {
|
45
|
+
return NULL;
|
46
|
+
}
|
47
|
+
|
48
|
+
table = gm_tbl_new(&ctx);
|
49
|
+
if (table == NULL) {
|
50
|
+
return Ndt_SetError(&ctx);
|
51
|
+
}
|
52
|
+
|
53
|
+
if (gm_init_unary_kernels(table, &ctx) < 0) {
|
54
|
+
return Ndt_SetError(&ctx);
|
55
|
+
}
|
56
|
+
if (gm_init_binary_kernels(table, &ctx) < 0) {
|
57
|
+
return Ndt_SetError(&ctx);
|
58
|
+
}
|
59
|
+
|
60
|
+
initialized = 1;
|
61
|
+
}
|
62
|
+
|
63
|
+
m = PyModule_Create(&functions_module);
|
64
|
+
if (m == NULL) {
|
65
|
+
goto error;
|
66
|
+
}
|
67
|
+
|
68
|
+
if (Gumath_AddFunctions(m, table) < 0) {
|
69
|
+
goto error;
|
70
|
+
}
|
71
|
+
|
72
|
+
return m;
|
73
|
+
|
74
|
+
error:
|
75
|
+
Py_CLEAR(m);
|
76
|
+
return NULL;
|
77
|
+
}
|
@@ -0,0 +1,95 @@
|
|
1
|
+
/*
|
2
|
+
* BSD 3-Clause License
|
3
|
+
*
|
4
|
+
* Copyright (c) 2017-2018, plures
|
5
|
+
* All rights reserved.
|
6
|
+
*
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
9
|
+
*
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
* this list of conditions and the following disclaimer.
|
12
|
+
*
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
15
|
+
* and/or other materials provided with the distribution.
|
16
|
+
*
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
19
|
+
* this software without specific prior written permission.
|
20
|
+
*
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
*/
|
32
|
+
|
33
|
+
|
34
|
+
#ifndef PYGUMATH_H
|
35
|
+
#define PYGUMATH_H
|
36
|
+
#ifdef __cplusplus
|
37
|
+
extern "C" {
|
38
|
+
#endif
|
39
|
+
|
40
|
+
|
41
|
+
#include <Python.h>
|
42
|
+
#include "gumath.h"
|
43
|
+
|
44
|
+
|
45
|
+
/****************************************************************************/
|
46
|
+
/* Gufunc Object */
|
47
|
+
/****************************************************************************/
|
48
|
+
|
49
|
+
/* Exposed here for the benefit of Numba. The API should not be regarded
|
50
|
+
stable across versions. */
|
51
|
+
|
52
|
+
typedef struct {
|
53
|
+
PyObject_HEAD
|
54
|
+
const gm_tbl_t *tbl; /* kernel table */
|
55
|
+
char *name; /* function name */
|
56
|
+
} GufuncObject;
|
57
|
+
|
58
|
+
|
59
|
+
/****************************************************************************/
|
60
|
+
/* Capsule API */
|
61
|
+
/****************************************************************************/
|
62
|
+
|
63
|
+
#define Gumath_AddFunctions_INDEX 0
|
64
|
+
#define Gumath_AddFunctions_RETURN int
|
65
|
+
#define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
|
66
|
+
|
67
|
+
#define GUMATH_MAX_API 1
|
68
|
+
|
69
|
+
|
70
|
+
#ifdef GUMATH_MODULE
|
71
|
+
static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
|
72
|
+
#else
|
73
|
+
static void **_gumath_api;
|
74
|
+
|
75
|
+
#define Gumath_AddFunctions \
|
76
|
+
(*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
|
77
|
+
|
78
|
+
|
79
|
+
static int
|
80
|
+
import_gumath(void)
|
81
|
+
{
|
82
|
+
_gumath_api = (void **)PyCapsule_Import("gumath._gumath._API", 0);
|
83
|
+
if (_gumath_api == NULL) {
|
84
|
+
return -1;
|
85
|
+
}
|
86
|
+
|
87
|
+
return 0;
|
88
|
+
}
|
89
|
+
#endif
|
90
|
+
|
91
|
+
#ifdef __cplusplus
|
92
|
+
}
|
93
|
+
#endif
|
94
|
+
|
95
|
+
#endif /* PYGUMATH_H */
|
@@ -0,0 +1,405 @@
|
|
1
|
+
#
|
2
|
+
# BSD 3-Clause License
|
3
|
+
#
|
4
|
+
# Copyright (c) 2017-2018, plures
|
5
|
+
# All rights reserved.
|
6
|
+
#
|
7
|
+
# Redistribution and use in source and binary forms, with or without
|
8
|
+
# modification, are permitted provided that the following conditions are met:
|
9
|
+
#
|
10
|
+
# 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
# this list of conditions and the following disclaimer.
|
12
|
+
#
|
13
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
# this list of conditions and the following disclaimer in the documentation
|
15
|
+
# and/or other materials provided with the distribution.
|
16
|
+
#
|
17
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
# contributors may be used to endorse or promote products derived from
|
19
|
+
# this software without specific prior written permission.
|
20
|
+
#
|
21
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
#
|
32
|
+
|
33
|
+
import gumath as gm
|
34
|
+
import gumath.functions as fn
|
35
|
+
import gumath.examples as ex
|
36
|
+
from xnd import xnd
|
37
|
+
from ndtypes import ndt
|
38
|
+
from extending import Graph, bfloat16
|
39
|
+
import sys, time
|
40
|
+
import math
|
41
|
+
import unittest
|
42
|
+
import argparse
|
43
|
+
|
44
|
+
try:
|
45
|
+
import numpy as np
|
46
|
+
except ImportError:
|
47
|
+
np = None
|
48
|
+
|
49
|
+
|
50
|
+
TEST_CASES = [
|
51
|
+
([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
|
52
|
+
|
53
|
+
([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
54
|
+
"2 * 1000 * float64", "float64"),
|
55
|
+
|
56
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
|
57
|
+
|
58
|
+
([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
|
59
|
+
|
60
|
+
([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
61
|
+
"2 * 1000 * float32", "float32"),
|
62
|
+
|
63
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
|
64
|
+
]
|
65
|
+
|
66
|
+
|
67
|
+
class TestCall(unittest.TestCase):
|
68
|
+
|
69
|
+
def test_sin_scalar(self):
|
70
|
+
|
71
|
+
x1 = xnd(1.2, type="float64")
|
72
|
+
y1 = fn.sin(x1)
|
73
|
+
|
74
|
+
x2 = xnd(1.23e1, type="float32")
|
75
|
+
y2 = fn.sin(x2)
|
76
|
+
|
77
|
+
if np is not None:
|
78
|
+
a1 = np.array(1.2, dtype="float64")
|
79
|
+
b1 = np.sin(a1)
|
80
|
+
|
81
|
+
a2 = np.array(1.23e1, dtype="float32")
|
82
|
+
b2 = np.sin(a2)
|
83
|
+
|
84
|
+
np.testing.assert_equal(y1.value, b1)
|
85
|
+
np.testing.assert_equal(y2.value, b2)
|
86
|
+
|
87
|
+
def test_sin(self):
|
88
|
+
|
89
|
+
for lst, t, dtype in TEST_CASES:
|
90
|
+
x = xnd(lst, type=t)
|
91
|
+
y = fn.sin(x)
|
92
|
+
|
93
|
+
if np is not None:
|
94
|
+
a = np.array(lst, dtype=dtype)
|
95
|
+
b = np.sin(a)
|
96
|
+
np.testing.assert_equal(y, b)
|
97
|
+
|
98
|
+
def test_sin_strided(self):
|
99
|
+
|
100
|
+
for lst, t, dtype in TEST_CASES:
|
101
|
+
x = xnd(lst, type=t)
|
102
|
+
if x.type.ndim < 2:
|
103
|
+
continue
|
104
|
+
|
105
|
+
y = x[::-2, ::-2]
|
106
|
+
z = fn.sin(y)
|
107
|
+
|
108
|
+
if np is not None:
|
109
|
+
a = np.array(lst, dtype=dtype)
|
110
|
+
b = a[::-2, ::-2]
|
111
|
+
c = np.sin(b)
|
112
|
+
np.testing.assert_equal(z, c)
|
113
|
+
|
114
|
+
def test_copy(self):
|
115
|
+
|
116
|
+
for lst, t, dtype in TEST_CASES:
|
117
|
+
x = xnd(lst, type=t)
|
118
|
+
y = fn.copy(x)
|
119
|
+
|
120
|
+
if np is not None:
|
121
|
+
a = np.array(lst, dtype=dtype)
|
122
|
+
b = np.copy(a)
|
123
|
+
np.testing.assert_equal(y, b)
|
124
|
+
|
125
|
+
def test_copy_strided(self):
|
126
|
+
|
127
|
+
for lst, t, dtype in TEST_CASES:
|
128
|
+
x = xnd(lst, type=t)
|
129
|
+
if x.type.ndim < 2:
|
130
|
+
continue
|
131
|
+
|
132
|
+
y = x[::-2, ::-2]
|
133
|
+
z = fn.copy(y)
|
134
|
+
|
135
|
+
if np is not None:
|
136
|
+
a = np.array(lst, dtype=dtype)
|
137
|
+
b = a[::-2, ::-2]
|
138
|
+
c = np.copy(b)
|
139
|
+
np.testing.assert_equal(y, b)
|
140
|
+
|
141
|
+
@unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
|
142
|
+
def test_quaternion(self):
|
143
|
+
|
144
|
+
lst = [[[1+2j, 4+3j],
|
145
|
+
[-4+3j, 1-2j]],
|
146
|
+
[[4+2j, 1+10j],
|
147
|
+
[-1+10j, 4-2j]],
|
148
|
+
[[-4+2j, 3+10j],
|
149
|
+
[-3+10j, -4-2j]]]
|
150
|
+
|
151
|
+
x = xnd(lst, type="3 * quaternion64")
|
152
|
+
y = ex.multiply(x, x)
|
153
|
+
|
154
|
+
if np is not None:
|
155
|
+
a = np.array(lst, dtype="complex64")
|
156
|
+
b = np.einsum("ijk,ikl->ijl", a, a)
|
157
|
+
np.testing.assert_equal(y, b)
|
158
|
+
|
159
|
+
x = xnd(lst, type="3 * quaternion128")
|
160
|
+
y = ex.multiply(x, x)
|
161
|
+
|
162
|
+
if np is not None:
|
163
|
+
a = np.array(lst, dtype="complex128")
|
164
|
+
b = np.einsum("ijk,ikl->ijl", a, a)
|
165
|
+
np.testing.assert_equal(y, b)
|
166
|
+
|
167
|
+
x = xnd("xyz")
|
168
|
+
self.assertRaises(TypeError, ex.multiply, x, x)
|
169
|
+
|
170
|
+
@unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
|
171
|
+
def test_quaternion_error(self):
|
172
|
+
|
173
|
+
lst = [[[1+2j, 4+3j],
|
174
|
+
[-4+3j, 1-2j]],
|
175
|
+
[[4+2j, 1+10j],
|
176
|
+
[-1+10j, 4-2j]],
|
177
|
+
[[-4+2j, 3+10j],
|
178
|
+
[-3+10j, -4-2j]]]
|
179
|
+
|
180
|
+
x = xnd(lst, type="3 * Foo(2 * 2 * complex64)")
|
181
|
+
self.assertRaises(TypeError, ex.multiply, x, x)
|
182
|
+
|
183
|
+
def test_void(self):
|
184
|
+
|
185
|
+
x = ex.randint()
|
186
|
+
self.assertEqual(x.type, ndt("int32"))
|
187
|
+
|
188
|
+
def test_multiple_return(self):
|
189
|
+
|
190
|
+
x, y = ex.randtuple()
|
191
|
+
self.assertEqual(x.type, ndt("int32"))
|
192
|
+
self.assertEqual(y.type, ndt("int32"))
|
193
|
+
|
194
|
+
x, y = ex.divmod10(xnd(233))
|
195
|
+
self.assertEqual(x.value, 23)
|
196
|
+
self.assertEqual(y.value, 3)
|
197
|
+
|
198
|
+
|
199
|
+
class TestMissingValues(unittest.TestCase):
|
200
|
+
|
201
|
+
def test_missing_values(self):
|
202
|
+
|
203
|
+
x = [{'index': 0, 'name': 'brazil', 'value': 10},
|
204
|
+
{'index': 1, 'name': 'france', 'value': None},
|
205
|
+
{'index': 1, 'name': 'russia', 'value': 2}]
|
206
|
+
|
207
|
+
y = [{'index': 0, 'name': 'iceland', 'value': 5},
|
208
|
+
{'index': 1, 'name': 'norway', 'value': None},
|
209
|
+
{'index': 1, 'name': 'italy', 'value': None}]
|
210
|
+
|
211
|
+
z = xnd([x, y], type="2 * 3 * {index: int64, name: string, value: ?int64}")
|
212
|
+
ans = ex.count_valid_missing(z)
|
213
|
+
|
214
|
+
self.assertEqual(ans.value, [{'valid': 2, 'missing': 1}, {'valid': 1, 'missing': 2}])
|
215
|
+
|
216
|
+
|
217
|
+
class TestRaggedArrays(unittest.TestCase):
|
218
|
+
|
219
|
+
def test_sin(self):
|
220
|
+
s = math.sin
|
221
|
+
lst = [[[1.0],
|
222
|
+
[2.0, 3.0],
|
223
|
+
[4.0, 5.0, 6.0]],
|
224
|
+
[[7.0],
|
225
|
+
[8.0, 9.0],
|
226
|
+
[10.0, 11.0, 12.0]]]
|
227
|
+
|
228
|
+
ans = [[[s(1.0)],
|
229
|
+
[s(2.0), s(3.0)],
|
230
|
+
[s(4.0), s(5.0), s(6.0)]],
|
231
|
+
[[s(7.0)],
|
232
|
+
[s(8.0), s(9.0)],
|
233
|
+
[s(10.0), s(11.0), s(12.0)]]]
|
234
|
+
|
235
|
+
x = xnd(lst)
|
236
|
+
y = fn.sin(x)
|
237
|
+
self.assertEqual(y.value, ans)
|
238
|
+
|
239
|
+
|
240
|
+
class TestGraphs(unittest.TestCase):
|
241
|
+
|
242
|
+
def test_shortest_path(self):
|
243
|
+
graphs = [[[(1, 1.2), (2, 4.4)],
|
244
|
+
[(2, 2.2)],
|
245
|
+
[(1, 2.3)]],
|
246
|
+
|
247
|
+
[[(1, 1.2), (2, 4.4)],
|
248
|
+
[(2, 2.2)],
|
249
|
+
[(1, 2.3)],
|
250
|
+
[(2, 1.1)]]]
|
251
|
+
|
252
|
+
ans = [[[[0], [0, 1], [0, 1, 2]], # graph1, start 0
|
253
|
+
[[], [1], [1, 2]], # graph1, start 1
|
254
|
+
[[], [2, 1], [2]]], # graph1, start 2
|
255
|
+
|
256
|
+
[[[0], [0, 1], [0, 1, 2], []], # graph2, start 0
|
257
|
+
[[], [1], [1, 2], []], # graph2, start 1
|
258
|
+
[[], [2, 1], [2], []], # graph2, start 2
|
259
|
+
[[], [3, 2, 1], [3, 2], [3]]]] # graph2, start 3
|
260
|
+
|
261
|
+
|
262
|
+
for i, lst in enumerate(graphs):
|
263
|
+
N = len(lst)
|
264
|
+
graph = Graph(lst)
|
265
|
+
for start in range(N):
|
266
|
+
node = xnd(start, type="node")
|
267
|
+
x = graph.shortest_paths(node)
|
268
|
+
self.assertEqual(x.value, ans[i][start])
|
269
|
+
|
270
|
+
def test_constraint(self):
|
271
|
+
lst = [[(0, 1.2)],
|
272
|
+
[(2, 2.2), (1, 0.1)]]
|
273
|
+
|
274
|
+
self.assertRaises(ValueError, Graph, lst)
|
275
|
+
|
276
|
+
|
277
|
+
@unittest.skipIf(sys.platform == "win32", "unresolved external symbols")
|
278
|
+
class TestBFloat16(unittest.TestCase):
|
279
|
+
|
280
|
+
def test_init(self):
|
281
|
+
lst = [1.2e10, 2.1121, -3e20]
|
282
|
+
ans = [11945377792.0, 2.109375, -2.997595911977802e+20]
|
283
|
+
|
284
|
+
x = bfloat16(lst)
|
285
|
+
self.assertEqual(x.value, ans)
|
286
|
+
|
287
|
+
|
288
|
+
class TestPdist(unittest.TestCase):
|
289
|
+
|
290
|
+
def test_exceptions(self):
|
291
|
+
x = xnd([], dtype="float64")
|
292
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
293
|
+
|
294
|
+
x = xnd([[]], dtype="float64")
|
295
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
296
|
+
|
297
|
+
x = xnd([[], []], dtype="float64")
|
298
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
299
|
+
|
300
|
+
x = xnd([[1], [1]], dtype="int64")
|
301
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
302
|
+
|
303
|
+
def test_pdist(self):
|
304
|
+
x = xnd([[1]], dtype="float64")
|
305
|
+
y = ex.euclidian_pdist(x)
|
306
|
+
self.assertEqual(y.value, [])
|
307
|
+
|
308
|
+
x = xnd([[1, 2, 3]], dtype="float64")
|
309
|
+
y = ex.euclidian_pdist(x)
|
310
|
+
self.assertEqual(y.value, [])
|
311
|
+
|
312
|
+
x = xnd([[-1.2200, -100.5000, 20.1250, 30.1230],
|
313
|
+
[ 2.2200, 2.2720, -122.8400, 122.3330],
|
314
|
+
[ 2.1000, -25.0000, 100.2000, -99.5000]], dtype="float64")
|
315
|
+
y = ex.euclidian_pdist(x)
|
316
|
+
self.assertEqual(y.value, [198.78529349275314, 170.0746899276903, 315.75385646576035])
|
317
|
+
|
318
|
+
|
319
|
+
@unittest.skipIf(gm.xndvectorize is None, "test requires numpy and numba")
|
320
|
+
class TestNumba(unittest.TestCase):
|
321
|
+
|
322
|
+
def test_numba(self):
|
323
|
+
|
324
|
+
@gm.xndvectorize("... * N * M * float64, ... * M * P * float64 -> ... * N * P * float64")
|
325
|
+
def matmul(x, y, res):
|
326
|
+
col = np.arange(y.shape[0])
|
327
|
+
for j in range(y.shape[1]):
|
328
|
+
for k in range(y.shape[0]):
|
329
|
+
col[k] = y[k, j]
|
330
|
+
for i in range(x.shape[0]):
|
331
|
+
s = 0
|
332
|
+
for k in range(x.shape[1]):
|
333
|
+
s += x[i, k] * col[k]
|
334
|
+
res[i, j] = s
|
335
|
+
|
336
|
+
a = np.arange(50000.0).reshape(1000, 5, 10)
|
337
|
+
b = np.arange(70000.0).reshape(1000, 10, 7)
|
338
|
+
c = np.einsum("ijk,ikl->ijl", a, b)
|
339
|
+
|
340
|
+
x = xnd(a.tolist(), type="1000 * 5 * 10 * float64")
|
341
|
+
y = xnd(b.tolist(), type="1000 * 10 * 7 * float64")
|
342
|
+
z = matmul(x, y)
|
343
|
+
|
344
|
+
np.testing.assert_equal(z, c)
|
345
|
+
|
346
|
+
def test_numba_add_scalar(self):
|
347
|
+
|
348
|
+
import numba as nb
|
349
|
+
|
350
|
+
@nb.guvectorize(["void(int64[:], int64, int64[:])"], '(n),()->(n)')
|
351
|
+
def g(x, y, res):
|
352
|
+
for i in range(x.shape[0]):
|
353
|
+
res[i] = x[i] + y
|
354
|
+
|
355
|
+
a = np.arange(5000).reshape(100, 5, 10)
|
356
|
+
b = np.arange(500).reshape(100, 5)
|
357
|
+
c = g(a, b)
|
358
|
+
|
359
|
+
x = xnd(a.tolist(), type="100 * 5 * 10 * int64")
|
360
|
+
y = xnd(b.tolist(), type="100 * 5 * int64")
|
361
|
+
z = ex.add_scalar(x, y)
|
362
|
+
|
363
|
+
np.testing.assert_equal(z, c)
|
364
|
+
|
365
|
+
a = np.arange(500)
|
366
|
+
b = np.array(100)
|
367
|
+
c = g(a, b)
|
368
|
+
|
369
|
+
x = xnd(a.tolist(), type="500 * int64")
|
370
|
+
y = xnd(b.tolist(), type="int64")
|
371
|
+
z = ex.add_scalar(x, y)
|
372
|
+
|
373
|
+
np.testing.assert_equal(z, c)
|
374
|
+
|
375
|
+
|
376
|
+
|
377
|
+
ALL_TESTS = [
|
378
|
+
TestCall,
|
379
|
+
TestRaggedArrays,
|
380
|
+
TestMissingValues,
|
381
|
+
TestGraphs,
|
382
|
+
TestBFloat16,
|
383
|
+
TestPdist,
|
384
|
+
TestNumba,
|
385
|
+
]
|
386
|
+
|
387
|
+
|
388
|
+
if __name__ == '__main__':
|
389
|
+
parser = argparse.ArgumentParser()
|
390
|
+
parser.add_argument("-f", "--failfast", action="store_true",
|
391
|
+
help="stop the test run on first error")
|
392
|
+
args = parser.parse_args()
|
393
|
+
|
394
|
+
suite = unittest.TestSuite()
|
395
|
+
loader = unittest.TestLoader()
|
396
|
+
|
397
|
+
for case in ALL_TESTS:
|
398
|
+
s = loader.loadTestsFromTestCase(case)
|
399
|
+
suite.addTest(s)
|
400
|
+
|
401
|
+
runner = unittest.TextTestRunner(failfast=args.failfast, verbosity=2)
|
402
|
+
result = runner.run(suite)
|
403
|
+
ret = not result.wasSuccessful()
|
404
|
+
|
405
|
+
sys.exit(ret)
|