gumath 0.2.0dev5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CONTRIBUTING.md +61 -0
- data/Gemfile +5 -0
- data/History.md +0 -0
- data/README.md +5 -0
- data/Rakefile +105 -0
- data/ext/ruby_gumath/examples.c +126 -0
- data/ext/ruby_gumath/extconf.rb +97 -0
- data/ext/ruby_gumath/functions.c +106 -0
- data/ext/ruby_gumath/gufunc_object.c +79 -0
- data/ext/ruby_gumath/gufunc_object.h +55 -0
- data/ext/ruby_gumath/gumath/AUTHORS.txt +5 -0
- data/ext/ruby_gumath/gumath/INSTALL.txt +42 -0
- data/ext/ruby_gumath/gumath/LICENSE.txt +29 -0
- data/ext/ruby_gumath/gumath/MANIFEST.in +3 -0
- data/ext/ruby_gumath/gumath/Makefile.in +62 -0
- data/ext/ruby_gumath/gumath/README.rst +20 -0
- data/ext/ruby_gumath/gumath/config.guess +1530 -0
- data/ext/ruby_gumath/gumath/config.h.in +52 -0
- data/ext/ruby_gumath/gumath/config.sub +1782 -0
- data/ext/ruby_gumath/gumath/configure +5049 -0
- data/ext/ruby_gumath/gumath/configure.ac +167 -0
- data/ext/ruby_gumath/gumath/doc/_static/copybutton.js +66 -0
- data/ext/ruby_gumath/gumath/doc/conf.py +26 -0
- data/ext/ruby_gumath/gumath/doc/gumath/functions.rst +62 -0
- data/ext/ruby_gumath/gumath/doc/gumath/index.rst +26 -0
- data/ext/ruby_gumath/gumath/doc/index.rst +45 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/data-structures.rst +130 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/functions.rst +78 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/index.rst +25 -0
- data/ext/ruby_gumath/gumath/doc/libgumath/kernels.rst +41 -0
- data/ext/ruby_gumath/gumath/doc/releases/index.rst +11 -0
- data/ext/ruby_gumath/gumath/install-sh +527 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +170 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +160 -0
- data/ext/ruby_gumath/gumath/libgumath/apply.c +201 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +130 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/examples.c +176 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +393 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +140 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/quaternion.c +156 -0
- data/ext/ruby_gumath/gumath/libgumath/func.c +177 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +205 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +547 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +449 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.c +219 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.c +223 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +175 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +130 -0
- data/ext/ruby_gumath/gumath/python/extending.py +24 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +74 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +577 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +93 -0
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +77 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +95 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +405 -0
- data/ext/ruby_gumath/gumath/setup.py +298 -0
- data/ext/ruby_gumath/gumath/vcbuild/INSTALL.txt +36 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcbuild32.bat +21 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcbuild64.bat +21 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcclean.bat +10 -0
- data/ext/ruby_gumath/gumath/vcbuild/vcdistclean.bat +11 -0
- data/ext/ruby_gumath/include/gumath.h +205 -0
- data/ext/ruby_gumath/include/ruby_gumath.h +41 -0
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so +1 -0
- data/ext/ruby_gumath/lib/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +295 -0
- data/ext/ruby_gumath/ruby_gumath.h +41 -0
- data/ext/ruby_gumath/ruby_gumath_internal.h +45 -0
- data/ext/ruby_gumath/util.c +68 -0
- data/ext/ruby_gumath/util.h +48 -0
- data/gumath.gemspec +47 -0
- data/lib/gumath.rb +7 -0
- data/lib/gumath/version.rb +5 -0
- data/lib/ruby_gumath.so +0 -0
- metadata +206 -0
@@ -0,0 +1,93 @@
|
|
1
|
+
#include <Python.h>
|
2
|
+
#include "ndtypes.h"
|
3
|
+
#include "pyndtypes.h"
|
4
|
+
#include "gumath.h"
|
5
|
+
#include "pygumath.h"
|
6
|
+
|
7
|
+
|
8
|
+
/****************************************************************************/
|
9
|
+
/* Module globals */
|
10
|
+
/****************************************************************************/
|
11
|
+
|
12
|
+
/* Function table */
|
13
|
+
static gm_tbl_t *table = NULL;
|
14
|
+
|
15
|
+
|
16
|
+
/****************************************************************************/
|
17
|
+
/* Module */
|
18
|
+
/****************************************************************************/
|
19
|
+
|
20
|
+
static struct PyModuleDef examples_module = {
|
21
|
+
PyModuleDef_HEAD_INIT, /* m_base */
|
22
|
+
"examples", /* m_name */
|
23
|
+
NULL, /* m_doc */
|
24
|
+
-1, /* m_size */
|
25
|
+
NULL, /* m_methods */
|
26
|
+
NULL, /* m_slots */
|
27
|
+
NULL, /* m_traverse */
|
28
|
+
NULL, /* m_clear */
|
29
|
+
NULL /* m_free */
|
30
|
+
};
|
31
|
+
|
32
|
+
|
33
|
+
PyMODINIT_FUNC
|
34
|
+
PyInit_examples(void)
|
35
|
+
{
|
36
|
+
NDT_STATIC_CONTEXT(ctx);
|
37
|
+
PyObject *m = NULL;
|
38
|
+
static int initialized = 0;
|
39
|
+
|
40
|
+
if (!initialized) {
|
41
|
+
if (import_ndtypes() < 0) {
|
42
|
+
return NULL;
|
43
|
+
}
|
44
|
+
if (import_gumath() < 0) {
|
45
|
+
return NULL;
|
46
|
+
}
|
47
|
+
|
48
|
+
table = gm_tbl_new(&ctx);
|
49
|
+
if (table == NULL) {
|
50
|
+
return Ndt_SetError(&ctx);
|
51
|
+
}
|
52
|
+
|
53
|
+
/* custom examples */
|
54
|
+
if (gm_init_example_kernels(table, &ctx) < 0) {
|
55
|
+
return Ndt_SetError(&ctx);
|
56
|
+
}
|
57
|
+
|
58
|
+
/* extending examples */
|
59
|
+
#ifndef _MSC_VER
|
60
|
+
if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
|
61
|
+
return Ndt_SetError(&ctx);
|
62
|
+
}
|
63
|
+
#endif
|
64
|
+
if (gm_init_graph_kernels(table, &ctx) < 0) {
|
65
|
+
return Ndt_SetError(&ctx);
|
66
|
+
}
|
67
|
+
#ifndef _MSC_VER
|
68
|
+
if (gm_init_quaternion_kernels(table, &ctx) < 0) {
|
69
|
+
return Ndt_SetError(&ctx);
|
70
|
+
}
|
71
|
+
#endif
|
72
|
+
if (gm_init_pdist_kernels(table, &ctx) < 0) {
|
73
|
+
return Ndt_SetError(&ctx);
|
74
|
+
}
|
75
|
+
|
76
|
+
initialized = 1;
|
77
|
+
}
|
78
|
+
|
79
|
+
m = PyModule_Create(&examples_module);
|
80
|
+
if (m == NULL) {
|
81
|
+
goto error;
|
82
|
+
}
|
83
|
+
|
84
|
+
if (Gumath_AddFunctions(m, table) < 0) {
|
85
|
+
goto error;
|
86
|
+
}
|
87
|
+
|
88
|
+
return m;
|
89
|
+
|
90
|
+
error:
|
91
|
+
Py_CLEAR(m);
|
92
|
+
return NULL;
|
93
|
+
}
|
@@ -0,0 +1,77 @@
|
|
1
|
+
#include <Python.h>
|
2
|
+
#include "ndtypes.h"
|
3
|
+
#include "pyndtypes.h"
|
4
|
+
#include "gumath.h"
|
5
|
+
#include "pygumath.h"
|
6
|
+
|
7
|
+
|
8
|
+
/****************************************************************************/
|
9
|
+
/* Module globals */
|
10
|
+
/****************************************************************************/
|
11
|
+
|
12
|
+
/* Function table */
|
13
|
+
static gm_tbl_t *table = NULL;
|
14
|
+
|
15
|
+
|
16
|
+
/****************************************************************************/
|
17
|
+
/* Module */
|
18
|
+
/****************************************************************************/
|
19
|
+
|
20
|
+
static struct PyModuleDef functions_module = {
|
21
|
+
PyModuleDef_HEAD_INIT, /* m_base */
|
22
|
+
"functions", /* m_name */
|
23
|
+
NULL, /* m_doc */
|
24
|
+
-1, /* m_size */
|
25
|
+
NULL, /* m_methods */
|
26
|
+
NULL, /* m_slots */
|
27
|
+
NULL, /* m_traverse */
|
28
|
+
NULL, /* m_clear */
|
29
|
+
NULL /* m_free */
|
30
|
+
};
|
31
|
+
|
32
|
+
|
33
|
+
PyMODINIT_FUNC
|
34
|
+
PyInit_functions(void)
|
35
|
+
{
|
36
|
+
NDT_STATIC_CONTEXT(ctx);
|
37
|
+
PyObject *m = NULL;
|
38
|
+
static int initialized = 0;
|
39
|
+
|
40
|
+
if (!initialized) {
|
41
|
+
if (import_ndtypes() < 0) {
|
42
|
+
return NULL;
|
43
|
+
}
|
44
|
+
if (import_gumath() < 0) {
|
45
|
+
return NULL;
|
46
|
+
}
|
47
|
+
|
48
|
+
table = gm_tbl_new(&ctx);
|
49
|
+
if (table == NULL) {
|
50
|
+
return Ndt_SetError(&ctx);
|
51
|
+
}
|
52
|
+
|
53
|
+
if (gm_init_unary_kernels(table, &ctx) < 0) {
|
54
|
+
return Ndt_SetError(&ctx);
|
55
|
+
}
|
56
|
+
if (gm_init_binary_kernels(table, &ctx) < 0) {
|
57
|
+
return Ndt_SetError(&ctx);
|
58
|
+
}
|
59
|
+
|
60
|
+
initialized = 1;
|
61
|
+
}
|
62
|
+
|
63
|
+
m = PyModule_Create(&functions_module);
|
64
|
+
if (m == NULL) {
|
65
|
+
goto error;
|
66
|
+
}
|
67
|
+
|
68
|
+
if (Gumath_AddFunctions(m, table) < 0) {
|
69
|
+
goto error;
|
70
|
+
}
|
71
|
+
|
72
|
+
return m;
|
73
|
+
|
74
|
+
error:
|
75
|
+
Py_CLEAR(m);
|
76
|
+
return NULL;
|
77
|
+
}
|
@@ -0,0 +1,95 @@
|
|
1
|
+
/*
|
2
|
+
* BSD 3-Clause License
|
3
|
+
*
|
4
|
+
* Copyright (c) 2017-2018, plures
|
5
|
+
* All rights reserved.
|
6
|
+
*
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
9
|
+
*
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
* this list of conditions and the following disclaimer.
|
12
|
+
*
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
15
|
+
* and/or other materials provided with the distribution.
|
16
|
+
*
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
19
|
+
* this software without specific prior written permission.
|
20
|
+
*
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
*/
|
32
|
+
|
33
|
+
|
34
|
+
#ifndef PYGUMATH_H
|
35
|
+
#define PYGUMATH_H
|
36
|
+
#ifdef __cplusplus
|
37
|
+
extern "C" {
|
38
|
+
#endif
|
39
|
+
|
40
|
+
|
41
|
+
#include <Python.h>
|
42
|
+
#include "gumath.h"
|
43
|
+
|
44
|
+
|
45
|
+
/****************************************************************************/
|
46
|
+
/* Gufunc Object */
|
47
|
+
/****************************************************************************/
|
48
|
+
|
49
|
+
/* Exposed here for the benefit of Numba. The API should not be regarded
|
50
|
+
stable across versions. */
|
51
|
+
|
52
|
+
typedef struct {
|
53
|
+
PyObject_HEAD
|
54
|
+
const gm_tbl_t *tbl; /* kernel table */
|
55
|
+
char *name; /* function name */
|
56
|
+
} GufuncObject;
|
57
|
+
|
58
|
+
|
59
|
+
/****************************************************************************/
|
60
|
+
/* Capsule API */
|
61
|
+
/****************************************************************************/
|
62
|
+
|
63
|
+
#define Gumath_AddFunctions_INDEX 0
|
64
|
+
#define Gumath_AddFunctions_RETURN int
|
65
|
+
#define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
|
66
|
+
|
67
|
+
#define GUMATH_MAX_API 1
|
68
|
+
|
69
|
+
|
70
|
+
#ifdef GUMATH_MODULE
|
71
|
+
static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
|
72
|
+
#else
|
73
|
+
static void **_gumath_api;
|
74
|
+
|
75
|
+
#define Gumath_AddFunctions \
|
76
|
+
(*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
|
77
|
+
|
78
|
+
|
79
|
+
static int
|
80
|
+
import_gumath(void)
|
81
|
+
{
|
82
|
+
_gumath_api = (void **)PyCapsule_Import("gumath._gumath._API", 0);
|
83
|
+
if (_gumath_api == NULL) {
|
84
|
+
return -1;
|
85
|
+
}
|
86
|
+
|
87
|
+
return 0;
|
88
|
+
}
|
89
|
+
#endif
|
90
|
+
|
91
|
+
#ifdef __cplusplus
|
92
|
+
}
|
93
|
+
#endif
|
94
|
+
|
95
|
+
#endif /* PYGUMATH_H */
|
@@ -0,0 +1,405 @@
|
|
1
|
+
#
|
2
|
+
# BSD 3-Clause License
|
3
|
+
#
|
4
|
+
# Copyright (c) 2017-2018, plures
|
5
|
+
# All rights reserved.
|
6
|
+
#
|
7
|
+
# Redistribution and use in source and binary forms, with or without
|
8
|
+
# modification, are permitted provided that the following conditions are met:
|
9
|
+
#
|
10
|
+
# 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
# this list of conditions and the following disclaimer.
|
12
|
+
#
|
13
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
# this list of conditions and the following disclaimer in the documentation
|
15
|
+
# and/or other materials provided with the distribution.
|
16
|
+
#
|
17
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
# contributors may be used to endorse or promote products derived from
|
19
|
+
# this software without specific prior written permission.
|
20
|
+
#
|
21
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
#
|
32
|
+
|
33
|
+
import gumath as gm
|
34
|
+
import gumath.functions as fn
|
35
|
+
import gumath.examples as ex
|
36
|
+
from xnd import xnd
|
37
|
+
from ndtypes import ndt
|
38
|
+
from extending import Graph, bfloat16
|
39
|
+
import sys, time
|
40
|
+
import math
|
41
|
+
import unittest
|
42
|
+
import argparse
|
43
|
+
|
44
|
+
try:
|
45
|
+
import numpy as np
|
46
|
+
except ImportError:
|
47
|
+
np = None
|
48
|
+
|
49
|
+
|
50
|
+
TEST_CASES = [
|
51
|
+
([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
|
52
|
+
|
53
|
+
([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
54
|
+
"2 * 1000 * float64", "float64"),
|
55
|
+
|
56
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
|
57
|
+
|
58
|
+
([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
|
59
|
+
|
60
|
+
([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
61
|
+
"2 * 1000 * float32", "float32"),
|
62
|
+
|
63
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
|
64
|
+
]
|
65
|
+
|
66
|
+
|
67
|
+
class TestCall(unittest.TestCase):
|
68
|
+
|
69
|
+
def test_sin_scalar(self):
|
70
|
+
|
71
|
+
x1 = xnd(1.2, type="float64")
|
72
|
+
y1 = fn.sin(x1)
|
73
|
+
|
74
|
+
x2 = xnd(1.23e1, type="float32")
|
75
|
+
y2 = fn.sin(x2)
|
76
|
+
|
77
|
+
if np is not None:
|
78
|
+
a1 = np.array(1.2, dtype="float64")
|
79
|
+
b1 = np.sin(a1)
|
80
|
+
|
81
|
+
a2 = np.array(1.23e1, dtype="float32")
|
82
|
+
b2 = np.sin(a2)
|
83
|
+
|
84
|
+
np.testing.assert_equal(y1.value, b1)
|
85
|
+
np.testing.assert_equal(y2.value, b2)
|
86
|
+
|
87
|
+
def test_sin(self):
|
88
|
+
|
89
|
+
for lst, t, dtype in TEST_CASES:
|
90
|
+
x = xnd(lst, type=t)
|
91
|
+
y = fn.sin(x)
|
92
|
+
|
93
|
+
if np is not None:
|
94
|
+
a = np.array(lst, dtype=dtype)
|
95
|
+
b = np.sin(a)
|
96
|
+
np.testing.assert_equal(y, b)
|
97
|
+
|
98
|
+
def test_sin_strided(self):
|
99
|
+
|
100
|
+
for lst, t, dtype in TEST_CASES:
|
101
|
+
x = xnd(lst, type=t)
|
102
|
+
if x.type.ndim < 2:
|
103
|
+
continue
|
104
|
+
|
105
|
+
y = x[::-2, ::-2]
|
106
|
+
z = fn.sin(y)
|
107
|
+
|
108
|
+
if np is not None:
|
109
|
+
a = np.array(lst, dtype=dtype)
|
110
|
+
b = a[::-2, ::-2]
|
111
|
+
c = np.sin(b)
|
112
|
+
np.testing.assert_equal(z, c)
|
113
|
+
|
114
|
+
def test_copy(self):
|
115
|
+
|
116
|
+
for lst, t, dtype in TEST_CASES:
|
117
|
+
x = xnd(lst, type=t)
|
118
|
+
y = fn.copy(x)
|
119
|
+
|
120
|
+
if np is not None:
|
121
|
+
a = np.array(lst, dtype=dtype)
|
122
|
+
b = np.copy(a)
|
123
|
+
np.testing.assert_equal(y, b)
|
124
|
+
|
125
|
+
def test_copy_strided(self):
|
126
|
+
|
127
|
+
for lst, t, dtype in TEST_CASES:
|
128
|
+
x = xnd(lst, type=t)
|
129
|
+
if x.type.ndim < 2:
|
130
|
+
continue
|
131
|
+
|
132
|
+
y = x[::-2, ::-2]
|
133
|
+
z = fn.copy(y)
|
134
|
+
|
135
|
+
if np is not None:
|
136
|
+
a = np.array(lst, dtype=dtype)
|
137
|
+
b = a[::-2, ::-2]
|
138
|
+
c = np.copy(b)
|
139
|
+
np.testing.assert_equal(y, b)
|
140
|
+
|
141
|
+
@unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
|
142
|
+
def test_quaternion(self):
|
143
|
+
|
144
|
+
lst = [[[1+2j, 4+3j],
|
145
|
+
[-4+3j, 1-2j]],
|
146
|
+
[[4+2j, 1+10j],
|
147
|
+
[-1+10j, 4-2j]],
|
148
|
+
[[-4+2j, 3+10j],
|
149
|
+
[-3+10j, -4-2j]]]
|
150
|
+
|
151
|
+
x = xnd(lst, type="3 * quaternion64")
|
152
|
+
y = ex.multiply(x, x)
|
153
|
+
|
154
|
+
if np is not None:
|
155
|
+
a = np.array(lst, dtype="complex64")
|
156
|
+
b = np.einsum("ijk,ikl->ijl", a, a)
|
157
|
+
np.testing.assert_equal(y, b)
|
158
|
+
|
159
|
+
x = xnd(lst, type="3 * quaternion128")
|
160
|
+
y = ex.multiply(x, x)
|
161
|
+
|
162
|
+
if np is not None:
|
163
|
+
a = np.array(lst, dtype="complex128")
|
164
|
+
b = np.einsum("ijk,ikl->ijl", a, a)
|
165
|
+
np.testing.assert_equal(y, b)
|
166
|
+
|
167
|
+
x = xnd("xyz")
|
168
|
+
self.assertRaises(TypeError, ex.multiply, x, x)
|
169
|
+
|
170
|
+
@unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
|
171
|
+
def test_quaternion_error(self):
|
172
|
+
|
173
|
+
lst = [[[1+2j, 4+3j],
|
174
|
+
[-4+3j, 1-2j]],
|
175
|
+
[[4+2j, 1+10j],
|
176
|
+
[-1+10j, 4-2j]],
|
177
|
+
[[-4+2j, 3+10j],
|
178
|
+
[-3+10j, -4-2j]]]
|
179
|
+
|
180
|
+
x = xnd(lst, type="3 * Foo(2 * 2 * complex64)")
|
181
|
+
self.assertRaises(TypeError, ex.multiply, x, x)
|
182
|
+
|
183
|
+
def test_void(self):
|
184
|
+
|
185
|
+
x = ex.randint()
|
186
|
+
self.assertEqual(x.type, ndt("int32"))
|
187
|
+
|
188
|
+
def test_multiple_return(self):
|
189
|
+
|
190
|
+
x, y = ex.randtuple()
|
191
|
+
self.assertEqual(x.type, ndt("int32"))
|
192
|
+
self.assertEqual(y.type, ndt("int32"))
|
193
|
+
|
194
|
+
x, y = ex.divmod10(xnd(233))
|
195
|
+
self.assertEqual(x.value, 23)
|
196
|
+
self.assertEqual(y.value, 3)
|
197
|
+
|
198
|
+
|
199
|
+
class TestMissingValues(unittest.TestCase):
|
200
|
+
|
201
|
+
def test_missing_values(self):
|
202
|
+
|
203
|
+
x = [{'index': 0, 'name': 'brazil', 'value': 10},
|
204
|
+
{'index': 1, 'name': 'france', 'value': None},
|
205
|
+
{'index': 1, 'name': 'russia', 'value': 2}]
|
206
|
+
|
207
|
+
y = [{'index': 0, 'name': 'iceland', 'value': 5},
|
208
|
+
{'index': 1, 'name': 'norway', 'value': None},
|
209
|
+
{'index': 1, 'name': 'italy', 'value': None}]
|
210
|
+
|
211
|
+
z = xnd([x, y], type="2 * 3 * {index: int64, name: string, value: ?int64}")
|
212
|
+
ans = ex.count_valid_missing(z)
|
213
|
+
|
214
|
+
self.assertEqual(ans.value, [{'valid': 2, 'missing': 1}, {'valid': 1, 'missing': 2}])
|
215
|
+
|
216
|
+
|
217
|
+
class TestRaggedArrays(unittest.TestCase):
|
218
|
+
|
219
|
+
def test_sin(self):
|
220
|
+
s = math.sin
|
221
|
+
lst = [[[1.0],
|
222
|
+
[2.0, 3.0],
|
223
|
+
[4.0, 5.0, 6.0]],
|
224
|
+
[[7.0],
|
225
|
+
[8.0, 9.0],
|
226
|
+
[10.0, 11.0, 12.0]]]
|
227
|
+
|
228
|
+
ans = [[[s(1.0)],
|
229
|
+
[s(2.0), s(3.0)],
|
230
|
+
[s(4.0), s(5.0), s(6.0)]],
|
231
|
+
[[s(7.0)],
|
232
|
+
[s(8.0), s(9.0)],
|
233
|
+
[s(10.0), s(11.0), s(12.0)]]]
|
234
|
+
|
235
|
+
x = xnd(lst)
|
236
|
+
y = fn.sin(x)
|
237
|
+
self.assertEqual(y.value, ans)
|
238
|
+
|
239
|
+
|
240
|
+
class TestGraphs(unittest.TestCase):
|
241
|
+
|
242
|
+
def test_shortest_path(self):
|
243
|
+
graphs = [[[(1, 1.2), (2, 4.4)],
|
244
|
+
[(2, 2.2)],
|
245
|
+
[(1, 2.3)]],
|
246
|
+
|
247
|
+
[[(1, 1.2), (2, 4.4)],
|
248
|
+
[(2, 2.2)],
|
249
|
+
[(1, 2.3)],
|
250
|
+
[(2, 1.1)]]]
|
251
|
+
|
252
|
+
ans = [[[[0], [0, 1], [0, 1, 2]], # graph1, start 0
|
253
|
+
[[], [1], [1, 2]], # graph1, start 1
|
254
|
+
[[], [2, 1], [2]]], # graph1, start 2
|
255
|
+
|
256
|
+
[[[0], [0, 1], [0, 1, 2], []], # graph2, start 0
|
257
|
+
[[], [1], [1, 2], []], # graph2, start 1
|
258
|
+
[[], [2, 1], [2], []], # graph2, start 2
|
259
|
+
[[], [3, 2, 1], [3, 2], [3]]]] # graph2, start 3
|
260
|
+
|
261
|
+
|
262
|
+
for i, lst in enumerate(graphs):
|
263
|
+
N = len(lst)
|
264
|
+
graph = Graph(lst)
|
265
|
+
for start in range(N):
|
266
|
+
node = xnd(start, type="node")
|
267
|
+
x = graph.shortest_paths(node)
|
268
|
+
self.assertEqual(x.value, ans[i][start])
|
269
|
+
|
270
|
+
def test_constraint(self):
|
271
|
+
lst = [[(0, 1.2)],
|
272
|
+
[(2, 2.2), (1, 0.1)]]
|
273
|
+
|
274
|
+
self.assertRaises(ValueError, Graph, lst)
|
275
|
+
|
276
|
+
|
277
|
+
@unittest.skipIf(sys.platform == "win32", "unresolved external symbols")
|
278
|
+
class TestBFloat16(unittest.TestCase):
|
279
|
+
|
280
|
+
def test_init(self):
|
281
|
+
lst = [1.2e10, 2.1121, -3e20]
|
282
|
+
ans = [11945377792.0, 2.109375, -2.997595911977802e+20]
|
283
|
+
|
284
|
+
x = bfloat16(lst)
|
285
|
+
self.assertEqual(x.value, ans)
|
286
|
+
|
287
|
+
|
288
|
+
class TestPdist(unittest.TestCase):
|
289
|
+
|
290
|
+
def test_exceptions(self):
|
291
|
+
x = xnd([], dtype="float64")
|
292
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
293
|
+
|
294
|
+
x = xnd([[]], dtype="float64")
|
295
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
296
|
+
|
297
|
+
x = xnd([[], []], dtype="float64")
|
298
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
299
|
+
|
300
|
+
x = xnd([[1], [1]], dtype="int64")
|
301
|
+
self.assertRaises(TypeError, ex.euclidian_pdist, x)
|
302
|
+
|
303
|
+
def test_pdist(self):
|
304
|
+
x = xnd([[1]], dtype="float64")
|
305
|
+
y = ex.euclidian_pdist(x)
|
306
|
+
self.assertEqual(y.value, [])
|
307
|
+
|
308
|
+
x = xnd([[1, 2, 3]], dtype="float64")
|
309
|
+
y = ex.euclidian_pdist(x)
|
310
|
+
self.assertEqual(y.value, [])
|
311
|
+
|
312
|
+
x = xnd([[-1.2200, -100.5000, 20.1250, 30.1230],
|
313
|
+
[ 2.2200, 2.2720, -122.8400, 122.3330],
|
314
|
+
[ 2.1000, -25.0000, 100.2000, -99.5000]], dtype="float64")
|
315
|
+
y = ex.euclidian_pdist(x)
|
316
|
+
self.assertEqual(y.value, [198.78529349275314, 170.0746899276903, 315.75385646576035])
|
317
|
+
|
318
|
+
|
319
|
+
@unittest.skipIf(gm.xndvectorize is None, "test requires numpy and numba")
|
320
|
+
class TestNumba(unittest.TestCase):
|
321
|
+
|
322
|
+
def test_numba(self):
|
323
|
+
|
324
|
+
@gm.xndvectorize("... * N * M * float64, ... * M * P * float64 -> ... * N * P * float64")
|
325
|
+
def matmul(x, y, res):
|
326
|
+
col = np.arange(y.shape[0])
|
327
|
+
for j in range(y.shape[1]):
|
328
|
+
for k in range(y.shape[0]):
|
329
|
+
col[k] = y[k, j]
|
330
|
+
for i in range(x.shape[0]):
|
331
|
+
s = 0
|
332
|
+
for k in range(x.shape[1]):
|
333
|
+
s += x[i, k] * col[k]
|
334
|
+
res[i, j] = s
|
335
|
+
|
336
|
+
a = np.arange(50000.0).reshape(1000, 5, 10)
|
337
|
+
b = np.arange(70000.0).reshape(1000, 10, 7)
|
338
|
+
c = np.einsum("ijk,ikl->ijl", a, b)
|
339
|
+
|
340
|
+
x = xnd(a.tolist(), type="1000 * 5 * 10 * float64")
|
341
|
+
y = xnd(b.tolist(), type="1000 * 10 * 7 * float64")
|
342
|
+
z = matmul(x, y)
|
343
|
+
|
344
|
+
np.testing.assert_equal(z, c)
|
345
|
+
|
346
|
+
def test_numba_add_scalar(self):
|
347
|
+
|
348
|
+
import numba as nb
|
349
|
+
|
350
|
+
@nb.guvectorize(["void(int64[:], int64, int64[:])"], '(n),()->(n)')
|
351
|
+
def g(x, y, res):
|
352
|
+
for i in range(x.shape[0]):
|
353
|
+
res[i] = x[i] + y
|
354
|
+
|
355
|
+
a = np.arange(5000).reshape(100, 5, 10)
|
356
|
+
b = np.arange(500).reshape(100, 5)
|
357
|
+
c = g(a, b)
|
358
|
+
|
359
|
+
x = xnd(a.tolist(), type="100 * 5 * 10 * int64")
|
360
|
+
y = xnd(b.tolist(), type="100 * 5 * int64")
|
361
|
+
z = ex.add_scalar(x, y)
|
362
|
+
|
363
|
+
np.testing.assert_equal(z, c)
|
364
|
+
|
365
|
+
a = np.arange(500)
|
366
|
+
b = np.array(100)
|
367
|
+
c = g(a, b)
|
368
|
+
|
369
|
+
x = xnd(a.tolist(), type="500 * int64")
|
370
|
+
y = xnd(b.tolist(), type="int64")
|
371
|
+
z = ex.add_scalar(x, y)
|
372
|
+
|
373
|
+
np.testing.assert_equal(z, c)
|
374
|
+
|
375
|
+
|
376
|
+
|
377
|
+
ALL_TESTS = [
|
378
|
+
TestCall,
|
379
|
+
TestRaggedArrays,
|
380
|
+
TestMissingValues,
|
381
|
+
TestGraphs,
|
382
|
+
TestBFloat16,
|
383
|
+
TestPdist,
|
384
|
+
TestNumba,
|
385
|
+
]
|
386
|
+
|
387
|
+
|
388
|
+
if __name__ == '__main__':
|
389
|
+
parser = argparse.ArgumentParser()
|
390
|
+
parser.add_argument("-f", "--failfast", action="store_true",
|
391
|
+
help="stop the test run on first error")
|
392
|
+
args = parser.parse_args()
|
393
|
+
|
394
|
+
suite = unittest.TestSuite()
|
395
|
+
loader = unittest.TestLoader()
|
396
|
+
|
397
|
+
for case in ALL_TESTS:
|
398
|
+
s = loader.loadTestsFromTestCase(case)
|
399
|
+
suite.addTest(s)
|
400
|
+
|
401
|
+
runner = unittest.TextTestRunner(failfast=args.failfast, verbosity=2)
|
402
|
+
result = runner.run(suite)
|
403
|
+
ret = not result.wasSuccessful()
|
404
|
+
|
405
|
+
sys.exit(ret)
|