symjit 1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symjit-1.2/Cargo.toml +28 -0
- symjit-1.2/LICENSE +21 -0
- symjit-1.2/MANIFEST.in +3 -0
- symjit-1.2/PKG-INFO +4 -0
- symjit-1.2/README.md +218 -0
- symjit-1.2/pyproject.toml +23 -0
- symjit-1.2/python/symjit/__init__.py +172 -0
- symjit-1.2/python/symjit/engine.py +124 -0
- symjit-1.2/python/symjit/structure.py +237 -0
- symjit-1.2/python/symjit.egg-info/PKG-INFO +4 -0
- symjit-1.2/python/symjit.egg-info/SOURCES.txt +28 -0
- symjit-1.2/python/symjit.egg-info/dependency_links.txt +1 -0
- symjit-1.2/python/symjit.egg-info/top_level.txt +1 -0
- symjit-1.2/rust/allocator.rs +115 -0
- symjit-1.2/rust/amd/macros.rs +322 -0
- symjit-1.2/rust/amd/mod.rs +239 -0
- symjit-1.2/rust/analyzer.rs +153 -0
- symjit-1.2/rust/arm/macros.rs +331 -0
- symjit-1.2/rust/arm/mod.rs +215 -0
- symjit-1.2/rust/code.rs +290 -0
- symjit-1.2/rust/interpreter/mod.rs +124 -0
- symjit-1.2/rust/lib.rs +283 -0
- symjit-1.2/rust/machine.rs +76 -0
- symjit-1.2/rust/memory.rs +295 -0
- symjit-1.2/rust/model.rs +412 -0
- symjit-1.2/rust/register.rs +219 -0
- symjit-1.2/rust/runnable.rs +120 -0
- symjit-1.2/rust/utils.rs +26 -0
- symjit-1.2/rust/wasm/mod.rs +246 -0
- symjit-1.2/setup.cfg +4 -0
symjit-1.2/Cargo.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[package]
|
|
2
|
+
name = "symjit"
|
|
3
|
+
version = "0.3.0"
|
|
4
|
+
edition = "2021"
|
|
5
|
+
|
|
6
|
+
[dependencies]
|
|
7
|
+
serde = { version = "1.0", features = ["derive"] }
|
|
8
|
+
serde_json = "1.0"
|
|
9
|
+
memmap2 = { version = "0.9", optional = true }
|
|
10
|
+
rand = "0.8"
|
|
11
|
+
anyhow = "1"
|
|
12
|
+
wasmtime = { version = "28.0", optional = true }
|
|
13
|
+
cffi = "*"
|
|
14
|
+
libc = "0.2"
|
|
15
|
+
windows-sys = { version = "0.59", features = ["Win32_System_Memory"] }
|
|
16
|
+
region = "3.0.2"
|
|
17
|
+
wasmtime-jit-icache-coherence = "*"
|
|
18
|
+
|
|
19
|
+
[features]
|
|
20
|
+
wasm = ["dep:wasmtime"]
|
|
21
|
+
selinux-fix = ['memmap2']
|
|
22
|
+
default = []
|
|
23
|
+
|
|
24
|
+
[lib]
|
|
25
|
+
name = "_lib"
|
|
26
|
+
path = "rust/lib.rs"
|
|
27
|
+
crate-type = ["cdylib"]
|
|
28
|
+
|
symjit-1.2/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 siravan
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
symjit-1.2/MANIFEST.in
ADDED
symjit-1.2/PKG-INFO
ADDED
symjit-1.2/README.md
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
# Introduction
|
|
2
|
+
|
|
3
|
+
*Symjit* is a lightweight just-in-time (JIT) compiler that directly translates *sympy* expressions into machine code. Its main utility is to generate fast numerical functions to feed into different numerical solvers, including numerical integration routines and ordinary differential equation (ODE) solvers.
|
|
4
|
+
|
|
5
|
+
The Symjit core is a Rust library with minimum external dependency. It does not use a separate compiler, such as LLVM or GCC. Currently, it can generate AMD64 (aka x86-64) and ARM64 (aka aarch64) machine codes on Linux and Windows platforms. Further architectures and operating systems (RISC V, aarch64 on Mac OS) are planned.
|
|
6
|
+
|
|
7
|
+
# Installation
|
|
8
|
+
|
|
9
|
+
You can install *symjit* using pip command as
|
|
10
|
+
|
|
11
|
+
```
|
|
12
|
+
pip install symjit
|
|
13
|
+
```
|
|
14
|
+
or from the source by cloning https://github.com/siravan/symjit into `symjit` folder and then running
|
|
15
|
+
|
|
16
|
+
```
|
|
17
|
+
pip install .
|
|
18
|
+
```
|
|
19
|
+
For the last option, you need a working Rust compiler and toolchains.
|
|
20
|
+
|
|
21
|
+
# Tutorial
|
|
22
|
+
|
|
23
|
+
## `compile_func`: a fast substitute for `lambdify`
|
|
24
|
+
|
|
25
|
+
*symjit* is invoked by calling different `compile_*` functions. The most basic is `compile_func`, which behaves similarly to sympy `lambdify` function. While `lambdify` translate sympy expressions into regular Python functions, which in turn call numpy functions, `compile_func` returns a callable object `BasicFunc`, which is a thin wrapper over the jit code generated by the Rust backend.
|
|
26
|
+
|
|
27
|
+
A simple example is
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
import numpy as np
|
|
31
|
+
from symjit import compile_func
|
|
32
|
+
from sympy import symbols
|
|
33
|
+
|
|
34
|
+
x, y = symbols('x y')
|
|
35
|
+
f = compile_func([x, y], [x+y, x*y])
|
|
36
|
+
assert(np.all(f([3, 5]) == [8., 15.]))
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
`compile_func` takes two mandatory arguments as `compile_func(states, eqs)`. The first one, `states`, is a list or tuple of symbols. The second argument, `eqs`, is a list or tuple of expressions. If either `states` or `eqs` has only one element, that element can be passed directly. In addition, `compile_func` accepts a named argument `params`, which is a list of symbolic parameters. For example,
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
x, y, a = symbols('x y a')
|
|
43
|
+
f = compile_func([x, y], [(x+y)**a], args=[a])
|
|
44
|
+
assert(np.all(f([3, 5], args=[2]) == [64.]))
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
`compile_func` helps generate functions to pass to numerical integration (quadrature) routines. The following example is adapted from scipy documentation:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
import numpy as np
|
|
51
|
+
from scipy.integrate import nquad
|
|
52
|
+
from sympy import symbols, exp
|
|
53
|
+
from symjit import compile_func
|
|
54
|
+
|
|
55
|
+
N = 5
|
|
56
|
+
t, x = symbols("t x")
|
|
57
|
+
f = compile_func([t, x], exp(-t*x)/t**N)
|
|
58
|
+
|
|
59
|
+
sol = nquad(f, [[1, np.inf], [0, np.inf]])
|
|
60
|
+
|
|
61
|
+
np.testing.assert_approx_equal(sol[0], 1/N)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## `compile_ode`: to solve ODEs
|
|
65
|
+
|
|
66
|
+
`compile_ode` returns a callable object (`OdeFunc`) suitable for passing to `scipy.integrate.solve_ivp` (the main numpy/scipy ODE solver). It takes three mandatory arguments as `compile_ode(iv, states, odes)`. The first one (`iv`) is a single symbol that specifies the independent variable. The second argument, `states`, is a list of symbols defining the ODE state. The right-hand side of ODE equations is passed as the third argument, `odes.` It is a list of expressions that define the ODE by providing the derivative of each state variable w.r.t the independent variable. In addition, similar to `compile_func`, `compile_ode` can accept an optional `args`. For example,
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
# examples/trig.py
|
|
70
|
+
import scipy.integrate
|
|
71
|
+
import matplotlib.pyplot as plt
|
|
72
|
+
import numpy as np
|
|
73
|
+
from sympy import symbols
|
|
74
|
+
from symjit import compile_ode
|
|
75
|
+
|
|
76
|
+
t, x, y = symbols('t x y')
|
|
77
|
+
f = compile_ode(t, (x, y), (y, -x))
|
|
78
|
+
t_eval=np.arange(0, 10, 0.01)
|
|
79
|
+
sol = scipy.integrate.solve_ivp(f, (0, 10), (0.0, 1.0), t_eval=t_eval)
|
|
80
|
+
|
|
81
|
+
plt.plot(t_eval, sol.y.T)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
Here, the ODE definition is `x' = y` and `y' = -x`, which means `y" = y`. The solution is `a*sin(t) + b*cos(t)`, where `a` and `b` are determined by the initial values. Given the initial values of 0 and 1 passed as the third argument of `solve_ivp`, the solutions are `sin(t)` and `cos(t)`. We can confirm this by running the code. The output is
|
|
85
|
+
|
|
86
|
+

|
|
87
|
+
|
|
88
|
+
Note that `OdeFunc` conforms to the function form `scipy.integrate.solve_ivp` expects, i.e., it should be called as `f(t, y, *args)`.
|
|
89
|
+
|
|
90
|
+
The following example is more complicated and showcases the [Lorenz system](https://en.wikipedia.org/wiki/Lorenz_system), an important milestone in the historical development of chaos theory.
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
import numpy as np
|
|
94
|
+
from scipy.integrate import solve_ivp
|
|
95
|
+
import matplotlib.pyplot as plt
|
|
96
|
+
from sympy import symbols
|
|
97
|
+
|
|
98
|
+
from symjit import compile_ode
|
|
99
|
+
|
|
100
|
+
t, x, y, z = symbols("t x y z")
|
|
101
|
+
sigma, rho, beta = symbols("sigma rho beta")
|
|
102
|
+
|
|
103
|
+
ode = (
|
|
104
|
+
sigma * (y - x),
|
|
105
|
+
x * (rho - z) - y,
|
|
106
|
+
x * y - beta * z
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
f = compile_ode(t, (x, y, z), ode, params=(sigma, rho, beta))
|
|
110
|
+
|
|
111
|
+
u0 = (1.0, 1.0, 1.0)
|
|
112
|
+
p = (10.0, 28.0, 8 / 3)
|
|
113
|
+
t_eval = np.arange(0, 100, 0.01)
|
|
114
|
+
|
|
115
|
+
sol = solve_ivp(f, (0, 100.0), u0, t_eval=t_eval, args=p)
|
|
116
|
+
|
|
117
|
+
plt.plot(sol.y[0, :], sol.y[2, :])
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
The result is the famous *strange attractor*:
|
|
121
|
+
|
|
122
|
+

|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
## `compile_jac`: calculating Jacobian
|
|
126
|
+
|
|
127
|
+
The ODE examples discussed in the previous section are non-stiff and easy to solve using explicit methods. However, not all differential equations are so accommodating! Many important equations are stiff and usually require implicit methods. Many implicit ODE solvers use the system's [Jacobian matrix](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) to improve performance.
|
|
128
|
+
|
|
129
|
+
There are different techniques for calculating the Jacobian. In the last few years, automatic differentiation (AD) methods have gained popularity, working at the abstract syntax tree or lower level. However, if we define our model symbolically using a Computer Algebra System (CAS) such as sympy, we can calculate the Jacobian by differentiating the source symbolic expressions.
|
|
130
|
+
|
|
131
|
+
`compile_jac` is the symjit function to calculate the Jacobian of an ODE system. It has the same call signature as `compile_ode,` i.e., it is called `compile_jac(iv, states, odes)` with an optional argument `params.` The return value (of type `JacFunc`) is a callable similar to `OdeFunc`, which returns an n-by-n matrix J, where n is the number of states. The element at the ith row and jth column of J is the derivative of `odes[i]` w.r.t `state[j]` (this is the definition of Jacobian).
|
|
132
|
+
|
|
133
|
+
For example, we can consider the [Van der Pol oscillator](https://en.wikipedia.org/wiki/Van_der_Pol_oscillator). This system has a control parameter (mu). For small values of mu, the ODE system is not stiff and can easily be solved using explicit methods.
|
|
134
|
+
|
|
135
|
+
```python
|
|
136
|
+
import matplotlib.pyplot as plt
|
|
137
|
+
import numpy as np
|
|
138
|
+
from scipy.integrate import solve_ivp
|
|
139
|
+
from sympy import symbols
|
|
140
|
+
from math import sqrt
|
|
141
|
+
from symjit import compile_ode, compile_jac
|
|
142
|
+
|
|
143
|
+
t, x, y, mu = symbols('t x y mu')
|
|
144
|
+
ode = [y, mu * ((1 - x*x) * y - x)]
|
|
145
|
+
|
|
146
|
+
f = compile_ode(t, [x, y], ode, params=[mu])
|
|
147
|
+
u0 = [0.0, sqrt(3.0)]
|
|
148
|
+
t_eval = np.arange(0, 10.0, 0.01)
|
|
149
|
+
|
|
150
|
+
sol1 = solve_ivp(f, (0, 10.0), u0, method='RK45', t_eval=t_eval, args=[5.0])
|
|
151
|
+
|
|
152
|
+
plt.plot(t_eval, sol1.y[0,:])
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
The output is
|
|
156
|
+
|
|
157
|
+

|
|
158
|
+
|
|
159
|
+
On the other hand, as mu is increased (for example, to 1e6), the system becomes very stiff. An explicit ODE solver, such as RK45 (Runge-Kutta 4/5), cannot solve this problem. Instead, we need an implicit method, such as the backward differentiation formula (BDF). BDF needs a Jacobian. If one is not provided, it numerically calculates one using finite-difference. However, this technique is both inaccurate and computationally intensive. It would be much better to give the solver a closed-form Jacobian. As mentioned above, `calculate_jac` exactly does this.
|
|
160
|
+
|
|
161
|
+
```python
|
|
162
|
+
jac = compile_jac(t, [x, y], ode, params=[mu])
|
|
163
|
+
sol2 = solve_ivp(f, (0, 10.0), u0, method='BDF', t_eval=t_eval, args=[1e6], jac=jac)
|
|
164
|
+
|
|
165
|
+
plot.plot(t_eval, sol2.y[0,:])
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
The output of the stiff system is
|
|
169
|
+
|
|
170
|
+

|
|
171
|
+
|
|
172
|
+
# Backends
|
|
173
|
+
|
|
174
|
+
All `compile_*` functions accept an optional parameter `ty`, which defines the type of the backend to use. Currently, the possible values are:
|
|
175
|
+
|
|
176
|
+
* `amd`: generates 64-bit AMD64/x86-64 code. It expects a minimum SSE2.1 spec, which should be easily fulfilled by all except the most ancient processors!
|
|
177
|
+
* `arm` generates 64-bit ARM64/aarch64 code. To test this instruction set, we use 64-bit Raspbian on Raspberry Pi 4/5 computers.
|
|
178
|
+
* `bytecode`: this is a generic fast bytecode as a fallback option in case the instruction set is not supported.
|
|
179
|
+
* `native` (**default**): selects the correct instruction set based on the current processor.
|
|
180
|
+
* `wasm` (**optional**): generates and runs WebAssembly code using the wasmtime library. This option is not included in the binary distributions. To enable wasm, you must make `symjit` from the source and add `features=["wasm"]` to `pyproject.toml`.
|
|
181
|
+
|
|
182
|
+
To inspect the generated code, we must first dump the binary into a file by calling the `dump` function of various `Func` callables. The resulting file is a flat binary code with no header or other extras. Then, use the disassembler of your choice to inspect the code. For example,
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
from symjit import compile_func
|
|
186
|
+
from sympy import symbols
|
|
187
|
+
|
|
188
|
+
x, y = symbols('x y')
|
|
189
|
+
f = compile_func([x, y], [x+y, x*y])
|
|
190
|
+
f.dump('test.bin')
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
On a Linux system, we can invoke `objdump` as below:
|
|
194
|
+
|
|
195
|
+
```
|
|
196
|
+
objdump -b binary -m i386:x86-64 -M intel -D test.bin
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
The output (assuming a Linux x86-64 machine) is
|
|
200
|
+
|
|
201
|
+
```
|
|
202
|
+
0: 55 push rbp
|
|
203
|
+
1: 53 push rbx
|
|
204
|
+
2: 48 89 fd mov rbp,rdi
|
|
205
|
+
5: 48 89 d3 mov rbx,rdx
|
|
206
|
+
8: f2 0f 10 4d 30 movsd xmm1,QWORD PTR [rbp+0x30]
|
|
207
|
+
d: f2 0f 10 45 28 movsd xmm0,QWORD PTR [rbp+0x28]
|
|
208
|
+
12: f2 0f 58 c1 addsd xmm0,xmm1
|
|
209
|
+
16: f2 0f 11 45 38 movsd QWORD PTR [rbp+0x38],xmm0
|
|
210
|
+
1b: f2 0f 10 4d 30 movsd xmm1,QWORD PTR [rbp+0x30]
|
|
211
|
+
20: f2 0f 10 45 28 movsd xmm0,QWORD PTR [rbp+0x28]
|
|
212
|
+
25: f2 0f 59 c1 mulsd xmm0,xmm1
|
|
213
|
+
29: f2 0f 11 45 40 movsd QWORD PTR [rbp+0x40],xmm0
|
|
214
|
+
2e: 5b pop rbx
|
|
215
|
+
2f: 5d pop rbp
|
|
216
|
+
30: c3 ret
|
|
217
|
+
```
|
|
218
|
+
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# pyproject.toml
|
|
2
|
+
[build-system]
|
|
3
|
+
requires = ["setuptools", "setuptools-rust"]
|
|
4
|
+
build-backend = "setuptools.build_meta"
|
|
5
|
+
|
|
6
|
+
[project]
|
|
7
|
+
name = "symjit"
|
|
8
|
+
version = "1.2"
|
|
9
|
+
|
|
10
|
+
[tool.setuptools.packages]
|
|
11
|
+
# Pure Python packages/modules
|
|
12
|
+
find = { where = ["python"] }
|
|
13
|
+
|
|
14
|
+
[[tool.setuptools-rust.ext-modules]]
|
|
15
|
+
# Private Rust extension module to be nested into the Python package
|
|
16
|
+
target = "symjit._lib" # The last part of the name (e.g. "_lib") has to match lib.name in Cargo.toml,
|
|
17
|
+
# but you can add a prefix to nest it inside of a Python package.
|
|
18
|
+
path = "Cargo.toml"
|
|
19
|
+
binding = "NoBinding"
|
|
20
|
+
features = []
|
|
21
|
+
debug = false
|
|
22
|
+
|
|
23
|
+
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from . import engine
|
|
5
|
+
from . import structure
|
|
6
|
+
|
|
7
|
+
lib = engine.Engine() # interface to the rust codegen engine
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def from_raw_parts(ptr, count):
|
|
11
|
+
return np.ctypeslib.as_array(ptr, shape=(count,))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseFunc:
|
|
15
|
+
def __init__(self, model, ty="native"):
|
|
16
|
+
self.p = lib._compile(model.encode("utf-8"), ty.encode("utf8"))
|
|
17
|
+
status = lib._check_status(self.p)
|
|
18
|
+
if status != b"Success":
|
|
19
|
+
raise ValueError(status)
|
|
20
|
+
self.populate()
|
|
21
|
+
self.model = model # for debugging
|
|
22
|
+
|
|
23
|
+
def __del__(self):
|
|
24
|
+
# lib._finalize(self.p)
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def get_u0(self):
|
|
28
|
+
u0 = np.zeros(self.count_states, dtype="double")
|
|
29
|
+
lib._fill_u0(self.p, np.ctypeslib.as_ctypes(u0), self.count_states)
|
|
30
|
+
return u0
|
|
31
|
+
|
|
32
|
+
def get_p(self):
|
|
33
|
+
p = np.zeros(self.count_params, dtype="double")
|
|
34
|
+
lib._fill_p(self.p, np.ctypeslib.as_ctypes(p), self.count_params)
|
|
35
|
+
return p
|
|
36
|
+
|
|
37
|
+
def populate(self):
|
|
38
|
+
self.count_states = lib._count_states(self.p)
|
|
39
|
+
self.count_params = lib._count_params(self.p)
|
|
40
|
+
self.count_obs = lib._count_obs(self.p)
|
|
41
|
+
self.count_diffs = lib._count_diffs(self.p)
|
|
42
|
+
|
|
43
|
+
self._states = from_raw_parts(lib._ptr_states(self.p), self.count_states)
|
|
44
|
+
self._params = from_raw_parts(lib._ptr_params(self.p), self.count_params)
|
|
45
|
+
self._obs = from_raw_parts(lib._ptr_obs(self.p), self.count_obs)
|
|
46
|
+
self._diffs = from_raw_parts(lib._ptr_diffs(self.p), self.count_diffs)
|
|
47
|
+
|
|
48
|
+
def dump(self, name):
|
|
49
|
+
lib._dump(self.p, name.encode("utf-8"))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Func(BaseFunc):
|
|
53
|
+
def __init__(self, model, ty="native"):
|
|
54
|
+
super().__init__(model, ty=ty)
|
|
55
|
+
|
|
56
|
+
def __call__(self, *args):
|
|
57
|
+
u = np.array(args, dtype="double")
|
|
58
|
+
self._states[:] = u
|
|
59
|
+
status = lib._execute(self.p, 0.0)
|
|
60
|
+
|
|
61
|
+
if not status:
|
|
62
|
+
raise ValueError("cannot execute the model")
|
|
63
|
+
|
|
64
|
+
return self._obs.copy()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class OdeFunc(BaseFunc):
|
|
68
|
+
def __init__(self, model, ty="native"):
|
|
69
|
+
super().__init__(model, ty=ty)
|
|
70
|
+
|
|
71
|
+
def __call__(self, t, y, *args):
|
|
72
|
+
y = np.array(y, dtype="double")
|
|
73
|
+
self._states[:] = y
|
|
74
|
+
|
|
75
|
+
if len(args) > 0:
|
|
76
|
+
p = np.array(args, dtype="double")
|
|
77
|
+
self._params[:] = p
|
|
78
|
+
|
|
79
|
+
status = lib._execute(self.p, t)
|
|
80
|
+
|
|
81
|
+
if not status:
|
|
82
|
+
raise ValueError("cannot execute the model")
|
|
83
|
+
|
|
84
|
+
return self._diffs.copy()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class JacFunc(BaseFunc):
|
|
88
|
+
def __init__(self, model, ty="native"):
|
|
89
|
+
super().__init__(model, ty=ty)
|
|
90
|
+
|
|
91
|
+
def __call__(self, t, y, *args):
|
|
92
|
+
y = np.array(y, dtype="double")
|
|
93
|
+
self._states[:] = y
|
|
94
|
+
|
|
95
|
+
if len(args) > 0:
|
|
96
|
+
p = np.array(args, dtype="double")
|
|
97
|
+
self._params[:] = p
|
|
98
|
+
|
|
99
|
+
status = lib._execute(self.p, t)
|
|
100
|
+
|
|
101
|
+
if not status:
|
|
102
|
+
raise ValueError("cannot execute the model")
|
|
103
|
+
|
|
104
|
+
jac = self._obs.copy()
|
|
105
|
+
return jac.reshape((self.count_states, self.count_states))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def compile_func(states, eqs, params=None):
|
|
109
|
+
"""Compile a list of symbolic expression into an executable form.
|
|
110
|
+
compile_func tries to mimic sympy lambdify, but instead of generating
|
|
111
|
+
a standard python funciton, it returns a callable (Func object) that
|
|
112
|
+
is a thin wrapper over compiled machine-code.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
==========
|
|
116
|
+
|
|
117
|
+
states: a single symbol or a list/tuple of symbols
|
|
118
|
+
eqs: a single symbolic expression or a list/tuple of symbolic expressions
|
|
119
|
+
params (optional): a list/tuple of additional symbols as parameters to the model
|
|
120
|
+
|
|
121
|
+
==> returns a Func object
|
|
122
|
+
|
|
123
|
+
>>> import numpy as np
|
|
124
|
+
>>> from symjit import compile_func
|
|
125
|
+
>>> from sympy import symbols
|
|
126
|
+
|
|
127
|
+
>>> x, y = symbols('x y')
|
|
128
|
+
>>> f = compile_func([x, y], [x+y, x*y])
|
|
129
|
+
>>> assert(np.all(f([3, 5]) == [8., 15.]))
|
|
130
|
+
"""
|
|
131
|
+
model = structure.model(states, eqs, params)
|
|
132
|
+
return Func(json.dumps(model))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def compile_ode(iv, states, odes, params=None):
|
|
136
|
+
"""Compile a symbolic ODE model into an executable form suitable for
|
|
137
|
+
passung to scipy.integrate.solve_ivp.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
==========
|
|
141
|
+
|
|
142
|
+
iv: a single symbol, the independent variable.
|
|
143
|
+
states: a single symbol or a list/tuple of symbols
|
|
144
|
+
odes: a single symbolic expression or a list/tuple of symbolic expressions,
|
|
145
|
+
representing the derivative of the state with respect to iv
|
|
146
|
+
params (optional): a list/tuple of additional symbols as parameters to the model
|
|
147
|
+
|
|
148
|
+
invariant => len(states) == len(odes)
|
|
149
|
+
|
|
150
|
+
==> returns an OdeFunc object
|
|
151
|
+
|
|
152
|
+
>>> import scipy.integrate
|
|
153
|
+
>>> import numpy as np
|
|
154
|
+
>>> from sympy import symbols
|
|
155
|
+
>>> from symjit import compile_ode
|
|
156
|
+
|
|
157
|
+
>>> t, x, y = symbols('t x y')
|
|
158
|
+
>>> f = compile_ode(t, (x, y), (y, -x))
|
|
159
|
+
>>> t_eval=np.arange(0, 10, 0.01)
|
|
160
|
+
>>> sol = scipy.integrate.solve_ivp(f, (0, 10), (0.0, 1.0), t_eval=t_eval)
|
|
161
|
+
|
|
162
|
+
>>> np.testing.assert_allclose(sol.y[0,:], np.sin(t_eval), atol=0.005)
|
|
163
|
+
"""
|
|
164
|
+
model = structure.model_ode(iv, states, odes, params)
|
|
165
|
+
return OdeFunc(json.dumps(model))
|
|
166
|
+
|
|
167
|
+
def compile_jac(iv, states, odes, params=None):
|
|
168
|
+
model = structure.model_jac(iv, states, odes, params)
|
|
169
|
+
return JacFunc(json.dumps(model))
|
|
170
|
+
|
|
171
|
+
def compile_json(model):
|
|
172
|
+
return OdeFunc(model)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import ctypes
|
|
4
|
+
import platform
|
|
5
|
+
|
|
6
|
+
def find_dll(substr):
|
|
7
|
+
files = os.listdir(os.path.dirname(__file__))
|
|
8
|
+
matches = list(filter(lambda s: s.find(substr) >= 0, files))
|
|
9
|
+
if len(matches) == 0:
|
|
10
|
+
return None
|
|
11
|
+
else:
|
|
12
|
+
return matches[0]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
dll_name = None
|
|
16
|
+
|
|
17
|
+
if sys.platform == "linux" and platform.machine() == "x86_64":
|
|
18
|
+
dll_name = find_dll("x86_64-linux")
|
|
19
|
+
if sys.platform == "linux" and platform.machine() == "aarch64":
|
|
20
|
+
dll_name = find_dll("aarch64-linux")
|
|
21
|
+
elif sys.platform == "win32":
|
|
22
|
+
dll_name = find_dll("win_amd64")
|
|
23
|
+
|
|
24
|
+
if dll_name is None:
|
|
25
|
+
raise ValueError("unsupported platform")
|
|
26
|
+
|
|
27
|
+
print(dll_name)
|
|
28
|
+
|
|
29
|
+
dll_path = os.path.join(os.path.dirname(__file__), dll_name)
|
|
30
|
+
dll = ctypes.CDLL(dll_path)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Engine:
|
|
34
|
+
def __init__(self):
|
|
35
|
+
self._info = dll.info
|
|
36
|
+
self._info.argtypes = []
|
|
37
|
+
self._info.restype = ctypes.c_char_p
|
|
38
|
+
|
|
39
|
+
self._check_status = dll.check_status
|
|
40
|
+
self._check_status.argtypes = [ctypes.c_void_p]
|
|
41
|
+
self._check_status.restype = ctypes.c_char_p
|
|
42
|
+
|
|
43
|
+
self._count_states = dll.count_states
|
|
44
|
+
self._count_states.argtypes = [ctypes.c_void_p]
|
|
45
|
+
self._count_states.restype = ctypes.c_size_t
|
|
46
|
+
|
|
47
|
+
self._count_params = dll.count_params
|
|
48
|
+
self._count_params.argtypes = [ctypes.c_void_p]
|
|
49
|
+
self._count_params.restype = ctypes.c_size_t
|
|
50
|
+
|
|
51
|
+
self._count_obs = dll.count_obs
|
|
52
|
+
self._count_obs.argtypes = [ctypes.c_void_p]
|
|
53
|
+
self._count_obs.restype = ctypes.c_size_t
|
|
54
|
+
|
|
55
|
+
self._count_diffs = dll.count_diffs
|
|
56
|
+
self._count_diffs.argtypes = [ctypes.c_void_p]
|
|
57
|
+
self._count_diffs.restype = ctypes.c_size_t
|
|
58
|
+
|
|
59
|
+
self._run = dll.run
|
|
60
|
+
self._run.argtypes = [
|
|
61
|
+
ctypes.c_void_p, # handle
|
|
62
|
+
ctypes.POINTER(ctypes.c_double), # du
|
|
63
|
+
ctypes.POINTER(ctypes.c_double), # u
|
|
64
|
+
ctypes.c_size_t, # ns
|
|
65
|
+
ctypes.POINTER(ctypes.c_double), # p
|
|
66
|
+
ctypes.c_size_t, # np
|
|
67
|
+
ctypes.c_double, # t
|
|
68
|
+
]
|
|
69
|
+
self._run.restype = ctypes.c_bool
|
|
70
|
+
|
|
71
|
+
self._execute = dll.execute
|
|
72
|
+
self._execute.argtypes = [
|
|
73
|
+
ctypes.c_void_p, # handle
|
|
74
|
+
ctypes.c_double, # t
|
|
75
|
+
]
|
|
76
|
+
self._execute.restype = ctypes.c_bool
|
|
77
|
+
|
|
78
|
+
self._fill_u0 = dll.fill_u0
|
|
79
|
+
self._fill_u0.argtypes = [
|
|
80
|
+
ctypes.c_void_p, # handle
|
|
81
|
+
ctypes.POINTER(ctypes.c_double), # u0
|
|
82
|
+
ctypes.c_size_t, # ns
|
|
83
|
+
]
|
|
84
|
+
self._fill_u0.restype = ctypes.c_bool
|
|
85
|
+
|
|
86
|
+
self._fill_p = dll.fill_p
|
|
87
|
+
self._fill_p.argtypes = [
|
|
88
|
+
ctypes.c_void_p, # handle
|
|
89
|
+
ctypes.POINTER(ctypes.c_double), # p
|
|
90
|
+
ctypes.c_size_t, # np
|
|
91
|
+
]
|
|
92
|
+
self._fill_p.restype = ctypes.c_bool
|
|
93
|
+
|
|
94
|
+
self._compile = dll.compile
|
|
95
|
+
self._compile.argtypes = [ctypes.c_char_p, ctypes.c_char_p]
|
|
96
|
+
self._compile.restype = ctypes.c_void_p
|
|
97
|
+
|
|
98
|
+
self._dump = dll.dump
|
|
99
|
+
self._dump.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
100
|
+
self._dump.restype = None
|
|
101
|
+
|
|
102
|
+
self._finalize = dll.finalize
|
|
103
|
+
self._finalize.argtypes = [ctypes.c_void_p]
|
|
104
|
+
self._finalize.restype = None
|
|
105
|
+
|
|
106
|
+
self._ptr_states = dll.ptr_states
|
|
107
|
+
self._ptr_states.argtypes = [ctypes.c_void_p]
|
|
108
|
+
self._ptr_states.restype = ctypes.POINTER(ctypes.c_double)
|
|
109
|
+
|
|
110
|
+
self._ptr_params = dll.ptr_params
|
|
111
|
+
self._ptr_params.argtypes = [ctypes.c_void_p]
|
|
112
|
+
self._ptr_params.restype = ctypes.POINTER(ctypes.c_double)
|
|
113
|
+
|
|
114
|
+
self._ptr_obs = dll.ptr_obs
|
|
115
|
+
self._ptr_obs.argtypes = [ctypes.c_void_p]
|
|
116
|
+
self._ptr_obs.restype = ctypes.POINTER(ctypes.c_double)
|
|
117
|
+
|
|
118
|
+
self._ptr_diffs = dll.ptr_diffs
|
|
119
|
+
self._ptr_diffs.argtypes = [ctypes.c_void_p]
|
|
120
|
+
self._ptr_diffs.restype = ctypes.POINTER(ctypes.c_double)
|
|
121
|
+
|
|
122
|
+
def info(self):
|
|
123
|
+
return self._info()
|
|
124
|
+
|